238 lines
6.8 KiB
R
238 lines
6.8 KiB
R
library(tidyverse)
|
|
library(cmdstanr)
|
|
library(dplyr)
|
|
|
|
# index order of the strategies assumed throughout
|
|
model_type <- "lstmlast"
|
|
batch_size <- "8"
|
|
lr <- "0.0001"
|
|
hidden_size <- "128"
|
|
model_type <- paste0(model_type, "_bs_", batch_size, "_lr_", lr, "_hidden_size_", hidden_size)
|
|
rates <- c("_0", "_10", "_20", "_30", "_40", "_50", "_60", "_70", "_80", "_90", "_100")
|
|
|
|
user_num <- 5
|
|
user <-c(0:(user_num-1))
|
|
strategies <- c(0:6) # 7 tasks
|
|
print(strategies)
|
|
print(length(strategies))
|
|
N <- 1
|
|
|
|
set.seed(9736754)
|
|
|
|
#read data from csv
|
|
sel <- vector("list", length(strategies))
|
|
for (u in seq_along(user)){
|
|
print('user')
|
|
print(u)
|
|
for (rate in rates) {
|
|
dat <- vector("list", length(strategies))
|
|
for (i in seq_along(strategies)) {
|
|
if (rate=="_0"){
|
|
dat[[i]] <- read.csv(paste0("../prediction/task", strategies[[i]], "/", model_type, "/user", user[[u]], "_rate_10", "_pred", ".csv"))
|
|
} else if (rate=="_100"){
|
|
dat[[i]] <- read.csv(paste0("../prediction/task", strategies[[i]], "/", model_type, "/user", user[[u]], "_pred", ".csv"))
|
|
} else{
|
|
dat[[i]] <- read.csv(paste0("../prediction/task", strategies[[i]], "/", model_type, "/user", user[[u]], "_rate", rate, "_pred", ".csv"))
|
|
}
|
|
# strategy assumed for prediction
|
|
dat[[i]]$assumed_strategy <- strategies[[i]]
|
|
dat[[i]]$index <- dat[[i]]$action_id
|
|
dat[[i]]$id <- dat[[i]][,1]
|
|
}
|
|
|
|
# reset N after inference
|
|
N <- 1
|
|
|
|
# select all action series and infer every one
|
|
if (rate == "_0"){
|
|
sel[[1]]<-dat[[1]] %>%
|
|
group_by(task_name) %>%
|
|
sample_n(N)
|
|
sel[[1]] <- data.frame(sel[[1]])
|
|
unique_act_id <- unique(sel[[1]]$action_id)
|
|
}
|
|
print(sel[[1]]$action_id)
|
|
print(sel[[1]]$task_name)
|
|
print(dat[[1]]$task_name)
|
|
|
|
|
|
for (i in seq_along(strategies)) {
|
|
dat[[i]]<-subset(dat[[i]], dat[[i]]$action_id == sel[[1]]$action_id[1])
|
|
}
|
|
row.names(dat) <- NULL
|
|
print(c('action id', dat[[1]]$action_id))
|
|
print(c('action id', dat[[2]]$action_id))
|
|
print(c('action id', dat[[3]]$action_id))
|
|
|
|
dir.create(file.path(paste0("result/", "N", N)), showWarnings = FALSE)
|
|
save_path <- paste0("result/", "N", N, "/", model_type, "_N", N, "_", "result","_user", user[[u]], "_rate_", rate, ".csv")
|
|
|
|
dat_act <- do.call(rbind, dat) %>%
|
|
mutate(index = as.numeric(as.factor(id))) %>%
|
|
rename(true_strategy = task_name) %>%
|
|
mutate(
|
|
true_strategy = factor(
|
|
true_strategy, levels = 0:6,
|
|
labels = strategies
|
|
),
|
|
q_type = case_when(
|
|
gt %in% c(3,4,5) ~ 0,
|
|
gt %in% c(1,2,3,4,5,6,7) ~ 1,
|
|
gt %in% c(1,2,3,4) ~ 2,
|
|
gt %in% c(1,4,5,6,7) ~ 3,
|
|
gt %in% c(1,2,3,6,7) ~ 4,
|
|
gt %in% c(2,3,4,5,6,7) ~ 5,
|
|
gt %in% c(1,2,3,4,5,6,7) ~ 6,
|
|
)
|
|
)
|
|
|
|
dat_obs <- dat_act %>% filter(assumed_strategy == strategies[[i]]) # put_fridge, was num
|
|
N <- nrow(dat_obs)
|
|
print(c("N: ", N))
|
|
print(c("dim dat_act: ", dim(dat_act)))
|
|
|
|
q <- dat_obs$gt
|
|
true_strategy <- dat_obs$true_strategy
|
|
|
|
K <- length(unique(dat_act$assumed_strategy))
|
|
I <- 7
|
|
|
|
P_q_S <- array(dim = c(N, I, K))
|
|
for (n in 1:N) {
|
|
print(n)
|
|
P_q_S[n, , ] <- dat_act %>%
|
|
filter(index == n) %>%
|
|
select(matches("^act[[:digit:]]+$")) %>%
|
|
as.matrix() %>%
|
|
t()
|
|
for (k in 1:K) {
|
|
# normalize probabilities
|
|
P_q_S[n, , k] <- P_q_S[n, , k] / sum(P_q_S[n, , k])
|
|
}
|
|
}
|
|
print(c("dim(P_q_S)", dim(P_q_S)))
|
|
|
|
mod <- cmdstan_model(paste0(getwd(),"/strategy_inference_model.stan"))
|
|
|
|
if (rate=="_0"){
|
|
sub <- integer(0)
|
|
} else {
|
|
sub <- which(true_strategy == 0) # "0"
|
|
}
|
|
if (length(sub) == 1){
|
|
temp <- P_q_S[sub, , ]
|
|
dim(temp) <- c(1, dim(temp))
|
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
|
} else{
|
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
|
}
|
|
fit_0 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
|
print(fit_0$summary(NULL, c("mean","sd")))
|
|
|
|
if (rate=="_0"){
|
|
sub <- integer(0)
|
|
} else {
|
|
sub <- which(true_strategy == 1)
|
|
}
|
|
if (length(sub) == 1){
|
|
temp <- P_q_S[sub, , ]
|
|
dim(temp) <- c(1, dim(temp))
|
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
|
} else{
|
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
|
}
|
|
fit_1 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
|
print(fit_1$summary(NULL, c("mean","sd")))
|
|
|
|
if (rate=="_0"){
|
|
sub <- integer(0)
|
|
} else {
|
|
sub <- which(true_strategy == 2)
|
|
}
|
|
if (length(sub) == 1){
|
|
temp <- P_q_S[sub, , ]
|
|
dim(temp) <- c(1, dim(temp))
|
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
|
} else{
|
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
|
}
|
|
fit_2 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
|
print(fit_2$summary(NULL, c("mean","sd")))
|
|
|
|
if (rate=="_0"){
|
|
sub <- integer(0)
|
|
} else {
|
|
sub <- which(true_strategy == 3)
|
|
}
|
|
if (length(sub) == 1){
|
|
temp <- P_q_S[sub, , ]
|
|
dim(temp) <- c(1, dim(temp))
|
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
|
} else{
|
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
|
}
|
|
fit_3 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
|
print(fit_3$summary(NULL, c("mean","sd")))
|
|
|
|
if (rate=="_0"){
|
|
sub <- integer(0)
|
|
} else {
|
|
sub <- which(true_strategy == 4)
|
|
}
|
|
if (length(sub) == 1){
|
|
temp <- P_q_S[sub, , ]
|
|
dim(temp) <- c(1, dim(temp))
|
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
|
} else{
|
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
|
}
|
|
fit_4 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
|
print(fit_4$summary(NULL, c("mean","sd")))
|
|
|
|
if (rate=="_0"){
|
|
sub <- integer(0)
|
|
} else {
|
|
sub <- which(true_strategy == 5)
|
|
}
|
|
if (length(sub) == 1){
|
|
temp <- P_q_S[sub, , ]
|
|
dim(temp) <- c(1, dim(temp))
|
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
|
} else{
|
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
|
}
|
|
fit_5 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
|
print(fit_5$summary(NULL, c("mean","sd")))
|
|
|
|
if (rate=="_0"){
|
|
sub <- integer(0)
|
|
} else {
|
|
sub <- which(true_strategy == 6)
|
|
}
|
|
if (length(sub) == 1){
|
|
temp <- P_q_S[sub, , ]
|
|
dim(temp) <- c(1, dim(temp))
|
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
|
} else{
|
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
|
}
|
|
fit_6 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
|
print(fit_6$summary(NULL, c("mean","sd")))
|
|
|
|
# save csv
|
|
df <-rbind(fit_0$summary(), fit_1$summary(), fit_2$summary(), fit_3$summary(), fit_4$summary(), fit_5$summary(), fit_6$summary())
|
|
write.csv(df,file=save_path,quote=FALSE)
|
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|