201 lines
8.9 KiB
Python
201 lines
8.9 KiB
Python
|
import os
|
||
|
import torch, torch.nn as nn, numpy as np
|
||
|
from sklearn.metrics import accuracy_score, f1_score
|
||
|
from src.data.game_parser_graphs_new import GameParser, make_splits, set_seed
|
||
|
from src.models.plan_model_graphs import Model
|
||
|
import argparse
|
||
|
import pickle
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
|
||
|
def print_epoch(data, acc_loss):
|
||
|
print(f'{acc_loss:9.4f}',end='; ',flush=True)
|
||
|
acc = []
|
||
|
f1 = []
|
||
|
for x in data:
|
||
|
a, b, _, _, _, _, _, _ = x
|
||
|
acc.append(accuracy_score(b, a))
|
||
|
f1.append(f1_score(b, a, zero_division=1))
|
||
|
print(f'{np.mean(f1):5.3f},', end=' ', flush=True)
|
||
|
print('', end='; ', flush=True)
|
||
|
return np.mean(acc), np.mean(f1), f1
|
||
|
|
||
|
def do_split(model, lst, exp, criterion, device, optimizer=None, global_plan=False, player_plan=False, incremental=False):
|
||
|
data = []
|
||
|
seq2seq_feats = []
|
||
|
acc_loss = 0
|
||
|
for batch, game in enumerate(lst):
|
||
|
if (exp != 2) and (exp != 3):
|
||
|
raise ValueError('This script is only for exp == 2 or exp == 3.')
|
||
|
prediction, ground_truth, sel, feats = model(game, experiment=exp, global_plan=global_plan, player_plan=player_plan, incremental=incremental, return_feats=True)
|
||
|
seq2seq_feats.append([feats, game.game_path])
|
||
|
if exp == 2:
|
||
|
if sel[0]:
|
||
|
prediction = prediction[game.player1_plan.edge_index.shape[1]:]
|
||
|
ground_truth = ground_truth[game.player1_plan.edge_index.shape[1]:]
|
||
|
if sel[1]:
|
||
|
prediction = prediction[game.player2_plan.edge_index.shape[1]:]
|
||
|
ground_truth = ground_truth[game.player2_plan.edge_index.shape[1]:]
|
||
|
if prediction.numel() == 0 and ground_truth.numel() == 0: continue
|
||
|
if incremental:
|
||
|
ground_truth = ground_truth.to(device).repeat(prediction.shape[0], 1)
|
||
|
data += list(zip(torch.round(torch.sigmoid(prediction)).float().cpu().data.numpy(),
|
||
|
ground_truth.cpu().data.numpy(),
|
||
|
[game.player1_plan.edge_index.shape[1]]*len(prediction),
|
||
|
[game.player2_plan.edge_index.shape[1]]*len(prediction),
|
||
|
[game.global_plan.edge_index.shape[1]]*len(prediction),
|
||
|
[sel]*len(prediction),
|
||
|
[game.game_path]*len(prediction),
|
||
|
[batch]*len(prediction)))
|
||
|
else:
|
||
|
ground_truth = ground_truth.to(device)
|
||
|
data.append((
|
||
|
torch.round(torch.sigmoid(prediction)).float().cpu().data.numpy(),
|
||
|
ground_truth.cpu().data.numpy(),
|
||
|
game.player1_plan.edge_index.shape[1],
|
||
|
game.player2_plan.edge_index.shape[1],
|
||
|
game.global_plan.edge_index.shape[1],
|
||
|
sel,
|
||
|
game.game_path,
|
||
|
batch,
|
||
|
))
|
||
|
loss = criterion(prediction, ground_truth)
|
||
|
acc_loss += loss.item()
|
||
|
acc_loss /= len(lst)
|
||
|
acc, f1, f1_list = print_epoch(data, acc_loss)
|
||
|
if not incremental:
|
||
|
data = [data[i] + (f1_list[i],) for i in range(len(data))]
|
||
|
return acc_loss, data, acc, f1, seq2seq_feats
|
||
|
|
||
|
def main(args):
|
||
|
print(args, flush=True)
|
||
|
print(f'PID: {os.getpid():6d}', flush=True)
|
||
|
|
||
|
if isinstance(args.device, int) and args.device >= 0:
|
||
|
DEVICE = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
|
||
|
print(f'Using {DEVICE}')
|
||
|
else:
|
||
|
print('Device must be a zero or positive integer, but got',args.device)
|
||
|
exit()
|
||
|
|
||
|
if isinstance(args.seed, int) and args.seed >= 0:
|
||
|
seed = set_seed(args.seed)
|
||
|
else:
|
||
|
print('Seed must be a zero or positive integer, but got',args.seed)
|
||
|
exit()
|
||
|
|
||
|
dataset_splits = make_splits('config/dataset_splits_new.json')
|
||
|
|
||
|
if args.use_dialogue=='Yes':
|
||
|
d_flag = True
|
||
|
elif args.use_dialogue=='No':
|
||
|
d_flag = False
|
||
|
else:
|
||
|
print('Use dialogue must be in [Yes, No], but got',args.use_dialogue)
|
||
|
exit()
|
||
|
|
||
|
if args.use_dialogue_moves=='Yes':
|
||
|
d_move_flag = True
|
||
|
elif args.use_dialogue_moves=='No':
|
||
|
d_move_flag = False
|
||
|
else:
|
||
|
print('Use dialogue must be in [Yes, No], but got',args.use_dialogue)
|
||
|
exit()
|
||
|
|
||
|
if not args.experiment in list(range(9)):
|
||
|
print('Experiment must be in',list(range(9)),', but got',args.experiment)
|
||
|
exit()
|
||
|
|
||
|
if not args.intermediate in list(range(32)):
|
||
|
print('Intermediate must be in',list(range(32)),', but got',args.intermediate)
|
||
|
exit()
|
||
|
|
||
|
if args.seq_model=='GRU':
|
||
|
seq_model = 0
|
||
|
elif args.seq_model=='LSTM':
|
||
|
seq_model = 1
|
||
|
elif args.seq_model=='Transformer':
|
||
|
seq_model = 2
|
||
|
else:
|
||
|
print('The sequence model must be in [GRU, LSTM, Transformer], but got', args.seq_model)
|
||
|
exit()
|
||
|
|
||
|
if args.plans=='Yes':
|
||
|
global_plan = (args.pov=='Third') or ((args.pov=='None') and (args.experiment in list(range(3))))
|
||
|
player_plan = (args.pov=='First') or ((args.pov=='None') and (args.experiment in list(range(3,9))))
|
||
|
elif args.plans=='No' or args.plans is None:
|
||
|
global_plan = False
|
||
|
player_plan = False
|
||
|
else:
|
||
|
print('Use Plan must be in [Yes, No], but got',args.plan)
|
||
|
exit()
|
||
|
print('global_plan', global_plan, 'player_plan', player_plan)
|
||
|
|
||
|
criterion = nn.BCEWithLogitsLoss()
|
||
|
|
||
|
model = Model(seq_model, DEVICE).to(DEVICE)
|
||
|
|
||
|
model.load_state_dict(torch.load(args.model_path))
|
||
|
model.eval()
|
||
|
|
||
|
if args.pov=='None':
|
||
|
test = [GameParser(f,d_flag,0,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||
|
if args.experiment > 2:
|
||
|
test += [GameParser(f,d_flag,4,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||
|
elif args.pov=='Third':
|
||
|
test = [GameParser(f,d_flag,3,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||
|
elif args.pov=='First':
|
||
|
test = [GameParser(f,d_flag,1,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||
|
test += [GameParser(f,d_flag,2,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||
|
else:
|
||
|
print('POV must be in [None, First, Third], but got', args.pov)
|
||
|
|
||
|
######### TEST
|
||
|
acc_loss, data, acc, f1, seq2seq_feats = do_split(model, test, args.experiment, criterion, device=DEVICE, global_plan=global_plan, player_plan=player_plan, incremental=False)
|
||
|
with open(f'{args.model_path[:-6]}_feats_test.pkl', 'wb') as f:
|
||
|
pickle.dump(seq2seq_feats, f)
|
||
|
|
||
|
if args.pov=='None':
|
||
|
train = [GameParser(f,d_flag,0,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||
|
if args.experiment > 2:
|
||
|
train += [GameParser(f,d_flag,4,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||
|
elif args.pov=='Third':
|
||
|
train = [GameParser(f,d_flag,3,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||
|
elif args.pov=='First':
|
||
|
train = [GameParser(f,d_flag,1,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||
|
train += [GameParser(f,d_flag,2,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||
|
else:
|
||
|
print('POV must be in [None, First, Third], but got', args.pov)
|
||
|
exit()
|
||
|
|
||
|
######### TRAIN
|
||
|
acc_loss, data, acc, f1, seq2seq_feats = do_split(model, train, args.experiment, criterion, device=DEVICE, global_plan=global_plan, player_plan=player_plan, incremental=False)
|
||
|
with open(f'{args.model_path[:-6]}_feats_train.pkl', 'wb') as f:
|
||
|
pickle.dump(seq2seq_feats, f)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
parser = argparse.ArgumentParser(description='Process some integers.')
|
||
|
parser.add_argument('--pov', type=str, default='First',
|
||
|
help='point of view [None, First, Third]')
|
||
|
parser.add_argument('--use_dialogue', type=str, default='Yes',
|
||
|
help='Use dialogue [Yes, No]')
|
||
|
parser.add_argument('--use_dialogue_moves', type=str, default='Yes',
|
||
|
help='Use dialogue [Yes, No]')
|
||
|
parser.add_argument('--plans', type=str, default='Yes',
|
||
|
help='Use dialogue [Yes, No]')
|
||
|
parser.add_argument('--seq_model', type=str, default='Transformer',
|
||
|
help='point of view [GRU, LSTM, Transformer]')
|
||
|
parser.add_argument('--experiment', type=int, default=2,
|
||
|
help='point of view [0:Global, 1:Partner, 2:GlobalDif, 3:PartnerDif]')
|
||
|
parser.add_argument('--intermediate', type=int,
|
||
|
help='')
|
||
|
parser.add_argument('--seed', type=int,
|
||
|
help='Selet random seed by index [0, 1, 2, ...]. 0 -> random seed set to 0. n>0 -> random seed '
|
||
|
'set to n\'th random number with original seed set to 0')
|
||
|
parser.add_argument('--device', type=int, default=0,
|
||
|
help='select cuda device number')
|
||
|
parser.add_argument('--model_path', type=str, default=None,
|
||
|
help='path to the pretrained model to be loaded')
|
||
|
|
||
|
main(parser.parse_args())
|