import os import torch, torch.nn as nn, numpy as np from sklearn.metrics import accuracy_score, f1_score from src.data.game_parser_graphs_new import GameParser, make_splits, set_seed from src.models.plan_model_graphs import Model import argparse import pickle from tqdm import tqdm def print_epoch(data, acc_loss): print(f'{acc_loss:9.4f}',end='; ',flush=True) acc = [] f1 = [] for x in data: a, b, _, _, _, _, _, _ = x acc.append(accuracy_score(b, a)) f1.append(f1_score(b, a, zero_division=1)) print(f'{np.mean(f1):5.3f},', end=' ', flush=True) print('', end='; ', flush=True) return np.mean(acc), np.mean(f1), f1 def do_split(model, lst, exp, criterion, device, optimizer=None, global_plan=False, player_plan=False, incremental=False): data = [] seq2seq_feats = [] acc_loss = 0 for batch, game in enumerate(lst): if (exp != 2) and (exp != 3): raise ValueError('This script is only for exp == 2 or exp == 3.') prediction, ground_truth, sel, feats = model(game, experiment=exp, global_plan=global_plan, player_plan=player_plan, incremental=incremental, return_feats=True) seq2seq_feats.append([feats, game.game_path]) if exp == 2: if sel[0]: prediction = prediction[game.player1_plan.edge_index.shape[1]:] ground_truth = ground_truth[game.player1_plan.edge_index.shape[1]:] if sel[1]: prediction = prediction[game.player2_plan.edge_index.shape[1]:] ground_truth = ground_truth[game.player2_plan.edge_index.shape[1]:] if prediction.numel() == 0 and ground_truth.numel() == 0: continue if incremental: ground_truth = ground_truth.to(device).repeat(prediction.shape[0], 1) data += list(zip(torch.round(torch.sigmoid(prediction)).float().cpu().data.numpy(), ground_truth.cpu().data.numpy(), [game.player1_plan.edge_index.shape[1]]*len(prediction), [game.player2_plan.edge_index.shape[1]]*len(prediction), [game.global_plan.edge_index.shape[1]]*len(prediction), [sel]*len(prediction), [game.game_path]*len(prediction), [batch]*len(prediction))) else: ground_truth = ground_truth.to(device) data.append(( torch.round(torch.sigmoid(prediction)).float().cpu().data.numpy(), ground_truth.cpu().data.numpy(), game.player1_plan.edge_index.shape[1], game.player2_plan.edge_index.shape[1], game.global_plan.edge_index.shape[1], sel, game.game_path, batch, )) loss = criterion(prediction, ground_truth) acc_loss += loss.item() acc_loss /= len(lst) acc, f1, f1_list = print_epoch(data, acc_loss) if not incremental: data = [data[i] + (f1_list[i],) for i in range(len(data))] return acc_loss, data, acc, f1, seq2seq_feats def main(args): print(args, flush=True) print(f'PID: {os.getpid():6d}', flush=True) if isinstance(args.device, int) and args.device >= 0: DEVICE = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu' print(f'Using {DEVICE}') else: print('Device must be a zero or positive integer, but got',args.device) exit() if isinstance(args.seed, int) and args.seed >= 0: seed = set_seed(args.seed) else: print('Seed must be a zero or positive integer, but got',args.seed) exit() dataset_splits = make_splits('config/dataset_splits_new.json') if args.use_dialogue=='Yes': d_flag = True elif args.use_dialogue=='No': d_flag = False else: print('Use dialogue must be in [Yes, No], but got',args.use_dialogue) exit() if args.use_dialogue_moves=='Yes': d_move_flag = True elif args.use_dialogue_moves=='No': d_move_flag = False else: print('Use dialogue must be in [Yes, No], but got',args.use_dialogue) exit() if not args.experiment in list(range(9)): print('Experiment must be in',list(range(9)),', but got',args.experiment) exit() if not args.intermediate in list(range(32)): print('Intermediate must be in',list(range(32)),', but got',args.intermediate) exit() if args.seq_model=='GRU': seq_model = 0 elif args.seq_model=='LSTM': seq_model = 1 elif args.seq_model=='Transformer': seq_model = 2 else: print('The sequence model must be in [GRU, LSTM, Transformer], but got', args.seq_model) exit() if args.plans=='Yes': global_plan = (args.pov=='Third') or ((args.pov=='None') and (args.experiment in list(range(3)))) player_plan = (args.pov=='First') or ((args.pov=='None') and (args.experiment in list(range(3,9)))) elif args.plans=='No' or args.plans is None: global_plan = False player_plan = False else: print('Use Plan must be in [Yes, No], but got',args.plan) exit() print('global_plan', global_plan, 'player_plan', player_plan) criterion = nn.BCEWithLogitsLoss() model = Model(seq_model, DEVICE).to(DEVICE) model.load_state_dict(torch.load(args.model_path)) model.eval() if args.pov=='None': test = [GameParser(f,d_flag,0,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])] if args.experiment > 2: test += [GameParser(f,d_flag,4,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])] elif args.pov=='Third': test = [GameParser(f,d_flag,3,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])] elif args.pov=='First': test = [GameParser(f,d_flag,1,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])] test += [GameParser(f,d_flag,2,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])] else: print('POV must be in [None, First, Third], but got', args.pov) ######### TEST acc_loss, data, acc, f1, seq2seq_feats = do_split(model, test, args.experiment, criterion, device=DEVICE, global_plan=global_plan, player_plan=player_plan, incremental=False) with open(f'{args.model_path[:-6]}_feats_test.pkl', 'wb') as f: pickle.dump(seq2seq_feats, f) if args.pov=='None': train = [GameParser(f,d_flag,0,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])] if args.experiment > 2: train += [GameParser(f,d_flag,4,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])] elif args.pov=='Third': train = [GameParser(f,d_flag,3,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])] elif args.pov=='First': train = [GameParser(f,d_flag,1,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])] train += [GameParser(f,d_flag,2,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])] else: print('POV must be in [None, First, Third], but got', args.pov) exit() ######### TRAIN acc_loss, data, acc, f1, seq2seq_feats = do_split(model, train, args.experiment, criterion, device=DEVICE, global_plan=global_plan, player_plan=player_plan, incremental=False) with open(f'{args.model_path[:-6]}_feats_train.pkl', 'wb') as f: pickle.dump(seq2seq_feats, f) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Process some integers.') parser.add_argument('--pov', type=str, default='First', help='point of view [None, First, Third]') parser.add_argument('--use_dialogue', type=str, default='Yes', help='Use dialogue [Yes, No]') parser.add_argument('--use_dialogue_moves', type=str, default='Yes', help='Use dialogue [Yes, No]') parser.add_argument('--plans', type=str, default='Yes', help='Use dialogue [Yes, No]') parser.add_argument('--seq_model', type=str, default='Transformer', help='point of view [GRU, LSTM, Transformer]') parser.add_argument('--experiment', type=int, default=2, help='point of view [0:Global, 1:Partner, 2:GlobalDif, 3:PartnerDif]') parser.add_argument('--intermediate', type=int, help='') parser.add_argument('--seed', type=int, help='Selet random seed by index [0, 1, 2, ...]. 0 -> random seed set to 0. n>0 -> random seed ' 'set to n\'th random number with original seed set to 0') parser.add_argument('--device', type=int, default=0, help='select cuda device number') parser.add_argument('--model_path', type=str, default=None, help='path to the pretrained model to be loaded') main(parser.parse_args())