InferringIntention/keyboard_and_mouse/sampler_single_act.py

57 lines
2 KiB
Python
Raw Permalink Normal View History

2024-03-24 23:42:27 +01:00
import numpy as np
from numpy import genfromtxt
import csv
import pandas
from pathlib import Path
import argparse
def sample_single_act(pred_path, save_path, j):
data = pandas.read_csv(pred_path).values
total_data = []
for u in range(1,6):
act_data = data[data[:,1]==u]
final_save_path = save_path + "/rate_" + str(j) + "_act_" + str(int(u)) + "_pred.csv"
head = []
for r in range(7):
head.append('act'+str(r+1))
head.append('task_name')
head.append('gt')
head.insert(0,'action_id')
pandas.DataFrame(act_data[:,1:]).to_csv(final_save_path, header=head)
def main():
# parsing parameters
parser = argparse.ArgumentParser(description='')
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
parser.add_argument('--hidden_size', type=int, default=64, help='hidden_size')
parser.add_argument('--model_type', type=str, default='lstmlast', help='model type')
args = parser.parse_args()
task = np.arange(7)
user_num = 5
bs = args.batch_size
lr = args.lr # 1e-4
hs = args.hidden_size #128
model_type = args.model_type #'lstmlast'
rate = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
for i in task:
for j in rate:
for l in range(user_num):
pred_path = "prediction/task" + str(i) + "/" + model_type + "_bs_" + str(bs) + "_lr_" + str(lr) + "_hidden_size_" + str(hs) + "/user" + str(l) + "_rate_" + str(j) + "_pred.csv"
if j == 100:
pred_path = "prediction/task" + str(i) + "/" + model_type + "_bs_" + str(bs) + "_lr_" + str(lr) + "_hidden_size_" + str(hs) + "/user" + str(l) + "_pred.csv"
save_path = "prediction/single_act/task" + str(i) + "/" + model_type + "_bs_" + str(bs) + "_lr_" + str(lr) + "_hidden_size_" + str(hs) + "/user" + str(l)
Path(save_path).mkdir(parents=True, exist_ok=True)
data = sample_single_act(pred_path, save_path, j)
if __name__ == '__main__':
# split the prediction by action sequence id, from 10% to 90%
main()