library(tidyverse) library(cmdstanr) library(dplyr) # index order of the strategies assumed throughout model_type <- "lstmlast" batch_size <- "8" lr <- "0.0001" hidden_size <- "128" model_type <- paste0(model_type, "_bs_", batch_size, "_lr_", lr, "_hidden_size_", hidden_size) rates <- c("_0", "_10", "_20", "_30", "_40", "_50", "_60", "_70", "_80", "_90", "_100") user_num <- 5 user <-c(0:(user_num-1)) strategies <- c(0:6) # 7 tasks print(strategies) print(length(strategies)) N <- 1 set.seed(9736754) #read data from csv sel <- vector("list", length(strategies)) for (u in seq_along(user)){ print('user') print(u) for (rate in rates) { dat <- vector("list", length(strategies)) for (i in seq_along(strategies)) { if (rate=="_0"){ dat[[i]] <- read.csv(paste0("../prediction/task", strategies[[i]], "/", model_type, "/user", user[[u]], "_rate_10", "_pred", ".csv")) } else if (rate=="_100"){ dat[[i]] <- read.csv(paste0("../prediction/task", strategies[[i]], "/", model_type, "/user", user[[u]], "_pred", ".csv")) } else{ dat[[i]] <- read.csv(paste0("../prediction/task", strategies[[i]], "/", model_type, "/user", user[[u]], "_rate", rate, "_pred", ".csv")) } # strategy assumed for prediction dat[[i]]$assumed_strategy <- strategies[[i]] dat[[i]]$index <- dat[[i]]$action_id dat[[i]]$id <- dat[[i]][,1] } # reset N after inference N <- 1 # select all action series and infer every one if (rate == "_0"){ sel[[1]]<-dat[[1]] %>% group_by(task_name) %>% sample_n(N) sel[[1]] <- data.frame(sel[[1]]) unique_act_id <- unique(sel[[1]]$action_id) } print(sel[[1]]$action_id) print(sel[[1]]$task_name) print(dat[[1]]$task_name) for (i in seq_along(strategies)) { dat[[i]]<-subset(dat[[i]], dat[[i]]$action_id == sel[[1]]$action_id[1]) } row.names(dat) <- NULL print(c('action id', dat[[1]]$action_id)) print(c('action id', dat[[2]]$action_id)) print(c('action id', dat[[3]]$action_id)) dir.create(file.path(paste0("result/", "N", N)), showWarnings = FALSE) save_path <- paste0("result/", "N", N, "/", model_type, "_N", N, "_", "result","_user", user[[u]], "_rate_", rate, ".csv") dat_act <- do.call(rbind, dat) %>% mutate(index = as.numeric(as.factor(id))) %>% rename(true_strategy = task_name) %>% mutate( true_strategy = factor( true_strategy, levels = 0:6, labels = strategies ), q_type = case_when( gt %in% c(3,4,5) ~ 0, gt %in% c(1,2,3,4,5,6,7) ~ 1, gt %in% c(1,2,3,4) ~ 2, gt %in% c(1,4,5,6,7) ~ 3, gt %in% c(1,2,3,6,7) ~ 4, gt %in% c(2,3,4,5,6,7) ~ 5, gt %in% c(1,2,3,4,5,6,7) ~ 6, ) ) dat_obs <- dat_act %>% filter(assumed_strategy == strategies[[i]]) # put_fridge, was num N <- nrow(dat_obs) print(c("N: ", N)) print(c("dim dat_act: ", dim(dat_act))) q <- dat_obs$gt true_strategy <- dat_obs$true_strategy K <- length(unique(dat_act$assumed_strategy)) I <- 7 P_q_S <- array(dim = c(N, I, K)) for (n in 1:N) { print(n) P_q_S[n, , ] <- dat_act %>% filter(index == n) %>% select(matches("^act[[:digit:]]+$")) %>% as.matrix() %>% t() for (k in 1:K) { # normalize probabilities P_q_S[n, , k] <- P_q_S[n, , k] / sum(P_q_S[n, , k]) } } print(c("dim(P_q_S)", dim(P_q_S))) mod <- cmdstan_model(paste0(getwd(),"/strategy_inference_model.stan")) if (rate=="_0"){ sub <- integer(0) } else { sub <- which(true_strategy == 0) # "0" } if (length(sub) == 1){ temp <- P_q_S[sub, , ] dim(temp) <- c(1, dim(temp)) sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp) } else{ sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ]) } fit_0 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp")) print(fit_0$summary(NULL, c("mean","sd"))) if (rate=="_0"){ sub <- integer(0) } else { sub <- which(true_strategy == 1) } if (length(sub) == 1){ temp <- P_q_S[sub, , ] dim(temp) <- c(1, dim(temp)) sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp) } else{ sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ]) } fit_1 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp")) print(fit_1$summary(NULL, c("mean","sd"))) if (rate=="_0"){ sub <- integer(0) } else { sub <- which(true_strategy == 2) } if (length(sub) == 1){ temp <- P_q_S[sub, , ] dim(temp) <- c(1, dim(temp)) sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp) } else{ sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ]) } fit_2 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp")) print(fit_2$summary(NULL, c("mean","sd"))) if (rate=="_0"){ sub <- integer(0) } else { sub <- which(true_strategy == 3) } if (length(sub) == 1){ temp <- P_q_S[sub, , ] dim(temp) <- c(1, dim(temp)) sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp) } else{ sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ]) } fit_3 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp")) print(fit_3$summary(NULL, c("mean","sd"))) if (rate=="_0"){ sub <- integer(0) } else { sub <- which(true_strategy == 4) } if (length(sub) == 1){ temp <- P_q_S[sub, , ] dim(temp) <- c(1, dim(temp)) sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp) } else{ sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ]) } fit_4 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp")) print(fit_4$summary(NULL, c("mean","sd"))) if (rate=="_0"){ sub <- integer(0) } else { sub <- which(true_strategy == 5) } if (length(sub) == 1){ temp <- P_q_S[sub, , ] dim(temp) <- c(1, dim(temp)) sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp) } else{ sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ]) } fit_5 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp")) print(fit_5$summary(NULL, c("mean","sd"))) if (rate=="_0"){ sub <- integer(0) } else { sub <- which(true_strategy == 6) } if (length(sub) == 1){ temp <- P_q_S[sub, , ] dim(temp) <- c(1, dim(temp)) sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp) } else{ sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ]) } fit_6 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp")) print(fit_6$summary(NULL, c("mean","sd"))) # save csv df <-rbind(fit_0$summary(), fit_1$summary(), fit_2$summary(), fit_3$summary(), fit_4$summary(), fit_5$summary(), fit_6$summary()) write.csv(df,file=save_path,quote=FALSE) } }