first commit
This commit is contained in:
commit
83b04e2133
109 changed files with 12081 additions and 0 deletions
88
keyboard_and_mouse/stan/plot_user.py
Normal file
88
keyboard_and_mouse/stan/plot_user.py
Normal file
|
@ -0,0 +1,88 @@
|
|||
import numpy as np
|
||||
from numpy import genfromtxt
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
|
||||
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
|
||||
parser.add_argument('--hidden_size', type=int, default=128, help='hidden_size')
|
||||
parser.add_argument('--model_type', type=str, default='lstmlast', help='model type')
|
||||
parser.add_argument('--N', type=int, default=1, help='number of sequence for inference')
|
||||
parser.add_argument('--user', type=int, default=1, help='number of users')
|
||||
|
||||
args = parser.parse_args()
|
||||
plot_type = 'bar' # line bar
|
||||
width = [-0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3]
|
||||
|
||||
# read data
|
||||
user_data_list = []
|
||||
for i in range(args.user):
|
||||
model_data_list = []
|
||||
path = "result/"+"N"+ str(args.N) + "/" + args.model_type + "bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_result_user" + str(i) +".csv"
|
||||
data = genfromtxt(path, delimiter=',', skip_header =1)
|
||||
for j in range(7):
|
||||
data_temp = data[[1+7*j+j,2+7*j+j,3+7*j+j,4+7*j+j,5+7*j+j,6+7*j+j,7+7*j+j],:][:,[2,4,6,7]]
|
||||
model_data_list.append(data_temp)
|
||||
model_data_list = np.concatenate(model_data_list, axis=0)
|
||||
user_data_list.append(model_data_list)
|
||||
|
||||
color = ['royalblue', 'lightgreen', 'tomato', 'indigo', 'plum', 'darkorange', 'blue']
|
||||
legend = ['rule 1', 'rule 2', 'rule 3', 'rule 4', 'rule 5', 'rule 6', 'rule 7']
|
||||
fig, axs = plt.subplots(7, sharex=True, sharey=True)
|
||||
fig.set_figheight(14)
|
||||
fig.set_figwidth(25)
|
||||
|
||||
for ax in range(7):
|
||||
y_total = []
|
||||
y_low_total = []
|
||||
y_high_total = []
|
||||
for j in range(7):
|
||||
y= []
|
||||
y_low = []
|
||||
y_high = []
|
||||
for i in range(len(user_data_list)):
|
||||
y.append(user_data_list[i][j+ax*7][0])
|
||||
y_low.append(user_data_list[i][j+ax*7][2])
|
||||
y_high.append(user_data_list[i][j+ax*7][3])
|
||||
y_total.append(y)
|
||||
y_low_total.append(y_low)
|
||||
y_high_total.append(y_high)
|
||||
print()
|
||||
print("user mean of mean prob: ", np.mean(y))
|
||||
print("user mean of sd prob: ", np.std(y))
|
||||
|
||||
for i in range(7):
|
||||
if plot_type == 'line':
|
||||
axs[ax].plot(range(args.user), y_total[i], color=color[i], label=legend[i])
|
||||
axs[ax].fill_between(range(args.user), y_low_total[i], y_high_total[i], color=color[i],alpha=0.3 )
|
||||
if plot_type == 'bar':
|
||||
width = [-0.36, -0.24, -0.12, 0, 0.12, 0.24, 0.36]
|
||||
yerror = [np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])]
|
||||
axs[ax].bar(np.arange(args.user)+width[i], y_total[i], width=0.08, yerr=yerror, label=legend[i], color=color[i])
|
||||
axs[ax].tick_params(axis='x', which='both', length=0)
|
||||
axs[ax].set_ylabel('prob', fontsize=22)
|
||||
for k,x in enumerate(np.arange(args.user)+width[i]):
|
||||
y = y_total[i][k] + yerror[1][k]
|
||||
axs[ax].annotate(f'{y_total[i][k]:.2f}', (x, y), textcoords='offset points', xytext=(-18,3), fontsize=16)
|
||||
|
||||
axs[0].text(-0.1, 0.9, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 22) # all: -0.3,0.5 3rows: -0.5,0.5
|
||||
axs[ax].text(-0.1, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 22, color=color[ax])
|
||||
axs[ax].tick_params(axis='both', which='major', labelsize=16)
|
||||
|
||||
plt.xticks(range(args.user),('1', '2', '3', '4', '5'))
|
||||
plt.xlabel('user', fontsize= 22)
|
||||
handles, labels = axs[0].get_legend_handles_labels()
|
||||
plt.ylim([0, 1])
|
||||
Path("figure").mkdir(parents=True, exist_ok=True)
|
||||
if plot_type == 'line':
|
||||
plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_line.png", bbox_inches='tight')
|
||||
if plot_type == 'bar':
|
||||
plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_bar.png", bbox_inches='tight')
|
||||
plt.show()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
8
keyboard_and_mouse/stan/plot_user.sh
Normal file
8
keyboard_and_mouse/stan/plot_user.sh
Normal file
|
@ -0,0 +1,8 @@
|
|||
python3 plot_user.py \
|
||||
--model_type lstmlast_ \
|
||||
--batch_size 8 \
|
||||
--lr 1e-4 \
|
||||
--hidden_size 128 \
|
||||
--N 1 \
|
||||
--user 5
|
||||
|
99
keyboard_and_mouse/stan/plot_user_all_individual.py
Normal file
99
keyboard_and_mouse/stan/plot_user_all_individual.py
Normal file
|
@ -0,0 +1,99 @@
|
|||
import numpy as np
|
||||
from numpy import genfromtxt
|
||||
import matplotlib.pyplot as plt
|
||||
import argparse
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
|
||||
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
|
||||
parser.add_argument('--hidden_size', type=int, default=128, help='hidden_size')
|
||||
parser.add_argument('--model_type', type=str, default='lstmlast', help='model type')
|
||||
parser.add_argument('--N', type=int, default=1, help='number of sequence for inference')
|
||||
parser.add_argument('--user', type=int, default=1, help='number of users')
|
||||
|
||||
args = parser.parse_args()
|
||||
plot_type = 'bar' # line bar
|
||||
act_series = 5
|
||||
|
||||
# read data
|
||||
plot_list = []
|
||||
for act in range(1,act_series+1):
|
||||
user_data_list = []
|
||||
for i in range(args.user):
|
||||
model_data_list = []
|
||||
path = "result/"+"N"+ str(args.N) + "/" + args.model_type + "bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_result_user" + str(i) + "_rate__100" + "_act_" + str(act) +".csv"
|
||||
data = genfromtxt(path, delimiter=',', skip_header =1)
|
||||
for j in range(7):
|
||||
data_temp = data[[1+7*j+j,2+7*j+j,3+7*j+j,4+7*j+j,5+7*j+j,6+7*j+j,7+7*j+j],:][:,[2,4,6,7]]
|
||||
model_data_list.append(data_temp)
|
||||
model_data_list = np.concatenate(model_data_list, axis=0)
|
||||
print(model_data_list.shape)
|
||||
user_data_list.append(model_data_list)
|
||||
|
||||
color = ['royalblue', 'lightgreen', 'tomato', 'indigo', 'plum', 'darkorange', 'blue']
|
||||
legend = ['rule 1', 'rule 2', 'rule 3', 'rule 4', 'rule 5', 'rule 6', 'rule 7']
|
||||
fig, axs = plt.subplots(7, sharex=True, sharey=True)
|
||||
fig.set_figheight(14)
|
||||
fig.set_figwidth(25)
|
||||
|
||||
for ax in range(7):
|
||||
y_total = []
|
||||
y_low_total = []
|
||||
y_high_total = []
|
||||
for j in range(7):
|
||||
y= []
|
||||
y_low = []
|
||||
y_high = []
|
||||
for i in range(len(user_data_list)):
|
||||
y.append(user_data_list[i][j+ax*7][0])
|
||||
y_low.append(user_data_list[i][j+ax*7][2])
|
||||
y_high.append(user_data_list[i][j+ax*7][3])
|
||||
y_total.append(y)
|
||||
y_low_total.append(y_low)
|
||||
y_high_total.append(y_high)
|
||||
print()
|
||||
print("user mean of mean prob: ", np.mean(y))
|
||||
print("user mean of sd prob: ", np.std(y))
|
||||
|
||||
for i in range(7):
|
||||
if plot_type == 'line':
|
||||
axs[ax].plot(range(args.user), y_total[i], color=color[i], label=legend[i])
|
||||
axs[ax].fill_between(range(args.user), y_low_total[i], y_high_total[i], color=color[i],alpha=0.3 )
|
||||
if plot_type == 'bar':
|
||||
width = [-0.36, -0.24, -0.12, 0, 0.12, 0.24, 0.36]
|
||||
yerror = [np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])]
|
||||
axs[ax].bar(np.arange(args.user)+width[i], y_total[i], width=0.08, yerr=[np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])], label=legend[i], color=color[i])
|
||||
axs[ax].tick_params(axis='x', which='both', length=0)
|
||||
axs[ax].set_ylabel('prob', fontsize=36) # was 22,
|
||||
for k,x in enumerate(np.arange(args.user)+width[i]):
|
||||
y = y_total[i][k] + yerror[1][k]
|
||||
axs[ax].annotate(f'{y_total[i][k]:.2f}', (x, y), textcoords='offset points', xytext=(-18,3), fontsize=16) #was 16
|
||||
|
||||
axs[0].text(-0.17, 1.2, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 46) # was -0.1 0.9 25
|
||||
axs[ax].text(-0.17, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 46, color=color[ax]) # was 25
|
||||
axs[ax].tick_params(axis='both', which='major', labelsize=42) # was 18
|
||||
for tick in axs[ax].xaxis.get_major_ticks():
|
||||
tick.set_pad(20)
|
||||
|
||||
plt.xticks(range(args.user),('1', '2', '3', '4', '5'))
|
||||
plt.xlabel('user', fontsize= 42) # was 22
|
||||
handles, labels = axs[0].get_legend_handles_labels()
|
||||
|
||||
plt.ylim([0, 1])
|
||||
plt.tight_layout()
|
||||
if plot_type == 'line':
|
||||
plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_act_series" + str(act) + "_line_all_individual.png", bbox_inches='tight')
|
||||
if plot_type == 'bar':
|
||||
plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_act_series" + str(act) + "_bar_all_individual.png", bbox_inches='tight')
|
||||
|
||||
if plot_type == 'line':
|
||||
plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_act_series" + str(act) + "_line_all_individual.eps", bbox_inches='tight', format='eps')
|
||||
if plot_type == 'bar':
|
||||
plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_act_series" + str(act) + "_bar_all_individual.eps", bbox_inches='tight', format='eps')
|
||||
#plt.show()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
8
keyboard_and_mouse/stan/plot_user_all_individual.sh
Normal file
8
keyboard_and_mouse/stan/plot_user_all_individual.sh
Normal file
|
@ -0,0 +1,8 @@
|
|||
python3 plot_user_all_individual.py \
|
||||
--model_type lstmlast_ \
|
||||
--batch_size 8 \
|
||||
--lr 1e-4 \
|
||||
--hidden_size 128 \
|
||||
--N 1 \
|
||||
--user 5
|
||||
|
93
keyboard_and_mouse/stan/plot_user_all_individual_chiw.py
Normal file
93
keyboard_and_mouse/stan/plot_user_all_individual_chiw.py
Normal file
|
@ -0,0 +1,93 @@
|
|||
import numpy as np
|
||||
from numpy import genfromtxt
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
model_type = "lstmlast_"
|
||||
batch_size = 8
|
||||
lr = 1e-4
|
||||
hidden_size = 128
|
||||
N = 1
|
||||
user = 5
|
||||
plot_type = 'bar' # line bar
|
||||
act_series = 5
|
||||
|
||||
# read data
|
||||
plot_list = []
|
||||
for act in range(1,act_series+1):
|
||||
user_data_list = []
|
||||
for i in range(user):
|
||||
model_data_list = []
|
||||
path = "result/"+"N"+ str(N) + "/" + model_type + "bs_" + str(batch_size) + '_lr_' + str(lr) + '_hidden_size_' + str(hidden_size) + '_N' + str(N) + "_result_user" + str(i) + "_rate__100" + "_act_" + str(act) +".csv"
|
||||
data = genfromtxt(path, delimiter=',', skip_header =1)
|
||||
for j in range(7):
|
||||
data_temp = data[[1+7*j+j,2+7*j+j,3+7*j+j,4+7*j+j,5+7*j+j,6+7*j+j,7+7*j+j],:][:,[2,4,6,7]]
|
||||
model_data_list.append(data_temp)
|
||||
model_data_list = np.concatenate(model_data_list, axis=0)
|
||||
print(model_data_list.shape)
|
||||
user_data_list.append(model_data_list)
|
||||
|
||||
color = ['royalblue', 'lightgreen', 'tomato', 'indigo', 'plum', 'darkorange', 'blue']
|
||||
legend = ['rule 1', 'rule 2', 'rule 3', 'rule 4', 'rule 5', 'rule 6', 'rule 7']
|
||||
fig, axs = plt.subplots(7, sharex=True, sharey=True)
|
||||
fig.set_figheight(14)
|
||||
fig.set_figwidth(25)
|
||||
|
||||
for ax in range(7):
|
||||
y_total = []
|
||||
y_low_total = []
|
||||
y_high_total = []
|
||||
for j in range(7):
|
||||
y= []
|
||||
y_low = []
|
||||
y_high = []
|
||||
for i in range(len(user_data_list)):
|
||||
y.append(user_data_list[i][j+ax*7][0])
|
||||
y_low.append(user_data_list[i][j+ax*7][2])
|
||||
y_high.append(user_data_list[i][j+ax*7][3])
|
||||
y_total.append(y)
|
||||
y_low_total.append(y_low)
|
||||
y_high_total.append(y_high)
|
||||
print()
|
||||
print(legend[ax])
|
||||
print("user mean of mean prob: ", np.mean(y))
|
||||
print("user mean of sd prob: ", np.std(y))
|
||||
|
||||
for i in range(7):
|
||||
if plot_type == 'line':
|
||||
axs[ax].plot(range(user), y_total[i], color=color[i], label=legend[i])
|
||||
axs[ax].fill_between(range(user), y_low_total[i], y_high_total[i], color=color[i],alpha=0.3 )
|
||||
if plot_type == 'bar':
|
||||
width = [-0.36, -0.24, -0.12, 0, 0.12, 0.24, 0.36]
|
||||
yerror = [np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])]
|
||||
axs[ax].bar(np.arange(user)+width[i], y_total[i], width=0.08, yerr=[np.array(y_total[i])-np.array(y_low_total[i]), np.array(y_high_total[i])-np.array(y_total[i])], label=legend[i], color=color[i])
|
||||
axs[ax].tick_params(axis='x', which='both', length=0)
|
||||
axs[ax].set_ylabel('prob', fontsize=26) # was 22,
|
||||
axs[ax].set_title(legend[ax], color=color[ax], fontsize=26)
|
||||
for k,x in enumerate(np.arange(user)+width[i]):
|
||||
y = y_total[i][k] + yerror[1][k]
|
||||
axs[ax].annotate(f'{y_total[i][k]:.2f}', (x, y), textcoords='offset points', xytext=(-18,3), fontsize=16) #was 16
|
||||
|
||||
#axs[0].text(-0.17, 1.2, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 46) # was -0.1 0.9 25
|
||||
#axs[ax].text(-0.17, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 46, color=color[ax]) # was 25
|
||||
axs[ax].tick_params(axis='both', which='major', labelsize=18) # was 18
|
||||
for tick in axs[ax].xaxis.get_major_ticks():
|
||||
tick.set_pad(20)
|
||||
|
||||
plt.xticks(range(user),('1', '2', '3', '4', '5'))
|
||||
plt.xlabel('user', fontsize= 26) # was 22
|
||||
handles, labels = axs[0].get_legend_handles_labels()
|
||||
|
||||
plt.ylim([0, 1.2])
|
||||
plt.tight_layout()
|
||||
if plot_type == 'line':
|
||||
plt.savefig("figure/"+"N"+ str(N) + "_ "+ model_type + "_bs_" + str(batch_size) + '_lr_' + str(lr) + '_hidden_size_' + str(hidden_size) + '_N' + str(N) + "_act_series" + str(act) + "_line_all_individual_chiw.png", bbox_inches='tight')
|
||||
if plot_type == 'bar':
|
||||
plt.savefig("figure/"+"N"+ str(N) + "_ "+ model_type + "_bs_" + str(batch_size) + '_lr_' + str(lr) + '_hidden_size_' + str(hidden_size) + '_N' + str(N) + "_act_series" + str(act) + "_bar_all_individual_chiw.png", bbox_inches='tight')
|
||||
#plt.show()
|
||||
if plot_type == 'line':
|
||||
plt.savefig("figure/"+"N"+ str(N) + "_ "+ model_type + "_bs_" + str(batch_size) + '_lr_' + str(lr) + '_hidden_size_' + str(hidden_size) + '_N' + str(N) + "_act_series" + str(act) + "_line_all_individual_chiw.eps", bbox_inches='tight', format='eps')
|
||||
if plot_type == 'bar':
|
||||
plt.savefig("figure/"+"N"+ str(N) + "_ "+ model_type + "_bs_" + str(batch_size) + '_lr_' + str(lr) + '_hidden_size_' + str(hidden_size) + '_N' + str(N) + "_act_series" + str(act) + "_bar_all_individual_chiw.eps", bbox_inches='tight', format='eps')
|
||||
|
||||
|
||||
|
88
keyboard_and_mouse/stan/plot_user_length_10_steps.py
Normal file
88
keyboard_and_mouse/stan/plot_user_length_10_steps.py
Normal file
|
@ -0,0 +1,88 @@
|
|||
import numpy as np
|
||||
from numpy import genfromtxt
|
||||
import matplotlib.pyplot as plt
|
||||
import argparse
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
|
||||
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
|
||||
parser.add_argument('--hidden_size', type=int, default=128, help='hidden_size')
|
||||
parser.add_argument('--model_type', type=str, default='lstmlast', help='model type')
|
||||
parser.add_argument('--N', type=int, default=1, help='number of sequence for inference')
|
||||
parser.add_argument('--user', type=int, default=1, help='number of users')
|
||||
|
||||
args = parser.parse_args()
|
||||
width = [-0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3]
|
||||
|
||||
rate_user_data_list = []
|
||||
for r in range(0,101,10): # rate = range(0,101,10)
|
||||
# read data
|
||||
user_data_list = []
|
||||
for i in range(args.user):
|
||||
model_data_list = []
|
||||
path = "result/"+"N"+ str(args.N) + "/" + args.model_type + "bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_result_user" + str(i) + "_rate__" + str(r) +".csv"
|
||||
data = genfromtxt(path, delimiter=',', skip_header =1)
|
||||
for j in range(7):
|
||||
data_temp = data[[1+7*j+j,2+7*j+j,3+7*j+j,4+7*j+j,5+7*j+j,6+7*j+j,7+7*j+j],:][:,[2,4,6,7]]
|
||||
model_data_list.append(data_temp)
|
||||
model_data_list = np.concatenate(model_data_list, axis=0)
|
||||
if i == 4:
|
||||
print(model_data_list.shape, model_data_list)
|
||||
user_data_list.append(model_data_list)
|
||||
model_data_list_total = np.stack(user_data_list)
|
||||
print(model_data_list_total.shape)
|
||||
mean_user_data = np.mean(model_data_list_total,axis=0)
|
||||
print(mean_user_data.shape)
|
||||
rate_user_data_list.append(mean_user_data)
|
||||
|
||||
|
||||
color = ['royalblue', 'lightgreen', 'tomato', 'indigo', 'plum', 'darkorange', 'blue']
|
||||
legend = ['rule 1', 'rule 2', 'rule 3', 'rule 4', 'rule 5', 'rule 6', 'rule 7']
|
||||
fig, axs = plt.subplots(7, sharex=True, sharey=True)
|
||||
fig.set_figheight(10) # all sample rate: 10; 3 row: 8
|
||||
fig.set_figwidth(20)
|
||||
|
||||
for ax in range(7):
|
||||
y_total = []
|
||||
y_low_total = []
|
||||
y_high_total = []
|
||||
for j in range(7):
|
||||
y= []
|
||||
y_low = []
|
||||
y_high = []
|
||||
for i in range(len(rate_user_data_list)):
|
||||
y.append(rate_user_data_list[i][j+ax*7][0])
|
||||
y_low.append(rate_user_data_list[i][j+ax*7][2])
|
||||
y_high.append(rate_user_data_list[i][j+ax*7][3])
|
||||
y_total.append(y)
|
||||
y_low_total.append(y_low)
|
||||
y_high_total.append(y_high)
|
||||
print()
|
||||
print("user mean of mean prob: ", np.mean(y))
|
||||
print("user mean of sd prob: ", np.std(y))
|
||||
|
||||
for i in range(7):
|
||||
axs[ax].plot(range(0,101,10), y_total[i], color=color[i], label=legend[i])
|
||||
axs[ax].fill_between(range(0,101,10), y_low_total[i], y_high_total[i], color=color[i],alpha=0.3 )
|
||||
axs[ax].set_xticks(range(0,101,10))
|
||||
axs[ax].set_ylabel('prob', fontsize=20)
|
||||
|
||||
|
||||
axs[0].text(-0.125, 0.9, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 20)
|
||||
axs[ax].text(-0.125, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 20, color=color[ax])
|
||||
axs[ax].tick_params(axis='both', which='major', labelsize=16)
|
||||
|
||||
|
||||
plt.xlabel('Percentage of occurred actions in one action sequence', fontsize= 20)
|
||||
handles, labels = axs[0].get_legend_handles_labels()
|
||||
|
||||
plt.xlim([0, 101])
|
||||
plt.ylim([0, 1])
|
||||
|
||||
plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_rate_full.png", bbox_inches='tight')
|
||||
|
||||
plt.show()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
8
keyboard_and_mouse/stan/plot_user_length_10_steps.sh
Normal file
8
keyboard_and_mouse/stan/plot_user_length_10_steps.sh
Normal file
|
@ -0,0 +1,8 @@
|
|||
python3 plot_user_length_10_steps.py \
|
||||
--model_type lstmlast_ \
|
||||
--batch_size 8 \
|
||||
--lr 1e-4 \
|
||||
--hidden_size 128 \
|
||||
--N 1 \
|
||||
--user 5
|
||||
|
|
@ -0,0 +1,90 @@
|
|||
import numpy as np
|
||||
from numpy import genfromtxt
|
||||
import matplotlib.pyplot as plt
|
||||
import argparse
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
|
||||
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
|
||||
parser.add_argument('--hidden_size', type=int, default=128, help='hidden_size')
|
||||
parser.add_argument('--model_type', type=str, default='lstmlast', help='model type')
|
||||
parser.add_argument('--N', type=int, default=1, help='number of sequence for inference')
|
||||
parser.add_argument('--user', type=int, default=1, help='number of users')
|
||||
|
||||
args = parser.parse_args()
|
||||
width = [-0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3]
|
||||
act_series = 5
|
||||
|
||||
for act in range(1,act_series+1):
|
||||
rate_user_data_list = []
|
||||
for r in range(0,101,10):
|
||||
# read data
|
||||
user_data_list = []
|
||||
for i in range(args.user):
|
||||
model_data_list = []
|
||||
path = "result/"+"N"+ str(args.N) + "/" + args.model_type + "bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_result_user" + str(i) + "_rate__" + str(r) + "_act_" + str(act) +".csv"
|
||||
data = genfromtxt(path, delimiter=',', skip_header =1)
|
||||
for j in range(7):
|
||||
data_temp = data[[1+7*j+j,2+7*j+j,3+7*j+j,4+7*j+j,5+7*j+j,6+7*j+j,7+7*j+j],:][:,[2,4,6,7]]
|
||||
model_data_list.append(data_temp)
|
||||
model_data_list = np.concatenate(model_data_list, axis=0)
|
||||
user_data_list.append(model_data_list)
|
||||
model_data_list_total = np.stack(user_data_list)
|
||||
mean_user_data = np.mean(model_data_list_total,axis=0)
|
||||
rate_user_data_list.append(mean_user_data)
|
||||
|
||||
color = ['royalblue', 'lightgreen', 'tomato', 'indigo', 'plum', 'darkorange', 'blue']
|
||||
legend = ['rule 1', 'rule 2', 'rule 3', 'rule 4', 'rule 5', 'rule 6', 'rule 7']
|
||||
fig, axs = plt.subplots(7, sharex=True, sharey=True)
|
||||
fig.set_figheight(14) # was 10
|
||||
fig.set_figwidth(20)
|
||||
|
||||
for ax in range(7):
|
||||
y_total = []
|
||||
y_low_total = []
|
||||
y_high_total = []
|
||||
for j in range(7):
|
||||
y= []
|
||||
y_low = []
|
||||
y_high = []
|
||||
for i in range(len(rate_user_data_list)):
|
||||
y.append(rate_user_data_list[i][j+ax*7][0])
|
||||
y_low.append(rate_user_data_list[i][j+ax*7][2])
|
||||
y_high.append(rate_user_data_list[i][j+ax*7][3])
|
||||
y_total.append(y)
|
||||
y_low_total.append(y_low)
|
||||
y_high_total.append(y_high)
|
||||
print()
|
||||
print("user mean of mean prob: ", np.mean(y))
|
||||
print("user mean of sd prob: ", np.std(y))
|
||||
|
||||
for i in range(7):
|
||||
axs[ax].plot(range(0,101,10), y_total[i], color=color[i], label=legend[i])
|
||||
axs[ax].fill_between(range(0,101,10), y_low_total[i], y_high_total[i], color=color[i],alpha=0.3 )
|
||||
axs[ax].set_xticks(range(0,101,10))
|
||||
axs[ax].set_ylabel('prob', fontsize=26) # was 20
|
||||
|
||||
axs[0].text(-0.15, 1.2, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 36) # was -0.125 20
|
||||
axs[ax].text(-0.15, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 36, color=color[ax]) # -0.125 20
|
||||
axs[ax].tick_params(axis='y', which='major', labelsize=24) # was 16
|
||||
axs[ax].tick_params(axis='x', which='major', labelsize=24) # was 16
|
||||
for tick in axs[ax].xaxis.get_major_ticks():
|
||||
tick.set_pad(20)
|
||||
|
||||
plt.xlabel('Percentage of occurred actions in one action sequence', fontsize= 36) # was 20
|
||||
handles, labels = axs[0].get_legend_handles_labels()
|
||||
|
||||
plt.xlim([0, 101])
|
||||
plt.ylim([0, 1])
|
||||
|
||||
plt.savefig("figure/"+"N"+ str(args.N) + "_ "+ args.model_type + "_bs_" + str(args.batch_size) + '_lr_' + str(args.lr) + '_hidden_size_' + str(args.hidden_size) + '_N' + str(args.N) + "_act_series" + str(act) + "_rate_ful_all_individuall.png", bbox_inches='tight')
|
||||
|
||||
#plt.show()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
python3 plot_user_length_10_steps_all_individual.py \
|
||||
--model_type lstmlast_ \
|
||||
--batch_size 8 \
|
||||
--lr 1e-4 \
|
||||
--hidden_size 128 \
|
||||
--N 1 \
|
||||
--user 5
|
||||
|
|
@ -0,0 +1,90 @@
|
|||
import numpy as np
|
||||
from numpy import genfromtxt
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
model_type = "lstmlast_"
|
||||
batch_size = 8
|
||||
lr = 1e-4
|
||||
hidden_size = 128
|
||||
N = 1
|
||||
user = 5
|
||||
width = [-0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3]
|
||||
act_series = 5
|
||||
|
||||
for act in range(1,act_series+1):
|
||||
rate_user_data_list = []
|
||||
for r in range(0,101,10):
|
||||
# read data
|
||||
print(r)
|
||||
user_data_list = []
|
||||
for i in range(user):
|
||||
model_data_list = []
|
||||
path = "result/"+"N"+ str(N) + "/" + model_type + "bs_" + str(batch_size) + '_lr_' + str(lr) + '_hidden_size_' + str(hidden_size) + '_N' + str(N) + "_result_user" + str(i) + "_rate__" + str(r) + "_act_" + str(act) +".csv"
|
||||
data = genfromtxt(path, delimiter=',', skip_header =1)
|
||||
for j in range(7):
|
||||
data_temp = data[[1+7*j+j,2+7*j+j,3+7*j+j,4+7*j+j,5+7*j+j,6+7*j+j,7+7*j+j],:][:,[2,4,6,7]]
|
||||
model_data_list.append(data_temp)
|
||||
model_data_list = np.concatenate(model_data_list, axis=0)
|
||||
user_data_list.append(model_data_list)
|
||||
model_data_list_total = np.stack(user_data_list)
|
||||
print(model_data_list_total.shape)
|
||||
mean_user_data = np.mean(model_data_list_total,axis=0)
|
||||
print(mean_user_data.shape)
|
||||
rate_user_data_list.append(mean_user_data)
|
||||
|
||||
color = ['royalblue', 'lightgreen', 'tomato', 'indigo', 'plum', 'darkorange', 'blue']
|
||||
legend = ['rule 1', 'rule 2', 'rule 3', 'rule 4', 'rule 5', 'rule 6', 'rule 7']
|
||||
fig, axs = plt.subplots(7, sharex=True, sharey=True)
|
||||
fig.set_figheight(14) # was 10
|
||||
fig.set_figwidth(20)
|
||||
|
||||
for ax in range(7):
|
||||
y_total = []
|
||||
y_low_total = []
|
||||
y_high_total = []
|
||||
for j in range(7):
|
||||
y= []
|
||||
y_low = []
|
||||
y_high = []
|
||||
for i in range(len(rate_user_data_list)):
|
||||
y.append(rate_user_data_list[i][j+ax*7][0])
|
||||
y_low.append(rate_user_data_list[i][j+ax*7][2])
|
||||
y_high.append(rate_user_data_list[i][j+ax*7][3])
|
||||
y_total.append(y)
|
||||
y_low_total.append(y_low)
|
||||
y_high_total.append(y_high)
|
||||
print()
|
||||
print(legend[ax])
|
||||
print("user mean of mean prob: ", np.mean(y))
|
||||
print("user mean of sd prob: ", np.std(y))
|
||||
|
||||
for i in range(7):
|
||||
axs[ax].plot(range(0,101,10), y_total[i], color=color[i], label=legend[i])
|
||||
axs[ax].fill_between(range(0,101,10), y_low_total[i], y_high_total[i], color=color[i],alpha=0.3 )
|
||||
axs[ax].set_xticks(range(0,101,10))
|
||||
axs[ax].set_ylabel('prob', fontsize=26) # was 20
|
||||
axs[ax].set_title(legend[ax], color=color[ax], fontsize=26)
|
||||
|
||||
|
||||
#axs[0].text(-0.15, 1.2, 'True Intention:', horizontalalignment='center', verticalalignment='center', transform=axs[0].transAxes, fontsize= 36) # was -0.125 20
|
||||
#axs[ax].text(-0.15, 0.5, legend[ax], horizontalalignment='center', verticalalignment='center', transform=axs[ax].transAxes, fontsize= 36, color=color[ax]) # -0.125 20
|
||||
axs[ax].tick_params(axis='y', which='major', labelsize=18) # was 16
|
||||
axs[ax].tick_params(axis='x', which='major', labelsize=18) # was 16
|
||||
for tick in axs[ax].xaxis.get_major_ticks():
|
||||
tick.set_pad(20)
|
||||
|
||||
plt.xlabel('Percentage of occurred actions in one action sequence', fontsize= 26) # was 20
|
||||
handles, labels = axs[0].get_legend_handles_labels()
|
||||
|
||||
plt.xlim([0, 101])
|
||||
plt.ylim([0, 1.1])
|
||||
plt.tight_layout()
|
||||
plt.savefig("figure/"+"N"+ str(N) + "_ "+ model_type + "_bs_" + str(batch_size) + '_lr_' + str(lr) + '_hidden_size_' + str(hidden_size) + '_N' + str(N) + "_act_series" + str(act) + "_rate_ful_all_individuall_chiw.png", bbox_inches='tight')
|
||||
|
||||
#plt.show()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
BIN
keyboard_and_mouse/stan/strategy_inference_model
Executable file
BIN
keyboard_and_mouse/stan/strategy_inference_model
Executable file
Binary file not shown.
26
keyboard_and_mouse/stan/strategy_inference_model.stan
Executable file
26
keyboard_and_mouse/stan/strategy_inference_model.stan
Executable file
|
@ -0,0 +1,26 @@
|
|||
data {
|
||||
int<lower=1> I; // number of question options (22)
|
||||
int<lower=0> N; // number of questions being asked by the user
|
||||
int<lower=1> K; // number of strategies
|
||||
// observed "true" questions of the user
|
||||
int q[N];
|
||||
// array of predicted probabilities of questions given strategies
|
||||
// coming from the forward neural network
|
||||
matrix[I, K] P_q_S[N];
|
||||
}
|
||||
parameters {
|
||||
// probabiliy vector of the strategies being applied by the user
|
||||
// to be inferred by the model here
|
||||
simplex[K] P_S;
|
||||
}
|
||||
model {
|
||||
for (n in 1:N) {
|
||||
// marginal probability vector of the questions being asked
|
||||
vector[I] theta = P_q_S[n] * P_S;
|
||||
// categorical likelihood
|
||||
target += categorical_lpmf(q[n] | theta);
|
||||
}
|
||||
// priors
|
||||
target += dirichlet_lpdf(P_S | rep_vector(1.0, K));
|
||||
}
|
||||
|
157
keyboard_and_mouse/stan/strategy_inference_test.R
Normal file
157
keyboard_and_mouse/stan/strategy_inference_test.R
Normal file
|
@ -0,0 +1,157 @@
|
|||
library(tidyverse)
|
||||
library(cmdstanr)
|
||||
library(dplyr)
|
||||
|
||||
|
||||
model_type <- "lstmlast"
|
||||
batch_size <- "8"
|
||||
lr <- "0.0001"
|
||||
hidden_size <- "128"
|
||||
model_type <- paste0(model_type, "_bs_", batch_size, "_lr_", lr, "_hidden_size_", hidden_size)
|
||||
print(model_type)
|
||||
set.seed(9736734)
|
||||
|
||||
user_num <- 5
|
||||
user <-c(0:(user_num-1))
|
||||
strategies <- c(0:6) # 7 tasks
|
||||
print(strategies)
|
||||
print(length(strategies))
|
||||
N <- 1
|
||||
|
||||
# read data from csv
|
||||
sel <- vector("list", length(strategies))
|
||||
for (u in seq_along(user)){
|
||||
dat <- vector("list", length(strategies))
|
||||
print(paste0('user: ', u))
|
||||
for (i in seq_along(strategies)) {
|
||||
dat[[i]] <- read.csv(paste0("../prediction/task", strategies[[i]], "/", model_type, "/user", user[[u]], "_pred", ".csv"))
|
||||
dat[[i]]$assumed_strategy <- strategies[[i]]
|
||||
dat[[i]]$index <- dat[[i]]$action_id # sample based on intention
|
||||
dat[[i]]$id <- dat[[i]][,1] # sample based on intention
|
||||
}
|
||||
|
||||
# reset N after inference
|
||||
N = 1
|
||||
|
||||
# select one action series from one intention
|
||||
if (user[[u]] == 0){
|
||||
sel[[1]]<-dat[[1]] %>%
|
||||
group_by(task_name) %>%
|
||||
sample_n(N)
|
||||
sel[[1]] <- data.frame(sel[[1]])
|
||||
}
|
||||
|
||||
# filter data from the selected action series, N series per intention
|
||||
for (i in seq_along(strategies)) {
|
||||
dat[[i]]<-subset(dat[[i]], dat[[i]]$action_id == sel[[1]]$action_id[1])
|
||||
}
|
||||
row.names(dat) <- NULL
|
||||
|
||||
# create save path
|
||||
dir.create(file.path("result"), showWarnings = FALSE)
|
||||
dir.create(file.path(paste0("result/", "N", N)), showWarnings = FALSE)
|
||||
save_path <- paste0("result/", "N", N, "/", model_type, "_N", N, "_", "result","_user", user[[u]], ".csv")
|
||||
|
||||
dat <- do.call(rbind, dat) %>%
|
||||
mutate(index = as.numeric(as.factor(id))) %>%
|
||||
rename(true_strategy = task_name) %>%
|
||||
mutate(
|
||||
true_strategy = factor(
|
||||
true_strategy, levels = 0:6,
|
||||
labels = strategies
|
||||
),
|
||||
q_type = case_when(
|
||||
gt %in% c(3,4,5) ~ 0,
|
||||
gt %in% c(1,2,3,4,5,6,7) ~ 1,
|
||||
gt %in% c(1,2,3,4) ~ 2,
|
||||
gt %in% c(1,4,5,6,7) ~ 3,
|
||||
gt %in% c(1,2,3,6,7) ~ 4,
|
||||
gt %in% c(2,3,4,5,6,7) ~ 5,
|
||||
gt %in% c(1,2,3,4,5,6,7) ~ 6,
|
||||
)
|
||||
)
|
||||
|
||||
dat_obs <- dat %>% filter(assumed_strategy == strategies[[i]])
|
||||
N <- nrow(dat_obs)
|
||||
print(c("N: ", N))
|
||||
q <- dat_obs$gt
|
||||
true_strategy <- dat_obs$true_strategy
|
||||
|
||||
K <- length(unique(dat$assumed_strategy))
|
||||
print(c("K: ", K))
|
||||
I <- 7
|
||||
|
||||
P_q_S <- array(dim = c(N, I, K))
|
||||
for (n in 1:N) {
|
||||
#print(n)
|
||||
P_q_S[n, , ] <- dat %>%
|
||||
filter(index == n) %>%
|
||||
select(matches("^act[[:digit:]]+$")) %>%
|
||||
as.matrix() %>%
|
||||
t()
|
||||
for (k in 1:K) {
|
||||
# normalize probabilities
|
||||
P_q_S[n, , k] <- P_q_S[n, , k] / sum(P_q_S[n, , k])
|
||||
}
|
||||
}
|
||||
print(c('dim P_q_S',dim(P_q_S)))
|
||||
|
||||
mod <- cmdstan_model("strategy_inference_model.stan")
|
||||
|
||||
sub <- which(true_strategy == 0) # "0"
|
||||
print(c('sub', sub))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
fit_0 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_0$summary(NULL, c("mean","sd")))
|
||||
|
||||
sub <- which(true_strategy == 1)
|
||||
print(c('sub', sub))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
fit_1 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_1$summary(NULL, c("mean","sd")))
|
||||
|
||||
sub <- which(true_strategy == 2)
|
||||
print(c('sub', sub))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
fit_2 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_2$summary(NULL, c("mean","sd")))
|
||||
|
||||
sub <- which(true_strategy == 3)
|
||||
print(c('sub', sub))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
fit_3 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_3$summary(NULL, c("mean","sd")))
|
||||
|
||||
sub <- which(true_strategy == 4)
|
||||
print(c('sub', sub))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
fit_4 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_4$summary(NULL, c("mean","sd")))
|
||||
|
||||
sub <- which(true_strategy == 5)
|
||||
print(c('sub', sub))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
fit_5 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_5$summary(NULL, c("mean","sd")))
|
||||
|
||||
sub <- which(true_strategy == 6)
|
||||
print(c('sub', sub))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
fit_6 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_6$summary(NULL, c("mean","sd")))
|
||||
|
||||
# save csv
|
||||
df <-rbind(fit_0$summary(), fit_1$summary(), fit_2$summary(), fit_3$summary(), fit_4$summary(), fit_5$summary(), fit_6$summary())
|
||||
write.csv(df,file=save_path,quote=FALSE)
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,239 @@
|
|||
library(tidyverse)
|
||||
library(cmdstanr)
|
||||
library(dplyr)
|
||||
|
||||
# using every action sequence from each user
|
||||
model_type <- "lstmlast"
|
||||
batch_size <- "8"
|
||||
lr <- "0.0001"
|
||||
hidden_size <- "128"
|
||||
model_type <- paste0(model_type, "_bs_", batch_size, "_lr_", lr, "_hidden_size_", hidden_size)
|
||||
rates <- c("_0", "_10", "_20", "_30", "_40", "_50", "_60", "_70", "_80", "_90", "_100")
|
||||
|
||||
user_num <- 5
|
||||
user <-c(0:(user_num-1))
|
||||
strategies <- c(0:6) # 7 tasks
|
||||
print('strategies')
|
||||
print(strategies)
|
||||
print('strategies length')
|
||||
print(length(strategies))
|
||||
N <- 1
|
||||
unique_act_id <- c(1:5)
|
||||
print('unique_act_id')
|
||||
print(unique_act_id)
|
||||
set.seed(9746234)
|
||||
|
||||
for (act_id in seq_along(unique_act_id)){
|
||||
for (u in seq_along(user)){
|
||||
print('user')
|
||||
print(u)
|
||||
for (rate in rates) {
|
||||
N <- 1
|
||||
dat <- vector("list", length(strategies))
|
||||
for (i in seq_along(strategies)) {
|
||||
if (rate=="_0"){
|
||||
# read data from csv
|
||||
dat[[i]] <- read.csv(paste0("../prediction/single_act/task", strategies[[i]], "/", model_type, "/user", user[[u]], "/rate_10", "_act_", unique_act_id[act_id], "_pred", ".csv"))
|
||||
} else{
|
||||
dat[[i]] <- read.csv(paste0("../prediction/single_act/task", strategies[[i]], "/", model_type, "/user", user[[u]], "/rate", rate, "_act_", unique_act_id[act_id], "_pred", ".csv"))
|
||||
}
|
||||
# strategy assumed for prediction
|
||||
dat[[i]]$assumed_strategy <- strategies[[i]]
|
||||
dat[[i]]$index <- dat[[i]]$action_id # sample based on intention
|
||||
dat[[i]]$id <- dat[[i]][,1] # sample based on intention
|
||||
}
|
||||
|
||||
save_path <- paste0("result/", "N", N, "/", model_type, "_N", N, "_", "result","_user", user[[u]], "_rate_", rate, "_act_", unique_act_id[act_id], ".csv")
|
||||
|
||||
dat_act <- do.call(rbind, dat) %>%
|
||||
mutate(index = as.numeric(as.factor(id))) %>%
|
||||
rename(true_strategy = task_name) %>%
|
||||
mutate(
|
||||
true_strategy = factor(
|
||||
true_strategy, levels = 0:6,
|
||||
labels = strategies
|
||||
),
|
||||
q_type = case_when(
|
||||
gt %in% c(3,4,5) ~ 0,
|
||||
gt %in% c(1,2,3,4,5,6,7) ~ 1,
|
||||
gt %in% c(1,2,3,4) ~ 2,
|
||||
gt %in% c(1,4,5,6,7) ~ 3,
|
||||
gt %in% c(1,2,3,6,7) ~ 4,
|
||||
gt %in% c(2,3,4,5,6,7) ~ 5,
|
||||
gt %in% c(1,2,3,4,5,6,7) ~ 6,
|
||||
)
|
||||
)
|
||||
|
||||
dat_obs <- dat_act %>% filter(assumed_strategy == strategies[[i]])
|
||||
N <- nrow(dat_obs)
|
||||
print(c("N: ", N))
|
||||
print(c("dim dat_act: ", dim(dat_act)))
|
||||
q <- dat_obs$gt
|
||||
true_strategy <- dat_obs$true_strategy
|
||||
|
||||
K <- length(unique(dat_act$assumed_strategy))
|
||||
I <- 7
|
||||
|
||||
P_q_S <- array(dim = c(N, I, K))
|
||||
for (n in 1:N) {
|
||||
print(n)
|
||||
P_q_S[n, , ] <- dat_act %>%
|
||||
filter(index == n) %>%
|
||||
select(matches("^act[[:digit:]]+$")) %>%
|
||||
as.matrix() %>%
|
||||
t()
|
||||
for (k in 1:K) {
|
||||
# normalize probabilities
|
||||
P_q_S[n, , k] <- P_q_S[n, , k] / sum(P_q_S[n, , k])
|
||||
}
|
||||
}
|
||||
|
||||
print(c("dim(P_q_S)", dim(P_q_S)))
|
||||
# read stan model
|
||||
mod <- cmdstan_model(paste0(getwd(),"/strategy_inference_model.stan"))
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == 0) # "0"
|
||||
}
|
||||
#print(sub)
|
||||
#print(length(sub))
|
||||
if (length(sub) == 1){
|
||||
temp <- P_q_S[sub, , ]
|
||||
dim(temp) <- c(1, dim(temp))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||
} else{
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
}
|
||||
fit_0 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),'/temp'))
|
||||
print(fit_0$summary(NULL, c("mean","sd")))
|
||||
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == 1)
|
||||
}
|
||||
#print(sub)
|
||||
#print(length(sub))
|
||||
if (length(sub) == 1){
|
||||
temp <- P_q_S[sub, , ]
|
||||
dim(temp) <- c(1, dim(temp))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||
} else{
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
}
|
||||
fit_1 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),'/temp'))
|
||||
print(fit_1$summary(NULL, c("mean","sd")))
|
||||
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == 2)
|
||||
}
|
||||
#print(sub)
|
||||
#print(length(sub))
|
||||
if (length(sub) == 1){
|
||||
temp <- P_q_S[sub, , ]
|
||||
dim(temp) <- c(1, dim(temp))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||
} else{
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
}
|
||||
fit_2 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),'/temp'))
|
||||
print(fit_2$summary(NULL, c("mean","sd")))
|
||||
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == 3)
|
||||
}
|
||||
#print(sub)
|
||||
#print(length(sub))
|
||||
if (length(sub) == 1){
|
||||
temp <- P_q_S[sub, , ]
|
||||
dim(temp) <- c(1, dim(temp))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||
} else{
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
}
|
||||
fit_3 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),'/temp'))
|
||||
print(fit_3$summary(NULL, c("mean","sd")))
|
||||
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == 4)
|
||||
}
|
||||
#print(sub)
|
||||
#print(length(sub))
|
||||
if (length(sub) == 1){
|
||||
temp <- P_q_S[sub, , ]
|
||||
dim(temp) <- c(1, dim(temp))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||
} else{
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
}
|
||||
fit_4 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),'/temp'))
|
||||
print(fit_4$summary(NULL, c("mean","sd")))
|
||||
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == 5)
|
||||
}
|
||||
#print(sub)
|
||||
#print(length(sub))
|
||||
if (length(sub) == 1){
|
||||
temp <- P_q_S[sub, , ]
|
||||
dim(temp) <- c(1, dim(temp))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||
} else{
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
}
|
||||
fit_5 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),'/temp'))
|
||||
print(fit_5$summary(NULL, c("mean","sd")))
|
||||
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == 6)
|
||||
}
|
||||
#print(sub)
|
||||
#print(length(sub))
|
||||
if (length(sub) == 1){
|
||||
temp <- P_q_S[sub, , ]
|
||||
dim(temp) <- c(1, dim(temp))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||
} else{
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
}
|
||||
fit_6 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),'/temp'))
|
||||
print(fit_6$summary(NULL, c("mean","sd")))
|
||||
|
||||
# save csv
|
||||
df <-rbind(fit_0$summary(), fit_1$summary(), fit_2$summary(), fit_3$summary(), fit_4$summary(), fit_5$summary(), fit_6$summary())
|
||||
write.csv(df,file=save_path,quote=FALSE)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
238
keyboard_and_mouse/stan/strategy_inference_test_full_length.R
Normal file
238
keyboard_and_mouse/stan/strategy_inference_test_full_length.R
Normal file
|
@ -0,0 +1,238 @@
|
|||
library(tidyverse)
|
||||
library(cmdstanr)
|
||||
library(dplyr)
|
||||
|
||||
# index order of the strategies assumed throughout
|
||||
model_type <- "lstmlast"
|
||||
batch_size <- "8"
|
||||
lr <- "0.0001"
|
||||
hidden_size <- "128"
|
||||
model_type <- paste0(model_type, "_bs_", batch_size, "_lr_", lr, "_hidden_size_", hidden_size)
|
||||
rates <- c("_0", "_10", "_20", "_30", "_40", "_50", "_60", "_70", "_80", "_90", "_100")
|
||||
|
||||
user_num <- 5
|
||||
user <-c(0:(user_num-1))
|
||||
strategies <- c(0:6) # 7 tasks
|
||||
print(strategies)
|
||||
print(length(strategies))
|
||||
N <- 1
|
||||
|
||||
set.seed(9736754)
|
||||
|
||||
#read data from csv
|
||||
sel <- vector("list", length(strategies))
|
||||
for (u in seq_along(user)){
|
||||
print('user')
|
||||
print(u)
|
||||
for (rate in rates) {
|
||||
dat <- vector("list", length(strategies))
|
||||
for (i in seq_along(strategies)) {
|
||||
if (rate=="_0"){
|
||||
dat[[i]] <- read.csv(paste0("../prediction/task", strategies[[i]], "/", model_type, "/user", user[[u]], "_rate_10", "_pred", ".csv"))
|
||||
} else if (rate=="_100"){
|
||||
dat[[i]] <- read.csv(paste0("../prediction/task", strategies[[i]], "/", model_type, "/user", user[[u]], "_pred", ".csv"))
|
||||
} else{
|
||||
dat[[i]] <- read.csv(paste0("../prediction/task", strategies[[i]], "/", model_type, "/user", user[[u]], "_rate", rate, "_pred", ".csv"))
|
||||
}
|
||||
# strategy assumed for prediction
|
||||
dat[[i]]$assumed_strategy <- strategies[[i]]
|
||||
dat[[i]]$index <- dat[[i]]$action_id
|
||||
dat[[i]]$id <- dat[[i]][,1]
|
||||
}
|
||||
|
||||
# reset N after inference
|
||||
N <- 1
|
||||
|
||||
# select all action series and infer every one
|
||||
if (rate == "_0"){
|
||||
sel[[1]]<-dat[[1]] %>%
|
||||
group_by(task_name) %>%
|
||||
sample_n(N)
|
||||
sel[[1]] <- data.frame(sel[[1]])
|
||||
unique_act_id <- unique(sel[[1]]$action_id)
|
||||
}
|
||||
print(sel[[1]]$action_id)
|
||||
print(sel[[1]]$task_name)
|
||||
print(dat[[1]]$task_name)
|
||||
|
||||
|
||||
for (i in seq_along(strategies)) {
|
||||
dat[[i]]<-subset(dat[[i]], dat[[i]]$action_id == sel[[1]]$action_id[1])
|
||||
}
|
||||
row.names(dat) <- NULL
|
||||
print(c('action id', dat[[1]]$action_id))
|
||||
print(c('action id', dat[[2]]$action_id))
|
||||
print(c('action id', dat[[3]]$action_id))
|
||||
|
||||
dir.create(file.path(paste0("result/", "N", N)), showWarnings = FALSE)
|
||||
save_path <- paste0("result/", "N", N, "/", model_type, "_N", N, "_", "result","_user", user[[u]], "_rate_", rate, ".csv")
|
||||
|
||||
dat_act <- do.call(rbind, dat) %>%
|
||||
mutate(index = as.numeric(as.factor(id))) %>%
|
||||
rename(true_strategy = task_name) %>%
|
||||
mutate(
|
||||
true_strategy = factor(
|
||||
true_strategy, levels = 0:6,
|
||||
labels = strategies
|
||||
),
|
||||
q_type = case_when(
|
||||
gt %in% c(3,4,5) ~ 0,
|
||||
gt %in% c(1,2,3,4,5,6,7) ~ 1,
|
||||
gt %in% c(1,2,3,4) ~ 2,
|
||||
gt %in% c(1,4,5,6,7) ~ 3,
|
||||
gt %in% c(1,2,3,6,7) ~ 4,
|
||||
gt %in% c(2,3,4,5,6,7) ~ 5,
|
||||
gt %in% c(1,2,3,4,5,6,7) ~ 6,
|
||||
)
|
||||
)
|
||||
|
||||
dat_obs <- dat_act %>% filter(assumed_strategy == strategies[[i]]) # put_fridge, was num
|
||||
N <- nrow(dat_obs)
|
||||
print(c("N: ", N))
|
||||
print(c("dim dat_act: ", dim(dat_act)))
|
||||
|
||||
q <- dat_obs$gt
|
||||
true_strategy <- dat_obs$true_strategy
|
||||
|
||||
K <- length(unique(dat_act$assumed_strategy))
|
||||
I <- 7
|
||||
|
||||
P_q_S <- array(dim = c(N, I, K))
|
||||
for (n in 1:N) {
|
||||
print(n)
|
||||
P_q_S[n, , ] <- dat_act %>%
|
||||
filter(index == n) %>%
|
||||
select(matches("^act[[:digit:]]+$")) %>%
|
||||
as.matrix() %>%
|
||||
t()
|
||||
for (k in 1:K) {
|
||||
# normalize probabilities
|
||||
P_q_S[n, , k] <- P_q_S[n, , k] / sum(P_q_S[n, , k])
|
||||
}
|
||||
}
|
||||
print(c("dim(P_q_S)", dim(P_q_S)))
|
||||
|
||||
mod <- cmdstan_model(paste0(getwd(),"/strategy_inference_model.stan"))
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == 0) # "0"
|
||||
}
|
||||
if (length(sub) == 1){
|
||||
temp <- P_q_S[sub, , ]
|
||||
dim(temp) <- c(1, dim(temp))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||
} else{
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
}
|
||||
fit_0 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_0$summary(NULL, c("mean","sd")))
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == 1)
|
||||
}
|
||||
if (length(sub) == 1){
|
||||
temp <- P_q_S[sub, , ]
|
||||
dim(temp) <- c(1, dim(temp))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||
} else{
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
}
|
||||
fit_1 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_1$summary(NULL, c("mean","sd")))
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == 2)
|
||||
}
|
||||
if (length(sub) == 1){
|
||||
temp <- P_q_S[sub, , ]
|
||||
dim(temp) <- c(1, dim(temp))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||
} else{
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
}
|
||||
fit_2 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_2$summary(NULL, c("mean","sd")))
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == 3)
|
||||
}
|
||||
if (length(sub) == 1){
|
||||
temp <- P_q_S[sub, , ]
|
||||
dim(temp) <- c(1, dim(temp))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||
} else{
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
}
|
||||
fit_3 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_3$summary(NULL, c("mean","sd")))
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == 4)
|
||||
}
|
||||
if (length(sub) == 1){
|
||||
temp <- P_q_S[sub, , ]
|
||||
dim(temp) <- c(1, dim(temp))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||
} else{
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
}
|
||||
fit_4 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_4$summary(NULL, c("mean","sd")))
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == 5)
|
||||
}
|
||||
if (length(sub) == 1){
|
||||
temp <- P_q_S[sub, , ]
|
||||
dim(temp) <- c(1, dim(temp))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||
} else{
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
}
|
||||
fit_5 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_5$summary(NULL, c("mean","sd")))
|
||||
|
||||
if (rate=="_0"){
|
||||
sub <- integer(0)
|
||||
} else {
|
||||
sub <- which(true_strategy == 6)
|
||||
}
|
||||
if (length(sub) == 1){
|
||||
temp <- P_q_S[sub, , ]
|
||||
dim(temp) <- c(1, dim(temp))
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = temp)
|
||||
} else{
|
||||
sdata <- list(N = length(sub), K = K, I = I, q = q[sub], P_q_S = P_q_S[sub, , ])
|
||||
}
|
||||
fit_6 <- mod$sample(data = sdata, refresh=0, output_dir=paste0(getwd(),"/temp"))
|
||||
print(fit_6$summary(NULL, c("mean","sd")))
|
||||
|
||||
# save csv
|
||||
df <-rbind(fit_0$summary(), fit_1$summary(), fit_2$summary(), fit_3$summary(), fit_4$summary(), fit_5$summary(), fit_6$summary())
|
||||
write.csv(df,file=save_path,quote=FALSE)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue