limits-of-tom/plan_predictor_graphs.py

325 lines
16 KiB
Python
Raw Permalink Normal View History

2024-06-11 15:36:55 +02:00
import os
import torch, torch.nn as nn, numpy as np
from torch import optim
from random import shuffle
from sklearn.metrics import accuracy_score, f1_score
from src.data.game_parser_graphs_new import GameParser, make_splits, set_seed
from src.models.plan_model_graphs import Model
import argparse
from tqdm import tqdm
import pickle
def print_epoch(data, acc_loss):
print(f'{acc_loss:9.4f}',end='; ',flush=True)
acc = []
f1 = []
for x in data:
a, b, _, _, _, _, _, _ = x
acc.append(accuracy_score(b, a))
f1.append(f1_score(b, a, zero_division=1))
print(f'{np.mean(f1):5.3f},', end=' ', flush=True)
print('', end='; ', flush=True)
return np.mean(acc), np.mean(f1), f1
def do_split(model, lst, exp, criterion, device, optimizer=None, global_plan=False, player_plan=False, incremental=False):
data = []
acc_loss = 0
for batch, game in enumerate(lst):
if (exp != 2) and (exp != 3):
raise ValueError('This script is only for exp == 2 or exp == 3.')
prediction, ground_truth, sel = model(game, experiment=exp, global_plan=global_plan, player_plan=player_plan, incremental=incremental)
if exp == 2:
if sel[0]:
prediction = prediction[game.player1_plan.edge_index.shape[1]:]
ground_truth = ground_truth[game.player1_plan.edge_index.shape[1]:]
if sel[1]:
prediction = prediction[game.player2_plan.edge_index.shape[1]:]
ground_truth = ground_truth[game.player2_plan.edge_index.shape[1]:]
if prediction.numel() == 0 and ground_truth.numel() == 0: continue
if incremental:
ground_truth = ground_truth.to(device).repeat(prediction.shape[0], 1)
data += list(zip(torch.round(torch.sigmoid(prediction)).float().cpu().data.numpy(),
ground_truth.cpu().data.numpy(),
[game.player1_plan.edge_index.shape[1]]*len(prediction),
[game.player2_plan.edge_index.shape[1]]*len(prediction),
[game.global_plan.edge_index.shape[1]]*len(prediction),
[sel]*len(prediction),
[game.game_path]*len(prediction),
[batch]*len(prediction)))
else:
ground_truth = ground_truth.to(device)
data.append((
torch.round(torch.sigmoid(prediction)).float().cpu().data.numpy(),
ground_truth.cpu().data.numpy(),
game.player1_plan.edge_index.shape[1],
game.player2_plan.edge_index.shape[1],
game.global_plan.edge_index.shape[1],
sel,
game.game_path,
batch,
))
loss = criterion(prediction, ground_truth)
# loss += 1e-5 * sum(p.pow(2.0).sum() for p in model.parameters())
acc_loss += loss.item()
if model.training and (not optimizer is None):
loss.backward()
if (batch+1) % 2 == 0: # gradient accumulation
# nn.utils.clip_grad_norm_(model.parameters(), 1)
optimizer.step()
optimizer.zero_grad()
acc_loss /= len(lst)
acc, f1, f1_list = print_epoch(data, acc_loss)
if not incremental:
data = [data[i] + (f1_list[i],) for i in range(len(data))]
return acc_loss, data, acc, f1
def init_weights(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.00)
def main(args):
print(args, flush=True)
print(f'PID: {os.getpid():6d}', flush=True)
if isinstance(args.device, int) and args.device >= 0:
DEVICE = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
print(f'Using {DEVICE}')
else:
print('Device must be a zero or positive integer, but got',args.device)
exit()
if isinstance(args.seed, int) and args.seed >= 0:
seed = set_seed(args.seed)
else:
print('Seed must be a zero or positive integer, but got',args.seed)
exit()
dataset_splits = make_splits('config/dataset_splits_new.json')
# dataset_splits = make_splits('config/dataset_splits_dev.json')
if args.use_dialogue=='Yes':
d_flag = True
elif args.use_dialogue=='No':
d_flag = False
else:
print('Use dialogue must be in [Yes, No], but got',args.use_dialogue)
exit()
if args.use_dialogue_moves=='Yes':
d_move_flag = True
elif args.use_dialogue_moves=='No':
d_move_flag = False
else:
print('Use dialogue must be in [Yes, No], but got',args.use_dialogue)
exit()
if not args.experiment in list(range(9)):
print('Experiment must be in',list(range(9)),', but got',args.experiment)
exit()
if not args.intermediate in list(range(32)):
print('Intermediate must be in',list(range(32)),', but got',args.intermediate)
exit()
if args.seq_model=='GRU':
seq_model = 0
elif args.seq_model=='LSTM':
seq_model = 1
elif args.seq_model=='Transformer':
seq_model = 2
else:
print('The sequence model must be in [GRU, LSTM, Transformer], but got', args.seq_model)
exit()
if args.plans=='Yes':
global_plan = (args.pov=='Third') or ((args.pov=='None') and (args.experiment in list(range(3))))
player_plan = (args.pov=='First') or ((args.pov=='None') and (args.experiment in list(range(3,9))))
elif args.plans=='No' or args.plans is None:
global_plan = False
player_plan = False
else:
print('Use Plan must be in [Yes, No], but got',args.plan)
exit()
print('global_plan', global_plan, 'player_plan', player_plan)
if args.use_int0_instead_of_intermediate:
if args.pov=='None':
val = [GameParser(f,d_flag,0,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['validation'])]
train = [GameParser(f,d_flag,0,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['training'])]
if args.experiment > 2:
val += [GameParser(f,d_flag,4,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['validation'])]
train += [GameParser(f,d_flag,4,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['training'])]
elif args.pov=='Third':
val = [GameParser(f,d_flag,3,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['validation'])]
train = [GameParser(f,d_flag,3,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['training'])]
elif args.pov=='First':
val = [GameParser(f,d_flag,1,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['validation'])]
train = [GameParser(f,d_flag,1,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['training'])]
val += [GameParser(f,d_flag,2,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['validation'])]
train += [GameParser(f,d_flag,2,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['training'])]
else:
print('POV must be in [None, First, Third], but got', args.pov)
exit()
else:
if args.pov=='None':
val = [GameParser(f,d_flag,0,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['validation'])]
train = [GameParser(f,d_flag,0,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
if args.experiment > 2:
val += [GameParser(f,d_flag,4,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['validation'])]
train += [GameParser(f,d_flag,4,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
elif args.pov=='Third':
val = [GameParser(f,d_flag,3,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['validation'])]
train = [GameParser(f,d_flag,3,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
elif args.pov=='First':
val = [GameParser(f,d_flag,1,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['validation'])]
train = [GameParser(f,d_flag,1,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
val += [GameParser(f,d_flag,2,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['validation'])]
train += [GameParser(f,d_flag,2,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
else:
print('POV must be in [None, First, Third], but got', args.pov)
exit()
model = Model(seq_model, DEVICE).to(DEVICE)
# model.apply(init_weights)
print(model)
model.train()
learning_rate = args.lr
weight_decay = args.weight_decay
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
if args.experiment == 2:
pos_weight = torch.tensor([2.5], device=DEVICE)
if args.experiment == 3:
pos_weight = torch.tensor([10.0], device=DEVICE)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
# criterion = nn.BCEWithLogitsLoss()
print(str(criterion), str(optimizer))
num_epochs = 200
min_acc_loss = 1e6
best_f1 = 0.0
epochs_since_improvement = 0
wait_epoch = 15
max_fails = 5
if args.model_path is not None:
print(f'Loading {args.model_path}')
model.load_state_dict(torch.load(args.model_path))
model.eval()
# acc_loss, data, acc, f1 = do_split(model, val, args.experiment, criterion, device=DEVICE, global_plan=global_plan, player_plan=player_plan, incremental=True)
acc_loss0, data, acc, f1 = do_split(model, val, args.experiment, criterion, device=DEVICE, global_plan=global_plan, player_plan=player_plan, incremental=False)
# if np.mean([acc_loss, acc_loss0]) < min_acc_loss:
if f1 > best_f1:
# min_acc_loss = np.mean([acc_loss, acc_loss0])
best_f1 = f1
epochs_since_improvement = 0
print('^')
torch.save(model.cpu().state_dict(), args.save_path)
model = model.to(DEVICE)
else:
print('Training model from scratch', flush=True)
for epoch in range(num_epochs):
print(f'{os.getpid():6d} {epoch+1:4d},',end=' ',flush=True)
shuffle(train)
model.train()
# do_split(model, train, args.experiment, criterion, device=DEVICE, optimizer=optimizer, global_plan=global_plan, player_plan=player_plan, incremental=True)
do_split(model, train, args.experiment, criterion, device=DEVICE, optimizer=optimizer, global_plan=global_plan, player_plan=player_plan, incremental=False)
model.eval()
# acc_loss, data, acc, f1 = do_split(model, val, args.experiment, criterion, device=DEVICE, global_plan=global_plan, player_plan=player_plan, incremental=True)
acc_loss0, data, acc, f1 = do_split(model, val, args.experiment, criterion, device=DEVICE, global_plan=global_plan, player_plan=player_plan, incremental=False)
# if np.mean([acc_loss, acc_loss0]) < min_acc_loss:
# if acc_loss0 < min_acc_loss:
if f1 > best_f1:
# min_acc_loss = np.mean([acc_loss, acc_loss0])
# min_acc_loss = acc_loss0
best_f1 = f1
epochs_since_improvement = 0
print('^')
torch.save(model.cpu().state_dict(), args.save_path)
model = model.to(DEVICE)
else:
epochs_since_improvement += 1
print()
if epoch > wait_epoch and epochs_since_improvement > max_fails:
break
print()
print('Test')
model.load_state_dict(torch.load(args.save_path))
if args.use_int0_instead_of_intermediate:
val = None
train = None
if args.pov=='None':
test = [GameParser(f,d_flag,0,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['test'])]
if args.experiment > 2:
test += [GameParser(f,d_flag,4,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['test'])]
elif args.pov=='Third':
test = [GameParser(f,d_flag,3,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['test'])]
elif args.pov=='First':
test = [GameParser(f,d_flag,1,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['test'])]
test += [GameParser(f,d_flag,2,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['test'])]
else:
print('POV must be in [None, First, Third], but got', args.pov)
else:
val = None
train = None
if args.pov=='None':
test = [GameParser(f,d_flag,0,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
if args.experiment > 2:
test += [GameParser(f,d_flag,4,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
elif args.pov=='Third':
test = [GameParser(f,d_flag,3,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
elif args.pov=='First':
test = [GameParser(f,d_flag,1,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
test += [GameParser(f,d_flag,2,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
else:
print('POV must be in [None, First, Third], but got', args.pov)
model.eval()
# acc_loss, data, acc, f1 = do_split(model, test, args.experiment, criterion, device=DEVICE, global_plan=global_plan, player_plan=player_plan, incremental=True)
acc_loss, data, acc, f1 = do_split(model, test, args.experiment, criterion, device=DEVICE, global_plan=global_plan, player_plan=player_plan, incremental=False)
print()
print(data)
print()
with open(f'{args.save_path[:-6]}_data.pkl', 'wb') as f:
pickle.dump(data, f)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--pov', type=str,
help='point of view [None, First, Third]')
parser.add_argument('--use_dialogue', type=str,
help='Use dialogue [Yes, No]')
parser.add_argument('--use_dialogue_moves', type=str,
help='Use dialogue [Yes, No]')
parser.add_argument('--plans', type=str,
help='Use dialogue [Yes, No]')
parser.add_argument('--seq_model', type=str,
help='point of view [GRU, LSTM, Transformer]')
parser.add_argument('--experiment', type=int,
help='point of view [0:Global, 1:Partner, 2:GlobalDif, 3:PartnerDif]')
parser.add_argument('--intermediate', type=int,
help='point of view [0:Global, 1:Partner, 2:GlobalDif, 3:PartnerDif]')
parser.add_argument('--save_path', type=str,
help='path where to save model')
parser.add_argument('--seed', type=int,
help='Selet random seed by index [0, 1, 2, ...]. 0 -> random seed set to 0. n>0 -> random seed '
'set to n\'th random number with original seed set to 0')
parser.add_argument('--device', type=int, default=0,
help='select cuda device number')
parser.add_argument('--model_path', type=str, default=None,
help='path to the pretrained model to be loaded')
parser.add_argument('--weight_decay', type=float, default=0.0)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--use_int0_instead_of_intermediate', action='store_true')
main(parser.parse_args())