88 lines
3.3 KiB
Python
88 lines
3.3 KiB
Python
import numpy as np
|
|
from numpy import genfromtxt
|
|
import matplotlib.pyplot as plt
|
|
import argparse
|
|
import pathlib
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='')
|
|
parser.add_argument('--loss_type', type=str, default='ce')
|
|
parser.add_argument('--model_type', type=str, default="lstmlast" )
|
|
parser.add_argument('--task_type', type=str, default='test_task')
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.task_type == 'new_test_task':
|
|
user = 9
|
|
N = 1
|
|
if args.task_type == 'test_task':
|
|
user = 92
|
|
N = 1
|
|
|
|
#rate = range(0,101,10)
|
|
rate_user_data_list = []
|
|
for r in range(0,101,10): # rate = range(0,101,10)
|
|
# read data
|
|
print(r)
|
|
model_data_list = []
|
|
for i in range(user):
|
|
path = "result/"+args.task_type+"/user"+str(user)+"/"+args.loss_type+"/N"+ str(N) + "/" + args.model_type + "_N" + str(N) + "_result_" + str(r) + "_user" + str(i) +".csv"
|
|
data = genfromtxt(path, delimiter=',', skip_header =1)
|
|
data = data[[1,2,3,5,6,7,9,10,11],:][:,[2,4,6,7]]
|
|
model_data_list.append(data)
|
|
#print(type(data))
|
|
model_data_list_total = np.stack(model_data_list)
|
|
mean_user_data = np.mean(model_data_list_total,axis=0)
|
|
rate_user_data_list.append(mean_user_data)
|
|
|
|
color = ['royalblue', 'lightgreen', 'tomato']
|
|
legend = ['put fridge', 'put\n dishwasher', 'read book']
|
|
fig, axs = plt.subplots(3, sharex=True, sharey=True)
|
|
fig.set_figheight(10) # all sample rate: 10; 3 row: 8
|
|
fig.set_figwidth(20)
|
|
axs[0].text(-0.145, 0.9, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 25) # all: -0.3,0.5 3rows: -0.5,0.5
|
|
|
|
for ax in range(3):
|
|
y_total = []
|
|
y_low_total = []
|
|
y_high_total = []
|
|
for j in range(3):
|
|
y= []
|
|
y_low = []
|
|
y_high = []
|
|
for i in range(len(rate_user_data_list)):
|
|
y.append(rate_user_data_list[i][j+ax*3][0])
|
|
y_low.append(rate_user_data_list[i][j+ax*3][2])
|
|
y_high.append(rate_user_data_list[i][j+ax*3][3])
|
|
y_total.append(y)
|
|
y_low_total.append(y_low)
|
|
y_high_total.append(y_high)
|
|
print()
|
|
print("user mean of mean prob: ", np.mean(y))
|
|
print("user mean of sd prob: ", np.std(y))
|
|
|
|
for i in range(3):
|
|
axs[ax].plot(range(0,101,10), y_total[i], color=color[i], label=legend[i])
|
|
axs[ax].fill_between(range(0,101,10), y_low_total[i], y_high_total[i], color=color[i],alpha=0.3 )
|
|
axs[ax].set_xticks(range(0,101,10))
|
|
axs[ax].set_ylabel('probability', fontsize=22)
|
|
|
|
axs[ax].text(-0.145, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 25, color=color[ax])
|
|
axs[ax].tick_params(axis='both', which='major', labelsize=18)
|
|
|
|
plt.xlabel('Percentage of observed actions in one action sequence', fontsize= 22)
|
|
handles, labels = axs[0].get_legend_handles_labels()
|
|
|
|
plt.xlim([0, 101])
|
|
plt.ylim([0, 1])
|
|
pathlib.Path("result/"+args.task_type+"/user"+str(user)+"/"+args.loss_type+"/figure/").mkdir(parents=True, exist_ok=True)
|
|
if args.task_type == 'test_task':
|
|
plt.savefig("result/"+args.task_type+"/user"+str(user)+ "/"+args.loss_type+"/figure/N"+ str(N) + "_"+args.model_type+"_rate_full_test_set_1.png", bbox_inches='tight')
|
|
if args.task_type == 'new_test_task':
|
|
plt.savefig("result/"+args.task_type+"/user"+str(user)+ "/"+args.loss_type+"/figure/N"+ str(N) + "_"+args.model_type+"_rate_full_test_set_2.png", bbox_inches='tight')
|
|
|
|
plt.show()
|
|
|
|
if __name__ == '__main__':
|
|
main()
|
|
|