import numpy as np from numpy import genfromtxt import matplotlib.pyplot as plt import argparse import pathlib def main(args): if args.task_type == 'new_test_task': user = 9 N = 1 if args.task_type == 'test_task': user = 92 N = 1 rate = 100 widths = [-0.1, 0, 0.1] user_table = [6, 13, 15, 19, 20, 23, 27, 30, 33, 44, 46, 49, 50, 51, 52, 53, 54, 56, 65, 71, 84] # read data model_data_list = [] user_list = [] if not args.plot_user_list: for i in range(user): path = "result/"+args.task_type+"/user"+str(user)+"/"+args.loss_type+"/N"+ str(N) + "/" + args.model_type + "_N" + str(N) + "_result_" + str(rate) + "_user" + str(i) +".csv" data = genfromtxt(path, delimiter=',', skip_header =1) data = data[[1,2,3,5,6,7,9,10,11],:][:,[2,4,6,7]] model_data_list.append(data) if args.task_type == 'test_task': user_list.append(np.transpose(data[:,[0]])) else: for i in range(user): for t in user_table: if t == i+1: path = "result/"+args.task_type+"/user"+str(user)+"/"+args.loss_type+"/N"+ str(N) + "/" + args.model_type + "_N" + str(N) + "_result_" + str(rate) + "_user" + str(i) +".csv" data = genfromtxt(path, delimiter=',', skip_header =1) data = data[[1,2,3,5,6,7,9,10,11],:][:,[2,4,6,7]] model_data_list.append(data) user_list.append(np.transpose(data[:,[0]])) color = ['royalblue', 'lightgreen', 'tomato'] legend = ['put fridge', 'put\n dishwasher', 'read book'] fig, axs = plt.subplots(3, sharex=True, sharey=True) fig.set_figheight(10) # all sample rate: 10; 3 row: 8 fig.set_figwidth(20) for ax in range(3): y_total = [] y_low_total = [] y_high_total = [] for j in range(3): y= [] y_low = [] y_high = [] for i in range(len(model_data_list)): y.append(model_data_list[i][j+ax*3][0]) y_low.append(model_data_list[i][j+ax*3][2]) y_high.append(model_data_list[i][j+ax*3][3]) y_total.append(y) y_low_total.append(y_low) y_high_total.append(y_high) print() print("user mean of mean prob: ", np.mean(y)) print("user mean of sd prob: ", np.std(y)) for i in range(3): if args.plot_type == 'line': axs[ax].plot(range(user), y_total[i], color=color[i], label=legend[i]) axs[ax].fill_between(range(user), y_low_total[i], y_high_total[i], color=color[i],alpha=0.3 ) if args.plot_type == 'bar': if args.task_type == 'new_test_task': widths = [-0.25, 0, 0.25] yerror = [np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])] axs[0].text(-0.19, 0.9, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 36) axs[ax].bar(np.arange(user)+widths[i],y_total[i], width=0.2, yerr=yerror, color=color[i], label=legend[i]) axs[ax].tick_params(axis='x', which='both', pad=15, length=0) plt.xticks(range(user), range(1,user+1)) axs[ax].set_ylabel('prob', fontsize= 36) # was 22 axs[ax].text(-0.19, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 36, color=color[ax]) plt.xlabel('user', fontsize= 40) # was 22 for k, x in enumerate(np.arange(user)+widths[i]): y = y_total[i][k] + yerror[1][k] axs[ax].annotate(f'{y_total[i][k]:.2f}', (x, y), textcoords='offset points', xytext=(-15, 3), fontsize=14) if args.task_type == 'test_task': if not args.plot_user_list: axs[0].text(-0.19, 0.9, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 36) axs[ax].errorbar(np.arange(user)+widths[i],y_total[i], yerr=[np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])],markerfacecolor=color[i], ecolor=color[i], markeredgecolor=color[i], label=legend[i],fmt='.k') axs[ax].tick_params(axis='x', which='both', pad=15, length=0) plt.xticks(range(user)[::5], range(1,user+1)[::5]) axs[ax].set_ylabel('prob', fontsize= 36) axs[ax].text(-0.19, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 36, color=color[ax]) plt.xlabel('user', fontsize= 40) else: axs[0].text(-0.19, 0.9, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 36) axs[ax].errorbar(np.arange(len(model_data_list))+widths[i],y_total[i], yerr=[np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])],markerfacecolor=color[i], ecolor=color[i], markeredgecolor=color[i], label=legend[i],fmt='.k') axs[ax].tick_params(axis='x', which='both', pad=15, length=0) plt.xticks(range(len(model_data_list)), user_table) axs[ax].set_ylabel('prob', fontsize= 36) #axs[ax].set_yticks(range(0.0,1.0, 0.25)) axs[ax].text(-0.19, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 36, color=color[ax]) plt.xlabel('user', fontsize= 40) axs[ax].tick_params(axis='both', which='major', labelsize=30) handles, labels = axs[0].get_legend_handles_labels() plt.ylim([0, 1.08]) plt.tight_layout() pathlib.Path("result/"+args.task_type+"/user"+str(user)+"/"+args.loss_type+"/figure/").mkdir(parents=True, exist_ok=True) if args.task_type == 'test_task': if not args.plot_user_list: plt.savefig("result/"+args.task_type+"/user"+str(user)+"/"+args.loss_type+"/figure/"+"N"+ str(N)+"_"+args.model_type+"_rate_"+str(rate)+"_"+args.plot_type+"_test_set_1.png", bbox_inches='tight') else: plt.savefig("result/"+args.task_type+"/user"+str(user)+"/"+args.loss_type+"/figure/"+"N"+ str(N)+"_"+args.model_type+"_rate_"+str(rate)+"_"+args.plot_type+"_test_set_1_user_analysis.png", bbox_inches='tight') if args.task_type == 'new_test_task': plt.savefig("result/"+args.ask_type+"/user"+str(user)+"/"+args.loss_type+"/figure/"+"N"+ str(N)+"_"+args.model_type+"_rate_"+str(rate)+"_"+args.plot_type+"_test_set_2.png", bbox_inches='tight') plt.show() if __name__ == '__main__': parser = argparse.ArgumentParser(description='') parser.add_argument('--loss_type', type=str, default='ce') parser.add_argument('--model_type', type=str, default="lstmlast" ) parser.add_argument('--plot_type', type=str, default='bar') # bar or line parser.add_argument('--task_type', type=str, default='test_task') parser.add_argument('--plot_user_list', action='store_true') # plot user_table or not args = parser.parse_args() main(args)