190 lines
7.2 KiB
R
190 lines
7.2 KiB
R
library(tidyverse)
|
|
library(cmdstanr)
|
|
library(dplyr)
|
|
|
|
# index order of the strategies assumed throughout
|
|
strategies <- c("put_fridge", "put_dishwasher", "read_book")
|
|
model_type <- "lstmlast"
|
|
rates <- c("_0", "_10", "_20", "_30", "_40", "_50", "_60", "_70", "_80", "_90", "_100")
|
|
task_type <- "test_task" # new_test_task test_task
|
|
loss_type <- "ce"
|
|
set.seed(9746234)
|
|
if (task_type=="test_task"){
|
|
user_num <- 92
|
|
user <-c(38:(user_num-1))
|
|
N <- 1
|
|
}
|
|
if (task_type=="new_test_task"){
|
|
user_num <- 9
|
|
user <-c(0:(user_num-1))
|
|
N <- 1
|
|
}
|
|
|
|
# read data from csv
|
|
sel <- vector("list", length(strategies))
|
|
act_series <- vector("list", user_num)
|
|
for (u in seq_along(user)){
|
|
for (rate in rates) {
|
|
dat <- vector("list", length(strategies))
|
|
for (i in seq_along(strategies)) {
|
|
if (rate=="_0"){
|
|
dat[[i]] <- read.csv(paste0("prediction/", task_type, "/user", user_num, "/", loss_type, "/", strategies[[i]], "/loss_weight_", model_type, "_prediction_", strategies[[i]], "_user", user[[u]], "_rate_", "10", ".csv")) # _60
|
|
} else if (rate=="_100"){
|
|
dat[[i]] <- read.csv(paste0("prediction/", task_type, "/user", user_num, "/", loss_type, "/", strategies[[i]], "/loss_weight_", model_type, "_prediction_", strategies[[i]], "_user", user[[u]], ".csv")) # _60
|
|
} else{
|
|
dat[[i]] <- read.csv(paste0("prediction/", task_type, "/user", user_num, "/", loss_type, "/", strategies[[i]], "/loss_weight_", model_type, "_prediction_", strategies[[i]], "_user", user[[u]], "_rate", rate, ".csv")) # _60
|
|
}
|
|
# strategy assumed for prediction
|
|
dat[[i]]$assumed_strategy <- strategies[[i]]
|
|
dat[[i]]$index <- dat[[i]]$action_id # sample based on intention
|
|
dat[[i]]$id <- dat[[i]][,1] # sample based on intention
|
|
}
|
|
|
|
# reset N after inference
|
|
if (task_type=="test_task"){
|
|
N <- 1
|
|
}
|
|
if (task_type=="new_test_task"){
|
|
N <- 1
|
|
}
|
|
|
|
# select one action series from one intention
|
|
if (rate == "_0"){
|
|
sel[[1]]<-dat[[1]] %>%
|
|
group_by(task_name) %>%
|
|
sample_n(N)
|
|
|
|
sel[[1]] <- data.frame(sel[[1]])
|
|
act_series[[u]] <- sel[[1]]$action_id
|
|
#print(typeof(sel[[1]]))
|
|
#print(typeof(dat[[1]]))
|
|
#print(sel[[1]]$action_id[2])
|
|
}
|
|
|
|
print(c('unique action id', sel[[1]]$action_id))
|
|
|
|
# filter data from the selected action series, N series per intention
|
|
dat[[1]]<-subset(dat[[1]], dat[[1]]$action_id == sel[[1]]$action_id[1] | dat[[1]]$action_id == sel[[1]]$action_id[2] | dat[[1]]$action_id == sel[[1]]$action_id[3])
|
|
dat[[2]]<-subset(dat[[2]], dat[[2]]$action_id == sel[[1]]$action_id[1] | dat[[2]]$action_id == sel[[1]]$action_id[2] | dat[[2]]$action_id == sel[[1]]$action_id[3])
|
|
dat[[3]]<-subset(dat[[3]], dat[[3]]$action_id == sel[[1]]$action_id[1] | dat[[3]]$action_id == sel[[1]]$action_id[2] | dat[[3]]$action_id == sel[[1]]$action_id[3])
|
|
row.names(dat) <- NULL
|
|
print(c('task name 1', dat[[1]]$task_name))
|
|
print(c('task name 2', dat[[2]]$task_name))
|
|
print(c('task name 3', dat[[3]]$task_name))
|
|
print(c('action id 1', dat[[1]]$action_id))
|
|
print(c('action id 2', dat[[2]]$action_id))
|
|
print(c('action id 3', dat[[3]]$action_id))
|
|
|
|
# create save path
|
|
dir.create(file.path(paste0("result/", task_type, "/user", user_num, "/", loss_type, "/N", N)), showWarnings = FALSE, recursive = TRUE)
|
|
dir.create(file.path("temp"), showWarnings = FALSE)
|
|
save_path <- paste0("result/", task_type, "/user", user_num, "/", loss_type, "/N", N, "/", model_type, "_N", N, "_", "result", rate,"_user", user[[u]], ".csv")
|
|
|
|
if(task_type=="test_task"){
|
|
dat <- do.call(rbind, dat) %>%
|
|
mutate(index = as.numeric(as.factor(id))) %>%
|
|
rename(true_strategy = task_name) %>%
|
|
mutate(
|
|
true_strategy = factor(
|
|
#true_strategy, levels = 0:3,
|
|
true_strategy, levels = 0:2,
|
|
labels = strategies
|
|
),
|
|
q_type = case_when(
|
|
gt %in% c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 19, 20, 22, 23, 29, 30, 31, 32, 33, 34, 35, 37, 38, 39, 40, 42, 43, 44, 58, 59, 64, 65, 68, 69, 70, 71, 72, 73, 74) ~ "put_fridge",
|
|
gt %in% c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 25, 29,30, 31, 32, 33, 34, 37, 38, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57) ~ "put_dishwasher",
|
|
gt %in% c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45) ~ "read_book",
|
|
)
|
|
)
|
|
}
|
|
|
|
if(task_type=="new_test_task"){
|
|
dat <- do.call(rbind, dat) %>%
|
|
mutate(index = as.numeric(as.factor(id))) %>%
|
|
rename(true_strategy = task_name) %>%
|
|
mutate(
|
|
true_strategy = factor(
|
|
true_strategy, levels = 0:2,
|
|
labels = strategies
|
|
),
|
|
q_type = case_when(
|
|
# new_test_set
|
|
gt %in% c(1, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 19, 20, 22, 23, 25, 29, 30, 31, 32, 33, 34, 35, 40, 42, 43, 44, 46, 47, 52, 53, 55, 56, 58, 59, 60, 64, 65, 68, 69, 70, 71, 72, 73, 74, 75, 77, 78) ~ "put_fridge",
|
|
gt %in% c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 21, 22, 23, 24, 25, 29, 30, 31, 32, 33, 34, 35, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74) ~ "put_dishwasher",
|
|
gt %in% c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 60, 75, 76, 77, 78) ~ "read_book",
|
|
)
|
|
)
|
|
}
|
|
#print(nrow(dat))
|
|
#print(dat)
|
|
|
|
dat_obs <- dat %>% filter(assumed_strategy == strategies[[i]])
|
|
N <- nrow(dat_obs)
|
|
print(c("N: ", N))
|
|
q <- dat_obs$gt
|
|
true_strategy <- dat_obs$true_strategy
|
|
|
|
K <- length(unique(dat$assumed_strategy))
|
|
I <- 79
|
|
|
|
P_q_S <- array(dim = c(N, I, K))
|
|
for (n in 1:N) {
|
|
P_q_S[n, , ] <- dat %>%
|
|
filter(index == n) %>%
|
|
select(matches("^act[[:digit:]]+$")) %>%
|
|
as.matrix() %>%
|
|
t()
|
|
for (k in 1:K) {
|
|
# normalize probabilities
|
|
P_q_S[n, , k] <- P_q_S[n, , k] / sum(P_q_S[n, , k])
|
|
}
|
|
}
|
|
|
|
mod <- cmdstan_model(paste0(getwd(),"/strategy_inference_model.stan"))
|
|
|
|
if (rate=="_0"){
|
|
sub <- integer(0)
|
|
} else {
|
|
sub <- which(true_strategy == "put_fridge")
|
|
}
|
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
|
fit_put_fridge <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
|
print(fit_put_fridge$summary(NULL, c("mean","sd")))
|
|
|
|
|
|
if (rate=="_0"){
|
|
sub <- integer(0)
|
|
} else {
|
|
sub <- which(true_strategy == "put_dishwasher")
|
|
}
|
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
|
fit_put_dishwasher <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
|
print(fit_put_dishwasher$summary(NULL, c("mean","sd")))
|
|
|
|
# read_book strategy (should favor index 3)
|
|
if (rate=="_0"){
|
|
sub <- integer(0)
|
|
} else {
|
|
sub <- which(true_strategy == "read_book")
|
|
}
|
|
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
|
fit_read_book <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
|
print(fit_read_book$summary(NULL, c("mean","sd")))
|
|
|
|
# save csv
|
|
df <-rbind(fit_put_fridge$summary(), fit_put_dishwasher$summary(), fit_read_book$summary())
|
|
write.csv(df,file=save_path,quote=FALSE)
|
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|