296 lines
12 KiB
Python
296 lines
12 KiB
Python
|
import os
|
||
|
import torch, random, torch.nn as nn, numpy as np
|
||
|
from torch import optim
|
||
|
from random import shuffle
|
||
|
from sklearn.metrics import accuracy_score, f1_score
|
||
|
from src.data.game_parser_graphs_new import GameParser, make_splits, DEVICE, set_seed
|
||
|
from src.models.model_with_dialogue_moves_graphs import Model
|
||
|
import argparse
|
||
|
from tqdm import tqdm
|
||
|
import pickle
|
||
|
|
||
|
def print_epoch(data,acc_loss,lst):
|
||
|
print(f'{acc_loss/len(lst):9.4f}',end='; ',flush=True)
|
||
|
data = list(zip(*data))
|
||
|
for x in data:
|
||
|
a, b = list(zip(*x))
|
||
|
if max(a) <= 1:
|
||
|
print(f'({accuracy_score(a,b):5.3f},{f1_score(a,b,average="weighted"):5.3f},{sum(a)/len(a):5.3f},{sum(b)/len(b):5.3f},{len(b)})', end=' ',flush=True)
|
||
|
else:
|
||
|
print(f'({accuracy_score(a,b):5.3f},{f1_score(a,b,average="weighted"):5.3f},{len(b)})', end=' ',flush=True)
|
||
|
print('', end='; ',flush=True)
|
||
|
|
||
|
def do_split(model,lst,exp,criterion,optimizer=None,global_plan=False, player_plan=False,device=DEVICE):
|
||
|
data = []
|
||
|
acc_loss = 0
|
||
|
for game in lst:
|
||
|
|
||
|
if model.training and (not optimizer is None): optimizer.zero_grad()
|
||
|
|
||
|
l = model(game, global_plan=global_plan, player_plan=player_plan)
|
||
|
prediction = []
|
||
|
ground_truth = []
|
||
|
for gt, prd in l:
|
||
|
lbls = [int(a==b) for a,b in zip(gt[0],gt[1])]
|
||
|
lbls += [['NO', 'MAYBE', 'YES'].index(gt[0][0]),['NO', 'MAYBE', 'YES'].index(gt[0][1])]
|
||
|
if gt[0][2] in game.materials_dict:
|
||
|
lbls.append(game.materials_dict[gt[0][2]])
|
||
|
else:
|
||
|
lbls.append(0)
|
||
|
lbls += [['NO', 'MAYBE', 'YES'].index(gt[1][0]),['NO', 'MAYBE', 'YES'].index(gt[1][1])]
|
||
|
if gt[1][2] in game.materials_dict:
|
||
|
lbls.append(game.materials_dict[gt[1][2]])
|
||
|
else:
|
||
|
lbls.append(0)
|
||
|
prd = prd[exp:exp+1]
|
||
|
lbls = lbls[exp:exp+1]
|
||
|
data.append([(g,torch.argmax(p).item()) for p,g in zip(prd,lbls)])
|
||
|
# p, g = zip(*[(p,torch.eye(p.shape[0]).float()[g]) for p,g in zip(prd,lbls)])
|
||
|
if exp == 0:
|
||
|
pairs = list(zip(*[(pr,gt) for pr,gt in zip(prd,lbls) if gt==0 or (random.random() < 2/3)]))
|
||
|
elif exp == 1:
|
||
|
pairs = list(zip(*[(pr,gt) for pr,gt in zip(prd,lbls) if gt==0 or (random.random() < 5/6)]))
|
||
|
elif exp == 2:
|
||
|
pairs = list(zip(*[(pr,gt) for pr,gt in zip(prd,lbls) if gt==1 or (random.random() < 5/6)]))
|
||
|
else:
|
||
|
pairs = list(zip(*[(pr,gt) for pr,gt in zip(prd,lbls)]))
|
||
|
# print(pairs)
|
||
|
if pairs:
|
||
|
p,g = pairs
|
||
|
else:
|
||
|
continue
|
||
|
# print(p,g)
|
||
|
prediction.append(torch.cat(p))
|
||
|
|
||
|
# ground_truth.append(torch.cat(g))
|
||
|
ground_truth += g
|
||
|
|
||
|
if prediction:
|
||
|
prediction = torch.stack(prediction)
|
||
|
else:
|
||
|
continue
|
||
|
if ground_truth:
|
||
|
# ground_truth = torch.stack(ground_truth).float().to(DEVICE)
|
||
|
ground_truth = torch.tensor(ground_truth).long().to(device)
|
||
|
else:
|
||
|
continue
|
||
|
|
||
|
loss = criterion(prediction,ground_truth)
|
||
|
|
||
|
if model.training and (not optimizer is None):
|
||
|
loss.backward()
|
||
|
# nn.utils.clip_grad_norm_(model.parameters(), 10)
|
||
|
nn.utils.clip_grad_norm_(model.parameters(), 1)
|
||
|
optimizer.step()
|
||
|
acc_loss += loss.item()
|
||
|
# return data, acc_loss + loss.item()
|
||
|
print_epoch(data,acc_loss,lst)
|
||
|
return acc_loss, data
|
||
|
|
||
|
def main(args):
|
||
|
print(args, flush=True)
|
||
|
print(f'PID: {os.getpid():6d}', flush=True)
|
||
|
|
||
|
if isinstance(args.device, int) and args.device >= 0:
|
||
|
DEVICE = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
|
||
|
print(f'Using {DEVICE}')
|
||
|
else:
|
||
|
print('Device must be a zero or positive integer, but got',args.device)
|
||
|
exit()
|
||
|
|
||
|
if isinstance(args.seed, int) and args.seed >= 0:
|
||
|
seed = set_seed(args.seed)
|
||
|
else:
|
||
|
print('Seed must be a zero or positive integer, but got',args.seed)
|
||
|
exit()
|
||
|
|
||
|
# dataset_splits = make_splits('config/dataset_splits.json')
|
||
|
# dataset_splits = make_splits('config/dataset_splits_dev.json')
|
||
|
# dataset_splits = make_splits('config/dataset_splits_old.json')
|
||
|
dataset_splits = make_splits('config/dataset_splits_new.json')
|
||
|
|
||
|
if args.use_dialogue=='Yes':
|
||
|
d_flag = True
|
||
|
elif args.use_dialogue=='No':
|
||
|
d_flag = False
|
||
|
else:
|
||
|
print('Use dialogue must be in [Yes, No], but got',args.use_dialogue)
|
||
|
exit()
|
||
|
|
||
|
if args.use_dialogue_moves=='Yes':
|
||
|
d_move_flag = True
|
||
|
elif args.use_dialogue_moves=='No':
|
||
|
d_move_flag = False
|
||
|
else:
|
||
|
print('Use dialogue must be in [Yes, No], but got',args.use_dialogue)
|
||
|
exit()
|
||
|
|
||
|
if not args.experiment in list(range(9)):
|
||
|
print('Experiment must be in',list(range(9)),', but got',args.experiment)
|
||
|
exit()
|
||
|
|
||
|
|
||
|
if args.seq_model=='GRU':
|
||
|
seq_model = 0
|
||
|
elif args.seq_model=='LSTM':
|
||
|
seq_model = 1
|
||
|
elif args.seq_model=='Transformer':
|
||
|
seq_model = 2
|
||
|
else:
|
||
|
print('The sequence model must be in [GRU, LSTM, Transformer], but got', args.seq_model)
|
||
|
exit()
|
||
|
|
||
|
if args.plans=='Yes':
|
||
|
global_plan = (args.pov=='Third') or ((args.pov=='None') and (args.experiment in list(range(3))))
|
||
|
player_plan = (args.pov=='First') or ((args.pov=='None') and (args.experiment in list(range(3,9))))
|
||
|
elif args.plans=='No' or args.plans is None:
|
||
|
global_plan = False
|
||
|
player_plan = False
|
||
|
else:
|
||
|
print('Use Plan must be in [Yes, No], but got',args.plan)
|
||
|
exit()
|
||
|
print('global_plan', global_plan, 'player_plan', player_plan)
|
||
|
|
||
|
if args.pov=='None':
|
||
|
val = [GameParser(f,d_flag,0,0,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||
|
train = [GameParser(f,d_flag,0,0,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||
|
if args.experiment > 2:
|
||
|
val += [GameParser(f,d_flag,4,0,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||
|
train += [GameParser(f,d_flag,4,0,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||
|
elif args.pov=='Third':
|
||
|
val = [GameParser(f,d_flag,3,0,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||
|
train = [GameParser(f,d_flag,3,0,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||
|
elif args.pov=='First':
|
||
|
val = [GameParser(f,d_flag,1,0,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||
|
train = [GameParser(f,d_flag,1,0,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||
|
val += [GameParser(f,d_flag,2,0,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||
|
train += [GameParser(f,d_flag,2,0,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||
|
else:
|
||
|
print('POV must be in [None, First, Third], but got', args.pov)
|
||
|
exit()
|
||
|
|
||
|
model = Model(seq_model,DEVICE).to(DEVICE)
|
||
|
|
||
|
print(model)
|
||
|
model.train()
|
||
|
|
||
|
learning_rate = 1e-4
|
||
|
num_epochs = 1000#2#1#
|
||
|
weight_decay=1e-4
|
||
|
|
||
|
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||
|
# optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||
|
# optimizer = optim.RMSprop(model.parameters(), lr=learning_rate)
|
||
|
# optimizer = optim.Adagrad(model.parameters(), lr=learning_rate)
|
||
|
# optimizer = optim.Adadelta(model.parameters())
|
||
|
# optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay)
|
||
|
criterion = nn.CrossEntropyLoss()
|
||
|
# criterion = nn.MSELoss()
|
||
|
|
||
|
print(str(criterion), str(optimizer))
|
||
|
|
||
|
min_acc_loss = 100
|
||
|
max_f1 = 0
|
||
|
epochs_since_improvement = 0
|
||
|
wait_epoch = 100
|
||
|
|
||
|
if args.model_path is not None:
|
||
|
print(f'Loading {args.model_path}')
|
||
|
model.load_state_dict(torch.load(args.model_path))
|
||
|
acc_loss, data = do_split(model,val,args.experiment,criterion,global_plan=global_plan, player_plan=player_plan, device=DEVICE)
|
||
|
data = list(zip(*data))
|
||
|
for x in data:
|
||
|
a, b = list(zip(*x))
|
||
|
f1 = f1_score(a,b,average='weighted')
|
||
|
f1 = f1_score(a,b,average='weighted')
|
||
|
if (max_f1 < f1):
|
||
|
max_f1 = f1
|
||
|
epochs_since_improvement = 0
|
||
|
print('^')
|
||
|
torch.save(model.cpu().state_dict(), args.save_path)
|
||
|
model = model.to(DEVICE)
|
||
|
else:
|
||
|
print('Training model from scratch', flush=True)
|
||
|
# exit()
|
||
|
|
||
|
for epoch in range(num_epochs):
|
||
|
print(f'{os.getpid():6d} {epoch+1:4d},',end=' ', flush=True)
|
||
|
shuffle(train)
|
||
|
model.train()
|
||
|
do_split(model,train,args.experiment,criterion,optimizer=optimizer,global_plan=global_plan, player_plan=player_plan, device=DEVICE)
|
||
|
model.eval()
|
||
|
acc_loss, data = do_split(model,val,args.experiment,criterion,global_plan=global_plan, player_plan=player_plan, device=DEVICE)
|
||
|
|
||
|
data = list(zip(*data))
|
||
|
for x in data:
|
||
|
a, b = list(zip(*x))
|
||
|
f1 = f1_score(a,b,average='weighted')
|
||
|
if (max_f1 < f1):
|
||
|
max_f1 = f1
|
||
|
epochs_since_improvement = 0
|
||
|
print('^')
|
||
|
torch.save(model.cpu().state_dict(), args.save_path)
|
||
|
model = model.to(DEVICE)
|
||
|
else:
|
||
|
epochs_since_improvement += 1
|
||
|
print()
|
||
|
# if (min_acc_loss > acc_loss):
|
||
|
# min_acc_loss = acc_loss
|
||
|
# epochs_since_improvement = 0
|
||
|
# print('^')
|
||
|
# else:
|
||
|
# epochs_since_improvement += 1
|
||
|
# print()
|
||
|
|
||
|
if epoch > wait_epoch and epochs_since_improvement > 20:
|
||
|
break
|
||
|
print()
|
||
|
print('Test')
|
||
|
model.load_state_dict(torch.load(args.save_path))
|
||
|
|
||
|
val = None
|
||
|
train = None
|
||
|
if args.pov=='None':
|
||
|
test = [GameParser(f,d_flag,0,0,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||
|
if args.experiment > 2:
|
||
|
test += [GameParser(f,d_flag,4,0,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||
|
elif args.pov=='Third':
|
||
|
test = [GameParser(f,d_flag,3,0,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||
|
elif args.pov=='First':
|
||
|
test = [GameParser(f,d_flag,1,0,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||
|
test += [GameParser(f,d_flag,2,0,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||
|
else:
|
||
|
print('POV must be in [None, First, Third], but got', args.pov)
|
||
|
|
||
|
model.eval()
|
||
|
_, data = do_split(model,test,args.experiment,criterion,global_plan=global_plan, player_plan=player_plan, device=DEVICE)
|
||
|
|
||
|
print()
|
||
|
print(data)
|
||
|
print()
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
parser = argparse.ArgumentParser(description='Process some integers.')
|
||
|
parser.add_argument('--pov', type=str,
|
||
|
help='point of view [None, First, Third]')
|
||
|
parser.add_argument('--use_dialogue', type=str,
|
||
|
help='Use dialogue [Yes, No]')
|
||
|
parser.add_argument('--use_dialogue_moves', type=str,
|
||
|
help='Use dialogue [Yes, No]')
|
||
|
parser.add_argument('--plans', type=str,
|
||
|
help='Use dialogue [Yes, No]')
|
||
|
parser.add_argument('--seq_model', type=str,
|
||
|
help='point of view [GRU, LSTM, Transformer]')
|
||
|
parser.add_argument('--experiment', type=int,
|
||
|
help='point of view [0:AggQ1, 1:AggQ2, 2:AggQ3, 3:P0Q1, 4:P0Q2, 5:P0Q3, 6:P1Q1, 7:P1Q2, 8:P1Q3]')
|
||
|
parser.add_argument('--seed', type=int,
|
||
|
help='Selet random seed by index [0, 1, 2, ...]. 0 -> random seed set to 0. n>0 -> random seed '
|
||
|
'set to n\'th random number with original seed set to 0')
|
||
|
parser.add_argument('--save_path', type=str,
|
||
|
help='path where to save model')
|
||
|
parser.add_argument('--model_path', type=str, default=None,
|
||
|
help='path to the pretrained model to be loaded')
|
||
|
parser.add_argument('--device', type=int, default=0,
|
||
|
help='select cuda device number')
|
||
|
|
||
|
main(parser.parse_args())
|