rm(list=ls(all=TRUE))  # clear all variables
graphics.off()  # clear all graphics
# ADHD Treatment
# Greg Francis
# PSY 626
# 12 October 2020

#  dependent ANOVA with shrinkage
# Focus on shrinkage of subject scores


# load data file
ATdata<-read.csv(file="ADHDTreatment.csv",header=TRUE,stringsAsFactors=TRUE)

# Set up dummy variables
ATdata$Dosage0 <- ifelse(ATdata$Dosage =="D0", 1, 0)  
ATdata$Dosage15 <- ifelse(ATdata$Dosage =="D15", 1, 0)  
ATdata$Dosage30 <- ifelse(ATdata$Dosage =="D30", 1, 0)  
ATdata$Dosage60 <- ifelse(ATdata$Dosage =="D60", 1, 0)  

# Make unique index for each participant
ATdata$Participant <- coerce_index(ATdata$SubjectID)


library(rethinking)


#---------------------------------------
# Investigate shrinkage on subjects


# Pull out individual subjects and plot
plot(ATdata$DosageNumber, ATdata$CorrectResponses)
for(i in c(1:24)) {
	thisSet<- subset(ATdata, ATdata$Participant == i)
	lines(thisSet$DosageNumber, thisSet$CorrectResponses, col=i, lty=i)
}


# This model has different means for different dosages and an intercept for each subject. Little shrinkage
# Intercept corresponds to Dosage0, b terms correspond to deviations from Dosage0 mean
ATmodel1 <- map2stan(
			alist( CorrectResponses ~ dnorm(mu, sigma), 
			mu <- a[Participant] + b2*Dosage15 + b3*Dosage30 + b4*Dosage60,
			a[Participant] ~ dnorm(50, 50),
			b2 ~ dnorm(0, 20),
			b3 ~ dnorm(0, 20),
			b4 ~ dnorm(0, 20),
			sigma ~ dunif(0, 100)
			), data= ATdata) 
cat("Finished ATmodel1\n")

print(summary(ATmodel1))

dev.new()
# Look at effect on each subject
post<-extract.samples(ATmodel1, n= 2000)

plot(ATdata$DosageNumber, ATdata$CorrectResponses, main="ATmodel1")
xLabels<-c(0, 15, 30, 60)
for(i in c(1:24)) {
	newMeans <- c(mean(post$a[,i]), mean(post$a[,i]+post$b2), mean(post$a[,i]+post$b3), mean(post$a[,i]+post$b4))

	lines(xLabels, newMeans, col=i, lty=i)
}

# Same model with tighter prior
ATmodel2 <- map2stan(
			alist( CorrectResponses ~ dnorm(mu, sigma), 
			mu <- a[Participant] + b2*Dosage15 + b3*Dosage30 + b4*Dosage60,
			a[Participant] ~ dnorm(50, 10),
			b2 ~ dnorm(0, 20),
			b3 ~ dnorm(0, 20),
			b4 ~ dnorm(0, 20),
			sigma ~ dunif(0, 100)
			), data= ATdata) 
cat("Finished ATmodel2\n")

print(summary(ATmodel2))

dev.new()
# Look at effect on each subject
post<-extract.samples(ATmodel2, n= 2000)

plot(ATdata$DosageNumber, ATdata$CorrectResponses, main="ATmodel2")
xLabels<-c(0, 15, 30, 60)
for(i in c(1:24)) {
	newMeans <- c(mean(post$a[,i]), mean(post$a[,i]+post$b2), mean(post$a[,i]+post$b3), mean(post$a[,i]+post$b4))

	lines(xLabels, newMeans, col=i, lty=i)
}


# Hierarchical model
ATmodel3 <- map2stan(
			alist( CorrectResponses ~ dnorm(mu, sigma), 
			mu <- a[Participant] + b2*Dosage15 + b3*Dosage30 + b4*Dosage60,
			a[Participant] ~ dnorm(Gmean, Gsd),
			Gmean ~ dnorm(50, 50),
			Gsd ~ dunif(0, 50),
			b2 ~ dnorm(0, 20),
			b3 ~ dnorm(0, 20),
			b4 ~ dnorm(0, 20),
			sigma ~ dunif(0, 100)
			), data= ATdata) 
cat("Finished ATmodel3\n")

print(summary(ATmodel3))

dev.new()
# Look at effect on each subject
post<-extract.samples(ATmodel3, n= 2000)

plot(ATdata$DosageNumber, ATdata$CorrectResponses, main="ATmodel3")
xLabels<-c(0, 15, 30, 60)
for(i in c(1:24)) {
	newMeans <- c(mean(post$a[,i]), mean(post$a[,i]+post$b2), mean(post$a[,i]+post$b3), mean(post$a[,i]+post$b4))

	lines(xLabels, newMeans, col=i, lty=i)
}

# Include shrinkage on intercepts
# Hierarchical model
ATmodel4 <- map2stan(
			alist( CorrectResponses ~ dnorm(mu, sigma), 
			mu <- a[Participant] + b2*Dosage15 + b3*Dosage30 + b4*Dosage60,
			a[Participant] ~ dnorm(Gmean, Gsd),
			Gmean ~ dnorm(50, 50),
			Gsd ~ dunif(0, 50),
			b2 ~ dnorm(Gbmean, Gbsd),
			b3 ~ dnorm(Gbmean, Gbsd),
			b4 ~ dnorm(Gbmean, Gbsd),
			Gbmean ~ dnorm(0, 20),
			Gbsd ~ dunif(0, 30),
			sigma ~ dunif(0, 100)
			), data= ATdata) 
cat("Finished ATmodel4\n")

print(summary(ATmodel4))

dev.new()
# Look at effect on each subject
post<-extract.samples(ATmodel4, n= 2000)

# Get each sample mean
Means <- aggregate(CorrectResponses~Dosage, FUN=mean, data=ATdata)

newMeans <- c(mean(post$a), mean(post$a)+mean(post$b2), mean(post$a)+mean(post$b3), mean(post$a)+mean(post$b4))

range = c(min(c(Means$CorrectResponses, newMeans)), max(c(Means$CorrectResponses, newMeans)))
dev.new()
plot(Means$Dosage, Means$CorrectResponses, main="ATmodel4", ylim=range)
points(Means$Dosage, newMeans, pch=19)
abline(h= mean(post$a)+mean(post$Gbmean), col="red", lwd=3, lty=2)





# Impact on estimated means


# Morris Efron Shrinkage
# Get each sample mean
Means <- aggregate(CorrectResponses~Dosage, FUN=mean, data=ATdata)
counts<- aggregate(CorrectResponses~Dosage, FUN=length, data=ATdata)
Vars <- aggregate(CorrectResponses~Dosage, FUN=var, data=ATdata) 
GrandMean = sum(Means$CorrectResponses)/length(Means$CorrectResponses)
newMeans = (1-  ((length(Means$CorrectResponses)-3))*Vars$CorrectResponses/counts$CorrectResponses /sum( (Means$CorrectResponses - GrandMean)^2 )) *(Means$CorrectResponses - GrandMean) + GrandMean

par(bg="lightblue")
range = c(min(c(Means$CorrectResponses, newMeans)), max(c(Means$CorrectResponses, newMeans)))
dev.new()
plot(Means$Dosage, Means$CorrectResponses, main="Morris-Efron", ylim=range)
points(Means$Dosage, newMeans, pch=19)
abline(h= GrandMean, col="red", lwd=3, lty=2)

