library(tidyverse) library(cmdstanr) library(dplyr) model_type <- "lstmlast" batch_size <- "8" lr <- "0.0001" hidden_size <- "128" model_type <- paste0(model_type, "_bs_", batch_size, "_lr_", lr, "_hidden_size_", hidden_size) print(model_type) set.seed(9736734) user_num <- 5 user <-c(0:(user_num-1)) strategies <- c(0:6) # 7 tasks print(strategies) print(length(strategies)) N <- 1 # read data from csv sel <- vector("list", length(strategies)) for (u in seq_along(user)){ dat <- vector("list", length(strategies)) print(paste0('user: ', u)) for (i in seq_along(strategies)) { dat[[i]] <- read.csv(paste0("../prediction/task", strategies[[i]], "/", model_type, "/user", user[[u]], "_pred", ".csv")) dat[[i]]$assumed_strategy <- strategies[[i]] dat[[i]]$index <- dat[[i]]$action_id # sample based on intention dat[[i]]$id <- dat[[i]][,1] # sample based on intention } # reset N after inference N = 1 # select one action series from one intention if (user[[u]] == 0){ sel[[1]]<-dat[[1]] %>% group_by(task_name) %>% sample_n(N) sel[[1]] <- data.frame(sel[[1]]) } # filter data from the selected action series, N series per intention for (i in seq_along(strategies)) { dat[[i]]<-subset(dat[[i]], dat[[i]]$action_id == sel[[1]]$action_id[1]) } row.names(dat) <- NULL # create save path dir.create(file.path("result"), showWarnings = FALSE) dir.create(file.path(paste0("result/", "N", N)), showWarnings = FALSE) save_path <- paste0("result/", "N", N, "/", model_type, "_N", N, "_", "result","_user", user[[u]], ".csv") dat <- do.call(rbind, dat) %>% mutate(index = as.numeric(as.factor(id))) %>% rename(true_strategy = task_name) %>% mutate( true_strategy = factor( true_strategy, levels = 0:6, labels = strategies ), q_type = case_when( gt %in% c(3,4,5) ~ 0, gt %in% c(1,2,3,4,5,6,7) ~ 1, gt %in% c(1,2,3,4) ~ 2, gt %in% c(1,4,5,6,7) ~ 3, gt %in% c(1,2,3,6,7) ~ 4, gt %in% c(2,3,4,5,6,7) ~ 5, gt %in% c(1,2,3,4,5,6,7) ~ 6, ) ) dat_obs <- dat %>% filter(assumed_strategy == strategies[[i]]) N <- nrow(dat_obs) print(c("N: ", N)) q <- dat_obs$gt true_strategy <- dat_obs$true_strategy K <- length(unique(dat$assumed_strategy)) print(c("K: ", K)) I <- 7 P_q_S <- array(dim = c(N, I, K)) for (n in 1:N) { #print(n) P_q_S[n, , ] <- dat %>% filter(index == n) %>% select(matches("^act[[:digit:]]+$")) %>% as.matrix() %>% t() for (k in 1:K) { # normalize probabilities P_q_S[n, , k] <- P_q_S[n, , k] / sum(P_q_S[n, , k]) } } print(c('dim P_q_S',dim(P_q_S))) mod <- cmdstan_model("strategy_inference_model.stan") sub <- which(true_strategy == 0) # "0" print(c('sub', sub)) sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ]) fit_0 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp")) print(fit_0$summary(NULL, c("mean","sd"))) sub <- which(true_strategy == 1) print(c('sub', sub)) sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ]) fit_1 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp")) print(fit_1$summary(NULL, c("mean","sd"))) sub <- which(true_strategy == 2) print(c('sub', sub)) sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ]) fit_2 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp")) print(fit_2$summary(NULL, c("mean","sd"))) sub <- which(true_strategy == 3) print(c('sub', sub)) sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ]) fit_3 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp")) print(fit_3$summary(NULL, c("mean","sd"))) sub <- which(true_strategy == 4) print(c('sub', sub)) sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ]) fit_4 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp")) print(fit_4$summary(NULL, c("mean","sd"))) sub <- which(true_strategy == 5) print(c('sub', sub)) sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ]) fit_5 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp")) print(fit_5$summary(NULL, c("mean","sd"))) sub <- which(true_strategy == 6) print(c('sub', sub)) sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ]) fit_6 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp")) print(fit_6$summary(NULL, c("mean","sd"))) # save csv df <-rbind(fit_0$summary(), fit_1$summary(), fit_2$summary(), fit_3$summary(), fit_4$summary(), fit_5$summary(), fit_6$summary()) write.csv(df,file=save_path,quote=FALSE) }