75 lines
3.2 KiB
Python
75 lines
3.2 KiB
Python
|
from src.models.model_with_dialogue_moves import Model as ToMModel
|
||
|
from src.models.plan_model_graphs import Model as CPAModel
|
||
|
from src.data.game_parser import GameParser
|
||
|
from src.data.game_parser_graphs_new import GameParser as GameParserCPA
|
||
|
import torch
|
||
|
import json
|
||
|
import numpy as np
|
||
|
|
||
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||
|
MODEL_TYPES = {
|
||
|
'GRU' : 0,
|
||
|
'LSTM' : 1,
|
||
|
'Transformer' : 2,
|
||
|
}
|
||
|
|
||
|
def main():
|
||
|
|
||
|
model_file = "models/tom_lstm_baseline/tom6_model.torch"
|
||
|
use_dialogue = "Yes"
|
||
|
model_type_name = "LSTM"
|
||
|
model_type = MODEL_TYPES[model_type_name]
|
||
|
model = ToMModel(model_type).to(DEVICE)
|
||
|
model.load_state_dict(torch.load(model_file))
|
||
|
dataset_splits = json.load(open('config/dataset_splits.json'))
|
||
|
for set in dataset_splits.values():
|
||
|
for path in set:
|
||
|
for pov in [1, 2]:
|
||
|
out_file = f'{path}/intermediate_ToM6_{path.split("/")[-1]}_player{pov}.npz'
|
||
|
# if os.path.isfile(out_file):
|
||
|
# continue
|
||
|
game = GameParser(path,use_dialogue=='Yes',pov,0,True)
|
||
|
l = model(game, global_plan=False, player_plan=True,intermediate=True).cpu().data.numpy()
|
||
|
np.savez_compressed(open(out_file,'wb'), data=l)
|
||
|
print(out_file,l.shape,model_type_name,use_dialogue,use_dialogue=='Yes')
|
||
|
|
||
|
model_file = "models/tom_lstm_baseline/tom7_model.torch"
|
||
|
use_dialogue = "Yes"
|
||
|
model_type_name = 'LSTM'
|
||
|
model_type = MODEL_TYPES[model_type_name]
|
||
|
model = ToMModel(model_type).to(DEVICE)
|
||
|
model.load_state_dict(torch.load(model_file))
|
||
|
dataset_splits = json.load(open('config/dataset_splits.json'))
|
||
|
for set in dataset_splits.values():
|
||
|
for path in set:
|
||
|
for pov in [1, 2]:
|
||
|
out_file = f'{path}/intermediate_ToM7_{path.split("/")[-1]}_player{pov}.npz'
|
||
|
# if os.path.isfile(out_file):
|
||
|
# continue
|
||
|
game = GameParser(path,use_dialogue=='Yes',4,0,True)
|
||
|
l = model(game, global_plan=False, player_plan=True,intermediate=True).cpu().data.numpy()
|
||
|
np.savez_compressed(open(out_file,'wb'), data=l)
|
||
|
print(out_file,l.shape,model_type_name,use_dialogue,use_dialogue=='Yes')
|
||
|
|
||
|
model_file = "models/tom_lstm_baseline/tom8_model.torch"
|
||
|
use_dialogue = "Yes"
|
||
|
model_type_name = 'LSTM'
|
||
|
model_type = MODEL_TYPES[model_type_name]
|
||
|
model = ToMModel(model_type).to(DEVICE)
|
||
|
model.load_state_dict(torch.load(model_file))
|
||
|
dataset_splits = json.load(open('config/dataset_splits.json'))
|
||
|
for set in dataset_splits.values():
|
||
|
for path in set:
|
||
|
for pov in [1, 2]:
|
||
|
out_file = f'{path}/intermediate_ToM8_{path.split("/")[-1]}_player{pov}.npz'
|
||
|
# if os.path.isfile(out_file):
|
||
|
# continue
|
||
|
game = GameParser(path,use_dialogue=='Yes',4,True)
|
||
|
l = model(game, global_plan=False, player_plan=True,intermediate=True).cpu().data.numpy()
|
||
|
np.savez_compressed(open(out_file,'wb'), data=l)
|
||
|
print(out_file,l.shape,model_type_name,use_dialogue,use_dialogue=='Yes')
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
|
||
|
main()
|