InferringIntention/keyboard_and_mouse/stan/strategy_inference_test.R

158 lines
4.6 KiB
R
Raw Normal View History

2024-03-24 23:42:27 +01:00
library(tidyverse)
library(cmdstanr)
library(dplyr)
model_type <- "lstmlast"
batch_size <- "8"
lr <- "0.0001"
hidden_size <- "128"
model_type <- paste0(model_type, "_bs_", batch_size, "_lr_", lr, "_hidden_size_", hidden_size)
print(model_type)
set.seed(9736734)
user_num <- 5
user <-c(0:(user_num-1))
strategies <- c(0:6) # 7 tasks
print(strategies)
print(length(strategies))
N <- 1
# read data from csv
sel <- vector("list", length(strategies))
for (u in seq_along(user)){
dat <- vector("list", length(strategies))
print(paste0('user: ', u))
for (i in seq_along(strategies)) {
dat[[i]] <- read.csv(paste0("../prediction/task", strategies[[i]], "/", model_type, "/user", user[[u]], "_pred", ".csv"))
dat[[i]]$assumed_strategy <- strategies[[i]]
dat[[i]]$index <- dat[[i]]$action_id # sample based on intention
dat[[i]]$id <- dat[[i]][,1] # sample based on intention
}
# reset N after inference
N = 1
# select one action series from one intention
if (user[[u]] == 0){
sel[[1]]<-dat[[1]] %>%
group_by(task_name) %>%
sample_n(N)
sel[[1]] <- data.frame(sel[[1]])
}
# filter data from the selected action series, N series per intention
for (i in seq_along(strategies)) {
dat[[i]]<-subset(dat[[i]], dat[[i]]$action_id == sel[[1]]$action_id[1])
}
row.names(dat) <- NULL
# create save path
dir.create(file.path("result"), showWarnings = FALSE)
dir.create(file.path(paste0("result/", "N", N)), showWarnings = FALSE)
save_path <- paste0("result/", "N", N, "/", model_type, "_N", N, "_", "result","_user", user[[u]], ".csv")
dat <- do.call(rbind, dat) %>%
mutate(index = as.numeric(as.factor(id))) %>%
rename(true_strategy = task_name) %>%
mutate(
true_strategy = factor(
true_strategy, levels = 0:6,
labels = strategies
),
q_type = case_when(
gt %in% c(3,4,5) ~ 0,
gt %in% c(1,2,3,4,5,6,7) ~ 1,
gt %in% c(1,2,3,4) ~ 2,
gt %in% c(1,4,5,6,7) ~ 3,
gt %in% c(1,2,3,6,7) ~ 4,
gt %in% c(2,3,4,5,6,7) ~ 5,
gt %in% c(1,2,3,4,5,6,7) ~ 6,
)
)
dat_obs <- dat %>% filter(assumed_strategy == strategies[[i]])
N <- nrow(dat_obs)
print(c("N: ", N))
q <- dat_obs$gt
true_strategy <- dat_obs$true_strategy
K <- length(unique(dat$assumed_strategy))
print(c("K: ", K))
I <- 7
P_q_S <- array(dim = c(N, I, K))
for (n in 1:N) {
#print(n)
P_q_S[n, , ] <- dat %>%
filter(index == n) %>%
select(matches("^act[[:digit:]]+$")) %>%
as.matrix() %>%
t()
for (k in 1:K) {
# normalize probabilities
P_q_S[n, , k] <- P_q_S[n, , k] / sum(P_q_S[n, , k])
}
}
print(c('dim P_q_S',dim(P_q_S)))
mod <- cmdstan_model("strategy_inference_model.stan")
sub <- which(true_strategy == 0) # "0"
print(c('sub', sub))
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
fit_0 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
print(fit_0$summary(NULL, c("mean","sd")))
sub <- which(true_strategy == 1)
print(c('sub', sub))
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
fit_1 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
print(fit_1$summary(NULL, c("mean","sd")))
sub <- which(true_strategy == 2)
print(c('sub', sub))
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
fit_2 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
print(fit_2$summary(NULL, c("mean","sd")))
sub <- which(true_strategy == 3)
print(c('sub', sub))
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
fit_3 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
print(fit_3$summary(NULL, c("mean","sd")))
sub <- which(true_strategy == 4)
print(c('sub', sub))
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
fit_4 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
print(fit_4$summary(NULL, c("mean","sd")))
sub <- which(true_strategy == 5)
print(c('sub', sub))
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
fit_5 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
print(fit_5$summary(NULL, c("mean","sd")))
sub <- which(true_strategy == 6)
print(c('sub', sub))
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
fit_6 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
print(fit_6$summary(NULL, c("mean","sd")))
# save csv
df <-rbind(fit_0$summary(), fit_1$summary(), fit_2$summary(), fit_3$summary(), fit_4$summary(), fit_5$summary(), fit_6$summary())
write.csv(df,file=save_path,quote=FALSE)
}