InferringIntention/watch_and_help/stan/split_user.py

88 lines
3.1 KiB
Python
Raw Permalink Normal View History

2024-03-24 23:42:27 +01:00
import numpy as np
import pathlib
import argparse
np.random.seed(seed=100)
def sample_user(data, num_users, split_inx):
np.random.seed(seed=100)
num_unique3 = np.unique(data[:,1])
num_unique2 = num_unique3[0:split_inx[1]]
num_unique = num_unique3[0:split_inx[0]]
user_list1 = [np.random.choice(num_unique, int(len(num_unique)/num_users), replace=False) for i in range(num_users)]
user_list2 = [np.random.choice(num_unique2, int(len(num_unique2)/num_users), replace=False) for i in range(num_users)]
user_list3 = [np.random.choice(num_unique3, int(len(num_unique3)/num_users), replace=False) for i in range(num_users)]
user_data = []
for i in range(num_users): # len(user_list)
user_idx1 = [int(item) for item in user_list1[i]]
user_idx2 = [int(item) for item in user_list2[i]]
user_idx3 = [int(item) for item in user_list3[i]]
data_list = []
for j in range(len(user_idx1)):
inx = np.where((data[:,1] == user_idx1[j]) & (data[:,-2]==0))
data_list.append(data[inx])
for j in range(len(user_idx2)):
inx = np.where((data[:,1] == user_idx2[j]) & (data[:,-2]==1))
data_list.append(data[inx])
for j in range(len(user_idx3)):
inx = np.where((data[:,1] == user_idx3[j]) & (data[:,-2]==2))
data_list.append(data[inx])
user_data.append(np.vstack(data_list))
return user_data
def main():
parser = argparse.ArgumentParser(description='')
parser.add_argument('--LOSS', type=str, default='ce')
parser.add_argument('--MODEL_TYPE', type=str, default="lstmlast_cross_entropy_bs_32_iter_2000_train_task_prob" )
parser.add_argument('--EPOCHS', type=int, default=50)
parser.add_argument('--TASK', type=str, default='test_task')
args = parser.parse_args()
pref = ['put_fridge', 'put_dishwasher', 'read_book']
if args.TASK == 'new_test_task':
NUM_USER = 9 # 9 for 1 user 1 action
SPLIT_INX = [NUM_USER, 45]
if args.TASK == 'test_task':
NUM_USER = 92
SPLIT_INX = [NUM_USER, 229]
head = []
for j in range(79):
head.append('act'+str(j+1))
head.append('task_name')
head.append('gt')
head.insert(0,'action_id')
head.insert(0,'')
for i in pref:
path = "prediction/"+args.TASK+"/" + args.MODEL_TYPE + "/model_" + i + "_strategy_put_fridge" +".csv"
data = np.genfromtxt(path, skip_header=1, delimiter=',')
data_task_name = np.genfromtxt(path, skip_header=1, delimiter=',', usecols=-2, dtype=None)
data_task_name[data_task_name==b'put_fridge'] = 0
data_task_name[data_task_name==b'put_dishwasher'] = 1
data_task_name[data_task_name==b'read_book'] = 2
data[:,-2] = data_task_name.astype(np.float)
print("data length: ", len(data))
users_data = sample_user(data, NUM_USER, SPLIT_INX)
length = 0
pathlib.Path("prediction/"+args.TASK+"/user" + str(NUM_USER) + "/" + args.LOSS + "/" + i).mkdir(parents=True, exist_ok=True)
for j in range(len(users_data)):
save_path = "prediction/"+args.TASK+"/user" + str(NUM_USER) + "/" + args.LOSS + "/" + i +"/loss_weight_"+ args.MODEL_TYPE + "_prediction_"+ i + "_user"+str(j)+".csv"
length = length + len(users_data[j])
np.savetxt(save_path, users_data[j], delimiter=',', header=','.join(head))
print("user data length: ", length)
if __name__ == '__main__':
main()