InferringIntention/watch_and_help/stan/plot_user_length.py

133 lines
6.5 KiB
Python
Raw Normal View History

2024-03-24 23:42:27 +01:00
import numpy as np
from numpy import genfromtxt
import matplotlib.pyplot as plt
import argparse
import pathlib
def main(args):
if args.task_type == 'new_test_task':
user = 9
N = 1
if args.task_type == 'test_task':
user = 92
N = 1
rate = 100
widths = [-0.1, 0, 0.1]
user_table = [6, 13, 15, 19, 20, 23, 27, 30, 33, 44, 46, 49, 50, 51, 52, 53, 54, 56, 65, 71, 84]
# read data
model_data_list = []
user_list = []
if not args.plot_user_list:
for i in range(user):
path = "result/"+args.task_type+"/user"+str(user)+"/"+args.loss_type+"/N"+ str(N) + "/" + args.model_type + "_N" + str(N) + "_result_" + str(rate) + "_user" + str(i) +".csv"
data = genfromtxt(path, delimiter=',', skip_header =1)
data = data[[1,2,3,5,6,7,9,10,11],:][:,[2,4,6,7]]
model_data_list.append(data)
if args.task_type == 'test_task':
user_list.append(np.transpose(data[:,[0]]))
else:
for i in range(user):
for t in user_table:
if t == i+1:
path = "result/"+args.task_type+"/user"+str(user)+"/"+args.loss_type+"/N"+ str(N) + "/" + args.model_type + "_N" + str(N) + "_result_" + str(rate) + "_user" + str(i) +".csv"
data = genfromtxt(path, delimiter=',', skip_header =1)
data = data[[1,2,3,5,6,7,9,10,11],:][:,[2,4,6,7]]
model_data_list.append(data)
user_list.append(np.transpose(data[:,[0]]))
color = ['royalblue', 'lightgreen', 'tomato']
legend = ['put fridge', 'put\n dishwasher', 'read book']
fig, axs = plt.subplots(3, sharex=True, sharey=True)
fig.set_figheight(10) # all sample rate: 10; 3 row: 8
fig.set_figwidth(20)
for ax in range(3):
y_total = []
y_low_total = []
y_high_total = []
for j in range(3):
y= []
y_low = []
y_high = []
for i in range(len(model_data_list)):
y.append(model_data_list[i][j+ax*3][0])
y_low.append(model_data_list[i][j+ax*3][2])
y_high.append(model_data_list[i][j+ax*3][3])
y_total.append(y)
y_low_total.append(y_low)
y_high_total.append(y_high)
print()
print("user mean of mean prob: ", np.mean(y))
print("user mean of sd prob: ", np.std(y))
for i in range(3):
if args.plot_type == 'line':
axs[ax].plot(range(user), y_total[i], color=color[i], label=legend[i])
axs[ax].fill_between(range(user), y_low_total[i], y_high_total[i], color=color[i],alpha=0.3 )
if args.plot_type == 'bar':
if args.task_type == 'new_test_task':
widths = [-0.25, 0, 0.25]
yerror = [np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])]
axs[0].text(-0.19, 0.9, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 36)
axs[ax].bar(np.arange(user)+widths[i],y_total[i], width=0.2, yerr=yerror, color=color[i], label=legend[i])
axs[ax].tick_params(axis='x', which='both', pad=15, length=0)
plt.xticks(range(user), range(1,user+1))
axs[ax].set_ylabel('prob', fontsize= 36) # was 22
axs[ax].text(-0.19, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 36, color=color[ax])
plt.xlabel('user', fontsize= 40) # was 22
for k, x in enumerate(np.arange(user)+widths[i]):
y = y_total[i][k] + yerror[1][k]
axs[ax].annotate(f'{y_total[i][k]:.2f}', (x, y), textcoords='offset points', xytext=(-15, 3), fontsize=14)
if args.task_type == 'test_task':
if not args.plot_user_list:
axs[0].text(-0.19, 0.9, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 36)
axs[ax].errorbar(np.arange(user)+widths[i],y_total[i], yerr=[np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])],markerfacecolor=color[i], ecolor=color[i], markeredgecolor=color[i], label=legend[i],fmt='.k')
axs[ax].tick_params(axis='x', which='both', pad=15, length=0)
plt.xticks(range(user)[::5], range(1,user+1)[::5])
axs[ax].set_ylabel('prob', fontsize= 36)
axs[ax].text(-0.19, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 36, color=color[ax])
plt.xlabel('user', fontsize= 40)
else:
axs[0].text(-0.19, 0.9, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 36)
axs[ax].errorbar(np.arange(len(model_data_list))+widths[i],y_total[i], yerr=[np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])],markerfacecolor=color[i], ecolor=color[i], markeredgecolor=color[i], label=legend[i],fmt='.k')
axs[ax].tick_params(axis='x', which='both', pad=15, length=0)
plt.xticks(range(len(model_data_list)), user_table)
axs[ax].set_ylabel('prob', fontsize= 36)
#axs[ax].set_yticks(range(0.0,1.0, 0.25))
axs[ax].text(-0.19, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 36, color=color[ax])
plt.xlabel('user', fontsize= 40)
axs[ax].tick_params(axis='both', which='major', labelsize=30)
handles, labels = axs[0].get_legend_handles_labels()
plt.ylim([0, 1.08])
plt.tight_layout()
pathlib.Path("result/"+args.task_type+"/user"+str(user)+"/"+args.loss_type+"/figure/").mkdir(parents=True, exist_ok=True)
if args.task_type == 'test_task':
if not args.plot_user_list:
plt.savefig("result/"+args.task_type+"/user"+str(user)+"/"+args.loss_type+"/figure/"+"N"+ str(N)+"_"+args.model_type+"_rate_"+str(rate)+"_"+args.plot_type+"_test_set_1.png", bbox_inches='tight')
else:
plt.savefig("result/"+args.task_type+"/user"+str(user)+"/"+args.loss_type+"/figure/"+"N"+ str(N)+"_"+args.model_type+"_rate_"+str(rate)+"_"+args.plot_type+"_test_set_1_user_analysis.png", bbox_inches='tight')
if args.task_type == 'new_test_task':
plt.savefig("result/"+args.ask_type+"/user"+str(user)+"/"+args.loss_type+"/figure/"+"N"+ str(N)+"_"+args.model_type+"_rate_"+str(rate)+"_"+args.plot_type+"_test_set_2.png", bbox_inches='tight')
plt.show()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='')
parser.add_argument('--loss_type', type=str, default='ce')
parser.add_argument('--model_type', type=str, default="lstmlast" )
parser.add_argument('--plot_type', type=str, default='bar') # bar or line
parser.add_argument('--task_type', type=str, default='test_task')
parser.add_argument('--plot_user_list', action='store_true') # plot user_table or not
args = parser.parse_args()
main(args)