limits-of-tom/intermediate_representations.py

75 lines
3.2 KiB
Python
Raw Permalink Normal View History

2024-06-11 15:36:55 +02:00
from src.models.model_with_dialogue_moves import Model as ToMModel
from src.models.plan_model_graphs import Model as CPAModel
from src.data.game_parser import GameParser
from src.data.game_parser_graphs_new import GameParser as GameParserCPA
import torch
import json
import numpy as np
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_TYPES = {
'GRU' : 0,
'LSTM' : 1,
'Transformer' : 2,
}
def main():
model_file = "models/tom_lstm_baseline/tom6_model.torch"
use_dialogue = "Yes"
model_type_name = "LSTM"
model_type = MODEL_TYPES[model_type_name]
model = ToMModel(model_type).to(DEVICE)
model.load_state_dict(torch.load(model_file))
dataset_splits = json.load(open('config/dataset_splits.json'))
for set in dataset_splits.values():
for path in set:
for pov in [1, 2]:
out_file = f'{path}/intermediate_ToM6_{path.split("/")[-1]}_player{pov}.npz'
# if os.path.isfile(out_file):
# continue
game = GameParser(path,use_dialogue=='Yes',pov,0,True)
l = model(game, global_plan=False, player_plan=True,intermediate=True).cpu().data.numpy()
np.savez_compressed(open(out_file,'wb'), data=l)
print(out_file,l.shape,model_type_name,use_dialogue,use_dialogue=='Yes')
model_file = "models/tom_lstm_baseline/tom7_model.torch"
use_dialogue = "Yes"
model_type_name = 'LSTM'
model_type = MODEL_TYPES[model_type_name]
model = ToMModel(model_type).to(DEVICE)
model.load_state_dict(torch.load(model_file))
dataset_splits = json.load(open('config/dataset_splits.json'))
for set in dataset_splits.values():
for path in set:
for pov in [1, 2]:
out_file = f'{path}/intermediate_ToM7_{path.split("/")[-1]}_player{pov}.npz'
# if os.path.isfile(out_file):
# continue
game = GameParser(path,use_dialogue=='Yes',4,0,True)
l = model(game, global_plan=False, player_plan=True,intermediate=True).cpu().data.numpy()
np.savez_compressed(open(out_file,'wb'), data=l)
print(out_file,l.shape,model_type_name,use_dialogue,use_dialogue=='Yes')
model_file = "models/tom_lstm_baseline/tom8_model.torch"
use_dialogue = "Yes"
model_type_name = 'LSTM'
model_type = MODEL_TYPES[model_type_name]
model = ToMModel(model_type).to(DEVICE)
model.load_state_dict(torch.load(model_file))
dataset_splits = json.load(open('config/dataset_splits.json'))
for set in dataset_splits.values():
for path in set:
for pov in [1, 2]:
out_file = f'{path}/intermediate_ToM8_{path.split("/")[-1]}_player{pov}.npz'
# if os.path.isfile(out_file):
# continue
game = GameParser(path,use_dialogue=='Yes',4,True)
l = model(game, global_plan=False, player_plan=True,intermediate=True).cpu().data.numpy()
np.savez_compressed(open(out_file,'wb'), data=l)
print(out_file,l.shape,model_type_name,use_dialogue,use_dialogue=='Yes')
if __name__ == "__main__":
main()