65 lines
1.9 KiB
Python
65 lines
1.9 KiB
Python
|
import numpy as np
|
||
|
from numpy import genfromtxt
|
||
|
import csv
|
||
|
import pandas
|
||
|
import argparse
|
||
|
|
||
|
def sample_predciton(path, rate):
|
||
|
data = pandas.read_csv(path).values
|
||
|
task_list = [0, 1, 2]
|
||
|
start = 0
|
||
|
stop = 0
|
||
|
num_unique = np.unique(data[:,1])
|
||
|
#print('unique number', num_unique)
|
||
|
|
||
|
samples = []
|
||
|
for j in task_list:
|
||
|
for i in num_unique:
|
||
|
inx = np.where((data[:,1] == i) & (data[:,-2] == j))
|
||
|
samples.append(data[inx])
|
||
|
|
||
|
for i in range(len(samples)):
|
||
|
n = int(len(samples[i])*(100-rate)/100)
|
||
|
samples[i] = samples[i][:-n]
|
||
|
|
||
|
return np.vstack(samples)
|
||
|
|
||
|
def main():
|
||
|
parser = argparse.ArgumentParser(description='')
|
||
|
parser.add_argument('--LOSS', type=str, default='ce')
|
||
|
parser.add_argument('--MODEL_TYPE', type=str, default="lstmlast_cross_entropy_bs_32_iter_2000_train_task_prob" )
|
||
|
parser.add_argument('--EPOCHS', type=int, default=50)
|
||
|
parser.add_argument('--TASK', type=str, default='test_task')
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
task = ['put_fridge', 'put_dishwasher', 'read_book']
|
||
|
sets = [args.TASK]
|
||
|
rate = [10, 20, 30, 40, 50, 60, 70, 80, 90]
|
||
|
|
||
|
for i in task:
|
||
|
for j in rate:
|
||
|
for k in sets:
|
||
|
if k == 'test_task':
|
||
|
user_num = 92
|
||
|
if k == 'new_test_task':
|
||
|
user_num = 9
|
||
|
|
||
|
for l in range(user_num):
|
||
|
pred_path = "prediction/" + k + "/" + "user" + str(user_num) + "/ce/" + i + "/" + "loss_weight_" + args.MODEL_TYPE + "_prediction_" + i + "_user" + str(l) + ".csv"
|
||
|
save_path = "prediction/" + k + "/" + "user" + str(user_num) + "/ce/" + i + "/" + "loss_weight_" + args.MODEL_TYPE + "_prediction_" + i + "_user" + str(l) + "_rate_" + str(j) + ".csv"
|
||
|
data = sample_predciton(pred_path, j)
|
||
|
|
||
|
head = []
|
||
|
for r in range(79):
|
||
|
head.append('act'+str(r+1))
|
||
|
head.append('task_name')
|
||
|
head.append('gt')
|
||
|
head.insert(0,'action_id')
|
||
|
pandas.DataFrame(data[:,1:]).to_csv(save_path, header=head)
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|