import numpy as np from numpy import genfromtxt import matplotlib.pyplot as plt import argparse def main(): parser = argparse.ArgumentParser(description='') parser.add_argument('--batch_size', type=int, default=8, help='batch size') parser.add_argument('--lr', type=float, default=1e-4, help='learning rate') parser.add_argument('--hidden_size', type=int, default=128, help='hidden_size') parser.add_argument('--model_type', type=str, default='lstmlast', help='model type') parser.add_argument('--N', type=int, default=1, help='number of sequence for inference') parser.add_argument('--user', type=int, default=1, help='number of users') args = parser.parse_args() plot_type = 'bar' # line bar act_series = 5 # read data plot_list = [] for act in range(1,act_series+1): user_data_list = [] for i in range(args.user): model_data_list = [] path = "result/"+"N"+ str(args.N) + "/" + args.model_type + "bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_result_user" + str(i) + "_rate__100" + "_act_" + str(act) +".csv" data = genfromtxt(path, delimiter=',', skip_header =1) for j in range(7): data_temp = data[[1+7*j+j,2+7*j+j,3+7*j+j,4+7*j+j,5+7*j+j,6+7*j+j,7+7*j+j],:][:,[2,4,6,7]] model_data_list.append(data_temp) model_data_list = np.concatenate(model_data_list, axis=0) print(model_data_list.shape) user_data_list.append(model_data_list) color = ['royalblue', 'lightgreen', 'tomato', 'indigo', 'plum', 'darkorange', 'blue'] legend = ['rule 1', 'rule 2', 'rule 3', 'rule 4', 'rule 5', 'rule 6', 'rule 7'] fig, axs = plt.subplots(7, sharex=True, sharey=True) fig.set_figheight(14) fig.set_figwidth(25) for ax in range(7): y_total = [] y_low_total = [] y_high_total = [] for j in range(7): y= [] y_low = [] y_high = [] for i in range(len(user_data_list)): y.append(user_data_list[i][j+ax*7][0]) y_low.append(user_data_list[i][j+ax*7][2]) y_high.append(user_data_list[i][j+ax*7][3]) y_total.append(y) y_low_total.append(y_low) y_high_total.append(y_high) print() print("user mean of mean prob: ", np.mean(y)) print("user mean of sd prob: ", np.std(y)) for i in range(7): if plot_type == 'line': axs[ax].plot(range(args.user), y_total[i], color=color[i], label=legend[i]) axs[ax].fill_between(range(args.user), y_low_total[i], y_high_total[i], color=color[i],alpha=0.3 ) if plot_type == 'bar': width = [-0.36, -0.24, -0.12, 0, 0.12, 0.24, 0.36] yerror = [np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])] axs[ax].bar(np.arange(args.user)+width[i], y_total[i], width=0.08, yerr=[np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])], label=legend[i], color=color[i]) axs[ax].tick_params(axis='x', which='both', length=0) axs[ax].set_ylabel('prob', fontsize=36) # was 22, for k,x in enumerate(np.arange(args.user)+width[i]): y = y_total[i][k] + yerror[1][k] axs[ax].annotate(f'{y_total[i][k]:.2f}', (x, y), textcoords='offset points', xytext=(-18,3), fontsize=16) #was 16 axs[0].text(-0.17, 1.2, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 46) # was -0.1 0.9 25 axs[ax].text(-0.17, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 46, color=color[ax]) # was 25 axs[ax].tick_params(axis='both', which='major', labelsize=42) # was 18 for tick in axs[ax].xaxis.get_major_ticks(): tick.set_pad(20) plt.xticks(range(args.user),('1', '2', '3', '4', '5')) plt.xlabel('user', fontsize= 42) # was 22 handles, labels = axs[0].get_legend_handles_labels() plt.ylim([0, 1]) plt.tight_layout() if plot_type == 'line': plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_act_series" + str(act) + "_line_all_individual.png", bbox_inches='tight') if plot_type == 'bar': plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_act_series" + str(act) + "_bar_all_individual.png", bbox_inches='tight') if plot_type == 'line': plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_act_series" + str(act) + "_line_all_individual.eps", bbox_inches='tight', format='eps') if plot_type == 'bar': plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_act_series" + str(act) + "_bar_all_individual.eps", bbox_inches='tight', format='eps') #plt.show() if __name__ == '__main__': main()