This commit is contained in:
Matteo Bortoletto 2024-06-11 15:36:55 +02:00
parent 08780752d9
commit e15b0d7b50
46 changed files with 14927 additions and 0 deletions

BIN
src/.DS_Store vendored Normal file

Binary file not shown.

0
src/__init__.py Normal file
View file

BIN
src/data/.DS_Store vendored Normal file

Binary file not shown.

0
src/data/__init__.py Normal file
View file

761
src/data/game_parser.py Executable file
View file

@ -0,0 +1,761 @@
from email.mime import base
from glob import glob
import os, string, json, pickle
import torch, random, numpy as np
from transformers import BertTokenizer, BertModel
import cv2
import imageio
from src.data.action_extractor import proc_action
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# def set_seed(seed_idx):
# seed = 0
# random.seed(0)
# for _ in range(seed_idx):
# seed = random.random()
# random.seed(seed)
# torch.manual_seed(seed)
# print('Random seed set to', seed)
# return seed
def set_seed(seed):
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
print('Random seed set to', seed)
return seed
def make_splits(split_file = 'config/dataset_splits.json'):
if not os.path.isfile(split_file):
dirs = sorted(glob('data/saved_logs/*') + glob('data/main_logs/*'))
games = sorted(list(map(GameParser, dirs)), key=lambda x: len(x.question_pairs), reverse=True)
test = games[0::5]
val = games[1::5]
train = games[2::5]+games[3::5]+games[4::5]
dataset_splits = {'test' : [g.game_path for g in test], 'validation' : [g.game_path for g in val], 'training' : [g.game_path for g in train]}
json.dump(dataset_splits, open('config/dataset_splits_old.json','w'), indent=4)
dirs = sorted(glob('data/new_logs/*'))
games = sorted(list(map(GameParser, dirs)), key=lambda x: len(x.question_pairs), reverse=True)
test = games[0::5]
val = games[1::5]
train = games[2::5]+games[3::5]+games[4::5]
dataset_splits['test'] += [g.game_path for g in test]
dataset_splits['validation'] += [g.game_path for g in val]
dataset_splits['training'] += [g.game_path for g in train]
json.dump(dataset_splits, open('config/dataset_splits_new.json','w'), indent=4)
json.dump(dataset_splits, open('config/dataset_splits.json','w'), indent=4)
dataset_splits['test'] = dataset_splits['test'][:2]
dataset_splits['validation'] = dataset_splits['validation'][:2]
dataset_splits['training'] = dataset_splits['training'][:2]
json.dump(dataset_splits, open('config/dataset_splits_dev.json','w'), indent=4)
dataset_splits = json.load(open(split_file))
return dataset_splits
def onehot(x,n):
retval = np.zeros(n)
if x > 0:
retval[x-1] = 1
return retval
class GameParser:
tokenizer = None
model = None
def __init__(self, game_path, load_dialogue=True, pov=0, intermediate=0, use_dialogue_moves=False):
# print(game_path,end = ' ')
self.load_dialogue = load_dialogue
if pov not in (0,1,2,3,4):
print('Point of view must be in (0,1,2,3,4), but got ', pov)
exit()
self.pov = pov
self.use_dialogue_moves = use_dialogue_moves
self.load_player1 = pov==1
self.load_player2 = pov==2
self.load_third_person = pov==3
self.game_path = game_path
# print(game_path)
self.dialogue_file = glob(os.path.join(game_path,'mcc*log'))[0]
self.questions_file = glob(os.path.join(game_path,'web*log'))[0]
self.plan_file = glob(os.path.join(game_path,'plan*json'))[0]
self.plan = json.load(open(self.plan_file))
self.img_w = 96
self.img_h = 96
self.intermediate = intermediate
self.flip_video = False
for l in open(self.dialogue_file):
if 'HAS JOINED' in l:
player_name = l.strip().split()[1]
self.flip_video = player_name[-1] == '2'
break
if not os.path.isfile("config/materials.json") or \
not os.path.isfile("config/mines.json") or \
not os.path.isfile("config/tools.json"):
plan_files = sorted(glob('data/*_logs/*/plan*.json'))
materials = []
tools = []
mines = []
for plan_file in plan_files:
plan = json.load(open(plan_file))
materials += plan['materials']
tools += plan['tools']
mines += plan['mines']
materials = sorted(list(set(materials)))
tools = sorted(list(set(tools)))
mines = sorted(list(set(mines)))
json.dump(materials, open('config/materials.json','w'), indent=4)
json.dump(mines, open('config/mines.json','w'), indent=4)
json.dump(tools, open('config/tools.json','w'), indent=4)
materials = json.load(open('config/materials.json'))
mines = json.load(open('config/mines.json'))
tools = json.load(open('config/tools.json'))
self.materials_dict = {x:i+1 for i,x in enumerate(materials)}
self.mines_dict = {x:i+1 for i,x in enumerate(mines)}
self.tools_dict = {x:i+1 for i,x in enumerate(tools)}
self.__load_dialogue_act_labels()
self.__load_dialogue_move_labels()
self.__parse_dialogue()
self.__parse_questions()
self.__parse_start_end()
self.__parse_question_pairs()
self.__load_videos()
self.__assign_dialogue_act_labels()
self.__assign_dialogue_move_labels()
self.__load_replay_data()
self.__load_intermediate()
# print(len(self.materials_dict))
self.global_plan = []
self.global_plan_mat = np.zeros((21,21))
mine_counter = 0
for n,v in zip(self.plan['materials'],self.plan['full']):
if v['make']:
mine = 0
m1 = self.materials_dict[self.plan['materials'][v['make'][0][0]]]
m2 = self.materials_dict[self.plan['materials'][v['make'][0][1]]]
self.global_plan_mat[self.materials_dict[n]-1][m1-1] = 1
self.global_plan_mat[self.materials_dict[n]-1][m2-1] = 1
else:
mine = self.mines_dict[self.plan['mines'][mine_counter]]
mine_counter += 1
m1 = 0
m2 = 0
mine = onehot(mine, len(self.mines_dict))
m1 = onehot(m1,len(self.materials_dict))
m2 = onehot(m2,len(self.materials_dict))
mat = onehot(self.materials_dict[n],len(self.materials_dict))
t = onehot(self.tools_dict[self.plan['tools'][v['tools'][0]]],len(self.tools_dict))
step = np.concatenate((mat,m1,m2,mine,t))
self.global_plan.append(step)
self.player1_plan = []
self.player1_plan_mat = np.zeros((21,21))
mine_counter = 0
for n,v in zip(self.plan['materials'],self.plan['player1']):
if v['make']:
mine = 0
if v['make'][0][0] < 0:
m1 = 0
m2 = 0
else:
m1 = self.materials_dict[self.plan['materials'][v['make'][0][0]]]
m2 = self.materials_dict[self.plan['materials'][v['make'][0][1]]]
self.player1_plan_mat[self.materials_dict[n]-1][m1-1] = 1
self.player1_plan_mat[self.materials_dict[n]-1][m2-1] = 1
else:
mine = self.mines_dict[self.plan['mines'][mine_counter]]
mine_counter += 1
m1 = 0
m2 = 0
mine = onehot(mine, len(self.mines_dict))
m1 = onehot(m1,len(self.materials_dict))
m2 = onehot(m2,len(self.materials_dict))
mat = onehot(self.materials_dict[n],len(self.materials_dict))
t = onehot(self.tools_dict[self.plan['tools'][v['tools'][0]]],len(self.tools_dict))
step = np.concatenate((mat,m1,m2,mine,t))
self.player1_plan.append(step)
self.player2_plan = []
self.player2_plan_mat = np.zeros((21,21))
mine_counter = 0
for n,v in zip(self.plan['materials'],self.plan['player2']):
if v['make']:
mine = 0
if v['make'][0][0] < 0:
m1 = 0
m2 = 0
else:
m1 = self.materials_dict[self.plan['materials'][v['make'][0][0]]]
m2 = self.materials_dict[self.plan['materials'][v['make'][0][1]]]
self.player2_plan_mat[self.materials_dict[n]-1][m1-1] = 1
self.player2_plan_mat[self.materials_dict[n]-1][m2-1] = 1
else:
mine = self.mines_dict[self.plan['mines'][mine_counter]]
mine_counter += 1
m1 = 0
m2 = 0
mine = onehot(mine, len(self.mines_dict))
m1 = onehot(m1,len(self.materials_dict))
m2 = onehot(m2,len(self.materials_dict))
mat = onehot(self.materials_dict[n],len(self.materials_dict))
t = onehot(self.tools_dict[self.plan['tools'][v['tools'][0]]],len(self.tools_dict))
step = np.concatenate((mat,m1,m2,mine,t))
self.player2_plan.append(step)
# print(self.global_plan_mat.reshape(-1))
# print(self.player1_plan_mat.reshape(-1))
# print(self.player2_plan_mat.reshape(-1))
# for x in zip(self.global_plan_mat.reshape(-1),self.player1_plan_mat.reshape(-1),self.player2_plan_mat.reshape(-1)):
# if sum(x) > 0:
# print(x)
# exit()
if self.load_player1:
self.plan_repr = self.player1_plan_mat
self.partner_plan = self.player2_plan_mat
elif self.load_player2:
self.plan_repr = self.player2_plan_mat
self.partner_plan = self.player1_plan_mat
else:
self.plan_repr = self.global_plan_mat
self.partner_plan = self.global_plan_mat
self.global_diff_plan_mat = self.global_plan_mat - self.plan_repr
self.partner_diff_plan_mat = self.global_plan_mat - self.partner_plan
self.__iter_ts = self.start_ts
# self.action_labels = sorted([t for a in self.actions for t in a if t.PacketData in ['BlockChangeData']], key=lambda x: x.TickIndex)
self.action_labels = None
# for tick in ticks:
# print(int(tick.TickIndex/30), self.plan['materials'].index( tick.items[0]), int(tick.Name[-1]))
# print(self.start_ts, self.end_ts, self.start_ts - self.end_ts, int(ticks[-1].TickIndex/30) if ticks else 0,self.action_file)
# exit()
self.materials = sorted(self.plan['materials'])
def __len__(self):
return self.end_ts - self.start_ts
def __next__(self):
if self.__iter_ts < self.end_ts:
if self.load_dialogue:
d = [x for x in self.dialogue_events if x[0] == self.__iter_ts]
l = [x for x in self.dialogue_act_labels if x[0] == self.__iter_ts]
d = d if d else None
l = l if l else None
else:
d = None
l = None
if self.use_dialogue_moves:
m = [x for x in self.dialogue_move_labels if x[0] == self.__iter_ts]
m = m if m else None
else:
m = None
if self.action_labels:
a = [x for x in self.action_labels if (x.TickIndex//30 + self.start_ts) >= self.__iter_ts]
if a:
try:
while not a[0].items:
a = a[1:]
al = self.materials.index(a[0].items[0]) if a else 0
except Exception:
print(a)
print(a[0])
print(a[0].items)
print(a[0].items[0])
exit()
at = a[0].TickIndex//30 + self.start_ts
an = int(a[0].Name[-1])
a = [(at,al,an)]
else:
a = [(self.__iter_ts, self.materials.index(self.plan['materials'][0]), 1)]
a = None
else:
if self.end_ts - self.__iter_ts < 10:
# a = [(self.__iter_ts, self.materials.index(self.plan['materials'][0]), 1)]
a = None
else:
a = None
# if not self.__iter_ts % 30 == 0:
# a= None
if not a is None:
if not a[0][0] == self.__iter_ts:
a = None
# q = [x for x in self.question_pairs if (x[0][0] < self.__iter_ts) and (x[0][1] > self.__iter_ts)]
q = [x for x in self.question_pairs if (x[0][1] == self.__iter_ts)]
q = q[0] if q else None
frame_idx = self.__iter_ts - self.start_ts
if self.load_third_person:
frames = self.third_pers_frames
elif self.load_player1:
frames = self.player1_pov_frames
elif self.load_player2:
frames = self.player2_pov_frames
else:
frames = np.array([0])
if len(frames) == 1:
f = np.zeros((self.img_h,self.img_w,3))
else:
if frame_idx < frames.shape[0]:
f = frames[frame_idx]
else:
f = np.zeros((self.img_h,self.img_w,3))
if self.do_upperbound:
if not q is None:
qnum = 0
base_rep = np.concatenate([
onehot(q[0][2],2),
onehot(q[0][3],2),
onehot(q[0][4][qnum][0]+1,2),
onehot(self.materials_dict[q[0][5][qnum][1]],len(self.materials_dict)),
onehot(q[0][4][qnum][0]+1,2),
onehot(self.materials_dict[q[0][5][qnum][1]],len(self.materials_dict)),
onehot(['YES','MAYBE','NO'].index(q[1][0][qnum])+1,3),
onehot(['YES','MAYBE','NO'].index(q[1][1][qnum])+1,3)
])
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
ToM6 = base_rep if self.ToM6 is not None else np.zeros(1024)
qnum = 1
base_rep = np.concatenate([
onehot(q[0][2],2),
onehot(q[0][3],2),
onehot(q[0][4][qnum][0]+1,2),
onehot(self.materials_dict[q[0][5][qnum][1]],len(self.materials_dict)),
onehot(q[0][4][qnum][0]+1,2),
onehot(self.materials_dict[q[0][5][qnum][1]],len(self.materials_dict)),
onehot(['YES','MAYBE','NO'].index(q[1][0][qnum])+1,3),
onehot(['YES','MAYBE','NO'].index(q[1][1][qnum])+1,3)
])
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
ToM7 = base_rep if self.ToM7 is not None else np.zeros(1024)
qnum = 2
base_rep = np.concatenate([
onehot(q[0][2],2),
onehot(q[0][3],2),
onehot(q[0][4][qnum]+1,2),
onehot(q[0][4][qnum]+1,2),
onehot(self.materials_dict[q[1][0][qnum]] if q[1][0][qnum] in self.materials_dict else len(self.materials_dict)+1,len(self.materials_dict)+1),
onehot(self.materials_dict[q[1][1][qnum]] if q[1][1][qnum] in self.materials_dict else len(self.materials_dict)+1,len(self.materials_dict)+1)
])
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
ToM8 = base_rep if self.ToM8 is not None else np.zeros(1024)
else:
ToM6 = np.zeros(1024)
ToM7 = np.zeros(1024)
ToM8 = np.zeros(1024)
if not l is None:
base_rep = np.concatenate([
onehot(l[0][1],2),
onehot(l[0][2],len(self.dialogue_act_labels_dict))
])
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
DAct = base_rep if self.DAct is not None else np.zeros(1024)
else:
DAct = np.zeros(1024)
if not m is None:
base_rep = np.concatenate([
onehot(m[0][1],2),
onehot(m[0][2][0],len(self.dialogue_move_labels_dict)),
onehot(m[0][2][1],len(self.tools_dict) + len(self.materials_dict) + len(self.mines_dict)+1),
onehot(m[0][2][2],len(self.tools_dict) + len(self.materials_dict) + len(self.mines_dict)+1),
onehot(m[0][2][3],len(self.tools_dict) + len(self.materials_dict) + len(self.mines_dict)+1),
])
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
DMove = base_rep if self.DMove is not None else np.zeros(1024)
else:
DMove = np.zeros(1024)
else:
ToM6 = self.ToM6[frame_idx] if self.ToM6 is not None else np.zeros(1024)
ToM7 = self.ToM7[frame_idx] if self.ToM7 is not None else np.zeros(1024)
ToM8 = self.ToM8[frame_idx] if self.ToM8 is not None else np.zeros(1024)
DAct = self.DAct[frame_idx] if self.DAct is not None else np.zeros(1024)
DMove = self.DAct[frame_idx] if self.DMove is not None else np.zeros(1024)
# if not m is None:
# base_rep = np.concatenate([
# onehot(m[0][1],2),
# onehot(m[0][2][0],len(self.dialogue_move_labels_dict)),
# onehot(m[0][2][1],len(self.tools_dict) + len(self.materials_dict) + len(self.mines_dict)+1),
# onehot(m[0][2][2],len(self.tools_dict) + len(self.materials_dict) + len(self.mines_dict)+1),
# onehot(m[0][2][3],len(self.tools_dict) + len(self.materials_dict) + len(self.mines_dict)+1),
# ])
# base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
# DMove = base_rep if self.DMove is not None else np.zeros(1024)
# else:
# DMove = np.zeros(1024)
intermediate = np.concatenate([ToM6,ToM7,ToM8,DAct,DMove])
retval = ((self.__iter_ts,self.pov),d,l,q,f,a,intermediate,m)
self.__iter_ts += 1
return retval
self.__iter_ts = self.start_ts
raise StopIteration()
def __iter__(self):
return self
def __load_videos(self):
d = self.end_ts - self.start_ts
if self.load_third_person:
try:
self.third_pers_file = glob(os.path.join(self.game_path,'third*gif'))[0]
np_file = self.third_pers_file[:-3]+'npz'
if os.path.isfile(np_file):
self.third_pers_frames = np.load(np_file)['data']
else:
frames = imageio.get_reader(self.third_pers_file, '.gif')
reshaper = lambda x: cv2.resize(x,(self.img_h,self.img_w))
if 'main' in self.game_path:
self.third_pers_frames = np.array([reshaper(f[95:4*95,250:-249,2::-1]) for f in frames])
else:
self.third_pers_frames = np.array([reshaper(f[-3*95:,250:-249,2::-1]) for f in frames])
print(np_file,end=' ')
np.savez_compressed(open(np_file,'wb'), data=self.third_pers_frames)
print('saved')
except Exception as e:
self.third_pers_frames = np.array([0])
if self.third_pers_frames.shape[0]//d < 10:
self.third_pov_frame_rate = 6
else:
if self.third_pers_frames.shape[0]//d < 20:
self.third_pov_frame_rate = 12
else:
if self.third_pers_frames.shape[0]//d < 45:
self.third_pov_frame_rate = 30
else:
self.third_pov_frame_rate = 60
self.third_pers_frames = self.third_pers_frames[::self.third_pov_frame_rate]
else:
self.third_pers_frames = np.array([0])
if self.load_player1:
try:
search_str = 'play2*gif' if self.flip_video else 'play1*gif'
self.player1_pov_file = glob(os.path.join(self.game_path,search_str))[0]
np_file = self.player1_pov_file[:-3]+'npz'
if os.path.isfile(np_file):
self.player1_pov_frames = np.load(np_file)['data']
else:
frames = imageio.get_reader(self.player1_pov_file, '.gif')
reshaper = lambda x: cv2.resize(x,(self.img_h,self.img_w))
self.player1_pov_frames = np.array([reshaper(f[:,:,2::-1]) for f in frames])
print(np_file,end=' ')
np.savez_compressed(open(np_file,'wb'), data=self.player1_pov_frames)
print('saved')
except Exception as e:
self.player1_pov_frames = np.array([0])
if self.player1_pov_frames.shape[0]//d < 10:
self.player1_pov_frame_rate = 6
else:
if self.player1_pov_frames.shape[0]//d < 20:
self.player1_pov_frame_rate = 12
else:
if self.player1_pov_frames.shape[0]//d < 45:
self.player1_pov_frame_rate = 30
else:
self.player1_pov_frame_rate = 60
self.player1_pov_frames = self.player1_pov_frames[::self.player1_pov_frame_rate]
else:
self.player1_pov_frames = np.array([0])
if self.load_player2:
try:
search_str = 'play1*gif' if self.flip_video else 'play2*gif'
self.player2_pov_file = glob(os.path.join(self.game_path,search_str))[0]
np_file = self.player2_pov_file[:-3]+'npz'
if os.path.isfile(np_file):
self.player2_pov_frames = np.load(np_file)['data']
else:
frames = imageio.get_reader(self.player2_pov_file, '.gif')
reshaper = lambda x: cv2.resize(x,(self.img_h,self.img_w))
self.player2_pov_frames = np.array([reshaper(f[:,:,2::-1]) for f in frames])
print(np_file,end=' ')
np.savez_compressed(open(np_file,'wb'), data=self.player2_pov_frames)
print('saved')
except Exception as e:
self.player2_pov_frames = np.array([0])
if self.player2_pov_frames.shape[0]//d < 10:
self.player2_pov_frame_rate = 6
else:
if self.player2_pov_frames.shape[0]//d < 20:
self.player2_pov_frame_rate = 12
else:
if self.player2_pov_frames.shape[0]//d < 45:
self.player2_pov_frame_rate = 30
else:
self.player2_pov_frame_rate = 60
self.player2_pov_frames = self.player2_pov_frames[::self.player2_pov_frame_rate]
else:
self.player2_pov_frames = np.array([0])
def __parse_question_pairs(self):
question_dict = {}
for q in self.questions:
k = q[2][0][1] + q[2][1][1]
if not k in question_dict:
question_dict[k] = []
question_dict[k].append(q)
self.question_pairs = []
for k,v in question_dict.items():
if len(v) == 2:
if v[0][1]+v[1][1] == 3:
self.question_pairs.append(v)
else:
while len(v) > 1:
pair = []
pair.append(v.pop(0))
pair.append(v.pop(0))
while not pair[0][1]+pair[1][1] == 3:
if not v:
break
# print(game_path,pair)
pair.append(v.pop(0))
pair.pop(0)
if not v:
break
self.question_pairs.append(pair)
self.question_pairs = sorted(self.question_pairs, key=lambda x: x[0][0])
if self.load_player2 or self.pov==4:
self.question_pairs = [sorted(q, key=lambda x: x[1],reverse=True) for q in self.question_pairs]
else:
self.question_pairs = [sorted(q, key=lambda x: x[1]) for q in self.question_pairs]
self.question_pairs = [((a[0], b[0], a[1], b[1], a[2], b[2]), (a[3], b[3])) for a,b in self.question_pairs]
def __parse_dialogue(self):
self.dialogue_events = []
# if not self.load_dialogue:
# return
save_path = os.path.join(self.game_path,f'dialogue_{self.game_path.split("/")[-1]}.pkl')
# print(save_path)
# exit()
if os.path.isfile(save_path):
self.dialogue_events = pickle.load(open( save_path, "rb" ))
return
for x in open(self.dialogue_file):
if '[Async Chat Thread' in x:
ts = list(map(int,x.split(' [')[0].strip('[]').split(':')))
ts = 3600*ts[0] + 60*ts[1] + ts[2]
player, event = x.strip().split('/INFO]: []<sledmcc')[1].split('> ',1)
event = event.lower()
event = ''.join([x if x in string.ascii_lowercase else f' {x} ' for x in event]).strip()
event = event.replace(' ',' ').replace(' ',' ')
player = int(player)
if GameParser.tokenizer is None:
GameParser.tokenizer = BertTokenizer.from_pretrained('bert-large-uncased', do_lower_case=True)
if self.model is None:
GameParser.model = BertModel.from_pretrained('bert-large-uncased', output_hidden_states=True).to(DEVICE)
encoded_dict = GameParser.tokenizer.encode_plus(
event, # Sentence to encode.
add_special_tokens=True, # Add '[CLS]' and '[SEP]'
return_tensors='pt', # Return pytorch tensors.
)
token_ids = encoded_dict['input_ids'].to(DEVICE)
segment_ids = torch.ones(token_ids.size()).long().to(DEVICE)
GameParser.model.eval()
with torch.no_grad():
outputs = GameParser.model(input_ids=token_ids, token_type_ids=segment_ids)
outputs = outputs[1][0].cpu().data.numpy()
self.dialogue_events.append((ts,player,event,outputs))
pickle.dump(self.dialogue_events, open( save_path, "wb" ))
print(f'Saved to {save_path}',flush=True)
def __parse_questions(self):
self.questions = []
for x in open(self.questions_file):
if x[0] == '#':
ts, qs = x.strip().split(' Number of records inserted: 1 # player')
# print(ts,qs)
ts = list(map(int,ts.split(' ')[5].split(':')))
ts = 3600*ts[0] + 60*ts[1] + ts[2]
player = int(qs[0])
questions = qs[2:].split(';')
answers =[x[7:] for x in questions[3:]]
questions = [x[9:].split(' ') for x in questions[:3]]
questions[0] = (int(questions[0][0] == 'Have'), questions[0][-3])
questions[1] = (int(questions[1][2] == 'know'), questions[1][-1])
questions[2] = int(questions[2][1] == 'are')
self.questions.append((ts,player,questions,answers))
def __parse_start_end(self):
self.start_ts = [x.strip() for x in open(self.dialogue_file) if 'THEY ARE PLAYER' in x][1]
self.start_ts = list(map(int,self.start_ts.split('] [')[0][1:].split(':')))
self.start_ts = 3600*self.start_ts[0] + 60*self.start_ts[1] + self.start_ts[2]
try:
self.start_ts = max(self.start_ts, self.questions[0][0]-75)
except Exception as e:
pass
self.end_ts = [x.strip() for x in open(self.dialogue_file) if 'Stopping' in x]
if self.end_ts:
self.end_ts = self.end_ts[0]
self.end_ts = list(map(int,self.end_ts.split('] [')[0][1:].split(':')))
self.end_ts = 3600*self.end_ts[0] + 60*self.end_ts[1] + self.end_ts[2]
else:
self.end_ts = self.dialogue_events[-1][0]
try:
self.end_ts = max(self.end_ts, self.questions[-1][0]) + 1
except Exception as e:
pass
def __load_dialogue_act_labels(self):
file_name = 'config/dialogue_act_labels.json'
if not os.path.isfile(file_name):
files = sorted(glob('/home/*/MCC/*done.txt'))
dialogue_act_dict = {}
for file in files:
game_str = ''
for line in open(file):
line = line.strip()
if '_logs/' in line:
game_str = line
else:
if line:
line = line.split()
key = f'{game_str}#{line[0]}'
dialogue_act_dict[key] = line[-1]
json.dump(dialogue_act_dict,open(file_name,'w'), indent=4)
self.dialogue_act_dict = json.load(open(file_name))
self.dialogue_act_labels_dict = {l : i for i, l in enumerate(sorted(list(set(self.dialogue_act_dict.values()))))}
self.dialogue_act_bias = {l : sum([int(x==l) for x in self.dialogue_act_dict.values()]) for l in self.dialogue_act_labels_dict.keys()}
json.dump(self.dialogue_act_labels_dict,open('config/dialogue_act_label_names.json','w'), indent=4)
# print(self.dialogue_act_bias)
# print(self.dialogue_act_labels_dict)
# exit()
def __assign_dialogue_act_labels(self):
# log_file = glob('/'.join([self.game_path,'mcc*log']))[0][5:]
log_file = glob('/'.join([self.game_path,'mcc*log']))[0].split('mindcraft/')[1]
self.dialogue_act_labels = []
for emb in self.dialogue_events:
ts = emb[0]
h = ts//3600
m = (ts%3600)//60
s = ts%60
key = f'{log_file}#[{h:02d}:{m:02d}:{s:02d}]:{emb[1]}>'
self.dialogue_act_labels.append((emb[0],emb[1],self.dialogue_act_labels_dict[self.dialogue_act_dict[key]]))
def __load_dialogue_move_labels(self):
file_name = "config/dialogue_move_labels.json"
dialogue_move_dict = {}
if not os.path.isfile(file_name):
file_text = ''
dialogue_moves = set()
for line in open("XXX"):
line = line.strip()
if not line:
continue
if line[0] == '#':
continue
if line[0] == '[':
tag_text = glob(f'data/*/*/mcc_{file_text}.log')[0].split('/',1)[-1]
key = f'{tag_text}#{line.split()[0]}'
value = line.split()[-1].split('#')
if len(value) < 4:
value += ['IGNORE']*(4-len(value))
dialogue_moves.add(value[0])
value = '#'.join(value)
dialogue_move_dict[key] = value
# print(key,value)
# break
else:
file_text = line
# print(line)
dialogue_moves = sorted(list(dialogue_moves))
# print(dialogue_moves)
json.dump(dialogue_move_dict,open(file_name,'w'), indent=4)
self.dialogue_move_dict = json.load(open(file_name))
self.dialogue_move_labels_dict = {l : i for i, l in enumerate(sorted(list(set([lbl.split('#')[0] for lbl in self.dialogue_move_dict.values()]))))}
self.dialogue_move_bias = {l : sum([int(x==l) for x in self.dialogue_move_dict.values()]) for l in self.dialogue_move_labels_dict.keys()}
json.dump(self.dialogue_move_labels_dict,open('config/dialogue_move_label_names.json','w'), indent=4)
def __assign_dialogue_move_labels(self):
# log_file = glob('/'.join([self.game_path,'mcc*log']))[0][5:]
log_file = glob('/'.join([self.game_path,'mcc*log']))[0].split('mindcraft/')[1]
self.dialogue_move_labels = []
for emb in self.dialogue_events:
ts = emb[0]
h = ts//3600
m = (ts%3600)//60
s = ts%60
key = f'{log_file}#[{h:02d}:{m:02d}:{s:02d}]:{emb[1]}>'
move = self.dialogue_move_dict[key].split('#')
move[0] = self.dialogue_move_labels_dict[move[0]]
for i,m in enumerate(move[1:]):
if m == 'IGNORE':
move[i+1] = 0
elif m in self.materials_dict:
move[i+1] = self.materials_dict[m]
elif m in self.mines_dict:
move[i+1] = self.mines_dict[m] + len(self.materials_dict)
elif m in self.tools_dict:
move[i+1] = self.tools_dict[m] + len(self.materials_dict) + len(self.mines_dict)
else:
print(move)
exit()
# print(move,self.dialogue_move_dict[key],key)
# exit()
self.dialogue_move_labels.append((emb[0],emb[1],move))
def __load_replay_data(self):
# self.action_file = "data/ReplayData/ActionsData_mcc_" + self.game_path.split('/')[-1]
# with open(self.action_file) as f:
# data = ' '.join(x.strip() for x in f).split('action')
# # preface = data[0]
# self.actions = list(map(proc_action, data[1:]))
self.actions = None
def __load_intermediate(self):
if self.intermediate > 15:
self.do_upperbound = True
else:
self.do_upperbound = False
if self.pov in [1,2]:
self.ToM6 = np.load(glob(f'{self.game_path}/intermediate_baseline_ToM6*player{self.pov}.npz')[0])['data'] if self.intermediate % 2 else None
self.intermediate = self.intermediate // 2
self.ToM7 = np.load(glob(f'{self.game_path}/intermediate_baseline_ToM7*player{self.pov}.npz')[0])['data'] if self.intermediate % 2 else None
self.intermediate = self.intermediate // 2
self.ToM8 = np.load(glob(f'{self.game_path}/intermediate_baseline_ToM8*player{self.pov}.npz')[0])['data'] if self.intermediate % 2 else None
self.intermediate = self.intermediate // 2
self.DAct = np.load(glob(f'{self.game_path}/intermediate_baseline_DAct*player{self.pov}.npz')[0])['data'] if self.intermediate % 2 else None
self.intermediate = self.intermediate // 2
self.DMove = None
# print(self.ToM6)
# print(self.ToM7)
# print(self.ToM8)
# print(self.DAct)
else:
self.ToM6 = None
self.ToM7 = None
self.ToM8 = None
self.DAct = None
self.DMove = None
# exit()

View file

@ -0,0 +1,837 @@
from glob import glob
import os, string, json, pickle
import torch, random, numpy as np
from transformers import BertTokenizer, BertModel
import cv2
import imageio
import networkx as nx
from torch_geometric.utils.convert import from_networkx
import matplotlib.pyplot as plt
from torch_geometric.utils import degree
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# def set_seed(seed_idx):
# seed = 0
# random.seed(0)
# for _ in range(seed_idx):
# seed = random.random()
# random.seed(seed)
# torch.manual_seed(seed)
# print('Random seed set to', seed)
# return seed
def set_seed(seed):
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
print('Random seed set to', seed)
return seed
def make_splits(split_file = 'config/dataset_splits.json'):
if not os.path.isfile(split_file):
dirs = sorted(glob('data/saved_logs/*') + glob('data/main_logs/*'))
games = sorted(list(map(GameParser, dirs)), key=lambda x: len(x.question_pairs), reverse=True)
test = games[0::5]
val = games[1::5]
train = games[2::5]+games[3::5]+games[4::5]
dataset_splits = {'test' : [g.game_path for g in test], 'validation' : [g.game_path for g in val], 'training' : [g.game_path for g in train]}
json.dump(dataset_splits, open('config/dataset_splits_old.json','w'), indent=4)
dirs = sorted(glob('data/new_logs/*'))
games = sorted(list(map(GameParser, dirs)), key=lambda x: len(x.question_pairs), reverse=True)
test = games[0::5]
val = games[1::5]
train = games[2::5]+games[3::5]+games[4::5]
dataset_splits['test'] += [g.game_path for g in test]
dataset_splits['validation'] += [g.game_path for g in val]
dataset_splits['training'] += [g.game_path for g in train]
json.dump(dataset_splits, open('config/dataset_splits_new.json','w'), indent=4)
json.dump(dataset_splits, open('config/dataset_splits.json','w'), indent=4)
dataset_splits['test'] = dataset_splits['test'][:2]
dataset_splits['validation'] = dataset_splits['validation'][:2]
dataset_splits['training'] = dataset_splits['training'][:2]
json.dump(dataset_splits, open('config/dataset_splits_dev.json','w'), indent=4)
dataset_splits = json.load(open(split_file))
return dataset_splits
def onehot(x,n):
retval = np.zeros(n)
if x > 0:
retval[x-1] = 1
return retval
class GameParser:
tokenizer = None
model = None
def __init__(self, game_path, load_dialogue=True, pov=0, intermediate=0, use_dialogue_moves=False, load_int0_feats=False):
self.load_dialogue = load_dialogue
if pov not in (0,1,2,3,4):
print('Point of view must be in (0,1,2,3,4), but got ', pov)
exit()
self.pov = pov
self.use_dialogue_moves = use_dialogue_moves
self.load_player1 = pov==1
self.load_player2 = pov==2
self.load_third_person = pov==3
self.game_path = game_path
self.dialogue_file = glob(os.path.join(game_path,'mcc*log'))[0]
self.questions_file = glob(os.path.join(game_path,'web*log'))[0]
self.plan_file = glob(os.path.join(game_path,'plan*json'))[0]
self.plan = json.load(open(self.plan_file))
self.img_w = 96
self.img_h = 96
self.intermediate = intermediate
self.flip_video = False
for l in open(self.dialogue_file):
if 'HAS JOINED' in l:
player_name = l.strip().split()[1]
self.flip_video = player_name[-1] == '2'
break
if not os.path.isfile("config/materials.json") or \
not os.path.isfile("config/mines.json") or \
not os.path.isfile("config/tools.json"):
plan_files = sorted(glob('data/*_logs/*/plan*.json'))
materials = []
tools = []
mines = []
for plan_file in plan_files:
plan = json.load(open(plan_file))
materials += plan['materials']
tools += plan['tools']
mines += plan['mines']
materials = sorted(list(set(materials)))
tools = sorted(list(set(tools)))
mines = sorted(list(set(mines)))
json.dump(materials, open('config/materials.json','w'), indent=4)
json.dump(mines, open('config/mines.json','w'), indent=4)
json.dump(tools, open('config/tools.json','w'), indent=4)
materials = json.load(open('config/materials.json'))
mines = json.load(open('config/mines.json'))
tools = json.load(open('config/tools.json'))
self.materials_dict = {x:i+1 for i,x in enumerate(materials)}
self.mines_dict = {x:i+1 for i,x in enumerate(mines)}
self.tools_dict = {x:i+1 for i,x in enumerate(tools)}
# NOTE new
shift_value = max(self.materials_dict.values())
self.materials_mines_dict = {**self.materials_dict, **{key: value + shift_value for key, value in self.mines_dict.items()}}
self.inverse_materials_mines_dict = {v: k for k, v in self.materials_mines_dict.items()}
# 
self.__load_dialogue_act_labels()
self.__load_dialogue_move_labels()
self.__parse_dialogue()
self.__parse_questions()
self.__parse_start_end()
self.__parse_question_pairs()
self.__load_videos()
self.__assign_dialogue_act_labels()
self.__assign_dialogue_move_labels()
self.__load_replay_data()
self.__load_intermediate()
self.load_int0 = load_int0_feats
if load_int0_feats:
self.__load_int0_feats()
#############################################
################## GRAPHS ###################
#############################################
################ Global Plan ################
self.global_plan = nx.DiGraph()
mine_counter = 0
for n, v in zip(self.plan['materials'], self.plan['full']):
mat = onehot(self.materials_mines_dict[n],len(self.materials_mines_dict))
t = onehot(self.tools_dict[self.plan['tools'][v['tools'][0]]], len(self.tools_dict))
self.global_plan = self._add_node(self.global_plan, n, features=mat)
if v['make']:
#print(n, v, self.plan['materials'][v['make'][0][0]], self.plan['materials'][v['make'][0][1]])
mine = 0
m1 = self.materials_dict[self.plan['materials'][v['make'][0][0]]]
m2 = self.materials_dict[self.plan['materials'][v['make'][0][1]]]
m1 = onehot(m1, len(self.materials_mines_dict))
m2 = onehot(m2, len(self.materials_mines_dict))
self.global_plan = self._add_node(self.global_plan, self.plan['materials'][v['make'][0][0]], features=m1)
self.global_plan = self._add_node(self.global_plan, self.plan['materials'][v['make'][0][1]], features=m2)
# m1 -> mat
self.global_plan = self._add_edge(self.global_plan, self.plan['materials'][v['make'][0][0]], n, tool=t)
# m2 -> mat
self.global_plan = self._add_edge(self.global_plan, self.plan['materials'][v['make'][0][1]], n, tool=t)
else:
#print(n, v, self.plan['mines'][mine_counter])
mine = self.materials_mines_dict[self.plan['mines'][mine_counter]]
mine_counter += 1
mine = onehot(mine, len(self.materials_mines_dict))
self.global_plan = self._add_node(self.global_plan, self.plan['mines'][mine_counter], features=mine)
self.global_plan = self._add_edge(self.global_plan, self.plan['mines'][mine_counter], n, tool=t)
#self._plot_plan_graph(self.global_plan, filename=f"plots/global_{game_path.split('/')[-2]}_{game_path.split('/')[-1]}.png")
self.global_plan = from_networkx(self.global_plan) # NOTE: I modified /torch_geometric/utils/convert.py, line 250
#############################################
#############################################
############### Player 1 Plan ###############
self.player1_plan = nx.DiGraph()
mine_counter = 0
for n,v in zip(self.plan['materials'], self.plan['player1']):
if v['make']:
mine = 0
if v['make'][0][0] < 0:
#print(n, v, "unknown", "unknown")
m1 = 0
m2 = 0
mat = onehot(self.materials_mines_dict[n],len(self.materials_mines_dict))
self.player1_plan = self._add_node(self.player1_plan, n, features=mat)
else:
#print(n, v, self.plan['materials'][v['make'][0][0]], self.plan['materials'][v['make'][0][1]])
mat = onehot(self.materials_mines_dict[n],len(self.materials_mines_dict))
t = onehot(self.tools_dict[self.plan['tools'][v['tools'][0]]], len(self.tools_dict))
self.player1_plan = self._add_node(self.player1_plan, n, features=mat)
m1 = self.materials_dict[self.plan['materials'][v['make'][0][0]]]
m2 = self.materials_dict[self.plan['materials'][v['make'][0][1]]]
m1 = onehot(m1, len(self.materials_mines_dict))
m2 = onehot(m2, len(self.materials_mines_dict))
self.player1_plan = self._add_node(self.player1_plan, self.plan['materials'][v['make'][0][0]], features=m1)
self.player1_plan = self._add_node(self.player1_plan, self.plan['materials'][v['make'][0][1]], features=m2)
# m1 -> mat
self.player1_plan = self._add_edge(self.player1_plan, self.plan['materials'][v['make'][0][0]], n, tool=t)
# m2 -> mat
self.player1_plan = self._add_edge(self.player1_plan, self.plan['materials'][v['make'][0][1]], n, tool=t)
else:
#print(n, v, self.plan['mines'][mine_counter])
mat = onehot(self.materials_mines_dict[n],len(self.materials_mines_dict))
t = onehot(self.tools_dict[self.plan['tools'][v['tools'][0]]], len(self.tools_dict))
self.player1_plan = self._add_node(self.player1_plan, n, features=mat)
mine = self.materials_mines_dict[self.plan['mines'][mine_counter]]
mine_counter += 1
mine = onehot(mine, len(self.materials_mines_dict))
self.player1_plan = self._add_node(self.player1_plan, self.plan['mines'][mine_counter], features=mine)
self.player1_plan = self._add_edge(self.player1_plan, self.plan['mines'][mine_counter], n, tool=t)
#self._plot_plan_graph(self.player1_plan, filename=f"plots/player1_{game_path.split('/')[-2]}_{game_path.split('/')[-1]}.png")
self.player1_plan = from_networkx(self.player1_plan)
#############################################
#############################################
############### Player 2 Plan ###############
self.player2_plan = nx.DiGraph()
mine_counter = 0
for n,v in zip(self.plan['materials'], self.plan['player2']):
if v['make']:
mine = 0
if v['make'][0][0] < 0:
#print(n, v, "unknown", "unknown")
m1 = 0
m2 = 0
mat = onehot(self.materials_mines_dict[n],len(self.materials_mines_dict))
self.player2_plan = self._add_node(self.player2_plan, n, features=mat)
else:
#print(n, v, self.plan['materials'][v['make'][0][0]], self.plan['materials'][v['make'][0][1]])
mat = onehot(self.materials_mines_dict[n],len(self.materials_mines_dict))
t = onehot(self.tools_dict[self.plan['tools'][v['tools'][0]]], len(self.tools_dict))
self.player2_plan = self._add_node(self.player2_plan, n, features=mat)
m1 = self.materials_dict[self.plan['materials'][v['make'][0][0]]]
m2 = self.materials_dict[self.plan['materials'][v['make'][0][1]]]
m1 = onehot(m1, len(self.materials_mines_dict))
m2 = onehot(m2, len(self.materials_mines_dict))
self.player2_plan = self._add_node(self.player2_plan, self.plan['materials'][v['make'][0][0]], features=m1)
self.player2_plan = self._add_node(self.player2_plan, self.plan['materials'][v['make'][0][1]], features=m2)
# m1 -> mat
self.player2_plan = self._add_edge(self.player2_plan, self.plan['materials'][v['make'][0][0]], n, tool=t)
# m2 -> mat
self.player2_plan = self._add_edge(self.player2_plan, self.plan['materials'][v['make'][0][1]], n, tool=t)
else:
#print(n, v, self.plan['mines'][mine_counter])
mat = onehot(self.materials_mines_dict[n],len(self.materials_mines_dict))
t = onehot(self.tools_dict[self.plan['tools'][v['tools'][0]]], len(self.tools_dict))
self.player2_plan = self._add_node(self.player2_plan, n, features=mat)
mine = self.materials_mines_dict[self.plan['mines'][mine_counter]]
mine_counter += 1
mine = onehot(mine, len(self.materials_mines_dict))
self.player2_plan = self._add_node(self.player2_plan, self.plan['mines'][mine_counter], features=mine)
self.player2_plan = self._add_edge(self.player2_plan, self.plan['mines'][mine_counter], n, tool=t)
#self._plot_plan_graph(self.player2_plan, filename=f"plots/player2_{game_path.split('/')[-2]}_{game_path.split('/')[-1]}.png")
self.player2_plan = from_networkx(self.player2_plan)
# with open('graphs.pkl', 'wb') as f:
# pickle.dump([self.global_plan, self.player1_plan, self.player2_plan], f)
# construct a dict mapping materials to node indexes for each graph
p1_dict = {self.inverse_materials_mines_dict[torch.argmax(features).item()+1]: node_index for node_index, features in enumerate(self.player1_plan.features)}
p2_dict = {self.inverse_materials_mines_dict[torch.argmax(features).item()+1]: node_index for node_index, features in enumerate(self.player2_plan.features)}
# candidate edge = (u,v)
# u is from nodes with no out degree, v is from nodes with no in degree
p1_u_candidates = [p1_dict[i] for i in self.find_nodes_with_less_than_four_out_degree(self.player1_plan)]
p1_v_candidates = [p1_dict[i] for i in self.find_nodes_with_no_in_degree(self.player1_plan)]
p2_u_candidates = [p2_dict[i] for i in self.find_nodes_with_less_than_four_out_degree(self.player2_plan)]
p2_v_candidates = [p2_dict[i] for i in self.find_nodes_with_no_in_degree(self.player2_plan)]
# convert candidates to indexes
p1_edge_candidates = torch.tensor([(start, end) for start in p1_u_candidates for end in p1_v_candidates])
p2_edge_candidates = torch.tensor([(start, end) for start in p2_u_candidates for end in p2_v_candidates])
# find missing edges
gl_edges = [[self.inverse_materials_mines_dict[torch.argmax(self.global_plan.features[edge[0]]).item()+1], self.inverse_materials_mines_dict[torch.argmax(self.global_plan.features[edge[1]]).item()+1]] for edge in self.global_plan.edge_index.t().tolist()]
p1_edges = [[self.inverse_materials_mines_dict[torch.argmax(self.player1_plan.features[edge[0]]).item()+1], self.inverse_materials_mines_dict[torch.argmax(self.player1_plan.features[edge[1]]).item()+1]] for edge in self.player1_plan.edge_index.t().tolist()]
p2_edges = [[self.inverse_materials_mines_dict[torch.argmax(self.player2_plan.features[edge[0]]).item()+1], self.inverse_materials_mines_dict[torch.argmax(self.player2_plan.features[edge[1]]).item()+1]] for edge in self.player2_plan.edge_index.t().tolist()]
p1_missing_edges = [list(sublist) for sublist in set(map(tuple, gl_edges)) - set(map(tuple, p1_edges))]
p2_missing_edges = [list(sublist) for sublist in set(map(tuple, gl_edges)) - set(map(tuple, p2_edges))]
# convert missing edges as indexes
p1_missing_edges_idx = torch.tensor([(p1_dict[e[0]], p1_dict[e[1]]) for e in p1_missing_edges])
p2_missing_edges_idx = torch.tensor([(p2_dict[e[0]], p2_dict[e[1]]) for e in p2_missing_edges])
# check if all missing edges are present in the candidates
assert all(any(torch.equal(element, row) for row in p1_edge_candidates) for element in p1_missing_edges_idx)
assert all(any(torch.equal(element, row) for row in p2_edge_candidates) for element in p2_missing_edges_idx)
# concat candidates to plan graph
if p1_edge_candidates.numel() != 0:
self.player1_edge_label_index = torch.cat([self.player1_plan.edge_index, p1_edge_candidates.permute(1, 0)], dim=-1)
# create labels
self.player1_edge_label_own_missing_knowledge = torch.cat((torch.ones(self.player1_plan.edge_index.shape[1]), torch.zeros(p1_edge_candidates.shape[0])))
else:
# no missing knowledge
self.player1_edge_label_index = self.player1_plan.edge_index
# create labels
self.player1_edge_label_own_missing_knowledge = torch.ones(self.player1_plan.edge_index.shape[1])
if p2_edge_candidates.numel() != 0:
self.player2_edge_label_index = torch.cat([self.player2_plan.edge_index, p2_edge_candidates.permute(1, 0)], dim=-1)
# create labels
self.player2_edge_label_own_missing_knowledge = torch.cat((torch.ones(self.player2_plan.edge_index.shape[1]), torch.zeros(p2_edge_candidates.shape[0])))
else:
# no missing knowledge
self.player2_edge_label_index = self.player2_plan.edge_index
# create labels
self.player2_edge_label_own_missing_knowledge = torch.ones(self.player2_plan.edge_index.shape[1])
p1_edge_list = [tuple(x) for x in self.player1_edge_label_index.T.tolist()]
p1_missing_edges_idx_list = [tuple(x) for x in p1_missing_edges_idx.tolist()]
self.player1_edge_label_own_missing_knowledge[[p1_edge_list.index(x) for x in p1_missing_edges_idx_list]] = 1.
p2_edge_list = [tuple(x) for x in self.player2_edge_label_index.T.tolist()]
p2_missing_edges_idx_list = [tuple(x) for x in p2_missing_edges_idx.tolist()]
self.player2_edge_label_own_missing_knowledge[[p2_edge_list.index(x) for x in p2_missing_edges_idx_list]] = 1.
# compute other's missing knowledge == identify which one of my edges is unknown to the other player
p1_original_edges_list = [tuple(x) for x in self.player1_plan.edge_index.T.tolist()]
p2_original_edges_list = [tuple(x) for x in self.player2_plan.edge_index.T.tolist()]
p1_other_missing_edges_idx = [(p1_dict[e[0]], p1_dict[e[1]]) for e in p2_missing_edges] # note here is p2_missing_edges
p2_other_missing_edges_idx = [(p2_dict[e[0]], p2_dict[e[1]]) for e in p1_missing_edges] # note here is p1_missing_edges
self.player1_edge_label_other_missing_knowledge = torch.zeros(self.player1_plan.edge_index.shape[1])
self.player1_edge_label_other_missing_knowledge[[p1_original_edges_list.index(x) for x in p1_other_missing_edges_idx]] = 1.
self.player2_edge_label_other_missing_knowledge = torch.zeros(self.player2_plan.edge_index.shape[1])
self.player2_edge_label_other_missing_knowledge[[p2_original_edges_list.index(x) for x in p2_other_missing_edges_idx]] = 1.
self.__iter_ts = self.start_ts
self.action_labels = None
self.materials = sorted(self.plan['materials'])
def _add_node(self, g, material, features):
if material not in g.nodes:
#print(f'Add node {material}')
g.add_node(material, features=features)
return g
def _add_edge(self, g, u, v, tool):
if not g.has_edge(u, v):
#print(f'Add edge ({u}, {v})')
g.add_edge(u, v, tool=tool)
return g
def _plot_plan_graph(self, g, filename):
plt.figure(figsize=(20,20))
pos = nx.spring_layout(g, seed=42)
nx.draw(g, pos, with_labels=True, node_color='lightblue', edge_color='gray')
plt.savefig(filename)
plt.close()
def find_nodes_with_less_than_four_out_degree(self, data):
edge_index = data.edge_index
num_nodes = data.num_nodes
degrees = degree(edge_index[0], num_nodes) # out degrees
# find all nodes that have out degree less than 2
nodes = torch.nonzero(degrees < 4).view(-1)
nodes = [self.inverse_materials_mines_dict[torch.argmax(data.features[i]).item()+1] for i in nodes]
# remove planks (bc all planks have out degree less than 2)
nodes = [n for n in nodes if n.split('_')[-1] != 'PLANKS']
# now check for planks with out degree 0
check_zero_out_degree_planks = torch.nonzero(degrees < 1).view(-1)
check_zero_out_degree_planks = [self.inverse_materials_mines_dict[torch.argmax(data.features[i]).item()+1] for i in check_zero_out_degree_planks]
check_zero_out_degree_planks = [n for n in nodes if n.split('_')[-1] == 'PLANKS']
nodes = nodes + check_zero_out_degree_planks
return nodes
def find_nodes_with_no_in_degree(self, data):
edge_index = data.edge_index
num_nodes = data.num_nodes
degrees = degree(edge_index[1], num_nodes) # in degrees
nodes = torch.nonzero(degrees < 1).view(-1)
nodes = [self.inverse_materials_mines_dict[torch.argmax(data.features[i]).item()+1] for i in nodes]
nodes = [n for n in nodes if n.split('_')[-1] != 'PLANKS']
return nodes
def __len__(self):
return self.end_ts - self.start_ts
def __next__(self):
if self.__iter_ts < self.end_ts:
if self.load_dialogue:
d = [x for x in self.dialogue_events if x[0] == self.__iter_ts]
l = [x for x in self.dialogue_act_labels if x[0] == self.__iter_ts]
d = d if d else None
l = l if l else None
else:
d = None
l = None
if self.use_dialogue_moves:
m = [x for x in self.dialogue_move_labels if x[0] == self.__iter_ts]
m = m if m else None
else:
m = None
if self.action_labels:
a = [x for x in self.action_labels if (x.TickIndex//30 + self.start_ts) >= self.__iter_ts]
if a:
try:
while not a[0].items:
a = a[1:]
al = self.materials.index(a[0].items[0]) if a else 0
except Exception:
print(a)
print(a[0])
print(a[0].items)
print(a[0].items[0])
exit()
at = a[0].TickIndex//30 + self.start_ts
an = int(a[0].Name[-1])
a = [(at,al,an)]
else:
a = [(self.__iter_ts, self.materials.index(self.plan['materials'][0]), 1)]
a = None
else:
if self.end_ts - self.__iter_ts < 10:
a = None
else:
a = None
if not a is None:
if not a[0][0] == self.__iter_ts:
a = None
q = [x for x in self.question_pairs if (x[0][1] == self.__iter_ts)]
q = q[0] if q else None
frame_idx = self.__iter_ts - self.start_ts
if self.load_third_person:
frames = self.third_pers_frames
elif self.load_player1:
frames = self.player1_pov_frames
elif self.load_player2:
frames = self.player2_pov_frames
else:
frames = np.array([0])
if len(frames) == 1:
f = np.zeros((self.img_h,self.img_w,3))
else:
if frame_idx < frames.shape[0]:
f = frames[frame_idx]
else:
f = np.zeros((self.img_h,self.img_w,3))
if self.do_upperbound:
if not q is None:
qnum = 0
base_rep = np.concatenate([
onehot(q[0][2],2),
onehot(q[0][3],2),
onehot(q[0][4][qnum][0]+1,2),
onehot(self.materials_dict[q[0][5][qnum][1]],len(self.materials_dict)),
onehot(q[0][4][qnum][0]+1,2),
onehot(self.materials_dict[q[0][5][qnum][1]],len(self.materials_dict)),
onehot(['YES','MAYBE','NO'].index(q[1][0][qnum])+1,3),
onehot(['YES','MAYBE','NO'].index(q[1][1][qnum])+1,3)
])
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
ToM6 = base_rep if self.ToM6 is not None else np.zeros(1024)
qnum = 1
base_rep = np.concatenate([
onehot(q[0][2],2),
onehot(q[0][3],2),
onehot(q[0][4][qnum][0]+1,2),
onehot(self.materials_dict[q[0][5][qnum][1]],len(self.materials_dict)),
onehot(q[0][4][qnum][0]+1,2),
onehot(self.materials_dict[q[0][5][qnum][1]],len(self.materials_dict)),
onehot(['YES','MAYBE','NO'].index(q[1][0][qnum])+1,3),
onehot(['YES','MAYBE','NO'].index(q[1][1][qnum])+1,3)
])
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
ToM7 = base_rep if self.ToM7 is not None else np.zeros(1024)
qnum = 2
base_rep = np.concatenate([
onehot(q[0][2],2),
onehot(q[0][3],2),
onehot(q[0][4][qnum]+1,2),
onehot(q[0][4][qnum]+1,2),
onehot(self.materials_dict[q[1][0][qnum]] if q[1][0][qnum] in self.materials_dict else len(self.materials_dict)+1,len(self.materials_dict)+1),
onehot(self.materials_dict[q[1][1][qnum]] if q[1][1][qnum] in self.materials_dict else len(self.materials_dict)+1,len(self.materials_dict)+1)
])
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
ToM8 = base_rep if self.ToM8 is not None else np.zeros(1024)
else:
ToM6 = np.zeros(1024)
ToM7 = np.zeros(1024)
ToM8 = np.zeros(1024)
if not l is None:
base_rep = np.concatenate([
onehot(l[0][1],2),
onehot(l[0][2],len(self.dialogue_act_labels_dict))
])
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
DAct = base_rep if self.DAct is not None else np.zeros(1024)
else:
DAct = np.zeros(1024)
if not m is None:
base_rep = np.concatenate([
onehot(m[0][1],2),
onehot(m[0][2][0],len(self.dialogue_move_labels_dict)),
onehot(m[0][2][1],len(self.tools_dict) + len(self.materials_dict) + len(self.mines_dict)+1),
onehot(m[0][2][2],len(self.tools_dict) + len(self.materials_dict) + len(self.mines_dict)+1),
onehot(m[0][2][3],len(self.tools_dict) + len(self.materials_dict) + len(self.mines_dict)+1),
])
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
DMove = base_rep if self.DMove is not None else np.zeros(1024)
else:
DMove = np.zeros(1024)
else:
ToM6 = self.ToM6[frame_idx] if self.ToM6 is not None else np.zeros(1024)
ToM7 = self.ToM7[frame_idx] if self.ToM7 is not None else np.zeros(1024)
ToM8 = self.ToM8[frame_idx] if self.ToM8 is not None else np.zeros(1024)
DAct = self.DAct[frame_idx] if self.DAct is not None else np.zeros(1024)
DMove = self.DAct[frame_idx] if self.DMove is not None else np.zeros(1024)
intermediate = np.concatenate([ToM6,ToM7,ToM8,DAct,DMove])
if self.load_int0:
intermediate = np.zeros(1024*5)
intermediate[:1024] = self.int0_exp2_feats[frame_idx]
retval = ((self.__iter_ts,self.pov),d,l,q,f,a,intermediate,m)
self.__iter_ts += 1
return retval
self.__iter_ts = self.start_ts
raise StopIteration()
def __iter__(self):
return self
def __load_videos(self):
d = self.end_ts - self.start_ts
if self.load_third_person:
try:
self.third_pers_file = glob(os.path.join(self.game_path,'third*gif'))[0]
np_file = self.third_pers_file[:-3]+'npz'
if os.path.isfile(np_file):
self.third_pers_frames = np.load(np_file)['data']
else:
frames = imageio.get_reader(self.third_pers_file, '.gif')
reshaper = lambda x: cv2.resize(x,(self.img_h,self.img_w))
if 'main' in self.game_path:
self.third_pers_frames = np.array([reshaper(f[95:4*95,250:-249,2::-1]) for f in frames])
else:
self.third_pers_frames = np.array([reshaper(f[-3*95:,250:-249,2::-1]) for f in frames])
print(np_file,end=' ')
np.savez_compressed(open(np_file,'wb'), data=self.third_pers_frames)
print('saved')
except Exception as e:
self.third_pers_frames = np.array([0])
if self.third_pers_frames.shape[0]//d < 10:
self.third_pov_frame_rate = 6
else:
if self.third_pers_frames.shape[0]//d < 20:
self.third_pov_frame_rate = 12
else:
if self.third_pers_frames.shape[0]//d < 45:
self.third_pov_frame_rate = 30
else:
self.third_pov_frame_rate = 60
self.third_pers_frames = self.third_pers_frames[::self.third_pov_frame_rate]
else:
self.third_pers_frames = np.array([0])
if self.load_player1:
try:
search_str = 'play2*gif' if self.flip_video else 'play1*gif'
self.player1_pov_file = glob(os.path.join(self.game_path,search_str))[0]
np_file = self.player1_pov_file[:-3]+'npz'
if os.path.isfile(np_file):
self.player1_pov_frames = np.load(np_file)['data']
else:
frames = imageio.get_reader(self.player1_pov_file, '.gif')
reshaper = lambda x: cv2.resize(x,(self.img_h,self.img_w))
self.player1_pov_frames = np.array([reshaper(f[:,:,2::-1]) for f in frames])
print(np_file,end=' ')
np.savez_compressed(open(np_file,'wb'), data=self.player1_pov_frames)
print('saved')
except Exception as e:
self.player1_pov_frames = np.array([0])
if self.player1_pov_frames.shape[0]//d < 10:
self.player1_pov_frame_rate = 6
else:
if self.player1_pov_frames.shape[0]//d < 20:
self.player1_pov_frame_rate = 12
else:
if self.player1_pov_frames.shape[0]//d < 45:
self.player1_pov_frame_rate = 30
else:
self.player1_pov_frame_rate = 60
self.player1_pov_frames = self.player1_pov_frames[::self.player1_pov_frame_rate]
else:
self.player1_pov_frames = np.array([0])
if self.load_player2:
try:
search_str = 'play1*gif' if self.flip_video else 'play2*gif'
self.player2_pov_file = glob(os.path.join(self.game_path,search_str))[0]
np_file = self.player2_pov_file[:-3]+'npz'
if os.path.isfile(np_file):
self.player2_pov_frames = np.load(np_file)['data']
else:
frames = imageio.get_reader(self.player2_pov_file, '.gif')
reshaper = lambda x: cv2.resize(x,(self.img_h,self.img_w))
self.player2_pov_frames = np.array([reshaper(f[:,:,2::-1]) for f in frames])
print(np_file,end=' ')
np.savez_compressed(open(np_file,'wb'), data=self.player2_pov_frames)
print('saved')
except Exception as e:
self.player2_pov_frames = np.array([0])
if self.player2_pov_frames.shape[0]//d < 10:
self.player2_pov_frame_rate = 6
else:
if self.player2_pov_frames.shape[0]//d < 20:
self.player2_pov_frame_rate = 12
else:
if self.player2_pov_frames.shape[0]//d < 45:
self.player2_pov_frame_rate = 30
else:
self.player2_pov_frame_rate = 60
self.player2_pov_frames = self.player2_pov_frames[::self.player2_pov_frame_rate]
else:
self.player2_pov_frames = np.array([0])
def __parse_question_pairs(self):
question_dict = {}
for q in self.questions:
k = q[2][0][1] + q[2][1][1]
if not k in question_dict:
question_dict[k] = []
question_dict[k].append(q)
self.question_pairs = []
for k,v in question_dict.items():
if len(v) == 2:
if v[0][1]+v[1][1] == 3:
self.question_pairs.append(v)
else:
while len(v) > 1:
pair = []
pair.append(v.pop(0))
pair.append(v.pop(0))
while not pair[0][1]+pair[1][1] == 3:
if not v:
break
# print(game_path,pair)
pair.append(v.pop(0))
pair.pop(0)
if not v:
break
self.question_pairs.append(pair)
self.question_pairs = sorted(self.question_pairs, key=lambda x: x[0][0])
if self.load_player2 or self.pov==4:
self.question_pairs = [sorted(q, key=lambda x: x[1],reverse=True) for q in self.question_pairs]
else:
self.question_pairs = [sorted(q, key=lambda x: x[1]) for q in self.question_pairs]
self.question_pairs = [((a[0], b[0], a[1], b[1], a[2], b[2]), (a[3], b[3])) for a,b in self.question_pairs]
def __parse_dialogue(self):
self.dialogue_events = []
save_path = os.path.join(self.game_path,f'dialogue_{self.game_path.split("/")[-1]}.pkl')
if os.path.isfile(save_path):
self.dialogue_events = pickle.load(open( save_path, "rb" ))
return
for x in open(self.dialogue_file):
if '[Async Chat Thread' in x:
ts = list(map(int,x.split(' [')[0].strip('[]').split(':')))
ts = 3600*ts[0] + 60*ts[1] + ts[2]
player, event = x.strip().split('/INFO]: []<sledmcc')[1].split('> ',1)
event = event.lower()
event = ''.join([x if x in string.ascii_lowercase else f' {x} ' for x in event]).strip()
event = event.replace(' ',' ').replace(' ',' ')
player = int(player)
if GameParser.tokenizer is None:
GameParser.tokenizer = BertTokenizer.from_pretrained('bert-large-uncased', do_lower_case=True)
if self.model is None:
GameParser.model = BertModel.from_pretrained('bert-large-uncased', output_hidden_states=True).to(DEVICE)
encoded_dict = GameParser.tokenizer.encode_plus(
event, # Sentence to encode.
add_special_tokens=True, # Add '[CLS]' and '[SEP]'
return_tensors='pt', # Return pytorch tensors.
)
token_ids = encoded_dict['input_ids'].to(DEVICE)
segment_ids = torch.ones(token_ids.size()).long().to(DEVICE)
GameParser.model.eval()
with torch.no_grad():
outputs = GameParser.model(input_ids=token_ids, token_type_ids=segment_ids)
outputs = outputs[1][0].cpu().data.numpy()
self.dialogue_events.append((ts,player,event,outputs))
pickle.dump(self.dialogue_events, open( save_path, "wb" ))
print(f'Saved to {save_path}',flush=True)
def __parse_questions(self):
self.questions = []
for x in open(self.questions_file):
if x[0] == '#':
ts, qs = x.strip().split(' Number of records inserted: 1 # player')
ts = list(map(int,ts.split(' ')[5].split(':')))
ts = 3600*ts[0] + 60*ts[1] + ts[2]
player = int(qs[0])
questions = qs[2:].split(';')
answers =[x[7:] for x in questions[3:]]
questions = [x[9:].split(' ') for x in questions[:3]]
questions[0] = (int(questions[0][0] == 'Have'), questions[0][-3])
questions[1] = (int(questions[1][2] == 'know'), questions[1][-1])
questions[2] = int(questions[2][1] == 'are')
self.questions.append((ts,player,questions,answers))
def __parse_start_end(self):
self.start_ts = [x.strip() for x in open(self.dialogue_file) if 'THEY ARE PLAYER' in x][1]
self.start_ts = list(map(int,self.start_ts.split('] [')[0][1:].split(':')))
self.start_ts = 3600*self.start_ts[0] + 60*self.start_ts[1] + self.start_ts[2]
try:
self.start_ts = max(self.start_ts, self.questions[0][0]-75)
except Exception as e:
pass
self.end_ts = [x.strip() for x in open(self.dialogue_file) if 'Stopping' in x]
if self.end_ts:
self.end_ts = self.end_ts[0]
self.end_ts = list(map(int,self.end_ts.split('] [')[0][1:].split(':')))
self.end_ts = 3600*self.end_ts[0] + 60*self.end_ts[1] + self.end_ts[2]
else:
self.end_ts = self.dialogue_events[-1][0]
try:
self.end_ts = max(self.end_ts, self.questions[-1][0]) + 1
except Exception as e:
pass
def __load_dialogue_act_labels(self):
file_name = 'config/dialogue_act_labels.json'
if not os.path.isfile(file_name):
files = sorted(glob('/home/*/MCC/*done.txt'))
dialogue_act_dict = {}
for file in files:
game_str = ''
for line in open(file):
line = line.strip()
if '_logs/' in line:
game_str = line
else:
if line:
line = line.split()
key = f'{game_str}#{line[0]}'
dialogue_act_dict[key] = line[-1]
json.dump(dialogue_act_dict,open(file_name,'w'), indent=4)
self.dialogue_act_dict = json.load(open(file_name))
self.dialogue_act_labels_dict = {l : i for i, l in enumerate(sorted(list(set(self.dialogue_act_dict.values()))))}
self.dialogue_act_bias = {l : sum([int(x==l) for x in self.dialogue_act_dict.values()]) for l in self.dialogue_act_labels_dict.keys()}
json.dump(self.dialogue_act_labels_dict,open('config/dialogue_act_label_names.json','w'), indent=4)
def __assign_dialogue_act_labels(self):
log_file = glob('/'.join([self.game_path,'mcc*log']))[0].split('mindcraft/')[1]
self.dialogue_act_labels = []
for emb in self.dialogue_events:
ts = emb[0]
h = ts//3600
m = (ts%3600)//60
s = ts%60
key = f'{log_file}#[{h:02d}:{m:02d}:{s:02d}]:{emb[1]}>'
self.dialogue_act_labels.append((emb[0],emb[1],self.dialogue_act_labels_dict[self.dialogue_act_dict[key]]))
def __load_dialogue_move_labels(self):
file_name = "config/dialogue_move_labels.json"
dialogue_move_dict = {}
if not os.path.isfile(file_name):
file_text = ''
dialogue_moves = set()
for line in open("XXX"):
line = line.strip()
if not line:
continue
if line[0] == '#':
continue
if line[0] == '[':
tag_text = glob(f'data/*/*/mcc_{file_text}.log')[0].split('/',1)[-1]
key = f'{tag_text}#{line.split()[0]}'
value = line.split()[-1].split('#')
if len(value) < 4:
value += ['IGNORE']*(4-len(value))
dialogue_moves.add(value[0])
value = '#'.join(value)
dialogue_move_dict[key] = value
else:
file_text = line
dialogue_moves = sorted(list(dialogue_moves))
json.dump(dialogue_move_dict,open(file_name,'w'), indent=4)
self.dialogue_move_dict = json.load(open(file_name))
self.dialogue_move_labels_dict = {l : i for i, l in enumerate(sorted(list(set([lbl.split('#')[0] for lbl in self.dialogue_move_dict.values()]))))}
self.dialogue_move_bias = {l : sum([int(x==l) for x in self.dialogue_move_dict.values()]) for l in self.dialogue_move_labels_dict.keys()}
json.dump(self.dialogue_move_labels_dict,open('config/dialogue_move_label_names.json','w'), indent=4)
def __assign_dialogue_move_labels(self):
log_file = glob('/'.join([self.game_path,'mcc*log']))[0].split('mindcraft/')[1]
self.dialogue_move_labels = []
for emb in self.dialogue_events:
ts = emb[0]
h = ts//3600
m = (ts%3600)//60
s = ts%60
key = f'{log_file}#[{h:02d}:{m:02d}:{s:02d}]:{emb[1]}>'
move = self.dialogue_move_dict[key].split('#')
move[0] = self.dialogue_move_labels_dict[move[0]]
for i,m in enumerate(move[1:]):
if m == 'IGNORE':
move[i+1] = 0
elif m in self.materials_dict:
move[i+1] = self.materials_dict[m]
elif m in self.mines_dict:
move[i+1] = self.mines_dict[m] + len(self.materials_dict)
elif m in self.tools_dict:
move[i+1] = self.tools_dict[m] + len(self.materials_dict) + len(self.mines_dict)
else:
print(move)
exit()
self.dialogue_move_labels.append((emb[0],emb[1],move))
def __load_replay_data(self):
self.actions = None
def __load_intermediate(self):
if self.intermediate > 15:
self.do_upperbound = True
else:
self.do_upperbound = False
if self.pov in [1,2]:
self.ToM6 = np.load(glob(f'{self.game_path}/intermediate_ToM6*player{self.pov}.npz')[0])['data'] if self.intermediate % 2 else None
self.intermediate = self.intermediate // 2
self.ToM7 = np.load(glob(f'{self.game_path}/intermediate_ToM7*player{self.pov}.npz')[0])['data'] if self.intermediate % 2 else None
self.intermediate = self.intermediate // 2
self.ToM8 = np.load(glob(f'{self.game_path}/intermediate_ToM8*player{self.pov}.npz')[0])['data'] if self.intermediate % 2 else None
self.intermediate = self.intermediate // 2
self.DAct = np.load(glob(f'{self.game_path}/intermediate_DAct*player{self.pov}.npz')[0])['data'] if self.intermediate % 2 else None
self.intermediate = self.intermediate // 2
self.DMove = None
else:
self.ToM6 = None
self.ToM7 = None
self.ToM8 = None
self.DAct = None
self.DMove = None
def __load_int0_feats(self):
self.int0_exp2_feats = np.load(glob(f'{self.game_path}/int0_exp2*player{self.pov}.npz')[0])['data']
# self.int0_exp3_feats = np.load(np.load(glob(f'{self.game_path}/int0_exp3*player{self.pov}.npz')[0])['data'])

BIN
src/models/.DS_Store vendored Normal file

Binary file not shown.

0
src/models/__init__.py Normal file
View file

207
src/models/losses.py Normal file
View file

@ -0,0 +1,207 @@
import torch
from torch import nn
import numpy as np
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss
def onehot(x,n):
retval = np.zeros(n)
if x > 0:
retval[x-1] = 1
return retval
class PlanLoss(nn.Module):
def __init__(self):
super(PlanLoss, self).__init__()
def getWeights(self, output, target):
# return 1
f1 = (1+5*torch.stack([2-torch.sum(target.reshape(-1,21,21),dim=-1)]*21,dim=-1)).reshape(-1,21*21)
f2 = 100*target + 1
return (f1+f2)/60
exit(0)
# print(max(torch.sum(target.reshape(21,21),dim=-1)))
return (target*torch.sum(target,dim=-1) + 1)
def MSELoss(self, output, target):
retval = (output - target)**2
retval *= self.getWeights(output,target)
return torch.mean(retval)
def BCELoss(self, output, target, loss_mask=None):
mask_factor = torch.ones(target.shape).to(output.device)
if loss_mask is not None:
loss_mask = loss_mask.reshape(-1,21,21)
mask_factor = mask_factor.reshape(-1,21,21)
# print(mask_factor.shape,loss_mask.shape,output.shape,target.shape)
for idx, tgt in enumerate(loss_mask):
for jdx, tgt_node in enumerate(tgt):
if sum(tgt_node) == 0:
mask_factor[idx,jdx] *= 0
# print(loss_mask[0].data.cpu().numpy())
# print(mask_factor[0].data.cpu().numpy())
# print()
# print(loss_mask[45].data.cpu().numpy())
# print(mask_factor[45].data.cpu().numpy())
# print()
# print(loss_mask[-1].data.cpu().numpy())
# print(mask_factor[-1].data.cpu().numpy())
# print()
loss_mask = loss_mask.reshape(-1,21*21)
mask_factor = mask_factor.reshape(-1,21*21)
# print(loss_mask.shape, target.shape, mask_factor.shape)
# exit()
factor = (10 if target.shape[-1]==441 else 1)# * torch.sum(target,dim=-1)+1
retval = -1 * mask_factor * (factor * target * torch.log(1e-6+output) + (1-target) * torch.log(1e-6+1-output))
factor = torch.stack([torch.sum(target,dim=-1)+1]*target.shape[-1],dim=-1)
return torch.mean(factor*retval)
return torch.mean(retval)
def forward(self, output, target, loss_mask=None):
return self.BCELoss(output,target,loss_mask) + 0.01*torch.sum(output - 1/21)
# return self.MSELoss(output,target)
class DialogueActLoss(nn.Module):
def __init__(self):
super(DialogueActLoss, self).__init__()
self.bias = torch.tensor([289,51,45,57,14,12,1,113,6,264,27,63,22,66,2,761,129,163,5]).float()
self.bias = max(self.bias) - self.bias + 1
self.bias /= torch.sum(self.bias)
self.bias = 1-self.bias
# self.bias *= self.bias
def BCELoss(self, output, target):
target = torch.stack([torch.tensor(onehot(x + 1,19)).long() for x in target]).to(output.device)
retval = -1 * (target * torch.log(1e-6+output) + (1-target) * torch.log(1e-6+1-output))
retval *= torch.stack([self.bias] * output.shape[0]).to(output.device)
# print(output)
# print(target)
# print(retval)
# print(torch.mean(retval))
# exit()
return torch.mean(retval)
def forward(self, output, target):
return self.BCELoss(output,target)
# return self.MSELoss(output,target)
class DialogueMoveLoss(nn.Module):
def __init__(self, device):
super(DialogueMoveLoss, self).__init__()
# self.bias = torch.tensor([289,51,45,57,14,12,1,113,6,264,27,63,22,66,2,761,129,163,5]).float()
# self.bias = max(self.bias) - self.bias + 1
# self.bias /= torch.sum(self.bias)
# self.bias = 1-self.bias
move_weights = torch.tensor(np.array([202, 34, 34, 48, 4, 2, 420, 10, 54, 1, 10, 11, 30, 28, 14, 2, 16, 6, 2, 86, 4, 12, 28, 2, 2, 16, 12, 14, 4, 1, 12, 258, 12, 26, 2])).float().to(device)
move_weights = 1+ max(move_weights) - move_weights
self.loss1 = CrossEntropyLoss(weight=move_weights)
zero_bias = 0.773
num_classes = 40
weight = torch.tensor(np.array([50 if not x else 1 for x in range(num_classes)])).float().to(device)
weight = 1+ max(weight) - weight
self.loss2 = CrossEntropyLoss(weight=weight)
# self.bias *= self.bias
def BCELoss(self, output, target,zero_bias):
# # print(output.shape,target.shape)
# bias = torch.tensor(np.array([1 if t else zero_bias for t in target])).to(output.device)
# target = torch.stack([torch.tensor(onehot(x,output.shape[-1])).long() for x in target]).to(output.device)
# # print(target.shape, bias.shape, bias)
# retval = -1 * (target * torch.log(1e-6+output) + (1-target) * torch.log(1e-6+1-output))
# retval = torch.mean(retval,-1)
# # print(retval.shape)
# retval *= bias
# # retval *= torch.stack([self.bias] * output.shape[0]).to(output.device)
# # print(output)
# # print(target)
# # print(retval)
# # print(torch.mean(retval))
# # exit()
# # retval = self.loss(output,target)
# return torch.mean(retval) # retval #
# weight = [zero_bias if x else (1-zero_bias)/(output.shape[-1]-1) for x in range(output.shape[-1])]
retval = self.loss2(output,target) if zero_bias else self.loss1(output,target)
return retval #
def forward(self, output, target):
o1, o2, o3, o4 = output
t1, t2, t3, t4 = target
# print(t2,t2.shape, o2.shape)
# if sum(t2):
# o2, t2 = zip(*[(a,b) for a,b in zip(o2,t2) if b])
# o2 = torch.stack(o2)
# t2 = torch.stack(t2)
# if sum(t3):
# o3, t3 = zip(*[(a,b) for a,b in zip(o3,t3) if b])
# o3 = torch.stack(o3)
# t3 = torch.stack(t3)
# if sum(t4):
# o4, t4 = zip(*[(a,b) for a,b in zip(o4,t4) if b])
# o4 = torch.stack(o4)
# t4 = torch.stack(t4)
# print(t2,t2.shape, o2.shape)
# exit()
retval = sum([
1*self.BCELoss(output[0],target[0],0),
0*self.BCELoss(output[1],target[1],1),
0*self.BCELoss(output[2],target[2],1),
0*self.BCELoss(output[3],target[3],1)
])
return retval #sum([fact*self.BCELoss(o,t,zbias) for fact,zbias,o,t in zip([1,0,0,0],[0,1,1,1],output,target)])
# return self.MSELoss(output,target)
class DialoguePredLoss(nn.Module):
def __init__(self):
super(DialoguePredLoss, self).__init__()
self.bias = torch.tensor([289,51,45,57,14,12,1,113,6,264,27,63,22,66,2,761,129,163,5,0]).float()
self.bias[-1] = 1460#2 * torch.sum(self.bias) // 3
self.bias = max(self.bias) - self.bias + 1
self.bias /= torch.sum(self.bias)
self.bias = 1-self.bias
# self.bias *= self.bias
def BCELoss(self, output, target):
target = torch.stack([torch.tensor(onehot(x + 1,20)).long() for x in target]).to(output.device)
retval = -1 * (target * torch.log(1e-6+output) + (1-target) * torch.log(1e-6+1-output))
retval *= torch.stack([self.bias] * output.shape[0]).to(output.device)
# print(output)
# print(target)
# print(retval)
# print(torch.mean(retval))
# exit()
return torch.mean(retval)
def forward(self, output, target):
return self.BCELoss(output,target)
# return self.MSELoss(output,target)
class ActionLoss(nn.Module):
def __init__(self):
super(ActionLoss, self).__init__()
self.bias1 = torch.tensor([134,1370,154,128,220,166,46,76,106,78,88,124,102,120,276,122,112,106,44,174,20]).float()
# self.bias[-1] = 1460#2 * torch.sum(self.bias) // 3
# self.bias1 = torch.ones(21).float()
self.bias1 = max(self.bias1) - self.bias1 + 1
self.bias1 /= torch.sum(self.bias1)
self.bias1 = 1-self.bias1
self.bias2 = torch.tensor([1168,1310]).float()
# self.bias2[-1] = 1460#2 * torch.sum(self.bias) // 3
# self.bias2 = torch.ones(21).float()
self.bias2 = max(self.bias2) - self.bias2 + 1
self.bias2 /= torch.sum(self.bias2)
self.bias2 = 1-self.bias2
# self.bias *= self.bias
def BCELoss(self, output, target):
# target = torch.stack([torch.tensor(onehot(x + 1,20)).long() for x in target]).to(output.device)
retval = -1 * (target * torch.log(1e-6+output) + (1-target) * torch.log(1e-6+1-output))
# print(self.bias1.shape,self.bias2.shape,output.shape[-1])
retval *= torch.stack([self.bias2 if output.shape[-1]==2 else self.bias1] * output.shape[0]).to(output.device)
# print(output)
# print(target)
# print(retval)
# print(torch.mean(retval))
# exit()
return torch.mean(retval)
def forward(self, output, target):
return self.BCELoss(output,target)
# return self.MSELoss(output,target)

View file

@ -0,0 +1,205 @@
import sys, torch, random
from numpy.core.fromnumeric import reshape
import torch.nn as nn, numpy as np
from src.data.game_parser import DEVICE
def onehot(x,n):
retval = np.zeros(n)
if x > 0:
retval[x-1] = 1
return retval
class Model(nn.Module):
def __init__(self, seq_model_type=0,device=DEVICE):
super(Model, self).__init__()
self.device = device
print("model set to device", self.device)
my_rnn = lambda i,o: nn.GRU(i,o)
#my_rnn = lambda i,o: nn.LSTM(i,o)
plan_emb_in = 81
plan_emb_out = 32
q_emb = 100
self.plan_embedder0 = my_rnn(plan_emb_in,plan_emb_out)
self.plan_embedder1 = my_rnn(plan_emb_in,plan_emb_out)
self.plan_embedder2 = my_rnn(plan_emb_in,plan_emb_out)
# self.dialogue_listener = my_rnn(1126,768)
dlist_hidden = 1024
frame_emb = 512
self.move_emb = 157
drnn_in = 1024 + 2 + q_emb + frame_emb + self.move_emb
# drnn_in = 1024 + 2
# my_rnn = lambda i,o: nn.GRU(i,o)
my_rnn = lambda i,o: nn.LSTM(i,o)
if seq_model_type==0:
self.dialogue_listener_rnn = nn.GRU(drnn_in,dlist_hidden)
self.dialogue_listener = lambda x: \
self.dialogue_listener_rnn(x.reshape(-1,1,drnn_in))[0]
elif seq_model_type==1:
self.dialogue_listener_rnn = nn.LSTM(drnn_in,dlist_hidden)
self.dialogue_listener = lambda x: \
self.dialogue_listener_rnn(x.reshape(-1,1,drnn_in))[0]
elif seq_model_type==2:
mask_fun = lambda x: torch.triu(torch.ones(x.shape[0],x.shape[0]),diagonal=1).bool().to(self.device)
sincos_fun = lambda x:torch.transpose(torch.stack([
torch.sin(2*np.pi*torch.tensor(list(range(x)))/x),
torch.cos(2*np.pi*torch.tensor(list(range(x)))/x)
]),0,1).reshape(-1,1,2)
self.dialogue_listener_lin1 = nn.Linear(drnn_in,dlist_hidden-2)
self.dialogue_listener_attn = nn.MultiheadAttention(dlist_hidden, 8)
self.dialogue_listener_wrap = lambda x: self.dialogue_listener_attn(x,x,x,attn_mask=mask_fun(x))
self.dialogue_listener = lambda x: self.dialogue_listener_wrap(torch.cat([
sincos_fun(x.shape[0]).float().to(self.device),
self.dialogue_listener_lin1(x).reshape(-1,1,dlist_hidden-2)
], axis=-1))[0]
else:
print('Sequence model type must be in (0: GRU, 1: LSTM, 2: Transformer), but got ', seq_model_type)
exit()
conv_block = lambda i,o,k,p,s: nn.Sequential(
nn.Conv2d( i, o, k, padding=p, stride=s),
nn.BatchNorm2d(o),
nn.Dropout(0.5),
nn.MaxPool2d(2),
nn.BatchNorm2d(o),
nn.Dropout(0.5),
nn.ReLU()
)
self.conv = nn.Sequential(
conv_block( 3, 8, 3, 1, 1),
# conv_block( 3, 8, 5, 2, 2),
conv_block( 8, 32, 5, 2, 2),
conv_block( 32, frame_emb//4, 5, 2, 2),
nn.Conv2d( frame_emb//4, frame_emb, 3),nn.ReLU(),
)
qlayer = lambda i,o : nn.Sequential(
nn.Linear(i,512),
nn.Dropout(0.5),
nn.GELU(),
nn.Dropout(0.5),
nn.Linear(512,o),
nn.Dropout(0.5),
# nn.Softmax(-1)
)
q_in_size = 3*plan_emb_out+dlist_hidden+q_emb
self.q01 = qlayer(q_in_size,2)
self.q02 = qlayer(q_in_size,2)
self.q03 = qlayer(q_in_size,2)
self.q11 = qlayer(q_in_size,3)
self.q12 = qlayer(q_in_size,3)
self.q13 = qlayer(q_in_size,22)
self.q21 = qlayer(q_in_size,3)
self.q22 = qlayer(q_in_size,3)
self.q23 = qlayer(q_in_size,22)
def forward(self,game,global_plan=False, player_plan=False, intermediate=False):
retval = []
l = list(game)
_,d,_,q,f,_,_,m = zip(*list(game))
parse_move = lambda m: np.concatenate([
onehot(m[0][1], 2),
onehot(m[0][2][0]+1, len(game.dialogue_move_labels_dict)),
onehot(m[0][2][1]+1, len(game.materials_dict)+len(game.mines_dict)+len(game.tools_dict)+1),
onehot(m[0][2][2]+1, len(game.materials_dict)+len(game.mines_dict)+len(game.tools_dict)+1),
onehot(m[0][2][3]+1, len(game.materials_dict)+len(game.mines_dict)+len(game.tools_dict)+1)
])
# print(2+len(game.dialogue_move_labels_dict)+3*(len(game.materials_dict)+len(game.mines_dict)+len(game.tools_dict)+1))
# print(len(game.dialogue_move_labels_dict))
m = np.stack([np.zeros(self.move_emb) if move is None else parse_move(move) for move in m])
h = None
f = np.array(f, dtype=np.uint8)
# f = torch.tensor(f).permute(0,3,1,2).float().to(self.device)
# flt_lst = [(a,b) for a,b in zip(d,q) if (not a is None) or (not b is None)]
# if not flt_lst:
# return []
# d,q = zip(*flt_lst)
d = np.stack([np.concatenate(([int(x[0][1]==2),int(x[0][1]==1)],x[0][-1])) if not x is None else np.zeros(1026) for x in d])
def parse_q(q):
if not q is None:
q ,l = q
q = np.concatenate([
onehot(q[2],2),
onehot(q[3],2),
onehot(q[4][0][0]+1,2),
onehot(game.materials_dict[q[4][0][1]],len(game.materials_dict)),
onehot(q[4][1][0]+1,2),
onehot(game.materials_dict[q[4][1][1]],len(game.materials_dict)),
onehot(q[4][2]+1,2),
onehot(q[5][0][0]+1,2),
onehot(game.materials_dict[q[5][0][1]],len(game.materials_dict)),
onehot(q[5][1][0]+1,2),
onehot(game.materials_dict[q[5][1][1]],len(game.materials_dict)),
onehot(q[5][2]+1,2)
])
else:
q = np.zeros(100)
l = None
return q, l
try:
sel1 = int([x[0][2] for x in q if not x is None][0] == 1)
sel2 = 1 - sel1
except Exception as e:
sel1 = 0
sel2 = 0
q = [parse_q(x) for x in q]
q, l = zip(*q)
q = np.stack(q)
if not global_plan and not player_plan:
plan_emb = torch.cat([
self.plan_embedder0(torch.stack(list(map(torch.tensor,game.global_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
self.plan_embedder1(torch.stack(list(map(torch.tensor,game.player1_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
self.plan_embedder2(torch.stack(list(map(torch.tensor,game.player2_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0]
])
plan_emb = 0*plan_emb
elif global_plan:
plan_emb = torch.cat([
self.plan_embedder0(torch.stack(list(map(torch.tensor,game.global_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
self.plan_embedder1(torch.stack(list(map(torch.tensor,game.player1_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
self.plan_embedder2(torch.stack(list(map(torch.tensor,game.player2_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0]
])
else:
plan_emb = torch.cat([
0*self.plan_embedder0(torch.stack(list(map(torch.tensor,game.global_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
sel1*self.plan_embedder1(torch.stack(list(map(torch.tensor,game.player1_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
sel2*self.plan_embedder2(torch.stack(list(map(torch.tensor,game.player2_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0]
])
u = torch.cat((
torch.tensor(d).float().to(self.device),
torch.tensor(q).float().to(self.device),
self.conv(torch.tensor(f).permute(0,3,1,2).float().to(self.device)).reshape(-1,512),
torch.tensor(m).float().to(self.device)
),axis=-1)
u = u.float().to(self.device)
y = self.dialogue_listener(u)
y = y.reshape(-1,y.shape[-1])
if intermediate:
return y
if all([x is None for x in l]):
return []
fun_lst = [self.q01,self.q02,self.q03,self.q11,self.q12,self.q13,self.q21,self.q22,self.q23]
fun = lambda x: [f(x) for f in fun_lst]
retval = [(_l,fun(torch.cat((plan_emb,torch.tensor(_q).float().to(self.device),_y)))) for _y, _q, _l in zip(y,q,l)if not _l is None]
return retval

View file

@ -0,0 +1,226 @@
import torch
import torch.nn as nn, numpy as np
from src.data.game_parser import DEVICE
from torch_geometric.nn import GATv2Conv, MeanAggregation
def onehot(x,n):
retval = np.zeros(n)
if x > 0:
retval[x-1] = 1
return retval
class PlanGraphEmbedder(nn.Module):
def __init__(self, h_dim, dropout=0.0, heads=4):
super().__init__()
self.proj_x = nn.Sequential(
nn.Linear(27, h_dim),
nn.GELU(),
nn.Dropout(dropout)
)
self.proj_edge_attr = nn.Sequential(
nn.Linear(12, h_dim),
nn.GELU(),
nn.Dropout(dropout)
)
self.conv1 = GATv2Conv(h_dim, h_dim, heads=heads, edge_dim=h_dim)
self.conv2 = GATv2Conv(h_dim*heads, h_dim, heads=1, edge_dim=h_dim)
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.pool = MeanAggregation()
def forward(self, data):
x, edge_index, edge_attr = data.features.to(DEVICE), data.edge_index.to(DEVICE), data.tool.to(DEVICE)
x = self.proj_x(x)
edge_attr = self.proj_edge_attr(edge_attr)
x = self.conv1(x, edge_index, edge_attr)
x = self.act(x)
x = self.dropout(x)
x = self.conv2(x, edge_index, edge_attr)
x = self.pool(x)
return x.squeeze(0)
class Model(nn.Module):
def __init__(self, seq_model_type=0,device=DEVICE):
super(Model, self).__init__()
self.device = device
print("model set to device", self.device)
plan_emb_out = 32*3
q_emb = 100
self.plan_embedder0 = PlanGraphEmbedder(plan_emb_out)
self.plan_embedder1 = PlanGraphEmbedder(plan_emb_out)
self.plan_embedder2 = PlanGraphEmbedder(plan_emb_out)
dlist_hidden = 1024
frame_emb = 512
self.move_emb = 157
drnn_in = 1024 + 2 + q_emb + frame_emb + self.move_emb
if seq_model_type==0:
self.dialogue_listener_rnn = nn.GRU(drnn_in,dlist_hidden)
self.dialogue_listener = lambda x: \
self.dialogue_listener_rnn(x.reshape(-1,1,drnn_in))[0]
elif seq_model_type==1:
self.dialogue_listener_rnn = nn.LSTM(drnn_in,dlist_hidden)
self.dialogue_listener = lambda x: \
self.dialogue_listener_rnn(x.reshape(-1,1,drnn_in))[0]
elif seq_model_type==2:
mask_fun = lambda x: torch.triu(torch.ones(x.shape[0],x.shape[0]),diagonal=1).bool().to(self.device)
sincos_fun = lambda x:torch.transpose(torch.stack([
torch.sin(2*np.pi*torch.tensor(list(range(x)))/x),
torch.cos(2*np.pi*torch.tensor(list(range(x)))/x)
]),0,1).reshape(-1,1,2)
self.dialogue_listener_lin1 = nn.Linear(drnn_in,dlist_hidden-2)
self.dialogue_listener_attn = nn.MultiheadAttention(dlist_hidden, 8)
self.dialogue_listener_wrap = lambda x: self.dialogue_listener_attn(x,x,x,attn_mask=mask_fun(x))
self.dialogue_listener = lambda x: self.dialogue_listener_wrap(torch.cat([
sincos_fun(x.shape[0]).float().to(self.device),
self.dialogue_listener_lin1(x).reshape(-1,1,dlist_hidden-2)
], axis=-1))[0]
else:
print('Sequence model type must be in (0: GRU, 1: LSTM, 2: Transformer), but got ', seq_model_type)
exit()
conv_block = lambda i,o,k,p,s: nn.Sequential(
nn.Conv2d( i, o, k, padding=p, stride=s),
nn.BatchNorm2d(o),
nn.Dropout(0.5),
nn.MaxPool2d(2),
nn.BatchNorm2d(o),
nn.Dropout(0.5),
nn.ReLU()
)
self.conv = nn.Sequential(
conv_block( 3, 8, 3, 1, 1),
conv_block( 8, 32, 5, 2, 2),
conv_block( 32, frame_emb//4, 5, 2, 2),
nn.Conv2d( frame_emb//4, frame_emb, 3),nn.ReLU(),
)
qlayer = lambda i,o : nn.Sequential(
nn.Linear(i,512),
nn.Dropout(0.5),
nn.GELU(),
nn.Dropout(0.5),
nn.Linear(512,o),
nn.Dropout(0.5),
)
q_in_size = plan_emb_out+dlist_hidden+q_emb
self.q01 = qlayer(q_in_size,2)
self.q02 = qlayer(q_in_size,2)
self.q03 = qlayer(q_in_size,2)
self.q11 = qlayer(q_in_size,3)
self.q12 = qlayer(q_in_size,3)
self.q13 = qlayer(q_in_size,22)
self.q21 = qlayer(q_in_size,3)
self.q22 = qlayer(q_in_size,3)
self.q23 = qlayer(q_in_size,22)
def forward(self,game,global_plan=False, player_plan=False, intermediate=False):
retval = []
l = list(game)
_,d,_,q,f,_,_,m = zip(*list(game))
parse_move = lambda m: np.concatenate([
onehot(m[0][1], 2),
onehot(m[0][2][0]+1, len(game.dialogue_move_labels_dict)),
onehot(m[0][2][1]+1, len(game.materials_dict)+len(game.mines_dict)+len(game.tools_dict)+1),
onehot(m[0][2][2]+1, len(game.materials_dict)+len(game.mines_dict)+len(game.tools_dict)+1),
onehot(m[0][2][3]+1, len(game.materials_dict)+len(game.mines_dict)+len(game.tools_dict)+1)
])
m = np.stack([np.zeros(self.move_emb) if move is None else parse_move(move) for move in m])
f = np.array(f, dtype=np.uint8)
d = np.stack([np.concatenate(([int(x[0][1]==2),int(x[0][1]==1)],x[0][-1])) if not x is None else np.zeros(1026) for x in d])
def parse_q(q):
if not q is None:
q ,l = q
q = np.concatenate([
onehot(q[2],2),
onehot(q[3],2),
onehot(q[4][0][0]+1,2),
onehot(game.materials_dict[q[4][0][1]],len(game.materials_dict)),
onehot(q[4][1][0]+1,2),
onehot(game.materials_dict[q[4][1][1]],len(game.materials_dict)),
onehot(q[4][2]+1,2),
onehot(q[5][0][0]+1,2),
onehot(game.materials_dict[q[5][0][1]],len(game.materials_dict)),
onehot(q[5][1][0]+1,2),
onehot(game.materials_dict[q[5][1][1]],len(game.materials_dict)),
onehot(q[5][2]+1,2)
])
else:
q = np.zeros(100)
l = None
return q, l
try:
sel1 = int([x[0][2] for x in q if not x is None][0] == 1)
sel2 = 1 - sel1
except Exception as e:
sel1 = 0
sel2 = 0
q = [parse_q(x) for x in q]
q, l = zip(*q)
q = np.stack(q)
if not global_plan and not player_plan:
# plan_emb = torch.cat([
# self.plan_embedder0(game.global_plan),
# self.plan_embedder1(game.player1_plan),
# self.plan_embedder2(game.player2_plan)
# ])
plan_emb = self.plan_embedder0(game.global_plan)
plan_emb = 0*plan_emb
elif global_plan:
# plan_emb = torch.cat([
# self.plan_embedder0(game.global_plan),
# self.plan_embedder1(game.player1_plan),
# self.plan_embedder2(game.player2_plan)
# ])
plan_emb = self.plan_embedder0(game.global_plan)
else:
# plan_emb = torch.cat([
# 0*self.plan_embedder0(game.global_plan),
# sel1*self.plan_embedder1(game.player1_plan),
# sel2*self.plan_embedder2(game.player2_plan)
# ])
if sel1:
plan_emb = self.plan_embedder1(game.player1_plan)
elif sel2:
plan_emb = self.plan_embedder2(game.player2_plan)
else:
plan_emb = self.plan_embedder0(game.global_plan)
plan_emb = 0*plan_emb
u = torch.cat((
torch.tensor(d).float().to(self.device),
torch.tensor(q).float().to(self.device),
self.conv(torch.tensor(f).permute(0,3,1,2).float().to(self.device)).reshape(-1,512),
torch.tensor(m).float().to(self.device)
),axis=-1)
u = u.float().to(self.device)
y = self.dialogue_listener(u)
y = y.reshape(-1,y.shape[-1])
if intermediate:
return y
if all([x is None for x in l]):
return []
fun_lst = [self.q01,self.q02,self.q03,self.q11,self.q12,self.q13,self.q21,self.q22,self.q23]
fun = lambda x: [f(x) for f in fun_lst]
retval = [(_l,fun(torch.cat((plan_emb,torch.tensor(_q).float().to(self.device),_y)))) for _y, _q, _l in zip(y,q,l)if not _l is None]
return retval

225
src/models/plan_model.py Executable file
View file

@ -0,0 +1,225 @@
import sys, torch, random
from numpy.core.fromnumeric import reshape
import torch.nn as nn, numpy as np
from torch.nn import functional as F
from src.data.game_parser import DEVICE
def onehot(x,n):
retval = np.zeros(n)
if x > 0:
retval[x-1] = 1
return retval
class Model(nn.Module):
def __init__(self, seq_model_type=0,device=DEVICE):
super(Model, self).__init__()
self.device = device
my_rnn = lambda i,o: nn.GRU(i,o)
#my_rnn = lambda i,o: nn.LSTM(i,o)
plan_emb_in = 81
plan_emb_out = 32
q_emb = 100
self.plan_embedder0 = my_rnn(plan_emb_in,plan_emb_out)
self.plan_embedder1 = my_rnn(plan_emb_in,plan_emb_out)
self.plan_embedder2 = my_rnn(plan_emb_in,plan_emb_out)
# self.dialogue_listener = my_rnn(1126,768)
dlist_hidden = 1024
frame_emb = 512
drnn_in = 5*1024 + 2 + frame_emb + 1024
# drnn_in = 1024 + 2
# my_rnn = lambda i,o: nn.GRU(i,o)
my_rnn = lambda i,o: nn.LSTM(i,o)
if seq_model_type==0:
self.dialogue_listener_rnn = nn.GRU(drnn_in,dlist_hidden)
self.dialogue_listener = lambda x: \
self.dialogue_listener_rnn(x.reshape(-1,1,drnn_in))[0]
elif seq_model_type==1:
self.dialogue_listener_rnn = nn.LSTM(drnn_in,dlist_hidden)
self.dialogue_listener = lambda x: \
self.dialogue_listener_rnn(x.reshape(-1,1,drnn_in))[0]
elif seq_model_type==2:
mask_fun = lambda x: torch.triu(torch.ones(x.shape[0],x.shape[0]),diagonal=1).bool().to(device)
sincos_fun = lambda x:torch.transpose(torch.stack([
torch.sin(2*np.pi*torch.tensor(list(range(x)))/x),
torch.cos(2*np.pi*torch.tensor(list(range(x)))/x)
]),0,1).reshape(-1,1,2)
self.dialogue_listener_lin1 = nn.Linear(drnn_in,dlist_hidden-2)
self.dialogue_listener_attn = nn.MultiheadAttention(dlist_hidden, 8)
self.dialogue_listener_wrap = lambda x: self.dialogue_listener_attn(x,x,x,attn_mask=mask_fun(x))
self.dialogue_listener = lambda x: self.dialogue_listener_wrap(torch.cat([
sincos_fun(x.shape[0]).float().to(self.device),
self.dialogue_listener_lin1(x).reshape(-1,1,dlist_hidden-2)
], axis=-1))[0]
else:
print('Sequence model type must be in (0: GRU, 1: LSTM, 2: Transformer), but got ', seq_model_type)
exit()
conv_block = lambda i,o,k,p,s: nn.Sequential(
nn.Conv2d( i, o, k, padding=p, stride=s),
nn.BatchNorm2d(o),
nn.Dropout(0.5),
nn.MaxPool2d(2),
nn.BatchNorm2d(o),
nn.Dropout(0.5),
nn.ReLU()
)
self.conv = nn.Sequential(
conv_block( 3, 8, 3, 1, 1),
# conv_block( 3, 8, 5, 2, 2),
conv_block( 8, 32, 5, 2, 2),
conv_block( 32, frame_emb//4, 5, 2, 2),
nn.Conv2d( frame_emb//4, frame_emb, 3),nn.ReLU(),
)
plan_layer = lambda i,o : nn.Sequential(
# nn.Linear(i,(i+o)//2),
nn.Linear(i,(i+2*o)//3),
nn.Dropout(0.5),
nn.GELU(),
nn.Dropout(0.5),
# nn.Linear((i+o)//2,o),
nn.Linear((i+2*o)//3,o),
nn.GELU(),
nn.Dropout(0.5),
# nn.Sigmoid()
)
plan_mat_size = 21*21
q_in_size = 3*plan_emb_out+dlist_hidden
q_in_size = 3*plan_emb_out+dlist_hidden+plan_mat_size
q_in_size = dlist_hidden+plan_mat_size
# self.plan_out = plan_layer(q_in_size,plan_mat_size)
self.plan_out = plan_layer(q_in_size,plan_mat_size)
# self.q01 = qlayer(q_in_size,2)
# self.q02 = qlayer(q_in_size,2)
# self.q03 = qlayer(q_in_size,2)
# self.q11 = qlayer(q_in_size,3)
# self.q12 = qlayer(q_in_size,3)
# self.q13 = qlayer(q_in_size,22)
# self.q21 = qlayer(q_in_size,3)
# self.q22 = qlayer(q_in_size,3)
# self.q23 = qlayer(q_in_size,22)
def forward(self,game,global_plan=False, player_plan=False,evaluation=False, incremental=False):
retval = []
l = list(game)
_,d,l,q,f,_,intermediate,_ = zip(*list(game))
# print(np.array(intermediate).shape)
# exit()
h = None
intermediate = np.array(intermediate)
f = np.array(f, dtype=np.uint8)
# f = torch.tensor(f).permute(0,3,1,2).float().to(self.device)
# flt_lst = [(a,b) for a,b in zip(d,q) if (not a is None) or (not b is None)]
# if not flt_lst:
# return []
# d,q = zip(*flt_lst)
d = np.stack([np.concatenate(([int(x[0][1]==2),int(x[0][1]==1)],x[0][-1])) if not x is None else np.zeros(1026) for x in d])
# def parse_q(q):
# if not q is None:
# q ,l = q
# q = np.concatenate([
# onehot(q[2],2),
# onehot(q[3],2),
# onehot(q[4][0][0]+1,2),
# onehot(game.materials_dict[q[4][0][1]],len(game.materials_dict)),
# onehot(q[4][1][0]+1,2),
# onehot(game.materials_dict[q[4][1][1]],len(game.materials_dict)),
# onehot(q[4][2]+1,2),
# onehot(q[5][0][0]+1,2),
# onehot(game.materials_dict[q[5][0][1]],len(game.materials_dict)),
# onehot(q[5][1][0]+1,2),
# onehot(game.materials_dict[q[5][1][1]],len(game.materials_dict)),
# onehot(q[5][2]+1,2)
# ])
# else:
# q = np.zeros(100)
# l = None
# return q, l
try:
sel1 = int([x[0][2] for x in q if not x is None][0] == 1)
sel2 = 1 - sel1
except Exception as e:
sel1 = 0
sel2 = 0
# q = [parse_q(x) for x in q]
# q, l = zip(*q)
if not global_plan and not player_plan:
plan_emb = torch.cat([
self.plan_embedder0(torch.stack(list(map(torch.tensor,game.global_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
self.plan_embedder1(torch.stack(list(map(torch.tensor,game.player1_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
self.plan_embedder2(torch.stack(list(map(torch.tensor,game.player2_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0]
])
plan_emb = 0*plan_emb
elif global_plan:
plan_emb = torch.cat([
self.plan_embedder0(torch.stack(list(map(torch.tensor,game.global_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
self.plan_embedder1(torch.stack(list(map(torch.tensor,game.player1_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
self.plan_embedder2(torch.stack(list(map(torch.tensor,game.player2_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0]
])
else:
plan_emb = torch.cat([
0*self.plan_embedder0(torch.stack(list(map(torch.tensor,game.global_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
sel1*self.plan_embedder1(torch.stack(list(map(torch.tensor,game.player1_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
sel2*self.plan_embedder2(torch.stack(list(map(torch.tensor,game.player2_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0]
])
# if sel1 == 0 and sel2 == 0:
# print(torch.unique(plan_emb))
u = torch.cat((
torch.tensor(d).float().to(self.device),
# torch.tensor(q).float().to(self.device),
self.conv(torch.tensor(f).permute(0,3,1,2).float().to(self.device)).reshape(-1,512),
torch.tensor(intermediate).float().to(self.device)
),axis=-1)
u = u.float().to(self.device)
# print(d.shape)
# print(self.conv(torch.tensor(f).permute(0,3,1,2).float().to(self.device)).reshape(-1,512).shape)
# print(intermediate.shape)
# print(u.shape)
y = self.dialogue_listener(u)
y = y.reshape(-1,y.shape[-1])
# print(y[-1].shape,plan_emb.shape,torch.tensor(game.plan_repr).float().to(self.device).shape)
# return self.plan_out(torch.cat((y[-1],plan_emb))), y
# return self.plan_out(torch.cat((y[-1],plan_emb,torch.tensor(game.plan_repr.reshape(-1)).float().to(self.device)))), y
if incremental:
prediction = torch.stack([
self.plan_out(torch.cat((y[0],torch.tensor(game.plan_repr.reshape(-1)).float().to(self.device))))] + [
self.plan_out(torch.cat((f,torch.tensor(game.plan_repr.reshape(-1)).float().to(self.device)))) for f in y[len(y)%10-1::10]
])
prediction = F.softmax(prediction.reshape(-1,21,21),-1).reshape(-1,21*21)
else:
prediction = self.plan_out(torch.cat((y[-1],torch.tensor(game.plan_repr.reshape(-1)).float().to(self.device))))
prediction = F.softmax(prediction.reshape(21,21),-1).reshape(21*21)
# prediction = F.softmax(prediction,-1)
# prediction = F.softmax(prediction,-1)
# exit()
return prediction, y
# exit()
# if all([x is None for x in l]):
# return []
# fun_lst = [self.q01,self.q02,self.q03,self.q11,self.q12,self.q13,self.q21,self.q22,self.q23]
# fun = lambda x: [f(x) for f in fun_lst]
# retval = [(_l,fun(torch.cat((plan_emb,torch.tensor(_q).float().to(self.device),_y)))) for _y, _q, _l in zip(y,q,l) if not _l is None]
# return retval

View file

@ -0,0 +1,230 @@
import torch
import torch.nn as nn, numpy as np
from torch.nn import functional as F
from torch_geometric.nn import MeanAggregation, GATv2Conv
class PlanGraphEmbedder(nn.Module):
def __init__(self, device, h_dim, dropout=0.0, heads=4):
super().__init__()
self.device = device
self.proj_x = nn.Sequential(
nn.Linear(27, h_dim),
nn.GELU(),
nn.Dropout(dropout)
)
self.proj_edge_attr = nn.Sequential(
nn.Linear(12, h_dim),
nn.GELU(),
nn.Dropout(dropout)
)
self.conv1 = GATv2Conv(h_dim, h_dim, heads=heads, edge_dim=h_dim)
self.conv2 = GATv2Conv(h_dim*heads, h_dim, heads=1, edge_dim=h_dim)
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.dec = nn.Linear(h_dim*3, 1)
def encode(self, data):
x, edge_index, edge_attr = data.features.to(self.device), data.edge_index.to(self.device), data.tool.to(self.device)
x = self.proj_x(x)
edge_attr = self.proj_edge_attr(edge_attr)
x = self.conv1(x, edge_index, edge_attr)
x = self.act(x)
x = self.dropout(x)
x = self.conv2(x, edge_index, edge_attr)
return x, edge_attr
def decode(self, z, context, edge_label_index):
u = z[edge_label_index[0]]
v = z[edge_label_index[1]]
return self.dec(torch.cat((u, v, context), -1))
# def decode(self, z, edge_index, edge_attr, edge_label_index):
# z = self.conv3(z, edge_index.to(self.device), edge_attr)
# return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)
class Model(nn.Module):
def __init__(self, seq_model_type, device):
super(Model, self).__init__()
self.device = device
plan_emb_out = 128
self.plan_embedder0 = PlanGraphEmbedder(device, plan_emb_out)
self.plan_embedder1 = PlanGraphEmbedder(device, plan_emb_out)
self.plan_embedder2 = PlanGraphEmbedder(device, plan_emb_out)
self.plan_pool = MeanAggregation()
dlist_hidden = 1024
frame_emb = 512
drnn_in = 5*1024 + 2 + frame_emb + 1024
self.dialogue_listener_pre_ln = nn.LayerNorm(drnn_in)
if seq_model_type==0:
self.dialogue_listener_rnn = nn.GRU(drnn_in, dlist_hidden)
self.dialogue_listener = lambda x: \
self.dialogue_listener_rnn(x.reshape(-1, 1, drnn_in))[0]
elif seq_model_type==1:
self.dialogue_listener_rnn = nn.LSTM(drnn_in, dlist_hidden)
self.dialogue_listener = lambda x: \
self.dialogue_listener_rnn(x.reshape(-1, 1, drnn_in))[0]
elif seq_model_type==2:
mask_fun = lambda x: torch.triu(torch.ones(x.shape[0], x.shape[0]), diagonal=1).bool().to(device)
sincos_fun = lambda x:torch.transpose(torch.stack([
torch.sin(2*np.pi*torch.tensor(list(range(x)))/x),
torch.cos(2*np.pi*torch.tensor(list(range(x)))/x)
]),0,1).reshape(-1,1,2)
self.dialogue_listener_lin1 = nn.Linear(drnn_in, dlist_hidden-2)
self.dialogue_listener_attn = nn.MultiheadAttention(dlist_hidden, 8)
self.dialogue_listener_wrap = lambda x: self.dialogue_listener_attn(x, x, x, attn_mask=mask_fun(x))
self.dialogue_listener = lambda x: self.dialogue_listener_wrap(torch.cat([
sincos_fun(x.shape[0]).float().to(self.device),
self.dialogue_listener_lin1(x).reshape(-1, 1, dlist_hidden-2)
], axis=-1))[0]
else:
print('Sequence model type must be in (0: GRU, 1: LSTM, 2: Transformer), but got ', seq_model_type)
exit()
conv_block = lambda i, o, k, p, s: nn.Sequential(
nn.Conv2d(i, o, k, padding=p, stride=s),
nn.BatchNorm2d(o),
nn.Dropout(0.5),
nn.MaxPool2d(2),
nn.BatchNorm2d(o),
nn.Dropout(0.5),
nn.ReLU()
)
self.conv = nn.Sequential(
conv_block(3, 8, 3, 1, 1),
conv_block(8, 32, 5, 2, 2),
conv_block(32, frame_emb//4, 5, 2, 2),
nn.Conv2d(frame_emb//4, frame_emb, 3),
nn.ReLU(),
)
self.proj_y = nn.Sequential(
nn.Linear(1024, 512),
nn.GELU(),
nn.Dropout(0.5),
nn.Linear(512, plan_emb_out)
)
def forward(self, game, experiment, global_plan=False, player_plan=False, incremental=False, return_feats=False):
_, d, l, q, f, _, intermediate, _ = zip(*list(game))
intermediate = np.array(intermediate)
f = np.array(f, dtype=np.uint8)
d = np.stack([np.concatenate(([int(x[0][1]==2), int(x[0][1]==1)], x[0][-1])) if not x is None else np.zeros(1026) for x in d])
try:
sel1 = int([x[0][2] for x in q if not x is None][0] == 1)
sel2 = 1 - sel1
except Exception as e:
sel1 = 0
sel2 = 0
if player_plan:
if sel1:
z, _ = self.plan_embedder1.encode(game.player1_plan)
elif sel2:
z, _ = self.plan_embedder2.encode(game.player2_plan)
else:
z, _ = self.plan_embedder0.encode(game.global_plan)
else:
raise ValueError('There should never be a global plan!')
u = torch.cat((
torch.tensor(d).float().to(self.device),
self.conv(torch.tensor(f).permute(0, 3, 1, 2).float().to(self.device) / 255.0).reshape(-1, 512),
torch.tensor(intermediate).float().to(self.device)
), axis=-1)
u = u.float().to(self.device)
u = self.dialogue_listener_pre_ln(u)
y = self.dialogue_listener(u)
y = y.reshape(-1, y.shape[-1])
if return_feats:
_y = y.clone().detach().cpu().numpy()
y = self.proj_y(y)
if experiment == 2:
pred, label = self.decode_own_missing_knowledge(z, y, game, sel1, sel2, incremental)
elif experiment == 3:
pred, label = self.decode_partner_missing_knowledge(z, y, game, sel1, sel2, incremental)
else:
raise ValueError('Wrong experiment id! Valid values are 2 and 3.')
if return_feats:
return pred, label, [sel1, sel2], _y
return pred, label, [sel1, sel2]
def decode_own_missing_knowledge(self, z, y, game, sel1, sel2, incremental):
if incremental:
if sel1:
pred = torch.stack(
[self.plan_embedder1.decode(z, y[0].repeat(game.player1_edge_label_index.shape[1], 1), game.player1_edge_label_index).view(-1)]
+
[self.plan_embedder1.decode(z, f.repeat(game.player1_edge_label_index.shape[1], 1), game.player1_edge_label_index).view(-1) for f in y[len(y)%10-1::10]]
)
label = game.player1_edge_label_own_missing_knowledge.to(self.device)
elif sel2:
pred = torch.stack(
[self.plan_embedder2.decode(z, y[0].repeat(game.player2_edge_label_index.shape[1], 1), game.player2_edge_label_index).view(-1)]
+
[self.plan_embedder2.decode(z, f.repeat(game.player2_edge_label_index.shape[1], 1), game.player2_edge_label_index).view(-1) for f in y[len(y)%10-1::10]]
)
label = game.player2_edge_label_own_missing_knowledge.to(self.device)
else:
pred = torch.stack(
# NOTE: here I use game.global_plan.edge_index instead of edge_label_index => no negative sampling
[self.plan_embedder0.decode(z, y[0].repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1)]
+
[self.plan_embedder0.decode(z, f.repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1) for f in y[len(y)%10-1::10]]
)
label = torch.zeros(game.global_plan.edge_index.shape[1])
else:
if sel1:
pred = self.plan_embedder1.decode(z, y.mean(0, keepdim=True).repeat(game.player1_edge_label_index.shape[1], 1), game.player1_edge_label_index).view(-1)
label = game.player1_edge_label_own_missing_knowledge.to(self.device)
elif sel2:
pred = self.plan_embedder2.decode(z, y.mean(0, keepdim=True).repeat(game.player2_edge_label_index.shape[1], 1), game.player2_edge_label_index).view(-1)
label = game.player2_edge_label_own_missing_knowledge.to(self.device)
else:
# NOTE: here I use game.global_plan.edge_index instead of edge_label_index => no negative sampling
pred = self.plan_embedder0.decode(z, y.mean(0, keepdim=True).repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1)
label = torch.zeros(game.global_plan.edge_index.shape[1])
return (pred, label)
def decode_partner_missing_knowledge(self, z, y, game, sel1, sel2, incremental):
if incremental:
if sel1:
pred = torch.stack(
[self.plan_embedder1.decode(z, y[0].repeat(game.player1_plan.edge_index.shape[1], 1), game.player1_plan.edge_index).view(-1)]
+
[self.plan_embedder1.decode(z, f.repeat(game.player1_plan.edge_index.shape[1], 1), game.player1_plan.edge_index).view(-1) for f in y[len(y)%10-1::10]]
)
label = game.player1_edge_label_other_missing_knowledge.to(self.device)
elif sel2:
pred = torch.stack(
[self.plan_embedder2.decode(z, y[0].repeat(game.player2_plan.edge_index.shape[1], 1), game.player2_plan.edge_index).view(-1)]
+
[self.plan_embedder2.decode(z, f.repeat(game.player2_plan.edge_index.shape[1], 1), game.player2_plan.edge_index).view(-1) for f in y[len(y)%10-1::10]]
)
label = game.player2_edge_label_other_missing_knowledge.to(self.device)
else:
pred = torch.stack(
# NOTE: here I use game.global_plan.edge_index instead of edge_label_index => no negative sampling
[self.plan_embedder0.decode(z, y[0].repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1)]
+
[self.plan_embedder0.decode(z, f.repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1) for f in y[len(y)%10-1::10]]
)
label = torch.zeros(game.global_plan.edge_index.shape[1])
else:
if sel1:
pred = self.plan_embedder1.decode(z, y.mean(0, keepdim=True).repeat(game.player1_plan.edge_index.shape[1], 1), game.player1_plan.edge_index).view(-1)
label = game.player1_edge_label_other_missing_knowledge.to(self.device)
elif sel2:
pred = self.plan_embedder2.decode(z, y.mean(0, keepdim=True).repeat(game.player2_plan.edge_index.shape[1], 1), game.player2_plan.edge_index).view(-1)
label = game.player2_edge_label_other_missing_knowledge.to(self.device)
else:
# NOTE: here I use game.global_plan.edge_index instead of edge_label_index => no negative sampling
pred = self.plan_embedder0.decode(z, y.mean(0, keepdim=True).repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1)
label = torch.zeros(game.global_plan.edge_index.shape[1])
return (pred, label)

View file

@ -0,0 +1,291 @@
import torch
import torch.nn as nn, numpy as np
from torch.nn import functional as F
from torch_geometric.nn import MeanAggregation, GATv2Conv
def onehot(x,n):
retval = np.zeros(n)
if x > 0:
retval[x-1] = 1
return retval
class PlanGraphEmbedder(nn.Module):
def __init__(self, device, h_dim, dropout=0.0, heads=4):
super().__init__()
self.device = device
self.proj_x = nn.Sequential(
nn.Linear(27, h_dim),
nn.GELU(),
nn.Dropout(dropout)
)
self.proj_edge_attr = nn.Sequential(
nn.Linear(12, h_dim),
nn.GELU(),
nn.Dropout(dropout)
)
self.conv1 = GATv2Conv(h_dim, h_dim, heads=heads, edge_dim=h_dim)
self.conv2 = GATv2Conv(h_dim*heads, h_dim, heads=1, edge_dim=h_dim)
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.dec = nn.Linear(h_dim*3, 1)
def encode(self, data):
x, edge_index, edge_attr = data.features.to(self.device), data.edge_index.to(self.device), data.tool.to(self.device)
x = self.proj_x(x)
edge_attr = self.proj_edge_attr(edge_attr)
x = self.conv1(x, edge_index, edge_attr)
x = self.act(x)
x = self.dropout(x)
x = self.conv2(x, edge_index, edge_attr)
return x, edge_attr
def decode(self, z, context, edge_label_index):
u = z[edge_label_index[0]]
v = z[edge_label_index[1]]
return self.dec(torch.cat((u, v, context), -1))
class Model(nn.Module):
def __init__(self, seq_model_type, device):
super(Model, self).__init__()
self.device = device
plan_emb_out = 128
self.plan_embedder0 = PlanGraphEmbedder(device, plan_emb_out)
self.plan_embedder1 = PlanGraphEmbedder(device, plan_emb_out)
self.plan_embedder2 = PlanGraphEmbedder(device, plan_emb_out)
self.plan_pool = MeanAggregation()
dlist_hidden = 1024
frame_emb = 512
drnn_in = 5*1024 + 2 + frame_emb + 1024
self.dialogue_listener_pre_ln = nn.LayerNorm(drnn_in)
if seq_model_type==0:
self.dialogue_listener_rnn = nn.GRU(drnn_in, dlist_hidden)
self.dialogue_listener = lambda x: \
self.dialogue_listener_rnn(x.reshape(-1, 1, drnn_in))[0]
elif seq_model_type==1:
self.dialogue_listener_rnn = nn.LSTM(drnn_in, dlist_hidden)
self.dialogue_listener = lambda x: \
self.dialogue_listener_rnn(x.reshape(-1, 1, drnn_in))[0]
elif seq_model_type==2:
mask_fun = lambda x: torch.triu(torch.ones(x.shape[0], x.shape[0]), diagonal=1).bool().to(device)
sincos_fun = lambda x:torch.transpose(torch.stack([
torch.sin(2*np.pi*torch.tensor(list(range(x)))/x),
torch.cos(2*np.pi*torch.tensor(list(range(x)))/x)
]),0,1).reshape(-1,1,2)
self.dialogue_listener_lin1 = nn.Linear(drnn_in, dlist_hidden-2)
self.dialogue_listener_attn = nn.MultiheadAttention(dlist_hidden, 8)
self.dialogue_listener_wrap = lambda x: self.dialogue_listener_attn(x, x, x, attn_mask=mask_fun(x))
self.dialogue_listener = lambda x: self.dialogue_listener_wrap(torch.cat([
sincos_fun(x.shape[0]).float().to(self.device),
self.dialogue_listener_lin1(x).reshape(-1, 1, dlist_hidden-2)
], axis=-1))[0]
else:
print('Sequence model type must be in (0: GRU, 1: LSTM, 2: Transformer), but got ', seq_model_type)
exit()
conv_block = lambda i, o, k, p, s: nn.Sequential(
nn.Conv2d(i, o, k, padding=p, stride=s),
nn.BatchNorm2d(o),
nn.Dropout(0.5),
nn.MaxPool2d(2),
nn.BatchNorm2d(o),
nn.Dropout(0.5),
nn.ReLU()
)
self.conv = nn.Sequential(
conv_block(3, 8, 3, 1, 1),
conv_block(8, 32, 5, 2, 2),
conv_block(32, frame_emb//4, 5, 2, 2),
nn.Conv2d(frame_emb//4, frame_emb, 3),
nn.ReLU(),
)
self.proj_y = nn.Sequential(
nn.Linear(1024, 512),
nn.GELU(),
nn.Dropout(0.5),
nn.Linear(512, plan_emb_out)
)
self.proj_tom = nn.Linear(154, 5*1024)
def parse_ql(self, q, game, intermediate):
tom12_answ = ['YES', 'NO', 'MAYBE']
materials_dict = game.materials_dict.copy()
materials_dict['NOT_SURE'] = 0
if not q is None:
q, l = q
tom_gt = np.concatenate([onehot(q[2],2), onehot(q[3],2)])
#### q1
q1_1 = np.concatenate([onehot(q[4][0][0]+1,2), onehot(materials_dict[q[4][0][1]], len(game.materials_dict))])
## r1
r1_1 = np.eye(len(tom12_answ))[tom12_answ.index(l[0][0])]
#### q2
q2_1 = np.concatenate([onehot(q[4][1][0]+1,2), onehot(materials_dict[q[4][1][1]], len(game.materials_dict))])
## r2
r2_1 = np.eye(len(tom12_answ))[tom12_answ.index(l[0][1])]
#### q3
q3_1 = onehot(q[4][2]+1, 2)
## r3
r3_1 = onehot(materials_dict[l[0][2]], len(game.materials_dict))
#### q1
q1_2 = np.concatenate([onehot(q[5][0][0]+1,2), onehot(materials_dict[q[5][0][1]], len(game.materials_dict))])
## r1
r1_2 = np.eye(len(tom12_answ))[tom12_answ.index(l[1][0])]
#### q2
q2_2 = np.concatenate([onehot(q[5][1][0]+1,2), onehot(materials_dict[q[5][1][1]], len(game.materials_dict))])
## r2
r2_2 = np.eye(len(tom12_answ))[tom12_answ.index(l[1][1])]
#### q3
q3_2 = onehot(q[5][2]+1,2)
## r3
r3_2 = onehot(materials_dict[l[1][2]], len(game.materials_dict))
if intermediate == 0:
tom_gt = np.zeros(154)
elif intermediate == 1:
# tom6
tom_gt = np.concatenate([tom_gt, q1_1, r1_1, np.zeros(q2_1.shape[0] + r2_1.shape[0] + q3_1.shape[0] + r3_1.shape[0]), q1_2, r1_2, np.zeros(q2_2.shape[0] + r2_2.shape[0] + q3_2.shape[0] + r3_2.shape[0])])
elif intermediate == 2:
# tom7
tom_gt = np.concatenate([tom_gt, np.zeros(q1_1.shape[0] + r1_1.shape[0]), q2_1, r2_1, np.zeros(q3_1.shape[0] + r3_1.shape[0] + q1_2.shape[0] + r1_2.shape[0]), q2_2, r2_2, np.zeros(q3_2.shape[0] + r3_2.shape[0])])
elif intermediate == 3:
# tom6 + tom7
tom_gt = np.concatenate([tom_gt, q1_1, r1_1, q2_1, r2_1, np.zeros(q3_1.shape[0] + r3_1.shape[0]), q1_2, r1_2, q2_2, r2_2, np.zeros(q3_2.shape[0] + r3_2.shape[0])])
elif intermediate == 4:
# tom8
tom_gt = np.concatenate([tom_gt, np.zeros(q1_1.shape[0] + r1_1.shape[0] + q2_1.shape[0] + r2_1.shape[0]), q3_1, r3_1, np.zeros(q1_2.shape[0] + r1_2.shape[0] + q2_2.shape[0] + r2_2.shape[0]), q3_2, r3_2])
elif intermediate == 5:
# tom6 + tom8
tom_gt = np.concatenate([tom_gt, q1_1, r1_1, np.zeros(q2_1.shape[0] + r2_1.shape[0]), q3_1, r3_1, q1_2, r1_2, np.zeros(q2_2.shape[0] + r2_2.shape[0]), q3_2, r3_2])
elif intermediate == 6:
# tom7 + tom8
tom_gt = np.concatenate([tom_gt, np.zeros(q1_1.shape[0] + r1_1.shape[0]), q2_1, r2_1, q3_1, r3_1, np.zeros(q1_2.shape[0] + r1_2.shape[0]), q2_2, r2_2, q3_2, r3_2])
elif intermediate == 7:
# tom6 + tom7 + tom8
tom_gt = np.concatenate([tom_gt, q1_1, r1_1, q2_1, r2_1, q3_1, r3_1, q1_2, r1_2, q2_2, r2_2, q3_2, r3_2])
else:
tom_gt = np.zeros(154)
if tom_gt.shape[0] != 154: breakpoint()
return tom_gt
def forward(self, game, experiment, global_plan=False, player_plan=False, incremental=False, intermediate=0):
l = list(game)
_, d, l, q, f, _, _, _ = zip(*list(game))
tom_gt = [self.parse_ql(x, game, intermediate) for x in q]
tom_gt = torch.tensor(np.stack(tom_gt), device=self.device, dtype=torch.float32)
tom_gt = self.proj_tom(tom_gt)
f = np.array(f, dtype=np.uint8)
d = np.stack([np.concatenate(([int(x[0][1]==2), int(x[0][1]==1)], x[0][-1])) if not x is None else np.zeros(1026) for x in d])
try:
sel1 = int([x[0][2] for x in q if not x is None][0] == 1)
sel2 = 1 - sel1
except Exception as e:
sel1 = 0
sel2 = 0
if player_plan:
if sel1:
z, _ = self.plan_embedder1.encode(game.player1_plan)
elif sel2:
z, _ = self.plan_embedder2.encode(game.player2_plan)
else:
z, _ = self.plan_embedder0.encode(game.global_plan)
else:
raise ValueError('There should never be a global plan!')
u = torch.cat((
torch.tensor(d).float().to(self.device),
self.conv(torch.tensor(f).permute(0, 3, 1, 2).float().to(self.device) / 255.0).reshape(-1, 512),
tom_gt
), axis=-1)
u = u.float().to(self.device)
u = self.dialogue_listener_pre_ln(u)
y = self.dialogue_listener(u)
y = y.reshape(-1, y.shape[-1])
y = self.proj_y(y)
if experiment == 2:
pred, label = self.decode_own_missing_knowledge(z, y, game, sel1, sel2, incremental)
elif experiment == 3:
pred, label = self.decode_partner_missing_knowledge(z, y, game, sel1, sel2, incremental)
else:
raise ValueError('Wrong experiment id! Valid values are 2 and 3.')
return pred, label, [sel1, sel2]
def decode_own_missing_knowledge(self, z, y, game, sel1, sel2, incremental):
if incremental:
if sel1:
pred = torch.stack(
[self.plan_embedder1.decode(z, y[0].repeat(game.player1_edge_label_index.shape[1], 1), game.player1_edge_label_index).view(-1)]
+
[self.plan_embedder1.decode(z, f.repeat(game.player1_edge_label_index.shape[1], 1), game.player1_edge_label_index).view(-1) for f in y[len(y)%10-1::10]]
)
label = game.player1_edge_label_own_missing_knowledge.to(self.device)
elif sel2:
pred = torch.stack(
[self.plan_embedder2.decode(z, y[0].repeat(game.player2_edge_label_index.shape[1], 1), game.player2_edge_label_index).view(-1)]
+
[self.plan_embedder2.decode(z, f.repeat(game.player2_edge_label_index.shape[1], 1), game.player2_edge_label_index).view(-1) for f in y[len(y)%10-1::10]]
)
label = game.player2_edge_label_own_missing_knowledge.to(self.device)
else:
pred = torch.stack(
# NOTE: here I use game.global_plan.edge_index instead of edge_label_index => no negative sampling
[self.plan_embedder0.decode(z, y[0].repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1)]
+
[self.plan_embedder0.decode(z, f.repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1) for f in y[len(y)%10-1::10]]
)
label = torch.zeros(game.global_plan.edge_index.shape[1])
else:
if sel1:
pred = self.plan_embedder1.decode(z, y.mean(0, keepdim=True).repeat(game.player1_edge_label_index.shape[1], 1), game.player1_edge_label_index).view(-1)
label = game.player1_edge_label_own_missing_knowledge.to(self.device)
elif sel2:
pred = self.plan_embedder2.decode(z, y.mean(0, keepdim=True).repeat(game.player2_edge_label_index.shape[1], 1), game.player2_edge_label_index).view(-1)
label = game.player2_edge_label_own_missing_knowledge.to(self.device)
else:
# NOTE: here I use game.global_plan.edge_index instead of edge_label_index => no negative sampling
pred = self.plan_embedder0.decode(z, y.mean(0, keepdim=True).repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1)
label = torch.zeros(game.global_plan.edge_index.shape[1])
return (pred, label)
def decode_partner_missing_knowledge(self, z, y, game, sel1, sel2, incremental):
if incremental:
if sel1:
pred = torch.stack(
[self.plan_embedder1.decode(z, y[0].repeat(game.player1_plan.edge_index.shape[1], 1), game.player1_plan.edge_index).view(-1)]
+
[self.plan_embedder1.decode(z, f.repeat(game.player1_plan.edge_index.shape[1], 1), game.player1_plan.edge_index).view(-1) for f in y[len(y)%10-1::10]]
)
label = game.player1_edge_label_other_missing_knowledge.to(self.device)
elif sel2:
pred = torch.stack(
[self.plan_embedder2.decode(z, y[0].repeat(game.player2_plan.edge_index.shape[1], 1), game.player2_plan.edge_index).view(-1)]
+
[self.plan_embedder2.decode(z, f.repeat(game.player2_plan.edge_index.shape[1], 1), game.player2_plan.edge_index).view(-1) for f in y[len(y)%10-1::10]]
)
label = game.player2_edge_label_other_missing_knowledge.to(self.device)
else:
pred = torch.stack(
# NOTE: here I use game.global_plan.edge_index instead of edge_label_index => no negative sampling
[self.plan_embedder0.decode(z, y[0].repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1)]
+
[self.plan_embedder0.decode(z, f.repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1) for f in y[len(y)%10-1::10]]
)
label = torch.zeros(game.global_plan.edge_index.shape[1])
else:
if sel1:
pred = self.plan_embedder1.decode(z, y.mean(0, keepdim=True).repeat(game.player1_plan.edge_index.shape[1], 1), game.player1_plan.edge_index).view(-1)
label = game.player1_edge_label_other_missing_knowledge.to(self.device)
elif sel2:
pred = self.plan_embedder2.decode(z, y.mean(0, keepdim=True).repeat(game.player2_plan.edge_index.shape[1], 1), game.player2_plan.edge_index).view(-1)
label = game.player2_edge_label_other_missing_knowledge.to(self.device)
else:
# NOTE: here I use game.global_plan.edge_index instead of edge_label_index => no negative sampling
pred = self.plan_embedder0.decode(z, y.mean(0, keepdim=True).repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1)
label = torch.zeros(game.global_plan.edge_index.shape[1])
return (pred, label)

View file

@ -0,0 +1,214 @@
import sys, torch, random
from numpy.core.fromnumeric import reshape
import torch.nn as nn, numpy as np
from torch.nn import functional as F
from src.data.game_parser import DEVICE
def onehot(x,n):
retval = np.zeros(n)
if x > 0:
retval[x-1] = 1
return retval
class Model(nn.Module):
def __init__(self, seq_model_type=0,device=DEVICE):
super(Model, self).__init__()
self.device = device
my_rnn = lambda i,o: nn.GRU(i,o)
plan_emb_in = 81
plan_emb_out = 32
q_emb = 100
self.plan_embedder0 = my_rnn(plan_emb_in,plan_emb_out)
self.plan_embedder1 = my_rnn(plan_emb_in,plan_emb_out)
self.plan_embedder2 = my_rnn(plan_emb_in,plan_emb_out)
dlist_hidden = 1024
frame_emb = 512
drnn_in = 5*1024 + 2 + frame_emb + 1024
my_rnn = lambda i,o: nn.LSTM(i,o)
if seq_model_type==0:
self.dialogue_listener_rnn = nn.GRU(drnn_in,dlist_hidden)
self.dialogue_listener = lambda x: \
self.dialogue_listener_rnn(x.reshape(-1,1,drnn_in))[0]
elif seq_model_type==1:
self.dialogue_listener_rnn = nn.LSTM(drnn_in,dlist_hidden)
self.dialogue_listener = lambda x: \
self.dialogue_listener_rnn(x.reshape(-1,1,drnn_in))[0]
elif seq_model_type==2:
mask_fun = lambda x: torch.triu(torch.ones(x.shape[0],x.shape[0]),diagonal=1).bool().to(device)
sincos_fun = lambda x:torch.transpose(torch.stack([
torch.sin(2*np.pi*torch.tensor(list(range(x)))/x),
torch.cos(2*np.pi*torch.tensor(list(range(x)))/x)
]),0,1).reshape(-1,1,2)
self.dialogue_listener_lin1 = nn.Linear(drnn_in,dlist_hidden-2)
self.dialogue_listener_attn = nn.MultiheadAttention(dlist_hidden, 8)
self.dialogue_listener_wrap = lambda x: self.dialogue_listener_attn(x,x,x,attn_mask=mask_fun(x))
self.dialogue_listener = lambda x: self.dialogue_listener_wrap(torch.cat([
sincos_fun(x.shape[0]).float().to(self.device),
self.dialogue_listener_lin1(x).reshape(-1,1,dlist_hidden-2)
], axis=-1))[0]
else:
print('Sequence model type must be in (0: GRU, 1: LSTM, 2: Transformer), but got ', seq_model_type)
exit()
conv_block = lambda i,o,k,p,s: nn.Sequential(
nn.Conv2d( i, o, k, padding=p, stride=s),
nn.BatchNorm2d(o),
nn.Dropout(0.5),
nn.MaxPool2d(2),
nn.BatchNorm2d(o),
nn.Dropout(0.5),
nn.ReLU()
)
self.conv = nn.Sequential(
conv_block( 3, 8, 3, 1, 1),
conv_block( 8, 32, 5, 2, 2),
conv_block( 32, frame_emb//4, 5, 2, 2),
nn.Conv2d( frame_emb//4, frame_emb, 3),nn.ReLU(),
)
plan_layer = lambda i,o : nn.Sequential(
nn.Linear(i,(i+2*o)//3),
nn.Dropout(0.5),
nn.GELU(),
nn.Dropout(0.5),
nn.Linear((i+2*o)//3,o),
nn.GELU(),
nn.Dropout(0.5),
)
plan_mat_size = 21*21
q_in_size = 3*plan_emb_out+dlist_hidden
q_in_size = 3*plan_emb_out+dlist_hidden+plan_mat_size
q_in_size = dlist_hidden+plan_mat_size
self.plan_out = plan_layer(q_in_size,plan_mat_size)
self.proj_tom = nn.Linear(154, 5*1024)
def forward(self,game,global_plan=False, player_plan=False,evaluation=False, incremental=False, intermediate=0):
_,d,l,q,f,_,_,_ = zip(*list(game))
tom_gt = [self.parse_ql(x, game, intermediate) for x in q]
tom_gt = torch.tensor(np.stack(tom_gt), device=self.device, dtype=torch.float32)
tom_gt = self.proj_tom(tom_gt)
f = np.array(f, dtype=np.uint8)
d = np.stack([np.concatenate(([int(x[0][1]==2),int(x[0][1]==1)],x[0][-1])) if not x is None else np.zeros(1026) for x in d])
try:
sel1 = int([x[0][2] for x in q if not x is None][0] == 1)
sel2 = 1 - sel1
except Exception as e:
sel1 = 0
sel2 = 0
if not global_plan and not player_plan:
plan_emb = torch.cat([
self.plan_embedder0(torch.stack(list(map(torch.tensor,game.global_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
self.plan_embedder1(torch.stack(list(map(torch.tensor,game.player1_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
self.plan_embedder2(torch.stack(list(map(torch.tensor,game.player2_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0]
])
plan_emb = 0*plan_emb
elif global_plan:
plan_emb = torch.cat([
self.plan_embedder0(torch.stack(list(map(torch.tensor,game.global_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
self.plan_embedder1(torch.stack(list(map(torch.tensor,game.player1_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
self.plan_embedder2(torch.stack(list(map(torch.tensor,game.player2_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0]
])
else:
plan_emb = torch.cat([
0*self.plan_embedder0(torch.stack(list(map(torch.tensor,game.global_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
sel1*self.plan_embedder1(torch.stack(list(map(torch.tensor,game.player1_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
sel2*self.plan_embedder2(torch.stack(list(map(torch.tensor,game.player2_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0]
])
u = torch.cat((
torch.tensor(d).float().to(self.device),
self.conv(torch.tensor(f).permute(0,3,1,2).float().to(self.device)).reshape(-1,512),
tom_gt
),axis=-1)
u = u.float().to(self.device)
y = self.dialogue_listener(u)
y = y.reshape(-1,y.shape[-1])
if incremental:
prediction = torch.stack([
self.plan_out(torch.cat((y[0],torch.tensor(game.plan_repr.reshape(-1)).float().to(self.device))))] + [
self.plan_out(torch.cat((f,torch.tensor(game.plan_repr.reshape(-1)).float().to(self.device)))) for f in y[len(y)%10-1::10]
])
prediction = F.softmax(prediction.reshape(-1,21,21),-1).reshape(-1,21*21)
else:
prediction = self.plan_out(torch.cat((y[-1],torch.tensor(game.plan_repr.reshape(-1)).float().to(self.device))))
prediction = F.softmax(prediction.reshape(21,21),-1).reshape(21*21)
return prediction, y
def parse_ql(self, q, game, intermediate):
tom12_answ = ['YES', 'NO', 'MAYBE']
materials_dict = game.materials_dict.copy()
materials_dict['NOT_SURE'] = 0
if not q is None:
q, l = q
tom_gt = np.concatenate([onehot(q[2],2), onehot(q[3],2)])
#### q1
q1_1 = np.concatenate([onehot(q[4][0][0]+1,2), onehot(materials_dict[q[4][0][1]], len(game.materials_dict))])
## r1
r1_1 = np.eye(len(tom12_answ))[tom12_answ.index(l[0][0])]
#### q2
q2_1 = np.concatenate([onehot(q[4][1][0]+1,2), onehot(materials_dict[q[4][1][1]], len(game.materials_dict))])
## r2
r2_1 = np.eye(len(tom12_answ))[tom12_answ.index(l[0][1])]
#### q3
q3_1 = onehot(q[4][2]+1, 2)
## r3
r3_1 = onehot(materials_dict[l[0][2]], len(game.materials_dict))
#### q1
q1_2 = np.concatenate([onehot(q[5][0][0]+1,2), onehot(materials_dict[q[5][0][1]], len(game.materials_dict))])
## r1
r1_2 = np.eye(len(tom12_answ))[tom12_answ.index(l[1][0])]
#### q2
q2_2 = np.concatenate([onehot(q[5][1][0]+1,2), onehot(materials_dict[q[5][1][1]], len(game.materials_dict))])
## r2
r2_2 = np.eye(len(tom12_answ))[tom12_answ.index(l[1][1])]
#### q3
q3_2 = onehot(q[5][2]+1,2)
## r3
r3_2 = onehot(materials_dict[l[1][2]], len(game.materials_dict))
if intermediate == 0:
tom_gt = np.zeros(154)
elif intermediate == 1:
# tom6
tom_gt = np.concatenate([tom_gt, q1_1, r1_1, np.zeros(q2_1.shape[0] + r2_1.shape[0] + q3_1.shape[0] + r3_1.shape[0]), q1_2, r1_2, np.zeros(q2_2.shape[0] + r2_2.shape[0] + q3_2.shape[0] + r3_2.shape[0])])
elif intermediate == 2:
# tom7
tom_gt = np.concatenate([tom_gt, np.zeros(q1_1.shape[0] + r1_1.shape[0]), q2_1, r2_1, np.zeros(q3_1.shape[0] + r3_1.shape[0] + q1_2.shape[0] + r1_2.shape[0]), q2_2, r2_2, np.zeros(q3_2.shape[0] + r3_2.shape[0])])
elif intermediate == 3:
# tom6 + tom7
tom_gt = np.concatenate([tom_gt, q1_1, r1_1, q2_1, r2_1, np.zeros(q3_1.shape[0] + r3_1.shape[0]), q1_2, r1_2, q2_2, r2_2, np.zeros(q3_2.shape[0] + r3_2.shape[0])])
elif intermediate == 4:
# tom8
tom_gt = np.concatenate([tom_gt, np.zeros(q1_1.shape[0] + r1_1.shape[0] + q2_1.shape[0] + r2_1.shape[0]), q3_1, r3_1, np.zeros(q1_2.shape[0] + r1_2.shape[0] + q2_2.shape[0] + r2_2.shape[0]), q3_2, r3_2])
elif intermediate == 5:
# tom6 + tom8
tom_gt = np.concatenate([tom_gt, q1_1, r1_1, np.zeros(q2_1.shape[0] + r2_1.shape[0]), q3_1, r3_1, q1_2, r1_2, np.zeros(q2_2.shape[0] + r2_2.shape[0]), q3_2, r3_2])
elif intermediate == 6:
# tom7 + tom8
tom_gt = np.concatenate([tom_gt, np.zeros(q1_1.shape[0] + r1_1.shape[0]), q2_1, r2_1, q3_1, r3_1, np.zeros(q1_2.shape[0] + r1_2.shape[0]), q2_2, r2_2, q3_2, r3_2])
elif intermediate == 7:
# tom6 + tom7 + tom8
tom_gt = np.concatenate([tom_gt, q1_1, r1_1, q2_1, r2_1, q3_1, r3_1, q1_2, r1_2, q2_2, r2_2, q3_2, r3_2])
else:
tom_gt = np.zeros(154)
if tom_gt.shape[0] != 154: breakpoint()
return tom_gt