158 lines
4.6 KiB
R
158 lines
4.6 KiB
R
|
library(tidyverse)
|
||
|
library(cmdstanr)
|
||
|
library(dplyr)
|
||
|
|
||
|
|
||
|
model_type <- "lstmlast"
|
||
|
batch_size <- "8"
|
||
|
lr <- "0.0001"
|
||
|
hidden_size <- "128"
|
||
|
model_type <- paste0(model_type, "_bs_", batch_size, "_lr_", lr, "_hidden_size_", hidden_size)
|
||
|
print(model_type)
|
||
|
set.seed(9736734)
|
||
|
|
||
|
user_num <- 5
|
||
|
user <-c(0:(user_num-1))
|
||
|
strategies <- c(0:6) # 7 tasks
|
||
|
print(strategies)
|
||
|
print(length(strategies))
|
||
|
N <- 1
|
||
|
|
||
|
# read data from csv
|
||
|
sel <- vector("list", length(strategies))
|
||
|
for (u in seq_along(user)){
|
||
|
dat <- vector("list", length(strategies))
|
||
|
print(paste0('user: ', u))
|
||
|
for (i in seq_along(strategies)) {
|
||
|
dat[[i]] <- read.csv(paste0("../prediction/task", strategies[[i]], "/", model_type, "/user", user[[u]], "_pred", ".csv"))
|
||
|
dat[[i]]$assumed_strategy <- strategies[[i]]
|
||
|
dat[[i]]$index <- dat[[i]]$action_id # sample based on intention
|
||
|
dat[[i]]$id <- dat[[i]][,1] # sample based on intention
|
||
|
}
|
||
|
|
||
|
# reset N after inference
|
||
|
N = 1
|
||
|
|
||
|
# select one action series from one intention
|
||
|
if (user[[u]] == 0){
|
||
|
sel[[1]]<-dat[[1]] %>%
|
||
|
group_by(task_name) %>%
|
||
|
sample_n(N)
|
||
|
sel[[1]] <- data.frame(sel[[1]])
|
||
|
}
|
||
|
|
||
|
# filter data from the selected action series, N series per intention
|
||
|
for (i in seq_along(strategies)) {
|
||
|
dat[[i]]<-subset(dat[[i]], dat[[i]]$action_id == sel[[1]]$action_id[1])
|
||
|
}
|
||
|
row.names(dat) <- NULL
|
||
|
|
||
|
# create save path
|
||
|
dir.create(file.path("result"), showWarnings = FALSE)
|
||
|
dir.create(file.path(paste0("result/", "N", N)), showWarnings = FALSE)
|
||
|
save_path <- paste0("result/", "N", N, "/", model_type, "_N", N, "_", "result","_user", user[[u]], ".csv")
|
||
|
|
||
|
dat <- do.call(rbind, dat) %>%
|
||
|
mutate(index = as.numeric(as.factor(id))) %>%
|
||
|
rename(true_strategy = task_name) %>%
|
||
|
mutate(
|
||
|
true_strategy = factor(
|
||
|
true_strategy, levels = 0:6,
|
||
|
labels = strategies
|
||
|
),
|
||
|
q_type = case_when(
|
||
|
gt %in% c(3,4,5) ~ 0,
|
||
|
gt %in% c(1,2,3,4,5,6,7) ~ 1,
|
||
|
gt %in% c(1,2,3,4) ~ 2,
|
||
|
gt %in% c(1,4,5,6,7) ~ 3,
|
||
|
gt %in% c(1,2,3,6,7) ~ 4,
|
||
|
gt %in% c(2,3,4,5,6,7) ~ 5,
|
||
|
gt %in% c(1,2,3,4,5,6,7) ~ 6,
|
||
|
)
|
||
|
)
|
||
|
|
||
|
dat_obs <- dat %>% filter(assumed_strategy == strategies[[i]])
|
||
|
N <- nrow(dat_obs)
|
||
|
print(c("N: ", N))
|
||
|
q <- dat_obs$gt
|
||
|
true_strategy <- dat_obs$true_strategy
|
||
|
|
||
|
K <- length(unique(dat$assumed_strategy))
|
||
|
print(c("K: ", K))
|
||
|
I <- 7
|
||
|
|
||
|
P_q_S <- array(dim = c(N, I, K))
|
||
|
for (n in 1:N) {
|
||
|
#print(n)
|
||
|
P_q_S[n, , ] <- dat %>%
|
||
|
filter(index == n) %>%
|
||
|
select(matches("^act[[:digit:]]+$")) %>%
|
||
|
as.matrix() %>%
|
||
|
t()
|
||
|
for (k in 1:K) {
|
||
|
# normalize probabilities
|
||
|
P_q_S[n, , k] <- P_q_S[n, , k] / sum(P_q_S[n, , k])
|
||
|
}
|
||
|
}
|
||
|
print(c('dim P_q_S',dim(P_q_S)))
|
||
|
|
||
|
mod <- cmdstan_model("strategy_inference_model.stan")
|
||
|
|
||
|
sub <- which(true_strategy == 0) # "0"
|
||
|
print(c('sub', sub))
|
||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||
|
fit_0 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||
|
print(fit_0$summary(NULL, c("mean","sd")))
|
||
|
|
||
|
sub <- which(true_strategy == 1)
|
||
|
print(c('sub', sub))
|
||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||
|
fit_1 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||
|
print(fit_1$summary(NULL, c("mean","sd")))
|
||
|
|
||
|
sub <- which(true_strategy == 2)
|
||
|
print(c('sub', sub))
|
||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||
|
fit_2 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||
|
print(fit_2$summary(NULL, c("mean","sd")))
|
||
|
|
||
|
sub <- which(true_strategy == 3)
|
||
|
print(c('sub', sub))
|
||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||
|
fit_3 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||
|
print(fit_3$summary(NULL, c("mean","sd")))
|
||
|
|
||
|
sub <- which(true_strategy == 4)
|
||
|
print(c('sub', sub))
|
||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||
|
fit_4 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||
|
print(fit_4$summary(NULL, c("mean","sd")))
|
||
|
|
||
|
sub <- which(true_strategy == 5)
|
||
|
print(c('sub', sub))
|
||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||
|
fit_5 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||
|
print(fit_5$summary(NULL, c("mean","sd")))
|
||
|
|
||
|
sub <- which(true_strategy == 6)
|
||
|
print(c('sub', sub))
|
||
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||
|
fit_6 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||
|
print(fit_6$summary(NULL, c("mean","sd")))
|
||
|
|
||
|
# save csv
|
||
|
df <-rbind(fit_0$summary(), fit_1$summary(), fit_2$summary(), fit_3$summary(), fit_4$summary(), fit_5$summary(), fit_6$summary())
|
||
|
write.csv(df,file=save_path,quote=FALSE)
|
||
|
}
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|