InferringIntention/watch_and_help/stan/strategy_inference_test.R

191 lines
7.2 KiB
R
Raw Permalink Normal View History

2024-03-24 23:42:27 +01:00
library(tidyverse)
library(cmdstanr)
library(dplyr)
# index order of the strategies assumed throughout
strategies <- c("put_fridge", "put_dishwasher", "read_book")
model_type <- "lstmlast"
rates <- c("_0", "_10", "_20", "_30", "_40", "_50", "_60", "_70", "_80", "_90", "_100")
task_type <- "test_task" # new_test_task test_task
loss_type <- "ce"
set.seed(9746234)
if (task_type=="test_task"){
user_num <- 92
user <-c(38:(user_num-1))
N <- 1
}
if (task_type=="new_test_task"){
user_num <- 9
user <-c(0:(user_num-1))
N <- 1
}
# read data from csv
sel <- vector("list", length(strategies))
act_series <- vector("list", user_num)
for (u in seq_along(user)){
for (rate in rates) {
dat <- vector("list", length(strategies))
for (i in seq_along(strategies)) {
if (rate=="_0"){
dat[[i]] <- read.csv(paste0("prediction/", task_type, "/user", user_num, "/", loss_type, "/", strategies[[i]], "/loss_weight_", model_type, "_prediction_", strategies[[i]], "_user", user[[u]], "_rate_", "10", ".csv")) # _60
} else if (rate=="_100"){
dat[[i]] <- read.csv(paste0("prediction/", task_type, "/user", user_num, "/", loss_type, "/", strategies[[i]], "/loss_weight_", model_type, "_prediction_", strategies[[i]], "_user", user[[u]], ".csv")) # _60
} else{
dat[[i]] <- read.csv(paste0("prediction/", task_type, "/user", user_num, "/", loss_type, "/", strategies[[i]], "/loss_weight_", model_type, "_prediction_", strategies[[i]], "_user", user[[u]], "_rate", rate, ".csv")) # _60
}
# strategy assumed for prediction
dat[[i]]$assumed_strategy <- strategies[[i]]
dat[[i]]$index <- dat[[i]]$action_id # sample based on intention
dat[[i]]$id <- dat[[i]][,1] # sample based on intention
}
# reset N after inference
if (task_type=="test_task"){
N <- 1
}
if (task_type=="new_test_task"){
N <- 1
}
# select one action series from one intention
if (rate == "_0"){
sel[[1]]<-dat[[1]] %>%
group_by(task_name) %>%
sample_n(N)
sel[[1]] <- data.frame(sel[[1]])
act_series[[u]] <- sel[[1]]$action_id
#print(typeof(sel[[1]]))
#print(typeof(dat[[1]]))
#print(sel[[1]]$action_id[2])
}
print(c('unique action id', sel[[1]]$action_id))
# filter data from the selected action series, N series per intention
dat[[1]]<-subset(dat[[1]], dat[[1]]$action_id == sel[[1]]$action_id[1] | dat[[1]]$action_id == sel[[1]]$action_id[2] | dat[[1]]$action_id == sel[[1]]$action_id[3])
dat[[2]]<-subset(dat[[2]], dat[[2]]$action_id == sel[[1]]$action_id[1] | dat[[2]]$action_id == sel[[1]]$action_id[2] | dat[[2]]$action_id == sel[[1]]$action_id[3])
dat[[3]]<-subset(dat[[3]], dat[[3]]$action_id == sel[[1]]$action_id[1] | dat[[3]]$action_id == sel[[1]]$action_id[2] | dat[[3]]$action_id == sel[[1]]$action_id[3])
row.names(dat) <- NULL
print(c('task name 1', dat[[1]]$task_name))
print(c('task name 2', dat[[2]]$task_name))
print(c('task name 3', dat[[3]]$task_name))
print(c('action id 1', dat[[1]]$action_id))
print(c('action id 2', dat[[2]]$action_id))
print(c('action id 3', dat[[3]]$action_id))
# create save path
dir.create(file.path(paste0("result/", task_type, "/user", user_num, "/", loss_type, "/N", N)), showWarnings = FALSE, recursive = TRUE)
dir.create(file.path("temp"), showWarnings = FALSE)
save_path <- paste0("result/", task_type, "/user", user_num, "/", loss_type, "/N", N, "/", model_type, "_N", N, "_", "result", rate,"_user", user[[u]], ".csv")
if(task_type=="test_task"){
dat <- do.call(rbind, dat) %>%
mutate(index = as.numeric(as.factor(id))) %>%
rename(true_strategy = task_name) %>%
mutate(
true_strategy = factor(
#true_strategy, levels = 0:3,
true_strategy, levels = 0:2,
labels = strategies
),
q_type = case_when(
gt %in% c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 19, 20, 22, 23, 29, 30, 31, 32, 33, 34, 35, 37, 38, 39, 40, 42, 43, 44, 58, 59, 64, 65, 68, 69, 70, 71, 72, 73, 74) ~ "put_fridge",
gt %in% c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 25, 29,30, 31, 32, 33, 34, 37, 38, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57) ~ "put_dishwasher",
gt %in% c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45) ~ "read_book",
)
)
}
if(task_type=="new_test_task"){
dat <- do.call(rbind, dat) %>%
mutate(index = as.numeric(as.factor(id))) %>%
rename(true_strategy = task_name) %>%
mutate(
true_strategy = factor(
true_strategy, levels = 0:2,
labels = strategies
),
q_type = case_when(
# new_test_set
gt %in% c(1, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 19, 20, 22, 23, 25, 29, 30, 31, 32, 33, 34, 35, 40, 42, 43, 44, 46, 47, 52, 53, 55, 56, 58, 59, 60, 64, 65, 68, 69, 70, 71, 72, 73, 74, 75, 77, 78) ~ "put_fridge",
gt %in% c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 21, 22, 23, 24, 25, 29, 30, 31, 32, 33, 34, 35, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74) ~ "put_dishwasher",
gt %in% c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 60, 75, 76, 77, 78) ~ "read_book",
)
)
}
#print(nrow(dat))
#print(dat)
dat_obs <- dat %>% filter(assumed_strategy == strategies[[i]])
N <- nrow(dat_obs)
print(c("N: ", N))
q <- dat_obs$gt
true_strategy <- dat_obs$true_strategy
K <- length(unique(dat$assumed_strategy))
I <- 79
P_q_S <- array(dim = c(N, I, K))
for (n in 1:N) {
P_q_S[n, , ] <- dat %>%
filter(index == n) %>%
select(matches("^act[[:digit:]]+$")) %>%
as.matrix() %>%
t()
for (k in 1:K) {
# normalize probabilities
P_q_S[n, , k] <- P_q_S[n, , k] / sum(P_q_S[n, , k])
}
}
mod <- cmdstan_model(paste0(getwd(),"/strategy_inference_model.stan"))
if (rate=="_0"){
sub <- integer(0)
} else {
sub <- which(true_strategy == "put_fridge")
}
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
fit_put_fridge <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
print(fit_put_fridge$summary(NULL, c("mean","sd")))
if (rate=="_0"){
sub <- integer(0)
} else {
sub <- which(true_strategy == "put_dishwasher")
}
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
fit_put_dishwasher <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
print(fit_put_dishwasher$summary(NULL, c("mean","sd")))
# read_book strategy (should favor index 3)
if (rate=="_0"){
sub <- integer(0)
} else {
sub <- which(true_strategy == "read_book")
}
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
fit_read_book <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
print(fit_read_book$summary(NULL, c("mean","sd")))
# save csv
df <-rbind(fit_put_fridge$summary(), fit_put_dishwasher$summary(), fit_read_book$summary())
write.csv(df,file=save_path,quote=FALSE)
}
}