InferringIntention/watch_and_help/stan/save_act_series.R

77 lines
2.7 KiB
R
Raw Permalink Normal View History

2024-03-24 23:42:27 +01:00
library(tidyverse)
library(cmdstanr)
library(dplyr)
strategies <- c("put_fridge", "put_dishwasher", "read_book")
model_type <- "lstmlast_cross_entropy_bs_32_iter_2000_train_task_prob"
rate <- "_0"
task_type <- "new_test_task" # new_test_task test_task
loss_type <- "ce"
set.seed(9746234)
if (task_type=="test_task"){
user_num <- 92
user <-c(0:(user_num-1))
N <- 1
}
if (task_type=="new_test_task"){
user_num <- 9
user <-c(0:(user_num-1))
N <- 1
}
total_user_act1 <- vector("list", length(user_num))
total_user_act2 <- vector("list", length(user_num))
sel <- vector("list", length(strategies))
act_series <- vector("list", user_num)
for (u in seq_along(user)){
print('user')
print(u)
dat <- vector("list", length(strategies))
for (i in seq_along(strategies)) {
if (rate=="_0"){
dat[[i]] <- read.csv(paste0("prediction/", task_type, "/user", user_num, "/", loss_type, "/", strategies[[i]], "/loss_weight_", model_type, "_prediction_", strategies[[i]], "_user", user[[u]], "_rate_", "90", ".csv"))
} else if (rate=="_100"){
dat[[i]] <- read.csv(paste0("prediction/", task_type, "/user", user_num, "/", loss_type, "/", strategies[[i]], "/loss_weight_", model_type, "_prediction_", strategies[[i]], "_user", user[[u]], ".csv"))
} else{
dat[[i]] <- read.csv(paste0("prediction/", task_type, "/user", user_num, "/", loss_type, "/", strategies[[i]], "/loss_weight_", model_type, "_prediction_", strategies[[i]], "_user", user[[u]], "_rate", rate, ".csv"))
}
dat[[i]]$assumed_strategy <- strategies[[i]]
dat[[i]]$index <- dat[[i]]$action_id # sample based on intention
dat[[i]]$id <- dat[[i]][,1] # sample based on intention
}
N <- 1
# select all action series and infer every one
sel[[1]]<-dat[[1]] %>%
group_by(task_name) %>%
filter(task_name==1)
sel[[1]] <- data.frame(sel[[1]])
unique_act_id_t1 <- unique(sel[[1]]$action_id)
write.csv(unique_act_id_t1, paste0("result/", task_type, "/user", user_num, "/", loss_type, "/act", "/", "action_series_", "user_",u, "_put_dishwasher", ".csv"))
total_user_act1[[u]] <- unique_act_id_t1
sel[[1]]<-dat[[1]] %>%
group_by(task_name) %>%
filter(task_name==2)
sel[[1]] <- data.frame(sel[[1]])
unique_act_id_t1 <- unique(sel[[1]]$action_id)
write.csv(unique_act_id_t1, paste0("result/", task_type, "/user", user_num, "/", loss_type, "/act", "/", "action_series_", "user_",u, "_read_book", ".csv"))
total_user_act2[[u]] <- unique_act_id_t1
}
write.csv(total_user_act1, paste0("result/", task_type, "/user", user_num, "/", loss_type, "/act", "/", "action_series_", "_put_dishwasher_total", ".csv"))
write.csv(total_user_act2, paste0("result/", task_type, "/user", user_num, "/", loss_type, "/act", "/", "action_series_", "read_book_total", ".csv"))