from src.models.model_with_dialogue_moves import Model as ToMModel from src.models.plan_model_graphs import Model as CPAModel from src.data.game_parser import GameParser from src.data.game_parser_graphs_new import GameParser as GameParserCPA import torch import json import numpy as np DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' MODEL_TYPES = { 'GRU' : 0, 'LSTM' : 1, 'Transformer' : 2, } def main(): model_file = "models/tom_lstm_baseline/tom6_model.torch" use_dialogue = "Yes" model_type_name = "LSTM" model_type = MODEL_TYPES[model_type_name] model = ToMModel(model_type).to(DEVICE) model.load_state_dict(torch.load(model_file)) dataset_splits = json.load(open('config/dataset_splits.json')) for set in dataset_splits.values(): for path in set: for pov in [1, 2]: out_file = f'{path}/intermediate_ToM6_{path.split("/")[-1]}_player{pov}.npz' # if os.path.isfile(out_file): # continue game = GameParser(path,use_dialogue=='Yes',pov,0,True) l = model(game, global_plan=False, player_plan=True,intermediate=True).cpu().data.numpy() np.savez_compressed(open(out_file,'wb'), data=l) print(out_file,l.shape,model_type_name,use_dialogue,use_dialogue=='Yes') model_file = "models/tom_lstm_baseline/tom7_model.torch" use_dialogue = "Yes" model_type_name = 'LSTM' model_type = MODEL_TYPES[model_type_name] model = ToMModel(model_type).to(DEVICE) model.load_state_dict(torch.load(model_file)) dataset_splits = json.load(open('config/dataset_splits.json')) for set in dataset_splits.values(): for path in set: for pov in [1, 2]: out_file = f'{path}/intermediate_ToM7_{path.split("/")[-1]}_player{pov}.npz' # if os.path.isfile(out_file): # continue game = GameParser(path,use_dialogue=='Yes',4,0,True) l = model(game, global_plan=False, player_plan=True,intermediate=True).cpu().data.numpy() np.savez_compressed(open(out_file,'wb'), data=l) print(out_file,l.shape,model_type_name,use_dialogue,use_dialogue=='Yes') model_file = "models/tom_lstm_baseline/tom8_model.torch" use_dialogue = "Yes" model_type_name = 'LSTM' model_type = MODEL_TYPES[model_type_name] model = ToMModel(model_type).to(DEVICE) model.load_state_dict(torch.load(model_file)) dataset_splits = json.load(open('config/dataset_splits.json')) for set in dataset_splits.values(): for path in set: for pov in [1, 2]: out_file = f'{path}/intermediate_ToM8_{path.split("/")[-1]}_player{pov}.npz' # if os.path.isfile(out_file): # continue game = GameParser(path,use_dialogue=='Yes',4,True) l = model(game, global_plan=False, player_plan=True,intermediate=True).cpu().data.numpy() np.savez_compressed(open(out_file,'wb'), data=l) print(out_file,l.shape,model_type_name,use_dialogue,use_dialogue=='Yes') if __name__ == "__main__": main()