InferringIntention/keyboard_and_mouse/stan/strategy_inference_test_full_length.R
2024-03-24 23:42:27 +01:00

238 lines
6.8 KiB
R

library(tidyverse)
library(cmdstanr)
library(dplyr)
# index order of the strategies assumed throughout
model_type <- "lstmlast"
batch_size <- "8"
lr <- "0.0001"
hidden_size <- "128"
model_type <- paste0(model_type, "_bs_", batch_size, "_lr_", lr, "_hidden_size_", hidden_size)
rates <- c("_0", "_10", "_20", "_30", "_40", "_50", "_60", "_70", "_80", "_90", "_100")
user_num <- 5
user <-c(0:(user_num-1))
strategies <- c(0:6) # 7 tasks
print(strategies)
print(length(strategies))
N <- 1
set.seed(9736754)
#read data from csv
sel <- vector("list", length(strategies))
for (u in seq_along(user)){
print('user')
print(u)
for (rate in rates) {
dat <- vector("list", length(strategies))
for (i in seq_along(strategies)) {
if (rate=="_0"){
dat[[i]] <- read.csv(paste0("../prediction/task", strategies[[i]], "/", model_type, "/user", user[[u]], "_rate_10", "_pred", ".csv"))
} else if (rate=="_100"){
dat[[i]] <- read.csv(paste0("../prediction/task", strategies[[i]], "/", model_type, "/user", user[[u]], "_pred", ".csv"))
} else{
dat[[i]] <- read.csv(paste0("../prediction/task", strategies[[i]], "/", model_type, "/user", user[[u]], "_rate", rate, "_pred", ".csv"))
}
# strategy assumed for prediction
dat[[i]]$assumed_strategy <- strategies[[i]]
dat[[i]]$index <- dat[[i]]$action_id
dat[[i]]$id <- dat[[i]][,1]
}
# reset N after inference
N <- 1
# select all action series and infer every one
if (rate == "_0"){
sel[[1]]<-dat[[1]] %>%
group_by(task_name) %>%
sample_n(N)
sel[[1]] <- data.frame(sel[[1]])
unique_act_id <- unique(sel[[1]]$action_id)
}
print(sel[[1]]$action_id)
print(sel[[1]]$task_name)
print(dat[[1]]$task_name)
for (i in seq_along(strategies)) {
dat[[i]]<-subset(dat[[i]], dat[[i]]$action_id == sel[[1]]$action_id[1])
}
row.names(dat) <- NULL
print(c('action id', dat[[1]]$action_id))
print(c('action id', dat[[2]]$action_id))
print(c('action id', dat[[3]]$action_id))
dir.create(file.path(paste0("result/", "N", N)), showWarnings = FALSE)
save_path <- paste0("result/", "N", N, "/", model_type, "_N", N, "_", "result","_user", user[[u]], "_rate_", rate, ".csv")
dat_act <- do.call(rbind, dat) %>%
mutate(index = as.numeric(as.factor(id))) %>%
rename(true_strategy = task_name) %>%
mutate(
true_strategy = factor(
true_strategy, levels = 0:6,
labels = strategies
),
q_type = case_when(
gt %in% c(3,4,5) ~ 0,
gt %in% c(1,2,3,4,5,6,7) ~ 1,
gt %in% c(1,2,3,4) ~ 2,
gt %in% c(1,4,5,6,7) ~ 3,
gt %in% c(1,2,3,6,7) ~ 4,
gt %in% c(2,3,4,5,6,7) ~ 5,
gt %in% c(1,2,3,4,5,6,7) ~ 6,
)
)
dat_obs <- dat_act %>% filter(assumed_strategy == strategies[[i]]) # put_fridge, was num
N <- nrow(dat_obs)
print(c("N: ", N))
print(c("dim dat_act: ", dim(dat_act)))
q <- dat_obs$gt
true_strategy <- dat_obs$true_strategy
K <- length(unique(dat_act$assumed_strategy))
I <- 7
P_q_S <- array(dim = c(N, I, K))
for (n in 1:N) {
print(n)
P_q_S[n, , ] <- dat_act %>%
filter(index == n) %>%
select(matches("^act[[:digit:]]+$")) %>%
as.matrix() %>%
t()
for (k in 1:K) {
# normalize probabilities
P_q_S[n, , k] <- P_q_S[n, , k] / sum(P_q_S[n, , k])
}
}
print(c("dim(P_q_S)", dim(P_q_S)))
mod <- cmdstan_model(paste0(getwd(),"/strategy_inference_model.stan"))
if (rate=="_0"){
sub <- integer(0)
} else {
sub <- which(true_strategy == 0) # "0"
}
if (length(sub) == 1){
temp <- P_q_S[sub, , ]
dim(temp) <- c(1, dim(temp))
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
} else{
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
}
fit_0 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
print(fit_0$summary(NULL, c("mean","sd")))
if (rate=="_0"){
sub <- integer(0)
} else {
sub <- which(true_strategy == 1)
}
if (length(sub) == 1){
temp <- P_q_S[sub, , ]
dim(temp) <- c(1, dim(temp))
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
} else{
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
}
fit_1 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
print(fit_1$summary(NULL, c("mean","sd")))
if (rate=="_0"){
sub <- integer(0)
} else {
sub <- which(true_strategy == 2)
}
if (length(sub) == 1){
temp <- P_q_S[sub, , ]
dim(temp) <- c(1, dim(temp))
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
} else{
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
}
fit_2 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
print(fit_2$summary(NULL, c("mean","sd")))
if (rate=="_0"){
sub <- integer(0)
} else {
sub <- which(true_strategy == 3)
}
if (length(sub) == 1){
temp <- P_q_S[sub, , ]
dim(temp) <- c(1, dim(temp))
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
} else{
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
}
fit_3 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
print(fit_3$summary(NULL, c("mean","sd")))
if (rate=="_0"){
sub <- integer(0)
} else {
sub <- which(true_strategy == 4)
}
if (length(sub) == 1){
temp <- P_q_S[sub, , ]
dim(temp) <- c(1, dim(temp))
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
} else{
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
}
fit_4 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
print(fit_4$summary(NULL, c("mean","sd")))
if (rate=="_0"){
sub <- integer(0)
} else {
sub <- which(true_strategy == 5)
}
if (length(sub) == 1){
temp <- P_q_S[sub, , ]
dim(temp) <- c(1, dim(temp))
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
} else{
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
}
fit_5 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
print(fit_5$summary(NULL, c("mean","sd")))
if (rate=="_0"){
sub <- integer(0)
} else {
sub <- which(true_strategy == 6)
}
if (length(sub) == 1){
temp <- P_q_S[sub, , ]
dim(temp) <- c(1, dim(temp))
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
} else{
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
}
fit_6 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
print(fit_6$summary(NULL, c("mean","sd")))
# save csv
df <-rbind(fit_0$summary(), fit_1$summary(), fit_2$summary(), fit_3$summary(), fit_4$summary(), fit_5$summary(), fit_6$summary())
write.csv(df,file=save_path,quote=FALSE)
}
}