89 lines
3.6 KiB
Python
89 lines
3.6 KiB
Python
|
import numpy as np
|
||
|
from numpy import genfromtxt
|
||
|
import matplotlib.pyplot as plt
|
||
|
import argparse
|
||
|
|
||
|
def main():
|
||
|
parser = argparse.ArgumentParser(description='')
|
||
|
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
|
||
|
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
|
||
|
parser.add_argument('--hidden_size', type=int, default=128, help='hidden_size')
|
||
|
parser.add_argument('--model_type', type=str, default='lstmlast', help='model type')
|
||
|
parser.add_argument('--N', type=int, default=1, help='number of sequence for inference')
|
||
|
parser.add_argument('--user', type=int, default=1, help='number of users')
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
width = [-0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3]
|
||
|
|
||
|
rate_user_data_list = []
|
||
|
for r in range(0,101,10): # rate = range(0,101,10)
|
||
|
# read data
|
||
|
user_data_list = []
|
||
|
for i in range(args.user):
|
||
|
model_data_list = []
|
||
|
path = "result/"+"N"+ str(args.N) + "/" + args.model_type + "bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_result_user" + str(i) + "_rate__" + str(r) +".csv"
|
||
|
data = genfromtxt(path, delimiter=',', skip_header =1)
|
||
|
for j in range(7):
|
||
|
data_temp = data[[1+7*j+j,2+7*j+j,3+7*j+j,4+7*j+j,5+7*j+j,6+7*j+j,7+7*j+j],:][:,[2,4,6,7]]
|
||
|
model_data_list.append(data_temp)
|
||
|
model_data_list = np.concatenate(model_data_list, axis=0)
|
||
|
if i == 4:
|
||
|
print(model_data_list.shape, model_data_list)
|
||
|
user_data_list.append(model_data_list)
|
||
|
model_data_list_total = np.stack(user_data_list)
|
||
|
print(model_data_list_total.shape)
|
||
|
mean_user_data = np.mean(model_data_list_total,axis=0)
|
||
|
print(mean_user_data.shape)
|
||
|
rate_user_data_list.append(mean_user_data)
|
||
|
|
||
|
|
||
|
color = ['royalblue', 'lightgreen', 'tomato', 'indigo', 'plum', 'darkorange', 'blue']
|
||
|
legend = ['rule 1', 'rule 2', 'rule 3', 'rule 4', 'rule 5', 'rule 6', 'rule 7']
|
||
|
fig, axs = plt.subplots(7, sharex=True, sharey=True)
|
||
|
fig.set_figheight(10) # all sample rate: 10; 3 row: 8
|
||
|
fig.set_figwidth(20)
|
||
|
|
||
|
for ax in range(7):
|
||
|
y_total = []
|
||
|
y_low_total = []
|
||
|
y_high_total = []
|
||
|
for j in range(7):
|
||
|
y= []
|
||
|
y_low = []
|
||
|
y_high = []
|
||
|
for i in range(len(rate_user_data_list)):
|
||
|
y.append(rate_user_data_list[i][j+ax*7][0])
|
||
|
y_low.append(rate_user_data_list[i][j+ax*7][2])
|
||
|
y_high.append(rate_user_data_list[i][j+ax*7][3])
|
||
|
y_total.append(y)
|
||
|
y_low_total.append(y_low)
|
||
|
y_high_total.append(y_high)
|
||
|
print()
|
||
|
print("user mean of mean prob: ", np.mean(y))
|
||
|
print("user mean of sd prob: ", np.std(y))
|
||
|
|
||
|
for i in range(7):
|
||
|
axs[ax].plot(range(0,101,10), y_total[i], color=color[i], label=legend[i])
|
||
|
axs[ax].fill_between(range(0,101,10), y_low_total[i], y_high_total[i], color=color[i],alpha=0.3 )
|
||
|
axs[ax].set_xticks(range(0,101,10))
|
||
|
axs[ax].set_ylabel('prob', fontsize=20)
|
||
|
|
||
|
|
||
|
axs[0].text(-0.125, 0.9, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 20)
|
||
|
axs[ax].text(-0.125, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 20, color=color[ax])
|
||
|
axs[ax].tick_params(axis='both', which='major', labelsize=16)
|
||
|
|
||
|
|
||
|
plt.xlabel('Percentage of occurred actions in one action sequence', fontsize= 20)
|
||
|
handles, labels = axs[0].get_legend_handles_labels()
|
||
|
|
||
|
plt.xlim([0, 101])
|
||
|
plt.ylim([0, 1])
|
||
|
|
||
|
plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_rate_full.png", bbox_inches='tight')
|
||
|
|
||
|
plt.show()
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|