76 lines
2.7 KiB
R
76 lines
2.7 KiB
R
library(tidyverse)
|
|
library(cmdstanr)
|
|
library(dplyr)
|
|
|
|
strategies <- c("put_fridge", "put_dishwasher", "read_book")
|
|
model_type <- "lstmlast_cross_entropy_bs_32_iter_2000_train_task_prob"
|
|
rate <- "_0"
|
|
task_type <- "new_test_task" # new_test_task test_task
|
|
loss_type <- "ce"
|
|
set.seed(9746234)
|
|
if (task_type=="test_task"){
|
|
user_num <- 92
|
|
user <-c(0:(user_num-1))
|
|
N <- 1
|
|
}
|
|
if (task_type=="new_test_task"){
|
|
user_num <- 9
|
|
user <-c(0:(user_num-1))
|
|
N <- 1
|
|
}
|
|
total_user_act1 <- vector("list", length(user_num))
|
|
total_user_act2 <- vector("list", length(user_num))
|
|
|
|
sel <- vector("list", length(strategies))
|
|
act_series <- vector("list", user_num)
|
|
for (u in seq_along(user)){
|
|
print('user')
|
|
print(u)
|
|
dat <- vector("list", length(strategies))
|
|
for (i in seq_along(strategies)) {
|
|
if (rate=="_0"){
|
|
dat[[i]] <- read.csv(paste0("prediction/", task_type, "/user", user_num, "/", loss_type, "/", strategies[[i]], "/loss_weight_", model_type, "_prediction_", strategies[[i]], "_user", user[[u]], "_rate_", "90", ".csv"))
|
|
} else if (rate=="_100"){
|
|
dat[[i]] <- read.csv(paste0("prediction/", task_type, "/user", user_num, "/", loss_type, "/", strategies[[i]], "/loss_weight_", model_type, "_prediction_", strategies[[i]], "_user", user[[u]], ".csv"))
|
|
} else{
|
|
dat[[i]] <- read.csv(paste0("prediction/", task_type, "/user", user_num, "/", loss_type, "/", strategies[[i]], "/loss_weight_", model_type, "_prediction_", strategies[[i]], "_user", user[[u]], "_rate", rate, ".csv"))
|
|
}
|
|
dat[[i]]$assumed_strategy <- strategies[[i]]
|
|
dat[[i]]$index <- dat[[i]]$action_id # sample based on intention
|
|
dat[[i]]$id <- dat[[i]][,1] # sample based on intention
|
|
}
|
|
|
|
N <- 1
|
|
# select all action series and infer every one
|
|
sel[[1]]<-dat[[1]] %>%
|
|
group_by(task_name) %>%
|
|
filter(task_name==1)
|
|
sel[[1]] <- data.frame(sel[[1]])
|
|
unique_act_id_t1 <- unique(sel[[1]]$action_id)
|
|
write.csv(unique_act_id_t1, paste0("result/", task_type, "/user", user_num, "/", loss_type, "/act", "/", "action_series_", "user_",u, "_put_dishwasher", ".csv"))
|
|
total_user_act1[[u]] <- unique_act_id_t1
|
|
|
|
sel[[1]]<-dat[[1]] %>%
|
|
group_by(task_name) %>%
|
|
filter(task_name==2)
|
|
sel[[1]] <- data.frame(sel[[1]])
|
|
unique_act_id_t1 <- unique(sel[[1]]$action_id)
|
|
write.csv(unique_act_id_t1, paste0("result/", task_type, "/user", user_num, "/", loss_type, "/act", "/", "action_series_", "user_",u, "_read_book", ".csv"))
|
|
total_user_act2[[u]] <- unique_act_id_t1
|
|
}
|
|
|
|
write.csv(total_user_act1, paste0("result/", task_type, "/user", user_num, "/", loss_type, "/act", "/", "action_series_", "_put_dishwasher_total", ".csv"))
|
|
write.csv(total_user_act2, paste0("result/", task_type, "/user", user_num, "/", loss_type, "/act", "/", "action_series_", "read_book_total", ".csv"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|