88 lines
3.1 KiB
Python
88 lines
3.1 KiB
Python
|
import numpy as np
|
||
|
import pathlib
|
||
|
import argparse
|
||
|
|
||
|
np.random.seed(seed=100)
|
||
|
|
||
|
def sample_user(data, num_users, split_inx):
|
||
|
np.random.seed(seed=100)
|
||
|
num_unique3 = np.unique(data[:,1])
|
||
|
num_unique2 = num_unique3[0:split_inx[1]]
|
||
|
num_unique = num_unique3[0:split_inx[0]]
|
||
|
|
||
|
user_list1 = [np.random.choice(num_unique, int(len(num_unique)/num_users), replace=False) for i in range(num_users)]
|
||
|
user_list2 = [np.random.choice(num_unique2, int(len(num_unique2)/num_users), replace=False) for i in range(num_users)]
|
||
|
user_list3 = [np.random.choice(num_unique3, int(len(num_unique3)/num_users), replace=False) for i in range(num_users)]
|
||
|
|
||
|
user_data = []
|
||
|
|
||
|
for i in range(num_users): # len(user_list)
|
||
|
user_idx1 = [int(item) for item in user_list1[i]]
|
||
|
user_idx2 = [int(item) for item in user_list2[i]]
|
||
|
user_idx3 = [int(item) for item in user_list3[i]]
|
||
|
|
||
|
data_list = []
|
||
|
for j in range(len(user_idx1)):
|
||
|
inx = np.where((data[:,1] == user_idx1[j]) & (data[:,-2]==0))
|
||
|
data_list.append(data[inx])
|
||
|
|
||
|
for j in range(len(user_idx2)):
|
||
|
inx = np.where((data[:,1] == user_idx2[j]) & (data[:,-2]==1))
|
||
|
data_list.append(data[inx])
|
||
|
|
||
|
for j in range(len(user_idx3)):
|
||
|
inx = np.where((data[:,1] == user_idx3[j]) & (data[:,-2]==2))
|
||
|
data_list.append(data[inx])
|
||
|
|
||
|
user_data.append(np.vstack(data_list))
|
||
|
|
||
|
return user_data
|
||
|
|
||
|
|
||
|
def main():
|
||
|
parser = argparse.ArgumentParser(description='')
|
||
|
parser.add_argument('--LOSS', type=str, default='ce')
|
||
|
parser.add_argument('--MODEL_TYPE', type=str, default="lstmlast_cross_entropy_bs_32_iter_2000_train_task_prob" )
|
||
|
parser.add_argument('--EPOCHS', type=int, default=50)
|
||
|
parser.add_argument('--TASK', type=str, default='test_task')
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
pref = ['put_fridge', 'put_dishwasher', 'read_book']
|
||
|
|
||
|
if args.TASK == 'new_test_task':
|
||
|
NUM_USER = 9 # 9 for 1 user 1 action
|
||
|
SPLIT_INX = [NUM_USER, 45]
|
||
|
if args.TASK == 'test_task':
|
||
|
NUM_USER = 92
|
||
|
SPLIT_INX = [NUM_USER, 229]
|
||
|
|
||
|
head = []
|
||
|
for j in range(79):
|
||
|
head.append('act'+str(j+1))
|
||
|
head.append('task_name')
|
||
|
head.append('gt')
|
||
|
head.insert(0,'action_id')
|
||
|
head.insert(0,'')
|
||
|
|
||
|
for i in pref:
|
||
|
path = "prediction/"+args.TASK+"/" + args.MODEL_TYPE + "/model_" + i + "_strategy_put_fridge" +".csv"
|
||
|
data = np.genfromtxt(path, skip_header=1, delimiter=',')
|
||
|
data_task_name = np.genfromtxt(path, skip_header=1, delimiter=',', usecols=-2, dtype=None)
|
||
|
data_task_name[data_task_name==b'put_fridge'] = 0
|
||
|
data_task_name[data_task_name==b'put_dishwasher'] = 1
|
||
|
data_task_name[data_task_name==b'read_book'] = 2
|
||
|
data[:,-2] = data_task_name.astype(np.float)
|
||
|
print("data length: ", len(data))
|
||
|
users_data = sample_user(data, NUM_USER, SPLIT_INX)
|
||
|
|
||
|
length = 0
|
||
|
pathlib.Path("prediction/"+args.TASK+"/user" + str(NUM_USER) + "/" + args.LOSS + "/" + i).mkdir(parents=True, exist_ok=True)
|
||
|
for j in range(len(users_data)):
|
||
|
save_path = "prediction/"+args.TASK+"/user" + str(NUM_USER) + "/" + args.LOSS + "/" + i +"/loss_weight_"+ args.MODEL_TYPE + "_prediction_"+ i + "_user"+str(j)+".csv"
|
||
|
length = length + len(users_data[j])
|
||
|
np.savetxt(save_path, users_data[j], delimiter=',', header=','.join(head))
|
||
|
print("user data length: ", length)
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|