InferringIntention/keyboard_and_mouse/stan/plot_user_all_individual_chiw.py
2024-03-24 23:42:27 +01:00

93 lines
4.4 KiB
Python

import numpy as np
from numpy import genfromtxt
import matplotlib.pyplot as plt
model_type = "lstmlast_"
batch_size = 8
lr = 1e-4
hidden_size = 128
N = 1
user = 5
plot_type = 'bar' # line bar
act_series = 5
# read data
plot_list = []
for act in range(1,act_series+1):
user_data_list = []
for i in range(user):
model_data_list = []
path = "result/"+"N"+ str(N) + "/" + model_type + "bs_" + str(batch_size) + '_lr_' + str(lr) + '_hidden_size_' + str(hidden_size) + '_N' + str(N) + "_result_user" + str(i) + "_rate__100" + "_act_" + str(act) +".csv"
data = genfromtxt(path, delimiter=',', skip_header =1)
for j in range(7):
data_temp = data[[1+7*j+j,2+7*j+j,3+7*j+j,4+7*j+j,5+7*j+j,6+7*j+j,7+7*j+j],:][:,[2,4,6,7]]
model_data_list.append(data_temp)
model_data_list = np.concatenate(model_data_list, axis=0)
print(model_data_list.shape)
user_data_list.append(model_data_list)
color = ['royalblue', 'lightgreen', 'tomato', 'indigo', 'plum', 'darkorange', 'blue']
legend = ['rule 1', 'rule 2', 'rule 3', 'rule 4', 'rule 5', 'rule 6', 'rule 7']
fig, axs = plt.subplots(7, sharex=True, sharey=True)
fig.set_figheight(14)
fig.set_figwidth(25)
for ax in range(7):
y_total = []
y_low_total = []
y_high_total = []
for j in range(7):
y= []
y_low = []
y_high = []
for i in range(len(user_data_list)):
y.append(user_data_list[i][j+ax*7][0])
y_low.append(user_data_list[i][j+ax*7][2])
y_high.append(user_data_list[i][j+ax*7][3])
y_total.append(y)
y_low_total.append(y_low)
y_high_total.append(y_high)
print()
print(legend[ax])
print("user mean of mean prob: ", np.mean(y))
print("user mean of sd prob: ", np.std(y))
for i in range(7):
if plot_type == 'line':
axs[ax].plot(range(user), y_total[i], color=color[i], label=legend[i])
axs[ax].fill_between(range(user), y_low_total[i], y_high_total[i], color=color[i],alpha=0.3 )
if plot_type == 'bar':
width = [-0.36, -0.24, -0.12, 0, 0.12, 0.24, 0.36]
yerror = [np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])]
axs[ax].bar(np.arange(user)+width[i], y_total[i], width=0.08, yerr=[np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])], label=legend[i], color=color[i])
axs[ax].tick_params(axis='x', which='both', length=0)
axs[ax].set_ylabel('prob', fontsize=26) # was 22,
axs[ax].set_title(legend[ax], color=color[ax], fontsize=26)
for k,x in enumerate(np.arange(user)+width[i]):
y = y_total[i][k] + yerror[1][k]
axs[ax].annotate(f'{y_total[i][k]:.2f}', (x, y), textcoords='offset points', xytext=(-18,3), fontsize=16) #was 16
#axs[0].text(-0.17, 1.2, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 46) # was -0.1 0.9 25
#axs[ax].text(-0.17, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 46, color=color[ax]) # was 25
axs[ax].tick_params(axis='both', which='major', labelsize=18) # was 18
for tick in axs[ax].xaxis.get_major_ticks():
tick.set_pad(20)
plt.xticks(range(user),('1', '2', '3', '4', '5'))
plt.xlabel('user', fontsize= 26) # was 22
handles, labels = axs[0].get_legend_handles_labels()
plt.ylim([0, 1.2])
plt.tight_layout()
if plot_type == 'line':
plt.savefig("figure/"+"N"+ str(N) + "_ "+ model_type + "_bs_" + str(batch_size) + '_lr_' + str(lr) + '_hidden_size_' + str(hidden_size) + '_N' + str(N) + "_act_series" + str(act) + "_line_all_individual_chiw.png", bbox_inches='tight')
if plot_type == 'bar':
plt.savefig("figure/"+"N"+ str(N) + "_ "+ model_type + "_bs_" + str(batch_size) + '_lr_' + str(lr) + '_hidden_size_' + str(hidden_size) + '_N' + str(N) + "_act_series" + str(act) + "_bar_all_individual_chiw.png", bbox_inches='tight')
#plt.show()
if plot_type == 'line':
plt.savefig("figure/"+"N"+ str(N) + "_ "+ model_type + "_bs_" + str(batch_size) + '_lr_' + str(lr) + '_hidden_size_' + str(hidden_size) + '_N' + str(N) + "_act_series" + str(act) + "_line_all_individual_chiw.eps", bbox_inches='tight', format='eps')
if plot_type == 'bar':
plt.savefig("figure/"+"N"+ str(N) + "_ "+ model_type + "_bs_" + str(batch_size) + '_lr_' + str(lr) + '_hidden_size_' + str(hidden_size) + '_N' + str(N) + "_act_series" + str(act) + "_bar_all_individual_chiw.eps", bbox_inches='tight', format='eps')