up
This commit is contained in:
parent
08780752d9
commit
e15b0d7b50
46 changed files with 14927 additions and 0 deletions
BIN
src/.DS_Store
vendored
Normal file
BIN
src/.DS_Store
vendored
Normal file
Binary file not shown.
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
BIN
src/data/.DS_Store
vendored
Normal file
BIN
src/data/.DS_Store
vendored
Normal file
Binary file not shown.
0
src/data/__init__.py
Normal file
0
src/data/__init__.py
Normal file
761
src/data/game_parser.py
Executable file
761
src/data/game_parser.py
Executable file
|
@ -0,0 +1,761 @@
|
|||
from email.mime import base
|
||||
from glob import glob
|
||||
import os, string, json, pickle
|
||||
import torch, random, numpy as np
|
||||
from transformers import BertTokenizer, BertModel
|
||||
import cv2
|
||||
import imageio
|
||||
from src.data.action_extractor import proc_action
|
||||
|
||||
|
||||
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
# def set_seed(seed_idx):
|
||||
# seed = 0
|
||||
# random.seed(0)
|
||||
# for _ in range(seed_idx):
|
||||
# seed = random.random()
|
||||
# random.seed(seed)
|
||||
# torch.manual_seed(seed)
|
||||
# print('Random seed set to', seed)
|
||||
# return seed
|
||||
|
||||
def set_seed(seed):
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
print('Random seed set to', seed)
|
||||
return seed
|
||||
|
||||
def make_splits(split_file = 'config/dataset_splits.json'):
|
||||
if not os.path.isfile(split_file):
|
||||
dirs = sorted(glob('data/saved_logs/*') + glob('data/main_logs/*'))
|
||||
games = sorted(list(map(GameParser, dirs)), key=lambda x: len(x.question_pairs), reverse=True)
|
||||
|
||||
test = games[0::5]
|
||||
val = games[1::5]
|
||||
train = games[2::5]+games[3::5]+games[4::5]
|
||||
|
||||
dataset_splits = {'test' : [g.game_path for g in test], 'validation' : [g.game_path for g in val], 'training' : [g.game_path for g in train]}
|
||||
json.dump(dataset_splits, open('config/dataset_splits_old.json','w'), indent=4)
|
||||
|
||||
|
||||
dirs = sorted(glob('data/new_logs/*'))
|
||||
games = sorted(list(map(GameParser, dirs)), key=lambda x: len(x.question_pairs), reverse=True)
|
||||
|
||||
test = games[0::5]
|
||||
val = games[1::5]
|
||||
train = games[2::5]+games[3::5]+games[4::5]
|
||||
|
||||
dataset_splits['test'] += [g.game_path for g in test]
|
||||
dataset_splits['validation'] += [g.game_path for g in val]
|
||||
dataset_splits['training'] += [g.game_path for g in train]
|
||||
json.dump(dataset_splits, open('config/dataset_splits_new.json','w'), indent=4)
|
||||
json.dump(dataset_splits, open('config/dataset_splits.json','w'), indent=4)
|
||||
|
||||
|
||||
dataset_splits['test'] = dataset_splits['test'][:2]
|
||||
dataset_splits['validation'] = dataset_splits['validation'][:2]
|
||||
dataset_splits['training'] = dataset_splits['training'][:2]
|
||||
json.dump(dataset_splits, open('config/dataset_splits_dev.json','w'), indent=4)
|
||||
|
||||
dataset_splits = json.load(open(split_file))
|
||||
|
||||
return dataset_splits
|
||||
|
||||
def onehot(x,n):
|
||||
retval = np.zeros(n)
|
||||
if x > 0:
|
||||
retval[x-1] = 1
|
||||
return retval
|
||||
|
||||
class GameParser:
|
||||
tokenizer = None
|
||||
model = None
|
||||
def __init__(self, game_path, load_dialogue=True, pov=0, intermediate=0, use_dialogue_moves=False):
|
||||
# print(game_path,end = ' ')
|
||||
self.load_dialogue = load_dialogue
|
||||
if pov not in (0,1,2,3,4):
|
||||
print('Point of view must be in (0,1,2,3,4), but got ', pov)
|
||||
exit()
|
||||
self.pov = pov
|
||||
self.use_dialogue_moves = use_dialogue_moves
|
||||
self.load_player1 = pov==1
|
||||
self.load_player2 = pov==2
|
||||
self.load_third_person = pov==3
|
||||
self.game_path = game_path
|
||||
# print(game_path)
|
||||
self.dialogue_file = glob(os.path.join(game_path,'mcc*log'))[0]
|
||||
self.questions_file = glob(os.path.join(game_path,'web*log'))[0]
|
||||
self.plan_file = glob(os.path.join(game_path,'plan*json'))[0]
|
||||
self.plan = json.load(open(self.plan_file))
|
||||
self.img_w = 96
|
||||
self.img_h = 96
|
||||
self.intermediate = intermediate
|
||||
|
||||
self.flip_video = False
|
||||
for l in open(self.dialogue_file):
|
||||
if 'HAS JOINED' in l:
|
||||
player_name = l.strip().split()[1]
|
||||
self.flip_video = player_name[-1] == '2'
|
||||
break
|
||||
|
||||
if not os.path.isfile("config/materials.json") or \
|
||||
not os.path.isfile("config/mines.json") or \
|
||||
not os.path.isfile("config/tools.json"):
|
||||
plan_files = sorted(glob('data/*_logs/*/plan*.json'))
|
||||
materials = []
|
||||
tools = []
|
||||
mines = []
|
||||
for plan_file in plan_files:
|
||||
plan = json.load(open(plan_file))
|
||||
materials += plan['materials']
|
||||
tools += plan['tools']
|
||||
mines += plan['mines']
|
||||
materials = sorted(list(set(materials)))
|
||||
tools = sorted(list(set(tools)))
|
||||
mines = sorted(list(set(mines)))
|
||||
json.dump(materials, open('config/materials.json','w'), indent=4)
|
||||
json.dump(mines, open('config/mines.json','w'), indent=4)
|
||||
json.dump(tools, open('config/tools.json','w'), indent=4)
|
||||
|
||||
materials = json.load(open('config/materials.json'))
|
||||
mines = json.load(open('config/mines.json'))
|
||||
tools = json.load(open('config/tools.json'))
|
||||
|
||||
self.materials_dict = {x:i+1 for i,x in enumerate(materials)}
|
||||
self.mines_dict = {x:i+1 for i,x in enumerate(mines)}
|
||||
self.tools_dict = {x:i+1 for i,x in enumerate(tools)}
|
||||
|
||||
self.__load_dialogue_act_labels()
|
||||
self.__load_dialogue_move_labels()
|
||||
self.__parse_dialogue()
|
||||
self.__parse_questions()
|
||||
self.__parse_start_end()
|
||||
self.__parse_question_pairs()
|
||||
self.__load_videos()
|
||||
self.__assign_dialogue_act_labels()
|
||||
self.__assign_dialogue_move_labels()
|
||||
self.__load_replay_data()
|
||||
self.__load_intermediate()
|
||||
|
||||
# print(len(self.materials_dict))
|
||||
|
||||
self.global_plan = []
|
||||
self.global_plan_mat = np.zeros((21,21))
|
||||
mine_counter = 0
|
||||
for n,v in zip(self.plan['materials'],self.plan['full']):
|
||||
if v['make']:
|
||||
mine = 0
|
||||
m1 = self.materials_dict[self.plan['materials'][v['make'][0][0]]]
|
||||
m2 = self.materials_dict[self.plan['materials'][v['make'][0][1]]]
|
||||
self.global_plan_mat[self.materials_dict[n]-1][m1-1] = 1
|
||||
self.global_plan_mat[self.materials_dict[n]-1][m2-1] = 1
|
||||
else:
|
||||
mine = self.mines_dict[self.plan['mines'][mine_counter]]
|
||||
mine_counter += 1
|
||||
m1 = 0
|
||||
m2 = 0
|
||||
mine = onehot(mine, len(self.mines_dict))
|
||||
m1 = onehot(m1,len(self.materials_dict))
|
||||
m2 = onehot(m2,len(self.materials_dict))
|
||||
mat = onehot(self.materials_dict[n],len(self.materials_dict))
|
||||
t = onehot(self.tools_dict[self.plan['tools'][v['tools'][0]]],len(self.tools_dict))
|
||||
step = np.concatenate((mat,m1,m2,mine,t))
|
||||
self.global_plan.append(step)
|
||||
|
||||
self.player1_plan = []
|
||||
self.player1_plan_mat = np.zeros((21,21))
|
||||
mine_counter = 0
|
||||
for n,v in zip(self.plan['materials'],self.plan['player1']):
|
||||
if v['make']:
|
||||
mine = 0
|
||||
if v['make'][0][0] < 0:
|
||||
m1 = 0
|
||||
m2 = 0
|
||||
else:
|
||||
m1 = self.materials_dict[self.plan['materials'][v['make'][0][0]]]
|
||||
m2 = self.materials_dict[self.plan['materials'][v['make'][0][1]]]
|
||||
self.player1_plan_mat[self.materials_dict[n]-1][m1-1] = 1
|
||||
self.player1_plan_mat[self.materials_dict[n]-1][m2-1] = 1
|
||||
else:
|
||||
mine = self.mines_dict[self.plan['mines'][mine_counter]]
|
||||
mine_counter += 1
|
||||
m1 = 0
|
||||
m2 = 0
|
||||
mine = onehot(mine, len(self.mines_dict))
|
||||
m1 = onehot(m1,len(self.materials_dict))
|
||||
m2 = onehot(m2,len(self.materials_dict))
|
||||
mat = onehot(self.materials_dict[n],len(self.materials_dict))
|
||||
t = onehot(self.tools_dict[self.plan['tools'][v['tools'][0]]],len(self.tools_dict))
|
||||
step = np.concatenate((mat,m1,m2,mine,t))
|
||||
self.player1_plan.append(step)
|
||||
|
||||
self.player2_plan = []
|
||||
self.player2_plan_mat = np.zeros((21,21))
|
||||
mine_counter = 0
|
||||
for n,v in zip(self.plan['materials'],self.plan['player2']):
|
||||
if v['make']:
|
||||
mine = 0
|
||||
if v['make'][0][0] < 0:
|
||||
m1 = 0
|
||||
m2 = 0
|
||||
else:
|
||||
m1 = self.materials_dict[self.plan['materials'][v['make'][0][0]]]
|
||||
m2 = self.materials_dict[self.plan['materials'][v['make'][0][1]]]
|
||||
self.player2_plan_mat[self.materials_dict[n]-1][m1-1] = 1
|
||||
self.player2_plan_mat[self.materials_dict[n]-1][m2-1] = 1
|
||||
else:
|
||||
mine = self.mines_dict[self.plan['mines'][mine_counter]]
|
||||
mine_counter += 1
|
||||
m1 = 0
|
||||
m2 = 0
|
||||
mine = onehot(mine, len(self.mines_dict))
|
||||
m1 = onehot(m1,len(self.materials_dict))
|
||||
m2 = onehot(m2,len(self.materials_dict))
|
||||
mat = onehot(self.materials_dict[n],len(self.materials_dict))
|
||||
t = onehot(self.tools_dict[self.plan['tools'][v['tools'][0]]],len(self.tools_dict))
|
||||
step = np.concatenate((mat,m1,m2,mine,t))
|
||||
self.player2_plan.append(step)
|
||||
# print(self.global_plan_mat.reshape(-1))
|
||||
# print(self.player1_plan_mat.reshape(-1))
|
||||
# print(self.player2_plan_mat.reshape(-1))
|
||||
# for x in zip(self.global_plan_mat.reshape(-1),self.player1_plan_mat.reshape(-1),self.player2_plan_mat.reshape(-1)):
|
||||
# if sum(x) > 0:
|
||||
# print(x)
|
||||
# exit()
|
||||
if self.load_player1:
|
||||
self.plan_repr = self.player1_plan_mat
|
||||
self.partner_plan = self.player2_plan_mat
|
||||
elif self.load_player2:
|
||||
self.plan_repr = self.player2_plan_mat
|
||||
self.partner_plan = self.player1_plan_mat
|
||||
else:
|
||||
self.plan_repr = self.global_plan_mat
|
||||
self.partner_plan = self.global_plan_mat
|
||||
self.global_diff_plan_mat = self.global_plan_mat - self.plan_repr
|
||||
self.partner_diff_plan_mat = self.global_plan_mat - self.partner_plan
|
||||
|
||||
self.__iter_ts = self.start_ts
|
||||
|
||||
# self.action_labels = sorted([t for a in self.actions for t in a if t.PacketData in ['BlockChangeData']], key=lambda x: x.TickIndex)
|
||||
self.action_labels = None
|
||||
# for tick in ticks:
|
||||
# print(int(tick.TickIndex/30), self.plan['materials'].index( tick.items[0]), int(tick.Name[-1]))
|
||||
# print(self.start_ts, self.end_ts, self.start_ts - self.end_ts, int(ticks[-1].TickIndex/30) if ticks else 0,self.action_file)
|
||||
# exit()
|
||||
self.materials = sorted(self.plan['materials'])
|
||||
|
||||
def __len__(self):
|
||||
return self.end_ts - self.start_ts
|
||||
|
||||
def __next__(self):
|
||||
if self.__iter_ts < self.end_ts:
|
||||
|
||||
if self.load_dialogue:
|
||||
d = [x for x in self.dialogue_events if x[0] == self.__iter_ts]
|
||||
l = [x for x in self.dialogue_act_labels if x[0] == self.__iter_ts]
|
||||
d = d if d else None
|
||||
l = l if l else None
|
||||
else:
|
||||
d = None
|
||||
l = None
|
||||
|
||||
if self.use_dialogue_moves:
|
||||
m = [x for x in self.dialogue_move_labels if x[0] == self.__iter_ts]
|
||||
m = m if m else None
|
||||
else:
|
||||
m = None
|
||||
|
||||
if self.action_labels:
|
||||
a = [x for x in self.action_labels if (x.TickIndex//30 + self.start_ts) >= self.__iter_ts]
|
||||
if a:
|
||||
try:
|
||||
while not a[0].items:
|
||||
a = a[1:]
|
||||
al = self.materials.index(a[0].items[0]) if a else 0
|
||||
except Exception:
|
||||
print(a)
|
||||
print(a[0])
|
||||
print(a[0].items)
|
||||
print(a[0].items[0])
|
||||
exit()
|
||||
at = a[0].TickIndex//30 + self.start_ts
|
||||
an = int(a[0].Name[-1])
|
||||
a = [(at,al,an)]
|
||||
else:
|
||||
a = [(self.__iter_ts, self.materials.index(self.plan['materials'][0]), 1)]
|
||||
a = None
|
||||
else:
|
||||
if self.end_ts - self.__iter_ts < 10:
|
||||
# a = [(self.__iter_ts, self.materials.index(self.plan['materials'][0]), 1)]
|
||||
a = None
|
||||
else:
|
||||
a = None
|
||||
# if not self.__iter_ts % 30 == 0:
|
||||
# a= None
|
||||
if not a is None:
|
||||
if not a[0][0] == self.__iter_ts:
|
||||
a = None
|
||||
|
||||
# q = [x for x in self.question_pairs if (x[0][0] < self.__iter_ts) and (x[0][1] > self.__iter_ts)]
|
||||
q = [x for x in self.question_pairs if (x[0][1] == self.__iter_ts)]
|
||||
q = q[0] if q else None
|
||||
frame_idx = self.__iter_ts - self.start_ts
|
||||
if self.load_third_person:
|
||||
frames = self.third_pers_frames
|
||||
elif self.load_player1:
|
||||
frames = self.player1_pov_frames
|
||||
elif self.load_player2:
|
||||
frames = self.player2_pov_frames
|
||||
else:
|
||||
frames = np.array([0])
|
||||
if len(frames) == 1:
|
||||
f = np.zeros((self.img_h,self.img_w,3))
|
||||
else:
|
||||
if frame_idx < frames.shape[0]:
|
||||
f = frames[frame_idx]
|
||||
else:
|
||||
f = np.zeros((self.img_h,self.img_w,3))
|
||||
if self.do_upperbound:
|
||||
if not q is None:
|
||||
qnum = 0
|
||||
base_rep = np.concatenate([
|
||||
onehot(q[0][2],2),
|
||||
onehot(q[0][3],2),
|
||||
onehot(q[0][4][qnum][0]+1,2),
|
||||
onehot(self.materials_dict[q[0][5][qnum][1]],len(self.materials_dict)),
|
||||
onehot(q[0][4][qnum][0]+1,2),
|
||||
onehot(self.materials_dict[q[0][5][qnum][1]],len(self.materials_dict)),
|
||||
onehot(['YES','MAYBE','NO'].index(q[1][0][qnum])+1,3),
|
||||
onehot(['YES','MAYBE','NO'].index(q[1][1][qnum])+1,3)
|
||||
])
|
||||
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
|
||||
ToM6 = base_rep if self.ToM6 is not None else np.zeros(1024)
|
||||
qnum = 1
|
||||
base_rep = np.concatenate([
|
||||
onehot(q[0][2],2),
|
||||
onehot(q[0][3],2),
|
||||
onehot(q[0][4][qnum][0]+1,2),
|
||||
onehot(self.materials_dict[q[0][5][qnum][1]],len(self.materials_dict)),
|
||||
onehot(q[0][4][qnum][0]+1,2),
|
||||
onehot(self.materials_dict[q[0][5][qnum][1]],len(self.materials_dict)),
|
||||
onehot(['YES','MAYBE','NO'].index(q[1][0][qnum])+1,3),
|
||||
onehot(['YES','MAYBE','NO'].index(q[1][1][qnum])+1,3)
|
||||
])
|
||||
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
|
||||
ToM7 = base_rep if self.ToM7 is not None else np.zeros(1024)
|
||||
qnum = 2
|
||||
base_rep = np.concatenate([
|
||||
onehot(q[0][2],2),
|
||||
onehot(q[0][3],2),
|
||||
onehot(q[0][4][qnum]+1,2),
|
||||
onehot(q[0][4][qnum]+1,2),
|
||||
onehot(self.materials_dict[q[1][0][qnum]] if q[1][0][qnum] in self.materials_dict else len(self.materials_dict)+1,len(self.materials_dict)+1),
|
||||
onehot(self.materials_dict[q[1][1][qnum]] if q[1][1][qnum] in self.materials_dict else len(self.materials_dict)+1,len(self.materials_dict)+1)
|
||||
])
|
||||
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
|
||||
ToM8 = base_rep if self.ToM8 is not None else np.zeros(1024)
|
||||
else:
|
||||
ToM6 = np.zeros(1024)
|
||||
ToM7 = np.zeros(1024)
|
||||
ToM8 = np.zeros(1024)
|
||||
if not l is None:
|
||||
base_rep = np.concatenate([
|
||||
onehot(l[0][1],2),
|
||||
onehot(l[0][2],len(self.dialogue_act_labels_dict))
|
||||
])
|
||||
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
|
||||
DAct = base_rep if self.DAct is not None else np.zeros(1024)
|
||||
else:
|
||||
DAct = np.zeros(1024)
|
||||
if not m is None:
|
||||
base_rep = np.concatenate([
|
||||
onehot(m[0][1],2),
|
||||
onehot(m[0][2][0],len(self.dialogue_move_labels_dict)),
|
||||
onehot(m[0][2][1],len(self.tools_dict) + len(self.materials_dict) + len(self.mines_dict)+1),
|
||||
onehot(m[0][2][2],len(self.tools_dict) + len(self.materials_dict) + len(self.mines_dict)+1),
|
||||
onehot(m[0][2][3],len(self.tools_dict) + len(self.materials_dict) + len(self.mines_dict)+1),
|
||||
])
|
||||
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
|
||||
DMove = base_rep if self.DMove is not None else np.zeros(1024)
|
||||
else:
|
||||
DMove = np.zeros(1024)
|
||||
else:
|
||||
ToM6 = self.ToM6[frame_idx] if self.ToM6 is not None else np.zeros(1024)
|
||||
ToM7 = self.ToM7[frame_idx] if self.ToM7 is not None else np.zeros(1024)
|
||||
ToM8 = self.ToM8[frame_idx] if self.ToM8 is not None else np.zeros(1024)
|
||||
DAct = self.DAct[frame_idx] if self.DAct is not None else np.zeros(1024)
|
||||
DMove = self.DAct[frame_idx] if self.DMove is not None else np.zeros(1024)
|
||||
# if not m is None:
|
||||
# base_rep = np.concatenate([
|
||||
# onehot(m[0][1],2),
|
||||
# onehot(m[0][2][0],len(self.dialogue_move_labels_dict)),
|
||||
# onehot(m[0][2][1],len(self.tools_dict) + len(self.materials_dict) + len(self.mines_dict)+1),
|
||||
# onehot(m[0][2][2],len(self.tools_dict) + len(self.materials_dict) + len(self.mines_dict)+1),
|
||||
# onehot(m[0][2][3],len(self.tools_dict) + len(self.materials_dict) + len(self.mines_dict)+1),
|
||||
# ])
|
||||
# base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
|
||||
# DMove = base_rep if self.DMove is not None else np.zeros(1024)
|
||||
# else:
|
||||
# DMove = np.zeros(1024)
|
||||
intermediate = np.concatenate([ToM6,ToM7,ToM8,DAct,DMove])
|
||||
retval = ((self.__iter_ts,self.pov),d,l,q,f,a,intermediate,m)
|
||||
self.__iter_ts += 1
|
||||
return retval
|
||||
self.__iter_ts = self.start_ts
|
||||
raise StopIteration()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __load_videos(self):
|
||||
d = self.end_ts - self.start_ts
|
||||
|
||||
if self.load_third_person:
|
||||
try:
|
||||
self.third_pers_file = glob(os.path.join(self.game_path,'third*gif'))[0]
|
||||
np_file = self.third_pers_file[:-3]+'npz'
|
||||
if os.path.isfile(np_file):
|
||||
self.third_pers_frames = np.load(np_file)['data']
|
||||
else:
|
||||
frames = imageio.get_reader(self.third_pers_file, '.gif')
|
||||
reshaper = lambda x: cv2.resize(x,(self.img_h,self.img_w))
|
||||
if 'main' in self.game_path:
|
||||
self.third_pers_frames = np.array([reshaper(f[95:4*95,250:-249,2::-1]) for f in frames])
|
||||
else:
|
||||
self.third_pers_frames = np.array([reshaper(f[-3*95:,250:-249,2::-1]) for f in frames])
|
||||
print(np_file,end=' ')
|
||||
np.savez_compressed(open(np_file,'wb'), data=self.third_pers_frames)
|
||||
print('saved')
|
||||
except Exception as e:
|
||||
self.third_pers_frames = np.array([0])
|
||||
|
||||
if self.third_pers_frames.shape[0]//d < 10:
|
||||
self.third_pov_frame_rate = 6
|
||||
else:
|
||||
if self.third_pers_frames.shape[0]//d < 20:
|
||||
self.third_pov_frame_rate = 12
|
||||
else:
|
||||
if self.third_pers_frames.shape[0]//d < 45:
|
||||
self.third_pov_frame_rate = 30
|
||||
else:
|
||||
self.third_pov_frame_rate = 60
|
||||
self.third_pers_frames = self.third_pers_frames[::self.third_pov_frame_rate]
|
||||
else:
|
||||
self.third_pers_frames = np.array([0])
|
||||
|
||||
if self.load_player1:
|
||||
try:
|
||||
search_str = 'play2*gif' if self.flip_video else 'play1*gif'
|
||||
self.player1_pov_file = glob(os.path.join(self.game_path,search_str))[0]
|
||||
np_file = self.player1_pov_file[:-3]+'npz'
|
||||
if os.path.isfile(np_file):
|
||||
self.player1_pov_frames = np.load(np_file)['data']
|
||||
else:
|
||||
frames = imageio.get_reader(self.player1_pov_file, '.gif')
|
||||
reshaper = lambda x: cv2.resize(x,(self.img_h,self.img_w))
|
||||
self.player1_pov_frames = np.array([reshaper(f[:,:,2::-1]) for f in frames])
|
||||
print(np_file,end=' ')
|
||||
np.savez_compressed(open(np_file,'wb'), data=self.player1_pov_frames)
|
||||
print('saved')
|
||||
except Exception as e:
|
||||
self.player1_pov_frames = np.array([0])
|
||||
|
||||
if self.player1_pov_frames.shape[0]//d < 10:
|
||||
self.player1_pov_frame_rate = 6
|
||||
else:
|
||||
if self.player1_pov_frames.shape[0]//d < 20:
|
||||
self.player1_pov_frame_rate = 12
|
||||
else:
|
||||
if self.player1_pov_frames.shape[0]//d < 45:
|
||||
self.player1_pov_frame_rate = 30
|
||||
else:
|
||||
self.player1_pov_frame_rate = 60
|
||||
self.player1_pov_frames = self.player1_pov_frames[::self.player1_pov_frame_rate]
|
||||
else:
|
||||
self.player1_pov_frames = np.array([0])
|
||||
|
||||
if self.load_player2:
|
||||
try:
|
||||
search_str = 'play1*gif' if self.flip_video else 'play2*gif'
|
||||
self.player2_pov_file = glob(os.path.join(self.game_path,search_str))[0]
|
||||
np_file = self.player2_pov_file[:-3]+'npz'
|
||||
if os.path.isfile(np_file):
|
||||
self.player2_pov_frames = np.load(np_file)['data']
|
||||
else:
|
||||
frames = imageio.get_reader(self.player2_pov_file, '.gif')
|
||||
reshaper = lambda x: cv2.resize(x,(self.img_h,self.img_w))
|
||||
self.player2_pov_frames = np.array([reshaper(f[:,:,2::-1]) for f in frames])
|
||||
print(np_file,end=' ')
|
||||
np.savez_compressed(open(np_file,'wb'), data=self.player2_pov_frames)
|
||||
print('saved')
|
||||
except Exception as e:
|
||||
self.player2_pov_frames = np.array([0])
|
||||
|
||||
if self.player2_pov_frames.shape[0]//d < 10:
|
||||
self.player2_pov_frame_rate = 6
|
||||
else:
|
||||
if self.player2_pov_frames.shape[0]//d < 20:
|
||||
self.player2_pov_frame_rate = 12
|
||||
else:
|
||||
if self.player2_pov_frames.shape[0]//d < 45:
|
||||
self.player2_pov_frame_rate = 30
|
||||
else:
|
||||
self.player2_pov_frame_rate = 60
|
||||
self.player2_pov_frames = self.player2_pov_frames[::self.player2_pov_frame_rate]
|
||||
else:
|
||||
self.player2_pov_frames = np.array([0])
|
||||
|
||||
def __parse_question_pairs(self):
|
||||
question_dict = {}
|
||||
for q in self.questions:
|
||||
k = q[2][0][1] + q[2][1][1]
|
||||
if not k in question_dict:
|
||||
question_dict[k] = []
|
||||
question_dict[k].append(q)
|
||||
|
||||
self.question_pairs = []
|
||||
for k,v in question_dict.items():
|
||||
if len(v) == 2:
|
||||
if v[0][1]+v[1][1] == 3:
|
||||
self.question_pairs.append(v)
|
||||
else:
|
||||
while len(v) > 1:
|
||||
pair = []
|
||||
pair.append(v.pop(0))
|
||||
pair.append(v.pop(0))
|
||||
while not pair[0][1]+pair[1][1] == 3:
|
||||
if not v:
|
||||
break
|
||||
# print(game_path,pair)
|
||||
pair.append(v.pop(0))
|
||||
pair.pop(0)
|
||||
if not v:
|
||||
break
|
||||
self.question_pairs.append(pair)
|
||||
self.question_pairs = sorted(self.question_pairs, key=lambda x: x[0][0])
|
||||
if self.load_player2 or self.pov==4:
|
||||
self.question_pairs = [sorted(q, key=lambda x: x[1],reverse=True) for q in self.question_pairs]
|
||||
else:
|
||||
self.question_pairs = [sorted(q, key=lambda x: x[1]) for q in self.question_pairs]
|
||||
|
||||
|
||||
self.question_pairs = [((a[0], b[0], a[1], b[1], a[2], b[2]), (a[3], b[3])) for a,b in self.question_pairs]
|
||||
|
||||
def __parse_dialogue(self):
|
||||
self.dialogue_events = []
|
||||
# if not self.load_dialogue:
|
||||
# return
|
||||
save_path = os.path.join(self.game_path,f'dialogue_{self.game_path.split("/")[-1]}.pkl')
|
||||
# print(save_path)
|
||||
# exit()
|
||||
if os.path.isfile(save_path):
|
||||
self.dialogue_events = pickle.load(open( save_path, "rb" ))
|
||||
return
|
||||
for x in open(self.dialogue_file):
|
||||
if '[Async Chat Thread' in x:
|
||||
ts = list(map(int,x.split(' [')[0].strip('[]').split(':')))
|
||||
ts = 3600*ts[0] + 60*ts[1] + ts[2]
|
||||
player, event = x.strip().split('/INFO]: []<sledmcc')[1].split('> ',1)
|
||||
event = event.lower()
|
||||
event = ''.join([x if x in string.ascii_lowercase else f' {x} ' for x in event]).strip()
|
||||
event = event.replace(' ',' ').replace(' ',' ')
|
||||
player = int(player)
|
||||
if GameParser.tokenizer is None:
|
||||
GameParser.tokenizer = BertTokenizer.from_pretrained('bert-large-uncased', do_lower_case=True)
|
||||
if self.model is None:
|
||||
GameParser.model = BertModel.from_pretrained('bert-large-uncased', output_hidden_states=True).to(DEVICE)
|
||||
encoded_dict = GameParser.tokenizer.encode_plus(
|
||||
event, # Sentence to encode.
|
||||
add_special_tokens=True, # Add '[CLS]' and '[SEP]'
|
||||
return_tensors='pt', # Return pytorch tensors.
|
||||
)
|
||||
token_ids = encoded_dict['input_ids'].to(DEVICE)
|
||||
segment_ids = torch.ones(token_ids.size()).long().to(DEVICE)
|
||||
GameParser.model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = GameParser.model(input_ids=token_ids, token_type_ids=segment_ids)
|
||||
outputs = outputs[1][0].cpu().data.numpy()
|
||||
self.dialogue_events.append((ts,player,event,outputs))
|
||||
pickle.dump(self.dialogue_events, open( save_path, "wb" ))
|
||||
print(f'Saved to {save_path}',flush=True)
|
||||
|
||||
def __parse_questions(self):
|
||||
self.questions = []
|
||||
for x in open(self.questions_file):
|
||||
if x[0] == '#':
|
||||
ts, qs = x.strip().split(' Number of records inserted: 1 # player')
|
||||
# print(ts,qs)
|
||||
|
||||
ts = list(map(int,ts.split(' ')[5].split(':')))
|
||||
ts = 3600*ts[0] + 60*ts[1] + ts[2]
|
||||
|
||||
player = int(qs[0])
|
||||
questions = qs[2:].split(';')
|
||||
answers =[x[7:] for x in questions[3:]]
|
||||
questions = [x[9:].split(' ') for x in questions[:3]]
|
||||
questions[0] = (int(questions[0][0] == 'Have'), questions[0][-3])
|
||||
questions[1] = (int(questions[1][2] == 'know'), questions[1][-1])
|
||||
questions[2] = int(questions[2][1] == 'are')
|
||||
|
||||
self.questions.append((ts,player,questions,answers))
|
||||
def __parse_start_end(self):
|
||||
self.start_ts = [x.strip() for x in open(self.dialogue_file) if 'THEY ARE PLAYER' in x][1]
|
||||
self.start_ts = list(map(int,self.start_ts.split('] [')[0][1:].split(':')))
|
||||
self.start_ts = 3600*self.start_ts[0] + 60*self.start_ts[1] + self.start_ts[2]
|
||||
try:
|
||||
self.start_ts = max(self.start_ts, self.questions[0][0]-75)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
self.end_ts = [x.strip() for x in open(self.dialogue_file) if 'Stopping' in x]
|
||||
if self.end_ts:
|
||||
self.end_ts = self.end_ts[0]
|
||||
self.end_ts = list(map(int,self.end_ts.split('] [')[0][1:].split(':')))
|
||||
self.end_ts = 3600*self.end_ts[0] + 60*self.end_ts[1] + self.end_ts[2]
|
||||
else:
|
||||
self.end_ts = self.dialogue_events[-1][0]
|
||||
try:
|
||||
self.end_ts = max(self.end_ts, self.questions[-1][0]) + 1
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def __load_dialogue_act_labels(self):
|
||||
file_name = 'config/dialogue_act_labels.json'
|
||||
if not os.path.isfile(file_name):
|
||||
files = sorted(glob('/home/*/MCC/*done.txt'))
|
||||
dialogue_act_dict = {}
|
||||
for file in files:
|
||||
game_str = ''
|
||||
for line in open(file):
|
||||
line = line.strip()
|
||||
if '_logs/' in line:
|
||||
game_str = line
|
||||
else:
|
||||
if line:
|
||||
line = line.split()
|
||||
key = f'{game_str}#{line[0]}'
|
||||
dialogue_act_dict[key] = line[-1]
|
||||
json.dump(dialogue_act_dict,open(file_name,'w'), indent=4)
|
||||
self.dialogue_act_dict = json.load(open(file_name))
|
||||
self.dialogue_act_labels_dict = {l : i for i, l in enumerate(sorted(list(set(self.dialogue_act_dict.values()))))}
|
||||
self.dialogue_act_bias = {l : sum([int(x==l) for x in self.dialogue_act_dict.values()]) for l in self.dialogue_act_labels_dict.keys()}
|
||||
json.dump(self.dialogue_act_labels_dict,open('config/dialogue_act_label_names.json','w'), indent=4)
|
||||
# print(self.dialogue_act_bias)
|
||||
# print(self.dialogue_act_labels_dict)
|
||||
# exit()
|
||||
|
||||
def __assign_dialogue_act_labels(self):
|
||||
|
||||
# log_file = glob('/'.join([self.game_path,'mcc*log']))[0][5:]
|
||||
log_file = glob('/'.join([self.game_path,'mcc*log']))[0].split('mindcraft/')[1]
|
||||
self.dialogue_act_labels = []
|
||||
for emb in self.dialogue_events:
|
||||
ts = emb[0]
|
||||
h = ts//3600
|
||||
m = (ts%3600)//60
|
||||
s = ts%60
|
||||
key = f'{log_file}#[{h:02d}:{m:02d}:{s:02d}]:{emb[1]}>'
|
||||
self.dialogue_act_labels.append((emb[0],emb[1],self.dialogue_act_labels_dict[self.dialogue_act_dict[key]]))
|
||||
|
||||
def __load_dialogue_move_labels(self):
|
||||
file_name = "config/dialogue_move_labels.json"
|
||||
dialogue_move_dict = {}
|
||||
if not os.path.isfile(file_name):
|
||||
file_text = ''
|
||||
dialogue_moves = set()
|
||||
for line in open("XXX"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
if line[0] == '#':
|
||||
continue
|
||||
if line[0] == '[':
|
||||
tag_text = glob(f'data/*/*/mcc_{file_text}.log')[0].split('/',1)[-1]
|
||||
key = f'{tag_text}#{line.split()[0]}'
|
||||
value = line.split()[-1].split('#')
|
||||
if len(value) < 4:
|
||||
value += ['IGNORE']*(4-len(value))
|
||||
dialogue_moves.add(value[0])
|
||||
value = '#'.join(value)
|
||||
dialogue_move_dict[key] = value
|
||||
# print(key,value)
|
||||
# break
|
||||
else:
|
||||
file_text = line
|
||||
# print(line)
|
||||
dialogue_moves = sorted(list(dialogue_moves))
|
||||
# print(dialogue_moves)
|
||||
|
||||
json.dump(dialogue_move_dict,open(file_name,'w'), indent=4)
|
||||
self.dialogue_move_dict = json.load(open(file_name))
|
||||
self.dialogue_move_labels_dict = {l : i for i, l in enumerate(sorted(list(set([lbl.split('#')[0] for lbl in self.dialogue_move_dict.values()]))))}
|
||||
self.dialogue_move_bias = {l : sum([int(x==l) for x in self.dialogue_move_dict.values()]) for l in self.dialogue_move_labels_dict.keys()}
|
||||
json.dump(self.dialogue_move_labels_dict,open('config/dialogue_move_label_names.json','w'), indent=4)
|
||||
|
||||
def __assign_dialogue_move_labels(self):
|
||||
|
||||
# log_file = glob('/'.join([self.game_path,'mcc*log']))[0][5:]
|
||||
log_file = glob('/'.join([self.game_path,'mcc*log']))[0].split('mindcraft/')[1]
|
||||
self.dialogue_move_labels = []
|
||||
for emb in self.dialogue_events:
|
||||
ts = emb[0]
|
||||
h = ts//3600
|
||||
m = (ts%3600)//60
|
||||
s = ts%60
|
||||
key = f'{log_file}#[{h:02d}:{m:02d}:{s:02d}]:{emb[1]}>'
|
||||
move = self.dialogue_move_dict[key].split('#')
|
||||
move[0] = self.dialogue_move_labels_dict[move[0]]
|
||||
for i,m in enumerate(move[1:]):
|
||||
if m == 'IGNORE':
|
||||
move[i+1] = 0
|
||||
elif m in self.materials_dict:
|
||||
move[i+1] = self.materials_dict[m]
|
||||
elif m in self.mines_dict:
|
||||
move[i+1] = self.mines_dict[m] + len(self.materials_dict)
|
||||
elif m in self.tools_dict:
|
||||
move[i+1] = self.tools_dict[m] + len(self.materials_dict) + len(self.mines_dict)
|
||||
else:
|
||||
print(move)
|
||||
exit()
|
||||
# print(move,self.dialogue_move_dict[key],key)
|
||||
# exit()
|
||||
self.dialogue_move_labels.append((emb[0],emb[1],move))
|
||||
|
||||
def __load_replay_data(self):
|
||||
# self.action_file = "data/ReplayData/ActionsData_mcc_" + self.game_path.split('/')[-1]
|
||||
# with open(self.action_file) as f:
|
||||
# data = ' '.join(x.strip() for x in f).split('action')
|
||||
# # preface = data[0]
|
||||
# self.actions = list(map(proc_action, data[1:]))
|
||||
self.actions = None
|
||||
|
||||
def __load_intermediate(self):
|
||||
if self.intermediate > 15:
|
||||
self.do_upperbound = True
|
||||
else:
|
||||
self.do_upperbound = False
|
||||
if self.pov in [1,2]:
|
||||
self.ToM6 = np.load(glob(f'{self.game_path}/intermediate_baseline_ToM6*player{self.pov}.npz')[0])['data'] if self.intermediate % 2 else None
|
||||
self.intermediate = self.intermediate // 2
|
||||
self.ToM7 = np.load(glob(f'{self.game_path}/intermediate_baseline_ToM7*player{self.pov}.npz')[0])['data'] if self.intermediate % 2 else None
|
||||
self.intermediate = self.intermediate // 2
|
||||
self.ToM8 = np.load(glob(f'{self.game_path}/intermediate_baseline_ToM8*player{self.pov}.npz')[0])['data'] if self.intermediate % 2 else None
|
||||
self.intermediate = self.intermediate // 2
|
||||
self.DAct = np.load(glob(f'{self.game_path}/intermediate_baseline_DAct*player{self.pov}.npz')[0])['data'] if self.intermediate % 2 else None
|
||||
self.intermediate = self.intermediate // 2
|
||||
self.DMove = None
|
||||
# print(self.ToM6)
|
||||
# print(self.ToM7)
|
||||
# print(self.ToM8)
|
||||
# print(self.DAct)
|
||||
else:
|
||||
self.ToM6 = None
|
||||
self.ToM7 = None
|
||||
self.ToM8 = None
|
||||
self.DAct = None
|
||||
self.DMove = None
|
||||
# exit()
|
||||
|
837
src/data/game_parser_graphs_new.py
Normal file
837
src/data/game_parser_graphs_new.py
Normal file
|
@ -0,0 +1,837 @@
|
|||
from glob import glob
|
||||
import os, string, json, pickle
|
||||
import torch, random, numpy as np
|
||||
from transformers import BertTokenizer, BertModel
|
||||
import cv2
|
||||
import imageio
|
||||
import networkx as nx
|
||||
from torch_geometric.utils.convert import from_networkx
|
||||
import matplotlib.pyplot as plt
|
||||
from torch_geometric.utils import degree
|
||||
|
||||
|
||||
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
# def set_seed(seed_idx):
|
||||
# seed = 0
|
||||
# random.seed(0)
|
||||
# for _ in range(seed_idx):
|
||||
# seed = random.random()
|
||||
# random.seed(seed)
|
||||
# torch.manual_seed(seed)
|
||||
# print('Random seed set to', seed)
|
||||
# return seed
|
||||
|
||||
def set_seed(seed):
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
print('Random seed set to', seed)
|
||||
return seed
|
||||
|
||||
def make_splits(split_file = 'config/dataset_splits.json'):
|
||||
if not os.path.isfile(split_file):
|
||||
dirs = sorted(glob('data/saved_logs/*') + glob('data/main_logs/*'))
|
||||
games = sorted(list(map(GameParser, dirs)), key=lambda x: len(x.question_pairs), reverse=True)
|
||||
|
||||
test = games[0::5]
|
||||
val = games[1::5]
|
||||
train = games[2::5]+games[3::5]+games[4::5]
|
||||
|
||||
dataset_splits = {'test' : [g.game_path for g in test], 'validation' : [g.game_path for g in val], 'training' : [g.game_path for g in train]}
|
||||
json.dump(dataset_splits, open('config/dataset_splits_old.json','w'), indent=4)
|
||||
|
||||
dirs = sorted(glob('data/new_logs/*'))
|
||||
games = sorted(list(map(GameParser, dirs)), key=lambda x: len(x.question_pairs), reverse=True)
|
||||
|
||||
test = games[0::5]
|
||||
val = games[1::5]
|
||||
train = games[2::5]+games[3::5]+games[4::5]
|
||||
|
||||
dataset_splits['test'] += [g.game_path for g in test]
|
||||
dataset_splits['validation'] += [g.game_path for g in val]
|
||||
dataset_splits['training'] += [g.game_path for g in train]
|
||||
json.dump(dataset_splits, open('config/dataset_splits_new.json','w'), indent=4)
|
||||
json.dump(dataset_splits, open('config/dataset_splits.json','w'), indent=4)
|
||||
|
||||
dataset_splits['test'] = dataset_splits['test'][:2]
|
||||
dataset_splits['validation'] = dataset_splits['validation'][:2]
|
||||
dataset_splits['training'] = dataset_splits['training'][:2]
|
||||
json.dump(dataset_splits, open('config/dataset_splits_dev.json','w'), indent=4)
|
||||
|
||||
dataset_splits = json.load(open(split_file))
|
||||
|
||||
return dataset_splits
|
||||
|
||||
def onehot(x,n):
|
||||
retval = np.zeros(n)
|
||||
if x > 0:
|
||||
retval[x-1] = 1
|
||||
return retval
|
||||
|
||||
class GameParser:
|
||||
tokenizer = None
|
||||
model = None
|
||||
def __init__(self, game_path, load_dialogue=True, pov=0, intermediate=0, use_dialogue_moves=False, load_int0_feats=False):
|
||||
self.load_dialogue = load_dialogue
|
||||
if pov not in (0,1,2,3,4):
|
||||
print('Point of view must be in (0,1,2,3,4), but got ', pov)
|
||||
exit()
|
||||
self.pov = pov
|
||||
self.use_dialogue_moves = use_dialogue_moves
|
||||
self.load_player1 = pov==1
|
||||
self.load_player2 = pov==2
|
||||
self.load_third_person = pov==3
|
||||
self.game_path = game_path
|
||||
self.dialogue_file = glob(os.path.join(game_path,'mcc*log'))[0]
|
||||
self.questions_file = glob(os.path.join(game_path,'web*log'))[0]
|
||||
self.plan_file = glob(os.path.join(game_path,'plan*json'))[0]
|
||||
self.plan = json.load(open(self.plan_file))
|
||||
self.img_w = 96
|
||||
self.img_h = 96
|
||||
self.intermediate = intermediate
|
||||
|
||||
self.flip_video = False
|
||||
for l in open(self.dialogue_file):
|
||||
if 'HAS JOINED' in l:
|
||||
player_name = l.strip().split()[1]
|
||||
self.flip_video = player_name[-1] == '2'
|
||||
break
|
||||
|
||||
if not os.path.isfile("config/materials.json") or \
|
||||
not os.path.isfile("config/mines.json") or \
|
||||
not os.path.isfile("config/tools.json"):
|
||||
plan_files = sorted(glob('data/*_logs/*/plan*.json'))
|
||||
materials = []
|
||||
tools = []
|
||||
mines = []
|
||||
for plan_file in plan_files:
|
||||
plan = json.load(open(plan_file))
|
||||
materials += plan['materials']
|
||||
tools += plan['tools']
|
||||
mines += plan['mines']
|
||||
materials = sorted(list(set(materials)))
|
||||
tools = sorted(list(set(tools)))
|
||||
mines = sorted(list(set(mines)))
|
||||
json.dump(materials, open('config/materials.json','w'), indent=4)
|
||||
json.dump(mines, open('config/mines.json','w'), indent=4)
|
||||
json.dump(tools, open('config/tools.json','w'), indent=4)
|
||||
|
||||
materials = json.load(open('config/materials.json'))
|
||||
mines = json.load(open('config/mines.json'))
|
||||
tools = json.load(open('config/tools.json'))
|
||||
|
||||
self.materials_dict = {x:i+1 for i,x in enumerate(materials)}
|
||||
self.mines_dict = {x:i+1 for i,x in enumerate(mines)}
|
||||
self.tools_dict = {x:i+1 for i,x in enumerate(tools)}
|
||||
|
||||
# NOTE new
|
||||
shift_value = max(self.materials_dict.values())
|
||||
self.materials_mines_dict = {**self.materials_dict, **{key: value + shift_value for key, value in self.mines_dict.items()}}
|
||||
self.inverse_materials_mines_dict = {v: k for k, v in self.materials_mines_dict.items()}
|
||||
#
|
||||
|
||||
self.__load_dialogue_act_labels()
|
||||
self.__load_dialogue_move_labels()
|
||||
self.__parse_dialogue()
|
||||
self.__parse_questions()
|
||||
self.__parse_start_end()
|
||||
self.__parse_question_pairs()
|
||||
self.__load_videos()
|
||||
self.__assign_dialogue_act_labels()
|
||||
self.__assign_dialogue_move_labels()
|
||||
self.__load_replay_data()
|
||||
self.__load_intermediate()
|
||||
self.load_int0 = load_int0_feats
|
||||
if load_int0_feats:
|
||||
self.__load_int0_feats()
|
||||
|
||||
#############################################
|
||||
################## GRAPHS ###################
|
||||
|
||||
#############################################
|
||||
################ Global Plan ################
|
||||
self.global_plan = nx.DiGraph()
|
||||
mine_counter = 0
|
||||
for n, v in zip(self.plan['materials'], self.plan['full']):
|
||||
mat = onehot(self.materials_mines_dict[n],len(self.materials_mines_dict))
|
||||
t = onehot(self.tools_dict[self.plan['tools'][v['tools'][0]]], len(self.tools_dict))
|
||||
self.global_plan = self._add_node(self.global_plan, n, features=mat)
|
||||
if v['make']:
|
||||
#print(n, v, self.plan['materials'][v['make'][0][0]], self.plan['materials'][v['make'][0][1]])
|
||||
mine = 0
|
||||
m1 = self.materials_dict[self.plan['materials'][v['make'][0][0]]]
|
||||
m2 = self.materials_dict[self.plan['materials'][v['make'][0][1]]]
|
||||
m1 = onehot(m1, len(self.materials_mines_dict))
|
||||
m2 = onehot(m2, len(self.materials_mines_dict))
|
||||
self.global_plan = self._add_node(self.global_plan, self.plan['materials'][v['make'][0][0]], features=m1)
|
||||
self.global_plan = self._add_node(self.global_plan, self.plan['materials'][v['make'][0][1]], features=m2)
|
||||
# m1 -> mat
|
||||
self.global_plan = self._add_edge(self.global_plan, self.plan['materials'][v['make'][0][0]], n, tool=t)
|
||||
# m2 -> mat
|
||||
self.global_plan = self._add_edge(self.global_plan, self.plan['materials'][v['make'][0][1]], n, tool=t)
|
||||
else:
|
||||
#print(n, v, self.plan['mines'][mine_counter])
|
||||
mine = self.materials_mines_dict[self.plan['mines'][mine_counter]]
|
||||
mine_counter += 1
|
||||
mine = onehot(mine, len(self.materials_mines_dict))
|
||||
self.global_plan = self._add_node(self.global_plan, self.plan['mines'][mine_counter], features=mine)
|
||||
self.global_plan = self._add_edge(self.global_plan, self.plan['mines'][mine_counter], n, tool=t)
|
||||
#self._plot_plan_graph(self.global_plan, filename=f"plots/global_{game_path.split('/')[-2]}_{game_path.split('/')[-1]}.png")
|
||||
self.global_plan = from_networkx(self.global_plan) # NOTE: I modified /torch_geometric/utils/convert.py, line 250
|
||||
#############################################
|
||||
|
||||
#############################################
|
||||
############### Player 1 Plan ###############
|
||||
self.player1_plan = nx.DiGraph()
|
||||
mine_counter = 0
|
||||
for n,v in zip(self.plan['materials'], self.plan['player1']):
|
||||
if v['make']:
|
||||
mine = 0
|
||||
if v['make'][0][0] < 0:
|
||||
#print(n, v, "unknown", "unknown")
|
||||
m1 = 0
|
||||
m2 = 0
|
||||
mat = onehot(self.materials_mines_dict[n],len(self.materials_mines_dict))
|
||||
self.player1_plan = self._add_node(self.player1_plan, n, features=mat)
|
||||
else:
|
||||
#print(n, v, self.plan['materials'][v['make'][0][0]], self.plan['materials'][v['make'][0][1]])
|
||||
mat = onehot(self.materials_mines_dict[n],len(self.materials_mines_dict))
|
||||
t = onehot(self.tools_dict[self.plan['tools'][v['tools'][0]]], len(self.tools_dict))
|
||||
self.player1_plan = self._add_node(self.player1_plan, n, features=mat)
|
||||
m1 = self.materials_dict[self.plan['materials'][v['make'][0][0]]]
|
||||
m2 = self.materials_dict[self.plan['materials'][v['make'][0][1]]]
|
||||
m1 = onehot(m1, len(self.materials_mines_dict))
|
||||
m2 = onehot(m2, len(self.materials_mines_dict))
|
||||
self.player1_plan = self._add_node(self.player1_plan, self.plan['materials'][v['make'][0][0]], features=m1)
|
||||
self.player1_plan = self._add_node(self.player1_plan, self.plan['materials'][v['make'][0][1]], features=m2)
|
||||
# m1 -> mat
|
||||
self.player1_plan = self._add_edge(self.player1_plan, self.plan['materials'][v['make'][0][0]], n, tool=t)
|
||||
# m2 -> mat
|
||||
self.player1_plan = self._add_edge(self.player1_plan, self.plan['materials'][v['make'][0][1]], n, tool=t)
|
||||
else:
|
||||
#print(n, v, self.plan['mines'][mine_counter])
|
||||
mat = onehot(self.materials_mines_dict[n],len(self.materials_mines_dict))
|
||||
t = onehot(self.tools_dict[self.plan['tools'][v['tools'][0]]], len(self.tools_dict))
|
||||
self.player1_plan = self._add_node(self.player1_plan, n, features=mat)
|
||||
mine = self.materials_mines_dict[self.plan['mines'][mine_counter]]
|
||||
mine_counter += 1
|
||||
mine = onehot(mine, len(self.materials_mines_dict))
|
||||
self.player1_plan = self._add_node(self.player1_plan, self.plan['mines'][mine_counter], features=mine)
|
||||
self.player1_plan = self._add_edge(self.player1_plan, self.plan['mines'][mine_counter], n, tool=t)
|
||||
#self._plot_plan_graph(self.player1_plan, filename=f"plots/player1_{game_path.split('/')[-2]}_{game_path.split('/')[-1]}.png")
|
||||
self.player1_plan = from_networkx(self.player1_plan)
|
||||
#############################################
|
||||
|
||||
#############################################
|
||||
############### Player 2 Plan ###############
|
||||
self.player2_plan = nx.DiGraph()
|
||||
mine_counter = 0
|
||||
for n,v in zip(self.plan['materials'], self.plan['player2']):
|
||||
if v['make']:
|
||||
mine = 0
|
||||
if v['make'][0][0] < 0:
|
||||
#print(n, v, "unknown", "unknown")
|
||||
m1 = 0
|
||||
m2 = 0
|
||||
mat = onehot(self.materials_mines_dict[n],len(self.materials_mines_dict))
|
||||
self.player2_plan = self._add_node(self.player2_plan, n, features=mat)
|
||||
else:
|
||||
#print(n, v, self.plan['materials'][v['make'][0][0]], self.plan['materials'][v['make'][0][1]])
|
||||
mat = onehot(self.materials_mines_dict[n],len(self.materials_mines_dict))
|
||||
t = onehot(self.tools_dict[self.plan['tools'][v['tools'][0]]], len(self.tools_dict))
|
||||
self.player2_plan = self._add_node(self.player2_plan, n, features=mat)
|
||||
m1 = self.materials_dict[self.plan['materials'][v['make'][0][0]]]
|
||||
m2 = self.materials_dict[self.plan['materials'][v['make'][0][1]]]
|
||||
m1 = onehot(m1, len(self.materials_mines_dict))
|
||||
m2 = onehot(m2, len(self.materials_mines_dict))
|
||||
self.player2_plan = self._add_node(self.player2_plan, self.plan['materials'][v['make'][0][0]], features=m1)
|
||||
self.player2_plan = self._add_node(self.player2_plan, self.plan['materials'][v['make'][0][1]], features=m2)
|
||||
# m1 -> mat
|
||||
self.player2_plan = self._add_edge(self.player2_plan, self.plan['materials'][v['make'][0][0]], n, tool=t)
|
||||
# m2 -> mat
|
||||
self.player2_plan = self._add_edge(self.player2_plan, self.plan['materials'][v['make'][0][1]], n, tool=t)
|
||||
else:
|
||||
#print(n, v, self.plan['mines'][mine_counter])
|
||||
mat = onehot(self.materials_mines_dict[n],len(self.materials_mines_dict))
|
||||
t = onehot(self.tools_dict[self.plan['tools'][v['tools'][0]]], len(self.tools_dict))
|
||||
self.player2_plan = self._add_node(self.player2_plan, n, features=mat)
|
||||
mine = self.materials_mines_dict[self.plan['mines'][mine_counter]]
|
||||
mine_counter += 1
|
||||
mine = onehot(mine, len(self.materials_mines_dict))
|
||||
self.player2_plan = self._add_node(self.player2_plan, self.plan['mines'][mine_counter], features=mine)
|
||||
self.player2_plan = self._add_edge(self.player2_plan, self.plan['mines'][mine_counter], n, tool=t)
|
||||
#self._plot_plan_graph(self.player2_plan, filename=f"plots/player2_{game_path.split('/')[-2]}_{game_path.split('/')[-1]}.png")
|
||||
self.player2_plan = from_networkx(self.player2_plan)
|
||||
|
||||
# with open('graphs.pkl', 'wb') as f:
|
||||
# pickle.dump([self.global_plan, self.player1_plan, self.player2_plan], f)
|
||||
|
||||
# construct a dict mapping materials to node indexes for each graph
|
||||
p1_dict = {self.inverse_materials_mines_dict[torch.argmax(features).item()+1]: node_index for node_index, features in enumerate(self.player1_plan.features)}
|
||||
p2_dict = {self.inverse_materials_mines_dict[torch.argmax(features).item()+1]: node_index for node_index, features in enumerate(self.player2_plan.features)}
|
||||
# candidate edge = (u,v)
|
||||
# u is from nodes with no out degree, v is from nodes with no in degree
|
||||
p1_u_candidates = [p1_dict[i] for i in self.find_nodes_with_less_than_four_out_degree(self.player1_plan)]
|
||||
p1_v_candidates = [p1_dict[i] for i in self.find_nodes_with_no_in_degree(self.player1_plan)]
|
||||
p2_u_candidates = [p2_dict[i] for i in self.find_nodes_with_less_than_four_out_degree(self.player2_plan)]
|
||||
p2_v_candidates = [p2_dict[i] for i in self.find_nodes_with_no_in_degree(self.player2_plan)]
|
||||
# convert candidates to indexes
|
||||
p1_edge_candidates = torch.tensor([(start, end) for start in p1_u_candidates for end in p1_v_candidates])
|
||||
p2_edge_candidates = torch.tensor([(start, end) for start in p2_u_candidates for end in p2_v_candidates])
|
||||
# find missing edges
|
||||
gl_edges = [[self.inverse_materials_mines_dict[torch.argmax(self.global_plan.features[edge[0]]).item()+1], self.inverse_materials_mines_dict[torch.argmax(self.global_plan.features[edge[1]]).item()+1]] for edge in self.global_plan.edge_index.t().tolist()]
|
||||
p1_edges = [[self.inverse_materials_mines_dict[torch.argmax(self.player1_plan.features[edge[0]]).item()+1], self.inverse_materials_mines_dict[torch.argmax(self.player1_plan.features[edge[1]]).item()+1]] for edge in self.player1_plan.edge_index.t().tolist()]
|
||||
p2_edges = [[self.inverse_materials_mines_dict[torch.argmax(self.player2_plan.features[edge[0]]).item()+1], self.inverse_materials_mines_dict[torch.argmax(self.player2_plan.features[edge[1]]).item()+1]] for edge in self.player2_plan.edge_index.t().tolist()]
|
||||
p1_missing_edges = [list(sublist) for sublist in set(map(tuple, gl_edges)) - set(map(tuple, p1_edges))]
|
||||
p2_missing_edges = [list(sublist) for sublist in set(map(tuple, gl_edges)) - set(map(tuple, p2_edges))]
|
||||
# convert missing edges as indexes
|
||||
p1_missing_edges_idx = torch.tensor([(p1_dict[e[0]], p1_dict[e[1]]) for e in p1_missing_edges])
|
||||
p2_missing_edges_idx = torch.tensor([(p2_dict[e[0]], p2_dict[e[1]]) for e in p2_missing_edges])
|
||||
# check if all missing edges are present in the candidates
|
||||
assert all(any(torch.equal(element, row) for row in p1_edge_candidates) for element in p1_missing_edges_idx)
|
||||
assert all(any(torch.equal(element, row) for row in p2_edge_candidates) for element in p2_missing_edges_idx)
|
||||
# concat candidates to plan graph
|
||||
if p1_edge_candidates.numel() != 0:
|
||||
self.player1_edge_label_index = torch.cat([self.player1_plan.edge_index, p1_edge_candidates.permute(1, 0)], dim=-1)
|
||||
# create labels
|
||||
self.player1_edge_label_own_missing_knowledge = torch.cat((torch.ones(self.player1_plan.edge_index.shape[1]), torch.zeros(p1_edge_candidates.shape[0])))
|
||||
else:
|
||||
# no missing knowledge
|
||||
self.player1_edge_label_index = self.player1_plan.edge_index
|
||||
# create labels
|
||||
self.player1_edge_label_own_missing_knowledge = torch.ones(self.player1_plan.edge_index.shape[1])
|
||||
if p2_edge_candidates.numel() != 0:
|
||||
self.player2_edge_label_index = torch.cat([self.player2_plan.edge_index, p2_edge_candidates.permute(1, 0)], dim=-1)
|
||||
# create labels
|
||||
self.player2_edge_label_own_missing_knowledge = torch.cat((torch.ones(self.player2_plan.edge_index.shape[1]), torch.zeros(p2_edge_candidates.shape[0])))
|
||||
else:
|
||||
# no missing knowledge
|
||||
self.player2_edge_label_index = self.player2_plan.edge_index
|
||||
# create labels
|
||||
self.player2_edge_label_own_missing_knowledge = torch.ones(self.player2_plan.edge_index.shape[1])
|
||||
p1_edge_list = [tuple(x) for x in self.player1_edge_label_index.T.tolist()]
|
||||
p1_missing_edges_idx_list = [tuple(x) for x in p1_missing_edges_idx.tolist()]
|
||||
self.player1_edge_label_own_missing_knowledge[[p1_edge_list.index(x) for x in p1_missing_edges_idx_list]] = 1.
|
||||
p2_edge_list = [tuple(x) for x in self.player2_edge_label_index.T.tolist()]
|
||||
p2_missing_edges_idx_list = [tuple(x) for x in p2_missing_edges_idx.tolist()]
|
||||
self.player2_edge_label_own_missing_knowledge[[p2_edge_list.index(x) for x in p2_missing_edges_idx_list]] = 1.
|
||||
# compute other's missing knowledge == identify which one of my edges is unknown to the other player
|
||||
p1_original_edges_list = [tuple(x) for x in self.player1_plan.edge_index.T.tolist()]
|
||||
p2_original_edges_list = [tuple(x) for x in self.player2_plan.edge_index.T.tolist()]
|
||||
p1_other_missing_edges_idx = [(p1_dict[e[0]], p1_dict[e[1]]) for e in p2_missing_edges] # note here is p2_missing_edges
|
||||
p2_other_missing_edges_idx = [(p2_dict[e[0]], p2_dict[e[1]]) for e in p1_missing_edges] # note here is p1_missing_edges
|
||||
self.player1_edge_label_other_missing_knowledge = torch.zeros(self.player1_plan.edge_index.shape[1])
|
||||
self.player1_edge_label_other_missing_knowledge[[p1_original_edges_list.index(x) for x in p1_other_missing_edges_idx]] = 1.
|
||||
self.player2_edge_label_other_missing_knowledge = torch.zeros(self.player2_plan.edge_index.shape[1])
|
||||
self.player2_edge_label_other_missing_knowledge[[p2_original_edges_list.index(x) for x in p2_other_missing_edges_idx]] = 1.
|
||||
|
||||
self.__iter_ts = self.start_ts
|
||||
|
||||
self.action_labels = None
|
||||
self.materials = sorted(self.plan['materials'])
|
||||
|
||||
def _add_node(self, g, material, features):
|
||||
if material not in g.nodes:
|
||||
#print(f'Add node {material}')
|
||||
g.add_node(material, features=features)
|
||||
return g
|
||||
|
||||
def _add_edge(self, g, u, v, tool):
|
||||
if not g.has_edge(u, v):
|
||||
#print(f'Add edge ({u}, {v})')
|
||||
g.add_edge(u, v, tool=tool)
|
||||
return g
|
||||
|
||||
def _plot_plan_graph(self, g, filename):
|
||||
plt.figure(figsize=(20,20))
|
||||
pos = nx.spring_layout(g, seed=42)
|
||||
nx.draw(g, pos, with_labels=True, node_color='lightblue', edge_color='gray')
|
||||
plt.savefig(filename)
|
||||
plt.close()
|
||||
|
||||
def find_nodes_with_less_than_four_out_degree(self, data):
|
||||
edge_index = data.edge_index
|
||||
num_nodes = data.num_nodes
|
||||
degrees = degree(edge_index[0], num_nodes) # out degrees
|
||||
# find all nodes that have out degree less than 2
|
||||
nodes = torch.nonzero(degrees < 4).view(-1)
|
||||
nodes = [self.inverse_materials_mines_dict[torch.argmax(data.features[i]).item()+1] for i in nodes]
|
||||
# remove planks (bc all planks have out degree less than 2)
|
||||
nodes = [n for n in nodes if n.split('_')[-1] != 'PLANKS']
|
||||
# now check for planks with out degree 0
|
||||
check_zero_out_degree_planks = torch.nonzero(degrees < 1).view(-1)
|
||||
check_zero_out_degree_planks = [self.inverse_materials_mines_dict[torch.argmax(data.features[i]).item()+1] for i in check_zero_out_degree_planks]
|
||||
check_zero_out_degree_planks = [n for n in nodes if n.split('_')[-1] == 'PLANKS']
|
||||
nodes = nodes + check_zero_out_degree_planks
|
||||
return nodes
|
||||
|
||||
def find_nodes_with_no_in_degree(self, data):
|
||||
edge_index = data.edge_index
|
||||
num_nodes = data.num_nodes
|
||||
degrees = degree(edge_index[1], num_nodes) # in degrees
|
||||
nodes = torch.nonzero(degrees < 1).view(-1)
|
||||
nodes = [self.inverse_materials_mines_dict[torch.argmax(data.features[i]).item()+1] for i in nodes]
|
||||
nodes = [n for n in nodes if n.split('_')[-1] != 'PLANKS']
|
||||
return nodes
|
||||
|
||||
def __len__(self):
|
||||
return self.end_ts - self.start_ts
|
||||
|
||||
def __next__(self):
|
||||
if self.__iter_ts < self.end_ts:
|
||||
if self.load_dialogue:
|
||||
d = [x for x in self.dialogue_events if x[0] == self.__iter_ts]
|
||||
l = [x for x in self.dialogue_act_labels if x[0] == self.__iter_ts]
|
||||
d = d if d else None
|
||||
l = l if l else None
|
||||
else:
|
||||
d = None
|
||||
l = None
|
||||
if self.use_dialogue_moves:
|
||||
m = [x for x in self.dialogue_move_labels if x[0] == self.__iter_ts]
|
||||
m = m if m else None
|
||||
else:
|
||||
m = None
|
||||
if self.action_labels:
|
||||
a = [x for x in self.action_labels if (x.TickIndex//30 + self.start_ts) >= self.__iter_ts]
|
||||
if a:
|
||||
try:
|
||||
while not a[0].items:
|
||||
a = a[1:]
|
||||
al = self.materials.index(a[0].items[0]) if a else 0
|
||||
except Exception:
|
||||
print(a)
|
||||
print(a[0])
|
||||
print(a[0].items)
|
||||
print(a[0].items[0])
|
||||
exit()
|
||||
at = a[0].TickIndex//30 + self.start_ts
|
||||
an = int(a[0].Name[-1])
|
||||
a = [(at,al,an)]
|
||||
else:
|
||||
a = [(self.__iter_ts, self.materials.index(self.plan['materials'][0]), 1)]
|
||||
a = None
|
||||
else:
|
||||
if self.end_ts - self.__iter_ts < 10:
|
||||
a = None
|
||||
else:
|
||||
a = None
|
||||
if not a is None:
|
||||
if not a[0][0] == self.__iter_ts:
|
||||
a = None
|
||||
q = [x for x in self.question_pairs if (x[0][1] == self.__iter_ts)]
|
||||
q = q[0] if q else None
|
||||
frame_idx = self.__iter_ts - self.start_ts
|
||||
if self.load_third_person:
|
||||
frames = self.third_pers_frames
|
||||
elif self.load_player1:
|
||||
frames = self.player1_pov_frames
|
||||
elif self.load_player2:
|
||||
frames = self.player2_pov_frames
|
||||
else:
|
||||
frames = np.array([0])
|
||||
if len(frames) == 1:
|
||||
f = np.zeros((self.img_h,self.img_w,3))
|
||||
else:
|
||||
if frame_idx < frames.shape[0]:
|
||||
f = frames[frame_idx]
|
||||
else:
|
||||
f = np.zeros((self.img_h,self.img_w,3))
|
||||
if self.do_upperbound:
|
||||
if not q is None:
|
||||
qnum = 0
|
||||
base_rep = np.concatenate([
|
||||
onehot(q[0][2],2),
|
||||
onehot(q[0][3],2),
|
||||
onehot(q[0][4][qnum][0]+1,2),
|
||||
onehot(self.materials_dict[q[0][5][qnum][1]],len(self.materials_dict)),
|
||||
onehot(q[0][4][qnum][0]+1,2),
|
||||
onehot(self.materials_dict[q[0][5][qnum][1]],len(self.materials_dict)),
|
||||
onehot(['YES','MAYBE','NO'].index(q[1][0][qnum])+1,3),
|
||||
onehot(['YES','MAYBE','NO'].index(q[1][1][qnum])+1,3)
|
||||
])
|
||||
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
|
||||
ToM6 = base_rep if self.ToM6 is not None else np.zeros(1024)
|
||||
qnum = 1
|
||||
base_rep = np.concatenate([
|
||||
onehot(q[0][2],2),
|
||||
onehot(q[0][3],2),
|
||||
onehot(q[0][4][qnum][0]+1,2),
|
||||
onehot(self.materials_dict[q[0][5][qnum][1]],len(self.materials_dict)),
|
||||
onehot(q[0][4][qnum][0]+1,2),
|
||||
onehot(self.materials_dict[q[0][5][qnum][1]],len(self.materials_dict)),
|
||||
onehot(['YES','MAYBE','NO'].index(q[1][0][qnum])+1,3),
|
||||
onehot(['YES','MAYBE','NO'].index(q[1][1][qnum])+1,3)
|
||||
])
|
||||
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
|
||||
ToM7 = base_rep if self.ToM7 is not None else np.zeros(1024)
|
||||
qnum = 2
|
||||
base_rep = np.concatenate([
|
||||
onehot(q[0][2],2),
|
||||
onehot(q[0][3],2),
|
||||
onehot(q[0][4][qnum]+1,2),
|
||||
onehot(q[0][4][qnum]+1,2),
|
||||
onehot(self.materials_dict[q[1][0][qnum]] if q[1][0][qnum] in self.materials_dict else len(self.materials_dict)+1,len(self.materials_dict)+1),
|
||||
onehot(self.materials_dict[q[1][1][qnum]] if q[1][1][qnum] in self.materials_dict else len(self.materials_dict)+1,len(self.materials_dict)+1)
|
||||
])
|
||||
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
|
||||
ToM8 = base_rep if self.ToM8 is not None else np.zeros(1024)
|
||||
else:
|
||||
ToM6 = np.zeros(1024)
|
||||
ToM7 = np.zeros(1024)
|
||||
ToM8 = np.zeros(1024)
|
||||
if not l is None:
|
||||
base_rep = np.concatenate([
|
||||
onehot(l[0][1],2),
|
||||
onehot(l[0][2],len(self.dialogue_act_labels_dict))
|
||||
])
|
||||
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
|
||||
DAct = base_rep if self.DAct is not None else np.zeros(1024)
|
||||
else:
|
||||
DAct = np.zeros(1024)
|
||||
if not m is None:
|
||||
base_rep = np.concatenate([
|
||||
onehot(m[0][1],2),
|
||||
onehot(m[0][2][0],len(self.dialogue_move_labels_dict)),
|
||||
onehot(m[0][2][1],len(self.tools_dict) + len(self.materials_dict) + len(self.mines_dict)+1),
|
||||
onehot(m[0][2][2],len(self.tools_dict) + len(self.materials_dict) + len(self.mines_dict)+1),
|
||||
onehot(m[0][2][3],len(self.tools_dict) + len(self.materials_dict) + len(self.mines_dict)+1),
|
||||
])
|
||||
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
|
||||
DMove = base_rep if self.DMove is not None else np.zeros(1024)
|
||||
else:
|
||||
DMove = np.zeros(1024)
|
||||
else:
|
||||
ToM6 = self.ToM6[frame_idx] if self.ToM6 is not None else np.zeros(1024)
|
||||
ToM7 = self.ToM7[frame_idx] if self.ToM7 is not None else np.zeros(1024)
|
||||
ToM8 = self.ToM8[frame_idx] if self.ToM8 is not None else np.zeros(1024)
|
||||
DAct = self.DAct[frame_idx] if self.DAct is not None else np.zeros(1024)
|
||||
DMove = self.DAct[frame_idx] if self.DMove is not None else np.zeros(1024)
|
||||
intermediate = np.concatenate([ToM6,ToM7,ToM8,DAct,DMove])
|
||||
if self.load_int0:
|
||||
intermediate = np.zeros(1024*5)
|
||||
intermediate[:1024] = self.int0_exp2_feats[frame_idx]
|
||||
retval = ((self.__iter_ts,self.pov),d,l,q,f,a,intermediate,m)
|
||||
self.__iter_ts += 1
|
||||
return retval
|
||||
self.__iter_ts = self.start_ts
|
||||
raise StopIteration()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __load_videos(self):
|
||||
d = self.end_ts - self.start_ts
|
||||
if self.load_third_person:
|
||||
try:
|
||||
self.third_pers_file = glob(os.path.join(self.game_path,'third*gif'))[0]
|
||||
np_file = self.third_pers_file[:-3]+'npz'
|
||||
if os.path.isfile(np_file):
|
||||
self.third_pers_frames = np.load(np_file)['data']
|
||||
else:
|
||||
frames = imageio.get_reader(self.third_pers_file, '.gif')
|
||||
reshaper = lambda x: cv2.resize(x,(self.img_h,self.img_w))
|
||||
if 'main' in self.game_path:
|
||||
self.third_pers_frames = np.array([reshaper(f[95:4*95,250:-249,2::-1]) for f in frames])
|
||||
else:
|
||||
self.third_pers_frames = np.array([reshaper(f[-3*95:,250:-249,2::-1]) for f in frames])
|
||||
print(np_file,end=' ')
|
||||
np.savez_compressed(open(np_file,'wb'), data=self.third_pers_frames)
|
||||
print('saved')
|
||||
except Exception as e:
|
||||
self.third_pers_frames = np.array([0])
|
||||
if self.third_pers_frames.shape[0]//d < 10:
|
||||
self.third_pov_frame_rate = 6
|
||||
else:
|
||||
if self.third_pers_frames.shape[0]//d < 20:
|
||||
self.third_pov_frame_rate = 12
|
||||
else:
|
||||
if self.third_pers_frames.shape[0]//d < 45:
|
||||
self.third_pov_frame_rate = 30
|
||||
else:
|
||||
self.third_pov_frame_rate = 60
|
||||
self.third_pers_frames = self.third_pers_frames[::self.third_pov_frame_rate]
|
||||
else:
|
||||
self.third_pers_frames = np.array([0])
|
||||
if self.load_player1:
|
||||
try:
|
||||
search_str = 'play2*gif' if self.flip_video else 'play1*gif'
|
||||
self.player1_pov_file = glob(os.path.join(self.game_path,search_str))[0]
|
||||
np_file = self.player1_pov_file[:-3]+'npz'
|
||||
if os.path.isfile(np_file):
|
||||
self.player1_pov_frames = np.load(np_file)['data']
|
||||
else:
|
||||
frames = imageio.get_reader(self.player1_pov_file, '.gif')
|
||||
reshaper = lambda x: cv2.resize(x,(self.img_h,self.img_w))
|
||||
self.player1_pov_frames = np.array([reshaper(f[:,:,2::-1]) for f in frames])
|
||||
print(np_file,end=' ')
|
||||
np.savez_compressed(open(np_file,'wb'), data=self.player1_pov_frames)
|
||||
print('saved')
|
||||
except Exception as e:
|
||||
self.player1_pov_frames = np.array([0])
|
||||
if self.player1_pov_frames.shape[0]//d < 10:
|
||||
self.player1_pov_frame_rate = 6
|
||||
else:
|
||||
if self.player1_pov_frames.shape[0]//d < 20:
|
||||
self.player1_pov_frame_rate = 12
|
||||
else:
|
||||
if self.player1_pov_frames.shape[0]//d < 45:
|
||||
self.player1_pov_frame_rate = 30
|
||||
else:
|
||||
self.player1_pov_frame_rate = 60
|
||||
self.player1_pov_frames = self.player1_pov_frames[::self.player1_pov_frame_rate]
|
||||
else:
|
||||
self.player1_pov_frames = np.array([0])
|
||||
if self.load_player2:
|
||||
try:
|
||||
search_str = 'play1*gif' if self.flip_video else 'play2*gif'
|
||||
self.player2_pov_file = glob(os.path.join(self.game_path,search_str))[0]
|
||||
np_file = self.player2_pov_file[:-3]+'npz'
|
||||
if os.path.isfile(np_file):
|
||||
self.player2_pov_frames = np.load(np_file)['data']
|
||||
else:
|
||||
frames = imageio.get_reader(self.player2_pov_file, '.gif')
|
||||
reshaper = lambda x: cv2.resize(x,(self.img_h,self.img_w))
|
||||
self.player2_pov_frames = np.array([reshaper(f[:,:,2::-1]) for f in frames])
|
||||
print(np_file,end=' ')
|
||||
np.savez_compressed(open(np_file,'wb'), data=self.player2_pov_frames)
|
||||
print('saved')
|
||||
except Exception as e:
|
||||
self.player2_pov_frames = np.array([0])
|
||||
if self.player2_pov_frames.shape[0]//d < 10:
|
||||
self.player2_pov_frame_rate = 6
|
||||
else:
|
||||
if self.player2_pov_frames.shape[0]//d < 20:
|
||||
self.player2_pov_frame_rate = 12
|
||||
else:
|
||||
if self.player2_pov_frames.shape[0]//d < 45:
|
||||
self.player2_pov_frame_rate = 30
|
||||
else:
|
||||
self.player2_pov_frame_rate = 60
|
||||
self.player2_pov_frames = self.player2_pov_frames[::self.player2_pov_frame_rate]
|
||||
else:
|
||||
self.player2_pov_frames = np.array([0])
|
||||
|
||||
def __parse_question_pairs(self):
|
||||
question_dict = {}
|
||||
for q in self.questions:
|
||||
k = q[2][0][1] + q[2][1][1]
|
||||
if not k in question_dict:
|
||||
question_dict[k] = []
|
||||
question_dict[k].append(q)
|
||||
self.question_pairs = []
|
||||
for k,v in question_dict.items():
|
||||
if len(v) == 2:
|
||||
if v[0][1]+v[1][1] == 3:
|
||||
self.question_pairs.append(v)
|
||||
else:
|
||||
while len(v) > 1:
|
||||
pair = []
|
||||
pair.append(v.pop(0))
|
||||
pair.append(v.pop(0))
|
||||
while not pair[0][1]+pair[1][1] == 3:
|
||||
if not v:
|
||||
break
|
||||
# print(game_path,pair)
|
||||
pair.append(v.pop(0))
|
||||
pair.pop(0)
|
||||
if not v:
|
||||
break
|
||||
self.question_pairs.append(pair)
|
||||
self.question_pairs = sorted(self.question_pairs, key=lambda x: x[0][0])
|
||||
if self.load_player2 or self.pov==4:
|
||||
self.question_pairs = [sorted(q, key=lambda x: x[1],reverse=True) for q in self.question_pairs]
|
||||
else:
|
||||
self.question_pairs = [sorted(q, key=lambda x: x[1]) for q in self.question_pairs]
|
||||
self.question_pairs = [((a[0], b[0], a[1], b[1], a[2], b[2]), (a[3], b[3])) for a,b in self.question_pairs]
|
||||
|
||||
def __parse_dialogue(self):
|
||||
self.dialogue_events = []
|
||||
save_path = os.path.join(self.game_path,f'dialogue_{self.game_path.split("/")[-1]}.pkl')
|
||||
if os.path.isfile(save_path):
|
||||
self.dialogue_events = pickle.load(open( save_path, "rb" ))
|
||||
return
|
||||
for x in open(self.dialogue_file):
|
||||
if '[Async Chat Thread' in x:
|
||||
ts = list(map(int,x.split(' [')[0].strip('[]').split(':')))
|
||||
ts = 3600*ts[0] + 60*ts[1] + ts[2]
|
||||
player, event = x.strip().split('/INFO]: []<sledmcc')[1].split('> ',1)
|
||||
event = event.lower()
|
||||
event = ''.join([x if x in string.ascii_lowercase else f' {x} ' for x in event]).strip()
|
||||
event = event.replace(' ',' ').replace(' ',' ')
|
||||
player = int(player)
|
||||
if GameParser.tokenizer is None:
|
||||
GameParser.tokenizer = BertTokenizer.from_pretrained('bert-large-uncased', do_lower_case=True)
|
||||
if self.model is None:
|
||||
GameParser.model = BertModel.from_pretrained('bert-large-uncased', output_hidden_states=True).to(DEVICE)
|
||||
encoded_dict = GameParser.tokenizer.encode_plus(
|
||||
event, # Sentence to encode.
|
||||
add_special_tokens=True, # Add '[CLS]' and '[SEP]'
|
||||
return_tensors='pt', # Return pytorch tensors.
|
||||
)
|
||||
token_ids = encoded_dict['input_ids'].to(DEVICE)
|
||||
segment_ids = torch.ones(token_ids.size()).long().to(DEVICE)
|
||||
GameParser.model.eval()
|
||||
with torch.no_grad():
|
||||
outputs = GameParser.model(input_ids=token_ids, token_type_ids=segment_ids)
|
||||
outputs = outputs[1][0].cpu().data.numpy()
|
||||
self.dialogue_events.append((ts,player,event,outputs))
|
||||
pickle.dump(self.dialogue_events, open( save_path, "wb" ))
|
||||
print(f'Saved to {save_path}',flush=True)
|
||||
|
||||
def __parse_questions(self):
|
||||
self.questions = []
|
||||
for x in open(self.questions_file):
|
||||
if x[0] == '#':
|
||||
ts, qs = x.strip().split(' Number of records inserted: 1 # player')
|
||||
ts = list(map(int,ts.split(' ')[5].split(':')))
|
||||
ts = 3600*ts[0] + 60*ts[1] + ts[2]
|
||||
player = int(qs[0])
|
||||
questions = qs[2:].split(';')
|
||||
answers =[x[7:] for x in questions[3:]]
|
||||
questions = [x[9:].split(' ') for x in questions[:3]]
|
||||
questions[0] = (int(questions[0][0] == 'Have'), questions[0][-3])
|
||||
questions[1] = (int(questions[1][2] == 'know'), questions[1][-1])
|
||||
questions[2] = int(questions[2][1] == 'are')
|
||||
|
||||
self.questions.append((ts,player,questions,answers))
|
||||
def __parse_start_end(self):
|
||||
self.start_ts = [x.strip() for x in open(self.dialogue_file) if 'THEY ARE PLAYER' in x][1]
|
||||
self.start_ts = list(map(int,self.start_ts.split('] [')[0][1:].split(':')))
|
||||
self.start_ts = 3600*self.start_ts[0] + 60*self.start_ts[1] + self.start_ts[2]
|
||||
try:
|
||||
self.start_ts = max(self.start_ts, self.questions[0][0]-75)
|
||||
except Exception as e:
|
||||
pass
|
||||
self.end_ts = [x.strip() for x in open(self.dialogue_file) if 'Stopping' in x]
|
||||
if self.end_ts:
|
||||
self.end_ts = self.end_ts[0]
|
||||
self.end_ts = list(map(int,self.end_ts.split('] [')[0][1:].split(':')))
|
||||
self.end_ts = 3600*self.end_ts[0] + 60*self.end_ts[1] + self.end_ts[2]
|
||||
else:
|
||||
self.end_ts = self.dialogue_events[-1][0]
|
||||
try:
|
||||
self.end_ts = max(self.end_ts, self.questions[-1][0]) + 1
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def __load_dialogue_act_labels(self):
|
||||
file_name = 'config/dialogue_act_labels.json'
|
||||
if not os.path.isfile(file_name):
|
||||
files = sorted(glob('/home/*/MCC/*done.txt'))
|
||||
dialogue_act_dict = {}
|
||||
for file in files:
|
||||
game_str = ''
|
||||
for line in open(file):
|
||||
line = line.strip()
|
||||
if '_logs/' in line:
|
||||
game_str = line
|
||||
else:
|
||||
if line:
|
||||
line = line.split()
|
||||
key = f'{game_str}#{line[0]}'
|
||||
dialogue_act_dict[key] = line[-1]
|
||||
json.dump(dialogue_act_dict,open(file_name,'w'), indent=4)
|
||||
self.dialogue_act_dict = json.load(open(file_name))
|
||||
self.dialogue_act_labels_dict = {l : i for i, l in enumerate(sorted(list(set(self.dialogue_act_dict.values()))))}
|
||||
self.dialogue_act_bias = {l : sum([int(x==l) for x in self.dialogue_act_dict.values()]) for l in self.dialogue_act_labels_dict.keys()}
|
||||
json.dump(self.dialogue_act_labels_dict,open('config/dialogue_act_label_names.json','w'), indent=4)
|
||||
|
||||
def __assign_dialogue_act_labels(self):
|
||||
log_file = glob('/'.join([self.game_path,'mcc*log']))[0].split('mindcraft/')[1]
|
||||
self.dialogue_act_labels = []
|
||||
for emb in self.dialogue_events:
|
||||
ts = emb[0]
|
||||
h = ts//3600
|
||||
m = (ts%3600)//60
|
||||
s = ts%60
|
||||
key = f'{log_file}#[{h:02d}:{m:02d}:{s:02d}]:{emb[1]}>'
|
||||
self.dialogue_act_labels.append((emb[0],emb[1],self.dialogue_act_labels_dict[self.dialogue_act_dict[key]]))
|
||||
|
||||
def __load_dialogue_move_labels(self):
|
||||
file_name = "config/dialogue_move_labels.json"
|
||||
dialogue_move_dict = {}
|
||||
if not os.path.isfile(file_name):
|
||||
file_text = ''
|
||||
dialogue_moves = set()
|
||||
for line in open("XXX"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
if line[0] == '#':
|
||||
continue
|
||||
if line[0] == '[':
|
||||
tag_text = glob(f'data/*/*/mcc_{file_text}.log')[0].split('/',1)[-1]
|
||||
key = f'{tag_text}#{line.split()[0]}'
|
||||
value = line.split()[-1].split('#')
|
||||
if len(value) < 4:
|
||||
value += ['IGNORE']*(4-len(value))
|
||||
dialogue_moves.add(value[0])
|
||||
value = '#'.join(value)
|
||||
dialogue_move_dict[key] = value
|
||||
else:
|
||||
file_text = line
|
||||
dialogue_moves = sorted(list(dialogue_moves))
|
||||
|
||||
json.dump(dialogue_move_dict,open(file_name,'w'), indent=4)
|
||||
self.dialogue_move_dict = json.load(open(file_name))
|
||||
self.dialogue_move_labels_dict = {l : i for i, l in enumerate(sorted(list(set([lbl.split('#')[0] for lbl in self.dialogue_move_dict.values()]))))}
|
||||
self.dialogue_move_bias = {l : sum([int(x==l) for x in self.dialogue_move_dict.values()]) for l in self.dialogue_move_labels_dict.keys()}
|
||||
json.dump(self.dialogue_move_labels_dict,open('config/dialogue_move_label_names.json','w'), indent=4)
|
||||
|
||||
def __assign_dialogue_move_labels(self):
|
||||
log_file = glob('/'.join([self.game_path,'mcc*log']))[0].split('mindcraft/')[1]
|
||||
self.dialogue_move_labels = []
|
||||
for emb in self.dialogue_events:
|
||||
ts = emb[0]
|
||||
h = ts//3600
|
||||
m = (ts%3600)//60
|
||||
s = ts%60
|
||||
key = f'{log_file}#[{h:02d}:{m:02d}:{s:02d}]:{emb[1]}>'
|
||||
move = self.dialogue_move_dict[key].split('#')
|
||||
move[0] = self.dialogue_move_labels_dict[move[0]]
|
||||
for i,m in enumerate(move[1:]):
|
||||
if m == 'IGNORE':
|
||||
move[i+1] = 0
|
||||
elif m in self.materials_dict:
|
||||
move[i+1] = self.materials_dict[m]
|
||||
elif m in self.mines_dict:
|
||||
move[i+1] = self.mines_dict[m] + len(self.materials_dict)
|
||||
elif m in self.tools_dict:
|
||||
move[i+1] = self.tools_dict[m] + len(self.materials_dict) + len(self.mines_dict)
|
||||
else:
|
||||
print(move)
|
||||
exit()
|
||||
self.dialogue_move_labels.append((emb[0],emb[1],move))
|
||||
|
||||
def __load_replay_data(self):
|
||||
self.actions = None
|
||||
|
||||
def __load_intermediate(self):
|
||||
if self.intermediate > 15:
|
||||
self.do_upperbound = True
|
||||
else:
|
||||
self.do_upperbound = False
|
||||
if self.pov in [1,2]:
|
||||
self.ToM6 = np.load(glob(f'{self.game_path}/intermediate_ToM6*player{self.pov}.npz')[0])['data'] if self.intermediate % 2 else None
|
||||
self.intermediate = self.intermediate // 2
|
||||
self.ToM7 = np.load(glob(f'{self.game_path}/intermediate_ToM7*player{self.pov}.npz')[0])['data'] if self.intermediate % 2 else None
|
||||
self.intermediate = self.intermediate // 2
|
||||
self.ToM8 = np.load(glob(f'{self.game_path}/intermediate_ToM8*player{self.pov}.npz')[0])['data'] if self.intermediate % 2 else None
|
||||
self.intermediate = self.intermediate // 2
|
||||
self.DAct = np.load(glob(f'{self.game_path}/intermediate_DAct*player{self.pov}.npz')[0])['data'] if self.intermediate % 2 else None
|
||||
self.intermediate = self.intermediate // 2
|
||||
self.DMove = None
|
||||
else:
|
||||
self.ToM6 = None
|
||||
self.ToM7 = None
|
||||
self.ToM8 = None
|
||||
self.DAct = None
|
||||
self.DMove = None
|
||||
|
||||
def __load_int0_feats(self):
|
||||
self.int0_exp2_feats = np.load(glob(f'{self.game_path}/int0_exp2*player{self.pov}.npz')[0])['data']
|
||||
# self.int0_exp3_feats = np.load(np.load(glob(f'{self.game_path}/int0_exp3*player{self.pov}.npz')[0])['data'])
|
||||
|
BIN
src/models/.DS_Store
vendored
Normal file
BIN
src/models/.DS_Store
vendored
Normal file
Binary file not shown.
0
src/models/__init__.py
Normal file
0
src/models/__init__.py
Normal file
207
src/models/losses.py
Normal file
207
src/models/losses.py
Normal file
|
@ -0,0 +1,207 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss
|
||||
|
||||
def onehot(x,n):
|
||||
retval = np.zeros(n)
|
||||
if x > 0:
|
||||
retval[x-1] = 1
|
||||
return retval
|
||||
|
||||
class PlanLoss(nn.Module):
|
||||
def __init__(self):
|
||||
super(PlanLoss, self).__init__()
|
||||
def getWeights(self, output, target):
|
||||
# return 1
|
||||
f1 = (1+5*torch.stack([2-torch.sum(target.reshape(-1,21,21),dim=-1)]*21,dim=-1)).reshape(-1,21*21)
|
||||
f2 = 100*target + 1
|
||||
return (f1+f2)/60
|
||||
exit(0)
|
||||
# print(max(torch.sum(target.reshape(21,21),dim=-1)))
|
||||
return (target*torch.sum(target,dim=-1) + 1)
|
||||
def MSELoss(self, output, target):
|
||||
retval = (output - target)**2
|
||||
retval *= self.getWeights(output,target)
|
||||
return torch.mean(retval)
|
||||
def BCELoss(self, output, target, loss_mask=None):
|
||||
mask_factor = torch.ones(target.shape).to(output.device)
|
||||
if loss_mask is not None:
|
||||
loss_mask = loss_mask.reshape(-1,21,21)
|
||||
mask_factor = mask_factor.reshape(-1,21,21)
|
||||
# print(mask_factor.shape,loss_mask.shape,output.shape,target.shape)
|
||||
for idx, tgt in enumerate(loss_mask):
|
||||
for jdx, tgt_node in enumerate(tgt):
|
||||
if sum(tgt_node) == 0:
|
||||
mask_factor[idx,jdx] *= 0
|
||||
|
||||
# print(loss_mask[0].data.cpu().numpy())
|
||||
# print(mask_factor[0].data.cpu().numpy())
|
||||
# print()
|
||||
# print(loss_mask[45].data.cpu().numpy())
|
||||
# print(mask_factor[45].data.cpu().numpy())
|
||||
# print()
|
||||
# print(loss_mask[-1].data.cpu().numpy())
|
||||
# print(mask_factor[-1].data.cpu().numpy())
|
||||
# print()
|
||||
|
||||
|
||||
loss_mask = loss_mask.reshape(-1,21*21)
|
||||
mask_factor = mask_factor.reshape(-1,21*21)
|
||||
# print(loss_mask.shape, target.shape, mask_factor.shape)
|
||||
# exit()
|
||||
|
||||
factor = (10 if target.shape[-1]==441 else 1)# * torch.sum(target,dim=-1)+1
|
||||
retval = -1 * mask_factor * (factor * target * torch.log(1e-6+output) + (1-target) * torch.log(1e-6+1-output))
|
||||
factor = torch.stack([torch.sum(target,dim=-1)+1]*target.shape[-1],dim=-1)
|
||||
return torch.mean(factor*retval)
|
||||
return torch.mean(retval)
|
||||
def forward(self, output, target, loss_mask=None):
|
||||
return self.BCELoss(output,target,loss_mask) + 0.01*torch.sum(output - 1/21)
|
||||
# return self.MSELoss(output,target)
|
||||
|
||||
class DialogueActLoss(nn.Module):
|
||||
def __init__(self):
|
||||
super(DialogueActLoss, self).__init__()
|
||||
self.bias = torch.tensor([289,51,45,57,14,12,1,113,6,264,27,63,22,66,2,761,129,163,5]).float()
|
||||
self.bias = max(self.bias) - self.bias + 1
|
||||
self.bias /= torch.sum(self.bias)
|
||||
self.bias = 1-self.bias
|
||||
# self.bias *= self.bias
|
||||
def BCELoss(self, output, target):
|
||||
target = torch.stack([torch.tensor(onehot(x + 1,19)).long() for x in target]).to(output.device)
|
||||
retval = -1 * (target * torch.log(1e-6+output) + (1-target) * torch.log(1e-6+1-output))
|
||||
retval *= torch.stack([self.bias] * output.shape[0]).to(output.device)
|
||||
# print(output)
|
||||
# print(target)
|
||||
# print(retval)
|
||||
# print(torch.mean(retval))
|
||||
# exit()
|
||||
return torch.mean(retval)
|
||||
def forward(self, output, target):
|
||||
return self.BCELoss(output,target)
|
||||
# return self.MSELoss(output,target)
|
||||
|
||||
class DialogueMoveLoss(nn.Module):
|
||||
def __init__(self, device):
|
||||
super(DialogueMoveLoss, self).__init__()
|
||||
# self.bias = torch.tensor([289,51,45,57,14,12,1,113,6,264,27,63,22,66,2,761,129,163,5]).float()
|
||||
# self.bias = max(self.bias) - self.bias + 1
|
||||
# self.bias /= torch.sum(self.bias)
|
||||
# self.bias = 1-self.bias
|
||||
move_weights = torch.tensor(np.array([202, 34, 34, 48, 4, 2, 420, 10, 54, 1, 10, 11, 30, 28, 14, 2, 16, 6, 2, 86, 4, 12, 28, 2, 2, 16, 12, 14, 4, 1, 12, 258, 12, 26, 2])).float().to(device)
|
||||
move_weights = 1+ max(move_weights) - move_weights
|
||||
self.loss1 = CrossEntropyLoss(weight=move_weights)
|
||||
zero_bias = 0.773
|
||||
num_classes = 40
|
||||
|
||||
weight = torch.tensor(np.array([50 if not x else 1 for x in range(num_classes)])).float().to(device)
|
||||
weight = 1+ max(weight) - weight
|
||||
self.loss2 = CrossEntropyLoss(weight=weight)
|
||||
# self.bias *= self.bias
|
||||
def BCELoss(self, output, target,zero_bias):
|
||||
# # print(output.shape,target.shape)
|
||||
# bias = torch.tensor(np.array([1 if t else zero_bias for t in target])).to(output.device)
|
||||
# target = torch.stack([torch.tensor(onehot(x,output.shape[-1])).long() for x in target]).to(output.device)
|
||||
# # print(target.shape, bias.shape, bias)
|
||||
|
||||
# retval = -1 * (target * torch.log(1e-6+output) + (1-target) * torch.log(1e-6+1-output))
|
||||
# retval = torch.mean(retval,-1)
|
||||
|
||||
# # print(retval.shape)
|
||||
# retval *= bias
|
||||
# # retval *= torch.stack([self.bias] * output.shape[0]).to(output.device)
|
||||
# # print(output)
|
||||
# # print(target)
|
||||
# # print(retval)
|
||||
# # print(torch.mean(retval))
|
||||
# # exit()
|
||||
# # retval = self.loss(output,target)
|
||||
# return torch.mean(retval) # retval #
|
||||
# weight = [zero_bias if x else (1-zero_bias)/(output.shape[-1]-1) for x in range(output.shape[-1])]
|
||||
retval = self.loss2(output,target) if zero_bias else self.loss1(output,target)
|
||||
return retval #
|
||||
def forward(self, output, target):
|
||||
o1, o2, o3, o4 = output
|
||||
t1, t2, t3, t4 = target
|
||||
|
||||
# print(t2,t2.shape, o2.shape)
|
||||
|
||||
# if sum(t2):
|
||||
# o2, t2 = zip(*[(a,b) for a,b in zip(o2,t2) if b])
|
||||
# o2 = torch.stack(o2)
|
||||
# t2 = torch.stack(t2)
|
||||
# if sum(t3):
|
||||
# o3, t3 = zip(*[(a,b) for a,b in zip(o3,t3) if b])
|
||||
# o3 = torch.stack(o3)
|
||||
# t3 = torch.stack(t3)
|
||||
# if sum(t4):
|
||||
# o4, t4 = zip(*[(a,b) for a,b in zip(o4,t4) if b])
|
||||
# o4 = torch.stack(o4)
|
||||
# t4 = torch.stack(t4)
|
||||
|
||||
# print(t2,t2.shape, o2.shape)
|
||||
# exit()
|
||||
|
||||
retval = sum([
|
||||
1*self.BCELoss(output[0],target[0],0),
|
||||
0*self.BCELoss(output[1],target[1],1),
|
||||
0*self.BCELoss(output[2],target[2],1),
|
||||
0*self.BCELoss(output[3],target[3],1)
|
||||
])
|
||||
return retval #sum([fact*self.BCELoss(o,t,zbias) for fact,zbias,o,t in zip([1,0,0,0],[0,1,1,1],output,target)])
|
||||
# return self.MSELoss(output,target)
|
||||
|
||||
class DialoguePredLoss(nn.Module):
|
||||
def __init__(self):
|
||||
super(DialoguePredLoss, self).__init__()
|
||||
self.bias = torch.tensor([289,51,45,57,14,12,1,113,6,264,27,63,22,66,2,761,129,163,5,0]).float()
|
||||
self.bias[-1] = 1460#2 * torch.sum(self.bias) // 3
|
||||
self.bias = max(self.bias) - self.bias + 1
|
||||
self.bias /= torch.sum(self.bias)
|
||||
self.bias = 1-self.bias
|
||||
# self.bias *= self.bias
|
||||
def BCELoss(self, output, target):
|
||||
target = torch.stack([torch.tensor(onehot(x + 1,20)).long() for x in target]).to(output.device)
|
||||
retval = -1 * (target * torch.log(1e-6+output) + (1-target) * torch.log(1e-6+1-output))
|
||||
retval *= torch.stack([self.bias] * output.shape[0]).to(output.device)
|
||||
# print(output)
|
||||
# print(target)
|
||||
# print(retval)
|
||||
# print(torch.mean(retval))
|
||||
# exit()
|
||||
return torch.mean(retval)
|
||||
def forward(self, output, target):
|
||||
return self.BCELoss(output,target)
|
||||
# return self.MSELoss(output,target)
|
||||
|
||||
class ActionLoss(nn.Module):
|
||||
def __init__(self):
|
||||
super(ActionLoss, self).__init__()
|
||||
self.bias1 = torch.tensor([134,1370,154,128,220,166,46,76,106,78,88,124,102,120,276,122,112,106,44,174,20]).float()
|
||||
# self.bias[-1] = 1460#2 * torch.sum(self.bias) // 3
|
||||
# self.bias1 = torch.ones(21).float()
|
||||
self.bias1 = max(self.bias1) - self.bias1 + 1
|
||||
self.bias1 /= torch.sum(self.bias1)
|
||||
self.bias1 = 1-self.bias1
|
||||
self.bias2 = torch.tensor([1168,1310]).float()
|
||||
# self.bias2[-1] = 1460#2 * torch.sum(self.bias) // 3
|
||||
# self.bias2 = torch.ones(21).float()
|
||||
self.bias2 = max(self.bias2) - self.bias2 + 1
|
||||
self.bias2 /= torch.sum(self.bias2)
|
||||
self.bias2 = 1-self.bias2
|
||||
# self.bias *= self.bias
|
||||
def BCELoss(self, output, target):
|
||||
# target = torch.stack([torch.tensor(onehot(x + 1,20)).long() for x in target]).to(output.device)
|
||||
retval = -1 * (target * torch.log(1e-6+output) + (1-target) * torch.log(1e-6+1-output))
|
||||
# print(self.bias1.shape,self.bias2.shape,output.shape[-1])
|
||||
retval *= torch.stack([self.bias2 if output.shape[-1]==2 else self.bias1] * output.shape[0]).to(output.device)
|
||||
# print(output)
|
||||
# print(target)
|
||||
# print(retval)
|
||||
# print(torch.mean(retval))
|
||||
# exit()
|
||||
return torch.mean(retval)
|
||||
def forward(self, output, target):
|
||||
return self.BCELoss(output,target)
|
||||
# return self.MSELoss(output,target)
|
205
src/models/model_with_dialogue_moves.py
Executable file
205
src/models/model_with_dialogue_moves.py
Executable file
|
@ -0,0 +1,205 @@
|
|||
import sys, torch, random
|
||||
from numpy.core.fromnumeric import reshape
|
||||
import torch.nn as nn, numpy as np
|
||||
from src.data.game_parser import DEVICE
|
||||
|
||||
def onehot(x,n):
|
||||
retval = np.zeros(n)
|
||||
if x > 0:
|
||||
retval[x-1] = 1
|
||||
return retval
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, seq_model_type=0,device=DEVICE):
|
||||
super(Model, self).__init__()
|
||||
self.device = device
|
||||
print("model set to device", self.device)
|
||||
|
||||
my_rnn = lambda i,o: nn.GRU(i,o)
|
||||
#my_rnn = lambda i,o: nn.LSTM(i,o)
|
||||
|
||||
plan_emb_in = 81
|
||||
plan_emb_out = 32
|
||||
q_emb = 100
|
||||
|
||||
self.plan_embedder0 = my_rnn(plan_emb_in,plan_emb_out)
|
||||
self.plan_embedder1 = my_rnn(plan_emb_in,plan_emb_out)
|
||||
self.plan_embedder2 = my_rnn(plan_emb_in,plan_emb_out)
|
||||
|
||||
# self.dialogue_listener = my_rnn(1126,768)
|
||||
dlist_hidden = 1024
|
||||
frame_emb = 512
|
||||
self.move_emb = 157
|
||||
drnn_in = 1024 + 2 + q_emb + frame_emb + self.move_emb
|
||||
# drnn_in = 1024 + 2
|
||||
|
||||
# my_rnn = lambda i,o: nn.GRU(i,o)
|
||||
my_rnn = lambda i,o: nn.LSTM(i,o)
|
||||
|
||||
if seq_model_type==0:
|
||||
self.dialogue_listener_rnn = nn.GRU(drnn_in,dlist_hidden)
|
||||
self.dialogue_listener = lambda x: \
|
||||
self.dialogue_listener_rnn(x.reshape(-1,1,drnn_in))[0]
|
||||
elif seq_model_type==1:
|
||||
self.dialogue_listener_rnn = nn.LSTM(drnn_in,dlist_hidden)
|
||||
self.dialogue_listener = lambda x: \
|
||||
self.dialogue_listener_rnn(x.reshape(-1,1,drnn_in))[0]
|
||||
elif seq_model_type==2:
|
||||
mask_fun = lambda x: torch.triu(torch.ones(x.shape[0],x.shape[0]),diagonal=1).bool().to(self.device)
|
||||
sincos_fun = lambda x:torch.transpose(torch.stack([
|
||||
torch.sin(2*np.pi*torch.tensor(list(range(x)))/x),
|
||||
torch.cos(2*np.pi*torch.tensor(list(range(x)))/x)
|
||||
]),0,1).reshape(-1,1,2)
|
||||
self.dialogue_listener_lin1 = nn.Linear(drnn_in,dlist_hidden-2)
|
||||
self.dialogue_listener_attn = nn.MultiheadAttention(dlist_hidden, 8)
|
||||
self.dialogue_listener_wrap = lambda x: self.dialogue_listener_attn(x,x,x,attn_mask=mask_fun(x))
|
||||
self.dialogue_listener = lambda x: self.dialogue_listener_wrap(torch.cat([
|
||||
sincos_fun(x.shape[0]).float().to(self.device),
|
||||
self.dialogue_listener_lin1(x).reshape(-1,1,dlist_hidden-2)
|
||||
], axis=-1))[0]
|
||||
else:
|
||||
print('Sequence model type must be in (0: GRU, 1: LSTM, 2: Transformer), but got ', seq_model_type)
|
||||
exit()
|
||||
|
||||
conv_block = lambda i,o,k,p,s: nn.Sequential(
|
||||
nn.Conv2d( i, o, k, padding=p, stride=s),
|
||||
nn.BatchNorm2d(o),
|
||||
nn.Dropout(0.5),
|
||||
nn.MaxPool2d(2),
|
||||
nn.BatchNorm2d(o),
|
||||
nn.Dropout(0.5),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
conv_block( 3, 8, 3, 1, 1),
|
||||
# conv_block( 3, 8, 5, 2, 2),
|
||||
conv_block( 8, 32, 5, 2, 2),
|
||||
conv_block( 32, frame_emb//4, 5, 2, 2),
|
||||
nn.Conv2d( frame_emb//4, frame_emb, 3),nn.ReLU(),
|
||||
)
|
||||
|
||||
qlayer = lambda i,o : nn.Sequential(
|
||||
nn.Linear(i,512),
|
||||
nn.Dropout(0.5),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.5),
|
||||
nn.Linear(512,o),
|
||||
nn.Dropout(0.5),
|
||||
# nn.Softmax(-1)
|
||||
)
|
||||
|
||||
q_in_size = 3*plan_emb_out+dlist_hidden+q_emb
|
||||
|
||||
self.q01 = qlayer(q_in_size,2)
|
||||
self.q02 = qlayer(q_in_size,2)
|
||||
self.q03 = qlayer(q_in_size,2)
|
||||
|
||||
self.q11 = qlayer(q_in_size,3)
|
||||
self.q12 = qlayer(q_in_size,3)
|
||||
self.q13 = qlayer(q_in_size,22)
|
||||
|
||||
self.q21 = qlayer(q_in_size,3)
|
||||
self.q22 = qlayer(q_in_size,3)
|
||||
self.q23 = qlayer(q_in_size,22)
|
||||
|
||||
def forward(self,game,global_plan=False, player_plan=False, intermediate=False):
|
||||
retval = []
|
||||
|
||||
l = list(game)
|
||||
_,d,_,q,f,_,_,m = zip(*list(game))
|
||||
|
||||
|
||||
parse_move = lambda m: np.concatenate([
|
||||
onehot(m[0][1], 2),
|
||||
onehot(m[0][2][0]+1, len(game.dialogue_move_labels_dict)),
|
||||
onehot(m[0][2][1]+1, len(game.materials_dict)+len(game.mines_dict)+len(game.tools_dict)+1),
|
||||
onehot(m[0][2][2]+1, len(game.materials_dict)+len(game.mines_dict)+len(game.tools_dict)+1),
|
||||
onehot(m[0][2][3]+1, len(game.materials_dict)+len(game.mines_dict)+len(game.tools_dict)+1)
|
||||
])
|
||||
# print(2+len(game.dialogue_move_labels_dict)+3*(len(game.materials_dict)+len(game.mines_dict)+len(game.tools_dict)+1))
|
||||
# print(len(game.dialogue_move_labels_dict))
|
||||
m = np.stack([np.zeros(self.move_emb) if move is None else parse_move(move) for move in m])
|
||||
|
||||
|
||||
h = None
|
||||
f = np.array(f, dtype=np.uint8)
|
||||
# f = torch.tensor(f).permute(0,3,1,2).float().to(self.device)
|
||||
# flt_lst = [(a,b) for a,b in zip(d,q) if (not a is None) or (not b is None)]
|
||||
# if not flt_lst:
|
||||
# return []
|
||||
# d,q = zip(*flt_lst)
|
||||
d = np.stack([np.concatenate(([int(x[0][1]==2),int(x[0][1]==1)],x[0][-1])) if not x is None else np.zeros(1026) for x in d])
|
||||
def parse_q(q):
|
||||
if not q is None:
|
||||
q ,l = q
|
||||
q = np.concatenate([
|
||||
onehot(q[2],2),
|
||||
onehot(q[3],2),
|
||||
onehot(q[4][0][0]+1,2),
|
||||
onehot(game.materials_dict[q[4][0][1]],len(game.materials_dict)),
|
||||
onehot(q[4][1][0]+1,2),
|
||||
onehot(game.materials_dict[q[4][1][1]],len(game.materials_dict)),
|
||||
onehot(q[4][2]+1,2),
|
||||
onehot(q[5][0][0]+1,2),
|
||||
onehot(game.materials_dict[q[5][0][1]],len(game.materials_dict)),
|
||||
onehot(q[5][1][0]+1,2),
|
||||
onehot(game.materials_dict[q[5][1][1]],len(game.materials_dict)),
|
||||
onehot(q[5][2]+1,2)
|
||||
])
|
||||
else:
|
||||
q = np.zeros(100)
|
||||
l = None
|
||||
return q, l
|
||||
try:
|
||||
sel1 = int([x[0][2] for x in q if not x is None][0] == 1)
|
||||
sel2 = 1 - sel1
|
||||
except Exception as e:
|
||||
sel1 = 0
|
||||
sel2 = 0
|
||||
q = [parse_q(x) for x in q]
|
||||
q, l = zip(*q)
|
||||
q = np.stack(q)
|
||||
|
||||
if not global_plan and not player_plan:
|
||||
plan_emb = torch.cat([
|
||||
self.plan_embedder0(torch.stack(list(map(torch.tensor,game.global_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||
self.plan_embedder1(torch.stack(list(map(torch.tensor,game.player1_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||
self.plan_embedder2(torch.stack(list(map(torch.tensor,game.player2_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0]
|
||||
])
|
||||
plan_emb = 0*plan_emb
|
||||
elif global_plan:
|
||||
plan_emb = torch.cat([
|
||||
self.plan_embedder0(torch.stack(list(map(torch.tensor,game.global_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||
self.plan_embedder1(torch.stack(list(map(torch.tensor,game.player1_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||
self.plan_embedder2(torch.stack(list(map(torch.tensor,game.player2_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0]
|
||||
])
|
||||
else:
|
||||
plan_emb = torch.cat([
|
||||
0*self.plan_embedder0(torch.stack(list(map(torch.tensor,game.global_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||
sel1*self.plan_embedder1(torch.stack(list(map(torch.tensor,game.player1_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||
sel2*self.plan_embedder2(torch.stack(list(map(torch.tensor,game.player2_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0]
|
||||
])
|
||||
|
||||
u = torch.cat((
|
||||
torch.tensor(d).float().to(self.device),
|
||||
torch.tensor(q).float().to(self.device),
|
||||
self.conv(torch.tensor(f).permute(0,3,1,2).float().to(self.device)).reshape(-1,512),
|
||||
torch.tensor(m).float().to(self.device)
|
||||
),axis=-1)
|
||||
u = u.float().to(self.device)
|
||||
|
||||
y = self.dialogue_listener(u)
|
||||
y = y.reshape(-1,y.shape[-1])
|
||||
|
||||
if intermediate:
|
||||
return y
|
||||
|
||||
if all([x is None for x in l]):
|
||||
return []
|
||||
|
||||
fun_lst = [self.q01,self.q02,self.q03,self.q11,self.q12,self.q13,self.q21,self.q22,self.q23]
|
||||
fun = lambda x: [f(x) for f in fun_lst]
|
||||
|
||||
retval = [(_l,fun(torch.cat((plan_emb,torch.tensor(_q).float().to(self.device),_y)))) for _y, _q, _l in zip(y,q,l)if not _l is None]
|
||||
return retval
|
226
src/models/model_with_dialogue_moves_graphs.py
Normal file
226
src/models/model_with_dialogue_moves_graphs.py
Normal file
|
@ -0,0 +1,226 @@
|
|||
import torch
|
||||
import torch.nn as nn, numpy as np
|
||||
from src.data.game_parser import DEVICE
|
||||
from torch_geometric.nn import GATv2Conv, MeanAggregation
|
||||
|
||||
|
||||
def onehot(x,n):
|
||||
retval = np.zeros(n)
|
||||
if x > 0:
|
||||
retval[x-1] = 1
|
||||
return retval
|
||||
|
||||
|
||||
class PlanGraphEmbedder(nn.Module):
|
||||
def __init__(self, h_dim, dropout=0.0, heads=4):
|
||||
super().__init__()
|
||||
self.proj_x = nn.Sequential(
|
||||
nn.Linear(27, h_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
self.proj_edge_attr = nn.Sequential(
|
||||
nn.Linear(12, h_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
self.conv1 = GATv2Conv(h_dim, h_dim, heads=heads, edge_dim=h_dim)
|
||||
self.conv2 = GATv2Conv(h_dim*heads, h_dim, heads=1, edge_dim=h_dim)
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.pool = MeanAggregation()
|
||||
|
||||
def forward(self, data):
|
||||
x, edge_index, edge_attr = data.features.to(DEVICE), data.edge_index.to(DEVICE), data.tool.to(DEVICE)
|
||||
x = self.proj_x(x)
|
||||
edge_attr = self.proj_edge_attr(edge_attr)
|
||||
x = self.conv1(x, edge_index, edge_attr)
|
||||
x = self.act(x)
|
||||
x = self.dropout(x)
|
||||
x = self.conv2(x, edge_index, edge_attr)
|
||||
x = self.pool(x)
|
||||
return x.squeeze(0)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, seq_model_type=0,device=DEVICE):
|
||||
super(Model, self).__init__()
|
||||
self.device = device
|
||||
print("model set to device", self.device)
|
||||
|
||||
plan_emb_out = 32*3
|
||||
q_emb = 100
|
||||
|
||||
self.plan_embedder0 = PlanGraphEmbedder(plan_emb_out)
|
||||
self.plan_embedder1 = PlanGraphEmbedder(plan_emb_out)
|
||||
self.plan_embedder2 = PlanGraphEmbedder(plan_emb_out)
|
||||
|
||||
dlist_hidden = 1024
|
||||
frame_emb = 512
|
||||
self.move_emb = 157
|
||||
drnn_in = 1024 + 2 + q_emb + frame_emb + self.move_emb
|
||||
|
||||
if seq_model_type==0:
|
||||
self.dialogue_listener_rnn = nn.GRU(drnn_in,dlist_hidden)
|
||||
self.dialogue_listener = lambda x: \
|
||||
self.dialogue_listener_rnn(x.reshape(-1,1,drnn_in))[0]
|
||||
elif seq_model_type==1:
|
||||
self.dialogue_listener_rnn = nn.LSTM(drnn_in,dlist_hidden)
|
||||
self.dialogue_listener = lambda x: \
|
||||
self.dialogue_listener_rnn(x.reshape(-1,1,drnn_in))[0]
|
||||
elif seq_model_type==2:
|
||||
mask_fun = lambda x: torch.triu(torch.ones(x.shape[0],x.shape[0]),diagonal=1).bool().to(self.device)
|
||||
sincos_fun = lambda x:torch.transpose(torch.stack([
|
||||
torch.sin(2*np.pi*torch.tensor(list(range(x)))/x),
|
||||
torch.cos(2*np.pi*torch.tensor(list(range(x)))/x)
|
||||
]),0,1).reshape(-1,1,2)
|
||||
self.dialogue_listener_lin1 = nn.Linear(drnn_in,dlist_hidden-2)
|
||||
self.dialogue_listener_attn = nn.MultiheadAttention(dlist_hidden, 8)
|
||||
self.dialogue_listener_wrap = lambda x: self.dialogue_listener_attn(x,x,x,attn_mask=mask_fun(x))
|
||||
self.dialogue_listener = lambda x: self.dialogue_listener_wrap(torch.cat([
|
||||
sincos_fun(x.shape[0]).float().to(self.device),
|
||||
self.dialogue_listener_lin1(x).reshape(-1,1,dlist_hidden-2)
|
||||
], axis=-1))[0]
|
||||
else:
|
||||
print('Sequence model type must be in (0: GRU, 1: LSTM, 2: Transformer), but got ', seq_model_type)
|
||||
exit()
|
||||
|
||||
conv_block = lambda i,o,k,p,s: nn.Sequential(
|
||||
nn.Conv2d( i, o, k, padding=p, stride=s),
|
||||
nn.BatchNorm2d(o),
|
||||
nn.Dropout(0.5),
|
||||
nn.MaxPool2d(2),
|
||||
nn.BatchNorm2d(o),
|
||||
nn.Dropout(0.5),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
conv_block( 3, 8, 3, 1, 1),
|
||||
conv_block( 8, 32, 5, 2, 2),
|
||||
conv_block( 32, frame_emb//4, 5, 2, 2),
|
||||
nn.Conv2d( frame_emb//4, frame_emb, 3),nn.ReLU(),
|
||||
)
|
||||
|
||||
qlayer = lambda i,o : nn.Sequential(
|
||||
nn.Linear(i,512),
|
||||
nn.Dropout(0.5),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.5),
|
||||
nn.Linear(512,o),
|
||||
nn.Dropout(0.5),
|
||||
)
|
||||
|
||||
q_in_size = plan_emb_out+dlist_hidden+q_emb
|
||||
|
||||
self.q01 = qlayer(q_in_size,2)
|
||||
self.q02 = qlayer(q_in_size,2)
|
||||
self.q03 = qlayer(q_in_size,2)
|
||||
|
||||
self.q11 = qlayer(q_in_size,3)
|
||||
self.q12 = qlayer(q_in_size,3)
|
||||
self.q13 = qlayer(q_in_size,22)
|
||||
|
||||
self.q21 = qlayer(q_in_size,3)
|
||||
self.q22 = qlayer(q_in_size,3)
|
||||
self.q23 = qlayer(q_in_size,22)
|
||||
|
||||
def forward(self,game,global_plan=False, player_plan=False, intermediate=False):
|
||||
retval = []
|
||||
|
||||
l = list(game)
|
||||
_,d,_,q,f,_,_,m = zip(*list(game))
|
||||
|
||||
parse_move = lambda m: np.concatenate([
|
||||
onehot(m[0][1], 2),
|
||||
onehot(m[0][2][0]+1, len(game.dialogue_move_labels_dict)),
|
||||
onehot(m[0][2][1]+1, len(game.materials_dict)+len(game.mines_dict)+len(game.tools_dict)+1),
|
||||
onehot(m[0][2][2]+1, len(game.materials_dict)+len(game.mines_dict)+len(game.tools_dict)+1),
|
||||
onehot(m[0][2][3]+1, len(game.materials_dict)+len(game.mines_dict)+len(game.tools_dict)+1)
|
||||
])
|
||||
m = np.stack([np.zeros(self.move_emb) if move is None else parse_move(move) for move in m])
|
||||
|
||||
f = np.array(f, dtype=np.uint8)
|
||||
d = np.stack([np.concatenate(([int(x[0][1]==2),int(x[0][1]==1)],x[0][-1])) if not x is None else np.zeros(1026) for x in d])
|
||||
def parse_q(q):
|
||||
if not q is None:
|
||||
q ,l = q
|
||||
q = np.concatenate([
|
||||
onehot(q[2],2),
|
||||
onehot(q[3],2),
|
||||
onehot(q[4][0][0]+1,2),
|
||||
onehot(game.materials_dict[q[4][0][1]],len(game.materials_dict)),
|
||||
onehot(q[4][1][0]+1,2),
|
||||
onehot(game.materials_dict[q[4][1][1]],len(game.materials_dict)),
|
||||
onehot(q[4][2]+1,2),
|
||||
onehot(q[5][0][0]+1,2),
|
||||
onehot(game.materials_dict[q[5][0][1]],len(game.materials_dict)),
|
||||
onehot(q[5][1][0]+1,2),
|
||||
onehot(game.materials_dict[q[5][1][1]],len(game.materials_dict)),
|
||||
onehot(q[5][2]+1,2)
|
||||
])
|
||||
else:
|
||||
q = np.zeros(100)
|
||||
l = None
|
||||
return q, l
|
||||
try:
|
||||
sel1 = int([x[0][2] for x in q if not x is None][0] == 1)
|
||||
sel2 = 1 - sel1
|
||||
except Exception as e:
|
||||
sel1 = 0
|
||||
sel2 = 0
|
||||
q = [parse_q(x) for x in q]
|
||||
q, l = zip(*q)
|
||||
q = np.stack(q)
|
||||
|
||||
if not global_plan and not player_plan:
|
||||
# plan_emb = torch.cat([
|
||||
# self.plan_embedder0(game.global_plan),
|
||||
# self.plan_embedder1(game.player1_plan),
|
||||
# self.plan_embedder2(game.player2_plan)
|
||||
# ])
|
||||
plan_emb = self.plan_embedder0(game.global_plan)
|
||||
plan_emb = 0*plan_emb
|
||||
elif global_plan:
|
||||
# plan_emb = torch.cat([
|
||||
# self.plan_embedder0(game.global_plan),
|
||||
# self.plan_embedder1(game.player1_plan),
|
||||
# self.plan_embedder2(game.player2_plan)
|
||||
# ])
|
||||
plan_emb = self.plan_embedder0(game.global_plan)
|
||||
else:
|
||||
# plan_emb = torch.cat([
|
||||
# 0*self.plan_embedder0(game.global_plan),
|
||||
# sel1*self.plan_embedder1(game.player1_plan),
|
||||
# sel2*self.plan_embedder2(game.player2_plan)
|
||||
# ])
|
||||
if sel1:
|
||||
plan_emb = self.plan_embedder1(game.player1_plan)
|
||||
elif sel2:
|
||||
plan_emb = self.plan_embedder2(game.player2_plan)
|
||||
else:
|
||||
plan_emb = self.plan_embedder0(game.global_plan)
|
||||
plan_emb = 0*plan_emb
|
||||
|
||||
u = torch.cat((
|
||||
torch.tensor(d).float().to(self.device),
|
||||
torch.tensor(q).float().to(self.device),
|
||||
self.conv(torch.tensor(f).permute(0,3,1,2).float().to(self.device)).reshape(-1,512),
|
||||
torch.tensor(m).float().to(self.device)
|
||||
),axis=-1)
|
||||
u = u.float().to(self.device)
|
||||
|
||||
y = self.dialogue_listener(u)
|
||||
y = y.reshape(-1,y.shape[-1])
|
||||
|
||||
if intermediate:
|
||||
return y
|
||||
|
||||
if all([x is None for x in l]):
|
||||
return []
|
||||
|
||||
fun_lst = [self.q01,self.q02,self.q03,self.q11,self.q12,self.q13,self.q21,self.q22,self.q23]
|
||||
fun = lambda x: [f(x) for f in fun_lst]
|
||||
|
||||
retval = [(_l,fun(torch.cat((plan_emb,torch.tensor(_q).float().to(self.device),_y)))) for _y, _q, _l in zip(y,q,l)if not _l is None]
|
||||
return retval
|
225
src/models/plan_model.py
Executable file
225
src/models/plan_model.py
Executable file
|
@ -0,0 +1,225 @@
|
|||
import sys, torch, random
|
||||
from numpy.core.fromnumeric import reshape
|
||||
import torch.nn as nn, numpy as np
|
||||
from torch.nn import functional as F
|
||||
from src.data.game_parser import DEVICE
|
||||
|
||||
def onehot(x,n):
|
||||
retval = np.zeros(n)
|
||||
if x > 0:
|
||||
retval[x-1] = 1
|
||||
return retval
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, seq_model_type=0,device=DEVICE):
|
||||
super(Model, self).__init__()
|
||||
|
||||
self.device = device
|
||||
|
||||
my_rnn = lambda i,o: nn.GRU(i,o)
|
||||
#my_rnn = lambda i,o: nn.LSTM(i,o)
|
||||
|
||||
plan_emb_in = 81
|
||||
plan_emb_out = 32
|
||||
q_emb = 100
|
||||
|
||||
self.plan_embedder0 = my_rnn(plan_emb_in,plan_emb_out)
|
||||
self.plan_embedder1 = my_rnn(plan_emb_in,plan_emb_out)
|
||||
self.plan_embedder2 = my_rnn(plan_emb_in,plan_emb_out)
|
||||
|
||||
# self.dialogue_listener = my_rnn(1126,768)
|
||||
dlist_hidden = 1024
|
||||
frame_emb = 512
|
||||
drnn_in = 5*1024 + 2 + frame_emb + 1024
|
||||
# drnn_in = 1024 + 2
|
||||
|
||||
# my_rnn = lambda i,o: nn.GRU(i,o)
|
||||
my_rnn = lambda i,o: nn.LSTM(i,o)
|
||||
|
||||
if seq_model_type==0:
|
||||
self.dialogue_listener_rnn = nn.GRU(drnn_in,dlist_hidden)
|
||||
self.dialogue_listener = lambda x: \
|
||||
self.dialogue_listener_rnn(x.reshape(-1,1,drnn_in))[0]
|
||||
elif seq_model_type==1:
|
||||
self.dialogue_listener_rnn = nn.LSTM(drnn_in,dlist_hidden)
|
||||
self.dialogue_listener = lambda x: \
|
||||
self.dialogue_listener_rnn(x.reshape(-1,1,drnn_in))[0]
|
||||
elif seq_model_type==2:
|
||||
mask_fun = lambda x: torch.triu(torch.ones(x.shape[0],x.shape[0]),diagonal=1).bool().to(device)
|
||||
sincos_fun = lambda x:torch.transpose(torch.stack([
|
||||
torch.sin(2*np.pi*torch.tensor(list(range(x)))/x),
|
||||
torch.cos(2*np.pi*torch.tensor(list(range(x)))/x)
|
||||
]),0,1).reshape(-1,1,2)
|
||||
self.dialogue_listener_lin1 = nn.Linear(drnn_in,dlist_hidden-2)
|
||||
self.dialogue_listener_attn = nn.MultiheadAttention(dlist_hidden, 8)
|
||||
self.dialogue_listener_wrap = lambda x: self.dialogue_listener_attn(x,x,x,attn_mask=mask_fun(x))
|
||||
self.dialogue_listener = lambda x: self.dialogue_listener_wrap(torch.cat([
|
||||
sincos_fun(x.shape[0]).float().to(self.device),
|
||||
self.dialogue_listener_lin1(x).reshape(-1,1,dlist_hidden-2)
|
||||
], axis=-1))[0]
|
||||
else:
|
||||
print('Sequence model type must be in (0: GRU, 1: LSTM, 2: Transformer), but got ', seq_model_type)
|
||||
exit()
|
||||
|
||||
conv_block = lambda i,o,k,p,s: nn.Sequential(
|
||||
nn.Conv2d( i, o, k, padding=p, stride=s),
|
||||
nn.BatchNorm2d(o),
|
||||
nn.Dropout(0.5),
|
||||
nn.MaxPool2d(2),
|
||||
nn.BatchNorm2d(o),
|
||||
nn.Dropout(0.5),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
conv_block( 3, 8, 3, 1, 1),
|
||||
# conv_block( 3, 8, 5, 2, 2),
|
||||
conv_block( 8, 32, 5, 2, 2),
|
||||
conv_block( 32, frame_emb//4, 5, 2, 2),
|
||||
nn.Conv2d( frame_emb//4, frame_emb, 3),nn.ReLU(),
|
||||
)
|
||||
|
||||
plan_layer = lambda i,o : nn.Sequential(
|
||||
# nn.Linear(i,(i+o)//2),
|
||||
nn.Linear(i,(i+2*o)//3),
|
||||
nn.Dropout(0.5),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.5),
|
||||
# nn.Linear((i+o)//2,o),
|
||||
nn.Linear((i+2*o)//3,o),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.5),
|
||||
# nn.Sigmoid()
|
||||
)
|
||||
|
||||
plan_mat_size = 21*21
|
||||
q_in_size = 3*plan_emb_out+dlist_hidden
|
||||
q_in_size = 3*plan_emb_out+dlist_hidden+plan_mat_size
|
||||
q_in_size = dlist_hidden+plan_mat_size
|
||||
|
||||
# self.plan_out = plan_layer(q_in_size,plan_mat_size)
|
||||
self.plan_out = plan_layer(q_in_size,plan_mat_size)
|
||||
# self.q01 = qlayer(q_in_size,2)
|
||||
# self.q02 = qlayer(q_in_size,2)
|
||||
# self.q03 = qlayer(q_in_size,2)
|
||||
|
||||
# self.q11 = qlayer(q_in_size,3)
|
||||
# self.q12 = qlayer(q_in_size,3)
|
||||
# self.q13 = qlayer(q_in_size,22)
|
||||
|
||||
# self.q21 = qlayer(q_in_size,3)
|
||||
# self.q22 = qlayer(q_in_size,3)
|
||||
# self.q23 = qlayer(q_in_size,22)
|
||||
|
||||
def forward(self,game,global_plan=False, player_plan=False,evaluation=False, incremental=False):
|
||||
retval = []
|
||||
|
||||
l = list(game)
|
||||
_,d,l,q,f,_,intermediate,_ = zip(*list(game))
|
||||
# print(np.array(intermediate).shape)
|
||||
# exit()
|
||||
|
||||
h = None
|
||||
intermediate = np.array(intermediate)
|
||||
f = np.array(f, dtype=np.uint8)
|
||||
# f = torch.tensor(f).permute(0,3,1,2).float().to(self.device)
|
||||
# flt_lst = [(a,b) for a,b in zip(d,q) if (not a is None) or (not b is None)]
|
||||
# if not flt_lst:
|
||||
# return []
|
||||
# d,q = zip(*flt_lst)
|
||||
d = np.stack([np.concatenate(([int(x[0][1]==2),int(x[0][1]==1)],x[0][-1])) if not x is None else np.zeros(1026) for x in d])
|
||||
# def parse_q(q):
|
||||
# if not q is None:
|
||||
# q ,l = q
|
||||
# q = np.concatenate([
|
||||
# onehot(q[2],2),
|
||||
# onehot(q[3],2),
|
||||
# onehot(q[4][0][0]+1,2),
|
||||
# onehot(game.materials_dict[q[4][0][1]],len(game.materials_dict)),
|
||||
# onehot(q[4][1][0]+1,2),
|
||||
# onehot(game.materials_dict[q[4][1][1]],len(game.materials_dict)),
|
||||
# onehot(q[4][2]+1,2),
|
||||
# onehot(q[5][0][0]+1,2),
|
||||
# onehot(game.materials_dict[q[5][0][1]],len(game.materials_dict)),
|
||||
# onehot(q[5][1][0]+1,2),
|
||||
# onehot(game.materials_dict[q[5][1][1]],len(game.materials_dict)),
|
||||
# onehot(q[5][2]+1,2)
|
||||
# ])
|
||||
# else:
|
||||
# q = np.zeros(100)
|
||||
# l = None
|
||||
# return q, l
|
||||
try:
|
||||
sel1 = int([x[0][2] for x in q if not x is None][0] == 1)
|
||||
sel2 = 1 - sel1
|
||||
except Exception as e:
|
||||
sel1 = 0
|
||||
sel2 = 0
|
||||
# q = [parse_q(x) for x in q]
|
||||
# q, l = zip(*q)
|
||||
|
||||
|
||||
if not global_plan and not player_plan:
|
||||
plan_emb = torch.cat([
|
||||
self.plan_embedder0(torch.stack(list(map(torch.tensor,game.global_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||
self.plan_embedder1(torch.stack(list(map(torch.tensor,game.player1_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||
self.plan_embedder2(torch.stack(list(map(torch.tensor,game.player2_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0]
|
||||
])
|
||||
plan_emb = 0*plan_emb
|
||||
elif global_plan:
|
||||
plan_emb = torch.cat([
|
||||
self.plan_embedder0(torch.stack(list(map(torch.tensor,game.global_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||
self.plan_embedder1(torch.stack(list(map(torch.tensor,game.player1_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||
self.plan_embedder2(torch.stack(list(map(torch.tensor,game.player2_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0]
|
||||
])
|
||||
else:
|
||||
plan_emb = torch.cat([
|
||||
0*self.plan_embedder0(torch.stack(list(map(torch.tensor,game.global_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||
sel1*self.plan_embedder1(torch.stack(list(map(torch.tensor,game.player1_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||
sel2*self.plan_embedder2(torch.stack(list(map(torch.tensor,game.player2_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0]
|
||||
])
|
||||
|
||||
# if sel1 == 0 and sel2 == 0:
|
||||
# print(torch.unique(plan_emb))
|
||||
|
||||
u = torch.cat((
|
||||
torch.tensor(d).float().to(self.device),
|
||||
# torch.tensor(q).float().to(self.device),
|
||||
self.conv(torch.tensor(f).permute(0,3,1,2).float().to(self.device)).reshape(-1,512),
|
||||
torch.tensor(intermediate).float().to(self.device)
|
||||
),axis=-1)
|
||||
u = u.float().to(self.device)
|
||||
# print(d.shape)
|
||||
# print(self.conv(torch.tensor(f).permute(0,3,1,2).float().to(self.device)).reshape(-1,512).shape)
|
||||
# print(intermediate.shape)
|
||||
# print(u.shape)
|
||||
|
||||
y = self.dialogue_listener(u)
|
||||
y = y.reshape(-1,y.shape[-1])
|
||||
# print(y[-1].shape,plan_emb.shape,torch.tensor(game.plan_repr).float().to(self.device).shape)
|
||||
# return self.plan_out(torch.cat((y[-1],plan_emb))), y
|
||||
# return self.plan_out(torch.cat((y[-1],plan_emb,torch.tensor(game.plan_repr.reshape(-1)).float().to(self.device)))), y
|
||||
if incremental:
|
||||
prediction = torch.stack([
|
||||
self.plan_out(torch.cat((y[0],torch.tensor(game.plan_repr.reshape(-1)).float().to(self.device))))] + [
|
||||
self.plan_out(torch.cat((f,torch.tensor(game.plan_repr.reshape(-1)).float().to(self.device)))) for f in y[len(y)%10-1::10]
|
||||
])
|
||||
prediction = F.softmax(prediction.reshape(-1,21,21),-1).reshape(-1,21*21)
|
||||
else:
|
||||
prediction = self.plan_out(torch.cat((y[-1],torch.tensor(game.plan_repr.reshape(-1)).float().to(self.device))))
|
||||
prediction = F.softmax(prediction.reshape(21,21),-1).reshape(21*21)
|
||||
# prediction = F.softmax(prediction,-1)
|
||||
# prediction = F.softmax(prediction,-1)
|
||||
# exit()
|
||||
return prediction, y
|
||||
|
||||
# exit()
|
||||
# if all([x is None for x in l]):
|
||||
# return []
|
||||
|
||||
# fun_lst = [self.q01,self.q02,self.q03,self.q11,self.q12,self.q13,self.q21,self.q22,self.q23]
|
||||
# fun = lambda x: [f(x) for f in fun_lst]
|
||||
|
||||
|
||||
# retval = [(_l,fun(torch.cat((plan_emb,torch.tensor(_q).float().to(self.device),_y)))) for _y, _q, _l in zip(y,q,l) if not _l is None]
|
||||
# return retval
|
230
src/models/plan_model_graphs.py
Normal file
230
src/models/plan_model_graphs.py
Normal file
|
@ -0,0 +1,230 @@
|
|||
import torch
|
||||
import torch.nn as nn, numpy as np
|
||||
from torch.nn import functional as F
|
||||
from torch_geometric.nn import MeanAggregation, GATv2Conv
|
||||
|
||||
|
||||
class PlanGraphEmbedder(nn.Module):
|
||||
def __init__(self, device, h_dim, dropout=0.0, heads=4):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.proj_x = nn.Sequential(
|
||||
nn.Linear(27, h_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
self.proj_edge_attr = nn.Sequential(
|
||||
nn.Linear(12, h_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
self.conv1 = GATv2Conv(h_dim, h_dim, heads=heads, edge_dim=h_dim)
|
||||
self.conv2 = GATv2Conv(h_dim*heads, h_dim, heads=1, edge_dim=h_dim)
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.dec = nn.Linear(h_dim*3, 1)
|
||||
|
||||
def encode(self, data):
|
||||
x, edge_index, edge_attr = data.features.to(self.device), data.edge_index.to(self.device), data.tool.to(self.device)
|
||||
x = self.proj_x(x)
|
||||
edge_attr = self.proj_edge_attr(edge_attr)
|
||||
x = self.conv1(x, edge_index, edge_attr)
|
||||
x = self.act(x)
|
||||
x = self.dropout(x)
|
||||
x = self.conv2(x, edge_index, edge_attr)
|
||||
return x, edge_attr
|
||||
|
||||
def decode(self, z, context, edge_label_index):
|
||||
u = z[edge_label_index[0]]
|
||||
v = z[edge_label_index[1]]
|
||||
return self.dec(torch.cat((u, v, context), -1))
|
||||
|
||||
# def decode(self, z, edge_index, edge_attr, edge_label_index):
|
||||
# z = self.conv3(z, edge_index.to(self.device), edge_attr)
|
||||
# return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, seq_model_type, device):
|
||||
super(Model, self).__init__()
|
||||
|
||||
self.device = device
|
||||
plan_emb_out = 128
|
||||
self.plan_embedder0 = PlanGraphEmbedder(device, plan_emb_out)
|
||||
self.plan_embedder1 = PlanGraphEmbedder(device, plan_emb_out)
|
||||
self.plan_embedder2 = PlanGraphEmbedder(device, plan_emb_out)
|
||||
self.plan_pool = MeanAggregation()
|
||||
dlist_hidden = 1024
|
||||
frame_emb = 512
|
||||
drnn_in = 5*1024 + 2 + frame_emb + 1024
|
||||
self.dialogue_listener_pre_ln = nn.LayerNorm(drnn_in)
|
||||
if seq_model_type==0:
|
||||
self.dialogue_listener_rnn = nn.GRU(drnn_in, dlist_hidden)
|
||||
self.dialogue_listener = lambda x: \
|
||||
self.dialogue_listener_rnn(x.reshape(-1, 1, drnn_in))[0]
|
||||
elif seq_model_type==1:
|
||||
self.dialogue_listener_rnn = nn.LSTM(drnn_in, dlist_hidden)
|
||||
self.dialogue_listener = lambda x: \
|
||||
self.dialogue_listener_rnn(x.reshape(-1, 1, drnn_in))[0]
|
||||
elif seq_model_type==2:
|
||||
mask_fun = lambda x: torch.triu(torch.ones(x.shape[0], x.shape[0]), diagonal=1).bool().to(device)
|
||||
sincos_fun = lambda x:torch.transpose(torch.stack([
|
||||
torch.sin(2*np.pi*torch.tensor(list(range(x)))/x),
|
||||
torch.cos(2*np.pi*torch.tensor(list(range(x)))/x)
|
||||
]),0,1).reshape(-1,1,2)
|
||||
self.dialogue_listener_lin1 = nn.Linear(drnn_in, dlist_hidden-2)
|
||||
self.dialogue_listener_attn = nn.MultiheadAttention(dlist_hidden, 8)
|
||||
self.dialogue_listener_wrap = lambda x: self.dialogue_listener_attn(x, x, x, attn_mask=mask_fun(x))
|
||||
self.dialogue_listener = lambda x: self.dialogue_listener_wrap(torch.cat([
|
||||
sincos_fun(x.shape[0]).float().to(self.device),
|
||||
self.dialogue_listener_lin1(x).reshape(-1, 1, dlist_hidden-2)
|
||||
], axis=-1))[0]
|
||||
else:
|
||||
print('Sequence model type must be in (0: GRU, 1: LSTM, 2: Transformer), but got ', seq_model_type)
|
||||
exit()
|
||||
conv_block = lambda i, o, k, p, s: nn.Sequential(
|
||||
nn.Conv2d(i, o, k, padding=p, stride=s),
|
||||
nn.BatchNorm2d(o),
|
||||
nn.Dropout(0.5),
|
||||
nn.MaxPool2d(2),
|
||||
nn.BatchNorm2d(o),
|
||||
nn.Dropout(0.5),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.conv = nn.Sequential(
|
||||
conv_block(3, 8, 3, 1, 1),
|
||||
conv_block(8, 32, 5, 2, 2),
|
||||
conv_block(32, frame_emb//4, 5, 2, 2),
|
||||
nn.Conv2d(frame_emb//4, frame_emb, 3),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.proj_y = nn.Sequential(
|
||||
nn.Linear(1024, 512),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.5),
|
||||
nn.Linear(512, plan_emb_out)
|
||||
)
|
||||
|
||||
def forward(self, game, experiment, global_plan=False, player_plan=False, incremental=False, return_feats=False):
|
||||
|
||||
_, d, l, q, f, _, intermediate, _ = zip(*list(game))
|
||||
intermediate = np.array(intermediate)
|
||||
f = np.array(f, dtype=np.uint8)
|
||||
d = np.stack([np.concatenate(([int(x[0][1]==2), int(x[0][1]==1)], x[0][-1])) if not x is None else np.zeros(1026) for x in d])
|
||||
try:
|
||||
sel1 = int([x[0][2] for x in q if not x is None][0] == 1)
|
||||
sel2 = 1 - sel1
|
||||
except Exception as e:
|
||||
sel1 = 0
|
||||
sel2 = 0
|
||||
|
||||
if player_plan:
|
||||
if sel1:
|
||||
z, _ = self.plan_embedder1.encode(game.player1_plan)
|
||||
elif sel2:
|
||||
z, _ = self.plan_embedder2.encode(game.player2_plan)
|
||||
else:
|
||||
z, _ = self.plan_embedder0.encode(game.global_plan)
|
||||
else:
|
||||
raise ValueError('There should never be a global plan!')
|
||||
|
||||
u = torch.cat((
|
||||
torch.tensor(d).float().to(self.device),
|
||||
self.conv(torch.tensor(f).permute(0, 3, 1, 2).float().to(self.device) / 255.0).reshape(-1, 512),
|
||||
torch.tensor(intermediate).float().to(self.device)
|
||||
), axis=-1)
|
||||
u = u.float().to(self.device)
|
||||
u = self.dialogue_listener_pre_ln(u)
|
||||
y = self.dialogue_listener(u)
|
||||
y = y.reshape(-1, y.shape[-1])
|
||||
if return_feats:
|
||||
_y = y.clone().detach().cpu().numpy()
|
||||
y = self.proj_y(y)
|
||||
|
||||
if experiment == 2:
|
||||
pred, label = self.decode_own_missing_knowledge(z, y, game, sel1, sel2, incremental)
|
||||
elif experiment == 3:
|
||||
pred, label = self.decode_partner_missing_knowledge(z, y, game, sel1, sel2, incremental)
|
||||
else:
|
||||
raise ValueError('Wrong experiment id! Valid values are 2 and 3.')
|
||||
|
||||
if return_feats:
|
||||
return pred, label, [sel1, sel2], _y
|
||||
|
||||
return pred, label, [sel1, sel2]
|
||||
|
||||
def decode_own_missing_knowledge(self, z, y, game, sel1, sel2, incremental):
|
||||
if incremental:
|
||||
if sel1:
|
||||
pred = torch.stack(
|
||||
[self.plan_embedder1.decode(z, y[0].repeat(game.player1_edge_label_index.shape[1], 1), game.player1_edge_label_index).view(-1)]
|
||||
+
|
||||
[self.plan_embedder1.decode(z, f.repeat(game.player1_edge_label_index.shape[1], 1), game.player1_edge_label_index).view(-1) for f in y[len(y)%10-1::10]]
|
||||
)
|
||||
label = game.player1_edge_label_own_missing_knowledge.to(self.device)
|
||||
elif sel2:
|
||||
pred = torch.stack(
|
||||
[self.plan_embedder2.decode(z, y[0].repeat(game.player2_edge_label_index.shape[1], 1), game.player2_edge_label_index).view(-1)]
|
||||
+
|
||||
[self.plan_embedder2.decode(z, f.repeat(game.player2_edge_label_index.shape[1], 1), game.player2_edge_label_index).view(-1) for f in y[len(y)%10-1::10]]
|
||||
)
|
||||
label = game.player2_edge_label_own_missing_knowledge.to(self.device)
|
||||
else:
|
||||
pred = torch.stack(
|
||||
# NOTE: here I use game.global_plan.edge_index instead of edge_label_index => no negative sampling
|
||||
[self.plan_embedder0.decode(z, y[0].repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1)]
|
||||
+
|
||||
[self.plan_embedder0.decode(z, f.repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1) for f in y[len(y)%10-1::10]]
|
||||
)
|
||||
label = torch.zeros(game.global_plan.edge_index.shape[1])
|
||||
else:
|
||||
if sel1:
|
||||
pred = self.plan_embedder1.decode(z, y.mean(0, keepdim=True).repeat(game.player1_edge_label_index.shape[1], 1), game.player1_edge_label_index).view(-1)
|
||||
label = game.player1_edge_label_own_missing_knowledge.to(self.device)
|
||||
elif sel2:
|
||||
pred = self.plan_embedder2.decode(z, y.mean(0, keepdim=True).repeat(game.player2_edge_label_index.shape[1], 1), game.player2_edge_label_index).view(-1)
|
||||
label = game.player2_edge_label_own_missing_knowledge.to(self.device)
|
||||
else:
|
||||
# NOTE: here I use game.global_plan.edge_index instead of edge_label_index => no negative sampling
|
||||
pred = self.plan_embedder0.decode(z, y.mean(0, keepdim=True).repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1)
|
||||
label = torch.zeros(game.global_plan.edge_index.shape[1])
|
||||
|
||||
return (pred, label)
|
||||
|
||||
def decode_partner_missing_knowledge(self, z, y, game, sel1, sel2, incremental):
|
||||
if incremental:
|
||||
if sel1:
|
||||
pred = torch.stack(
|
||||
[self.plan_embedder1.decode(z, y[0].repeat(game.player1_plan.edge_index.shape[1], 1), game.player1_plan.edge_index).view(-1)]
|
||||
+
|
||||
[self.plan_embedder1.decode(z, f.repeat(game.player1_plan.edge_index.shape[1], 1), game.player1_plan.edge_index).view(-1) for f in y[len(y)%10-1::10]]
|
||||
)
|
||||
label = game.player1_edge_label_other_missing_knowledge.to(self.device)
|
||||
elif sel2:
|
||||
pred = torch.stack(
|
||||
[self.plan_embedder2.decode(z, y[0].repeat(game.player2_plan.edge_index.shape[1], 1), game.player2_plan.edge_index).view(-1)]
|
||||
+
|
||||
[self.plan_embedder2.decode(z, f.repeat(game.player2_plan.edge_index.shape[1], 1), game.player2_plan.edge_index).view(-1) for f in y[len(y)%10-1::10]]
|
||||
)
|
||||
label = game.player2_edge_label_other_missing_knowledge.to(self.device)
|
||||
else:
|
||||
pred = torch.stack(
|
||||
# NOTE: here I use game.global_plan.edge_index instead of edge_label_index => no negative sampling
|
||||
[self.plan_embedder0.decode(z, y[0].repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1)]
|
||||
+
|
||||
[self.plan_embedder0.decode(z, f.repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1) for f in y[len(y)%10-1::10]]
|
||||
)
|
||||
label = torch.zeros(game.global_plan.edge_index.shape[1])
|
||||
else:
|
||||
if sel1:
|
||||
pred = self.plan_embedder1.decode(z, y.mean(0, keepdim=True).repeat(game.player1_plan.edge_index.shape[1], 1), game.player1_plan.edge_index).view(-1)
|
||||
label = game.player1_edge_label_other_missing_knowledge.to(self.device)
|
||||
elif sel2:
|
||||
pred = self.plan_embedder2.decode(z, y.mean(0, keepdim=True).repeat(game.player2_plan.edge_index.shape[1], 1), game.player2_plan.edge_index).view(-1)
|
||||
label = game.player2_edge_label_other_missing_knowledge.to(self.device)
|
||||
else:
|
||||
# NOTE: here I use game.global_plan.edge_index instead of edge_label_index => no negative sampling
|
||||
pred = self.plan_embedder0.decode(z, y.mean(0, keepdim=True).repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1)
|
||||
label = torch.zeros(game.global_plan.edge_index.shape[1])
|
||||
|
||||
return (pred, label)
|
||||
|
291
src/models/plan_model_graphs_oracle.py
Normal file
291
src/models/plan_model_graphs_oracle.py
Normal file
|
@ -0,0 +1,291 @@
|
|||
import torch
|
||||
import torch.nn as nn, numpy as np
|
||||
from torch.nn import functional as F
|
||||
from torch_geometric.nn import MeanAggregation, GATv2Conv
|
||||
|
||||
|
||||
def onehot(x,n):
|
||||
retval = np.zeros(n)
|
||||
if x > 0:
|
||||
retval[x-1] = 1
|
||||
return retval
|
||||
|
||||
class PlanGraphEmbedder(nn.Module):
|
||||
def __init__(self, device, h_dim, dropout=0.0, heads=4):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.proj_x = nn.Sequential(
|
||||
nn.Linear(27, h_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
self.proj_edge_attr = nn.Sequential(
|
||||
nn.Linear(12, h_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
self.conv1 = GATv2Conv(h_dim, h_dim, heads=heads, edge_dim=h_dim)
|
||||
self.conv2 = GATv2Conv(h_dim*heads, h_dim, heads=1, edge_dim=h_dim)
|
||||
self.act = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.dec = nn.Linear(h_dim*3, 1)
|
||||
|
||||
def encode(self, data):
|
||||
x, edge_index, edge_attr = data.features.to(self.device), data.edge_index.to(self.device), data.tool.to(self.device)
|
||||
x = self.proj_x(x)
|
||||
edge_attr = self.proj_edge_attr(edge_attr)
|
||||
x = self.conv1(x, edge_index, edge_attr)
|
||||
x = self.act(x)
|
||||
x = self.dropout(x)
|
||||
x = self.conv2(x, edge_index, edge_attr)
|
||||
return x, edge_attr
|
||||
|
||||
def decode(self, z, context, edge_label_index):
|
||||
u = z[edge_label_index[0]]
|
||||
v = z[edge_label_index[1]]
|
||||
return self.dec(torch.cat((u, v, context), -1))
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, seq_model_type, device):
|
||||
super(Model, self).__init__()
|
||||
|
||||
self.device = device
|
||||
plan_emb_out = 128
|
||||
self.plan_embedder0 = PlanGraphEmbedder(device, plan_emb_out)
|
||||
self.plan_embedder1 = PlanGraphEmbedder(device, plan_emb_out)
|
||||
self.plan_embedder2 = PlanGraphEmbedder(device, plan_emb_out)
|
||||
self.plan_pool = MeanAggregation()
|
||||
dlist_hidden = 1024
|
||||
frame_emb = 512
|
||||
drnn_in = 5*1024 + 2 + frame_emb + 1024
|
||||
self.dialogue_listener_pre_ln = nn.LayerNorm(drnn_in)
|
||||
if seq_model_type==0:
|
||||
self.dialogue_listener_rnn = nn.GRU(drnn_in, dlist_hidden)
|
||||
self.dialogue_listener = lambda x: \
|
||||
self.dialogue_listener_rnn(x.reshape(-1, 1, drnn_in))[0]
|
||||
elif seq_model_type==1:
|
||||
self.dialogue_listener_rnn = nn.LSTM(drnn_in, dlist_hidden)
|
||||
self.dialogue_listener = lambda x: \
|
||||
self.dialogue_listener_rnn(x.reshape(-1, 1, drnn_in))[0]
|
||||
elif seq_model_type==2:
|
||||
mask_fun = lambda x: torch.triu(torch.ones(x.shape[0], x.shape[0]), diagonal=1).bool().to(device)
|
||||
sincos_fun = lambda x:torch.transpose(torch.stack([
|
||||
torch.sin(2*np.pi*torch.tensor(list(range(x)))/x),
|
||||
torch.cos(2*np.pi*torch.tensor(list(range(x)))/x)
|
||||
]),0,1).reshape(-1,1,2)
|
||||
self.dialogue_listener_lin1 = nn.Linear(drnn_in, dlist_hidden-2)
|
||||
self.dialogue_listener_attn = nn.MultiheadAttention(dlist_hidden, 8)
|
||||
self.dialogue_listener_wrap = lambda x: self.dialogue_listener_attn(x, x, x, attn_mask=mask_fun(x))
|
||||
self.dialogue_listener = lambda x: self.dialogue_listener_wrap(torch.cat([
|
||||
sincos_fun(x.shape[0]).float().to(self.device),
|
||||
self.dialogue_listener_lin1(x).reshape(-1, 1, dlist_hidden-2)
|
||||
], axis=-1))[0]
|
||||
else:
|
||||
print('Sequence model type must be in (0: GRU, 1: LSTM, 2: Transformer), but got ', seq_model_type)
|
||||
exit()
|
||||
conv_block = lambda i, o, k, p, s: nn.Sequential(
|
||||
nn.Conv2d(i, o, k, padding=p, stride=s),
|
||||
nn.BatchNorm2d(o),
|
||||
nn.Dropout(0.5),
|
||||
nn.MaxPool2d(2),
|
||||
nn.BatchNorm2d(o),
|
||||
nn.Dropout(0.5),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.conv = nn.Sequential(
|
||||
conv_block(3, 8, 3, 1, 1),
|
||||
conv_block(8, 32, 5, 2, 2),
|
||||
conv_block(32, frame_emb//4, 5, 2, 2),
|
||||
nn.Conv2d(frame_emb//4, frame_emb, 3),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.proj_y = nn.Sequential(
|
||||
nn.Linear(1024, 512),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.5),
|
||||
nn.Linear(512, plan_emb_out)
|
||||
)
|
||||
self.proj_tom = nn.Linear(154, 5*1024)
|
||||
|
||||
def parse_ql(self, q, game, intermediate):
|
||||
tom12_answ = ['YES', 'NO', 'MAYBE']
|
||||
materials_dict = game.materials_dict.copy()
|
||||
materials_dict['NOT_SURE'] = 0
|
||||
if not q is None:
|
||||
q, l = q
|
||||
tom_gt = np.concatenate([onehot(q[2],2), onehot(q[3],2)])
|
||||
#### q1
|
||||
q1_1 = np.concatenate([onehot(q[4][0][0]+1,2), onehot(materials_dict[q[4][0][1]], len(game.materials_dict))])
|
||||
## r1
|
||||
r1_1 = np.eye(len(tom12_answ))[tom12_answ.index(l[0][0])]
|
||||
#### q2
|
||||
q2_1 = np.concatenate([onehot(q[4][1][0]+1,2), onehot(materials_dict[q[4][1][1]], len(game.materials_dict))])
|
||||
## r2
|
||||
r2_1 = np.eye(len(tom12_answ))[tom12_answ.index(l[0][1])]
|
||||
#### q3
|
||||
q3_1 = onehot(q[4][2]+1, 2)
|
||||
## r3
|
||||
r3_1 = onehot(materials_dict[l[0][2]], len(game.materials_dict))
|
||||
#### q1
|
||||
q1_2 = np.concatenate([onehot(q[5][0][0]+1,2), onehot(materials_dict[q[5][0][1]], len(game.materials_dict))])
|
||||
## r1
|
||||
r1_2 = np.eye(len(tom12_answ))[tom12_answ.index(l[1][0])]
|
||||
#### q2
|
||||
q2_2 = np.concatenate([onehot(q[5][1][0]+1,2), onehot(materials_dict[q[5][1][1]], len(game.materials_dict))])
|
||||
## r2
|
||||
r2_2 = np.eye(len(tom12_answ))[tom12_answ.index(l[1][1])]
|
||||
#### q3
|
||||
q3_2 = onehot(q[5][2]+1,2)
|
||||
## r3
|
||||
r3_2 = onehot(materials_dict[l[1][2]], len(game.materials_dict))
|
||||
if intermediate == 0:
|
||||
tom_gt = np.zeros(154)
|
||||
elif intermediate == 1:
|
||||
# tom6
|
||||
tom_gt = np.concatenate([tom_gt, q1_1, r1_1, np.zeros(q2_1.shape[0] + r2_1.shape[0] + q3_1.shape[0] + r3_1.shape[0]), q1_2, r1_2, np.zeros(q2_2.shape[0] + r2_2.shape[0] + q3_2.shape[0] + r3_2.shape[0])])
|
||||
elif intermediate == 2:
|
||||
# tom7
|
||||
tom_gt = np.concatenate([tom_gt, np.zeros(q1_1.shape[0] + r1_1.shape[0]), q2_1, r2_1, np.zeros(q3_1.shape[0] + r3_1.shape[0] + q1_2.shape[0] + r1_2.shape[0]), q2_2, r2_2, np.zeros(q3_2.shape[0] + r3_2.shape[0])])
|
||||
elif intermediate == 3:
|
||||
# tom6 + tom7
|
||||
tom_gt = np.concatenate([tom_gt, q1_1, r1_1, q2_1, r2_1, np.zeros(q3_1.shape[0] + r3_1.shape[0]), q1_2, r1_2, q2_2, r2_2, np.zeros(q3_2.shape[0] + r3_2.shape[0])])
|
||||
elif intermediate == 4:
|
||||
# tom8
|
||||
tom_gt = np.concatenate([tom_gt, np.zeros(q1_1.shape[0] + r1_1.shape[0] + q2_1.shape[0] + r2_1.shape[0]), q3_1, r3_1, np.zeros(q1_2.shape[0] + r1_2.shape[0] + q2_2.shape[0] + r2_2.shape[0]), q3_2, r3_2])
|
||||
elif intermediate == 5:
|
||||
# tom6 + tom8
|
||||
tom_gt = np.concatenate([tom_gt, q1_1, r1_1, np.zeros(q2_1.shape[0] + r2_1.shape[0]), q3_1, r3_1, q1_2, r1_2, np.zeros(q2_2.shape[0] + r2_2.shape[0]), q3_2, r3_2])
|
||||
elif intermediate == 6:
|
||||
# tom7 + tom8
|
||||
tom_gt = np.concatenate([tom_gt, np.zeros(q1_1.shape[0] + r1_1.shape[0]), q2_1, r2_1, q3_1, r3_1, np.zeros(q1_2.shape[0] + r1_2.shape[0]), q2_2, r2_2, q3_2, r3_2])
|
||||
elif intermediate == 7:
|
||||
# tom6 + tom7 + tom8
|
||||
tom_gt = np.concatenate([tom_gt, q1_1, r1_1, q2_1, r2_1, q3_1, r3_1, q1_2, r1_2, q2_2, r2_2, q3_2, r3_2])
|
||||
else:
|
||||
tom_gt = np.zeros(154)
|
||||
if tom_gt.shape[0] != 154: breakpoint()
|
||||
return tom_gt
|
||||
|
||||
def forward(self, game, experiment, global_plan=False, player_plan=False, incremental=False, intermediate=0):
|
||||
|
||||
l = list(game)
|
||||
_, d, l, q, f, _, _, _ = zip(*list(game))
|
||||
|
||||
tom_gt = [self.parse_ql(x, game, intermediate) for x in q]
|
||||
tom_gt = torch.tensor(np.stack(tom_gt), device=self.device, dtype=torch.float32)
|
||||
tom_gt = self.proj_tom(tom_gt)
|
||||
|
||||
f = np.array(f, dtype=np.uint8)
|
||||
d = np.stack([np.concatenate(([int(x[0][1]==2), int(x[0][1]==1)], x[0][-1])) if not x is None else np.zeros(1026) for x in d])
|
||||
try:
|
||||
sel1 = int([x[0][2] for x in q if not x is None][0] == 1)
|
||||
sel2 = 1 - sel1
|
||||
except Exception as e:
|
||||
sel1 = 0
|
||||
sel2 = 0
|
||||
|
||||
if player_plan:
|
||||
if sel1:
|
||||
z, _ = self.plan_embedder1.encode(game.player1_plan)
|
||||
elif sel2:
|
||||
z, _ = self.plan_embedder2.encode(game.player2_plan)
|
||||
else:
|
||||
z, _ = self.plan_embedder0.encode(game.global_plan)
|
||||
else:
|
||||
raise ValueError('There should never be a global plan!')
|
||||
|
||||
u = torch.cat((
|
||||
torch.tensor(d).float().to(self.device),
|
||||
self.conv(torch.tensor(f).permute(0, 3, 1, 2).float().to(self.device) / 255.0).reshape(-1, 512),
|
||||
tom_gt
|
||||
), axis=-1)
|
||||
u = u.float().to(self.device)
|
||||
u = self.dialogue_listener_pre_ln(u)
|
||||
y = self.dialogue_listener(u)
|
||||
y = y.reshape(-1, y.shape[-1])
|
||||
y = self.proj_y(y)
|
||||
|
||||
if experiment == 2:
|
||||
pred, label = self.decode_own_missing_knowledge(z, y, game, sel1, sel2, incremental)
|
||||
elif experiment == 3:
|
||||
pred, label = self.decode_partner_missing_knowledge(z, y, game, sel1, sel2, incremental)
|
||||
else:
|
||||
raise ValueError('Wrong experiment id! Valid values are 2 and 3.')
|
||||
|
||||
return pred, label, [sel1, sel2]
|
||||
|
||||
def decode_own_missing_knowledge(self, z, y, game, sel1, sel2, incremental):
|
||||
if incremental:
|
||||
if sel1:
|
||||
pred = torch.stack(
|
||||
[self.plan_embedder1.decode(z, y[0].repeat(game.player1_edge_label_index.shape[1], 1), game.player1_edge_label_index).view(-1)]
|
||||
+
|
||||
[self.plan_embedder1.decode(z, f.repeat(game.player1_edge_label_index.shape[1], 1), game.player1_edge_label_index).view(-1) for f in y[len(y)%10-1::10]]
|
||||
)
|
||||
label = game.player1_edge_label_own_missing_knowledge.to(self.device)
|
||||
elif sel2:
|
||||
pred = torch.stack(
|
||||
[self.plan_embedder2.decode(z, y[0].repeat(game.player2_edge_label_index.shape[1], 1), game.player2_edge_label_index).view(-1)]
|
||||
+
|
||||
[self.plan_embedder2.decode(z, f.repeat(game.player2_edge_label_index.shape[1], 1), game.player2_edge_label_index).view(-1) for f in y[len(y)%10-1::10]]
|
||||
)
|
||||
label = game.player2_edge_label_own_missing_knowledge.to(self.device)
|
||||
else:
|
||||
pred = torch.stack(
|
||||
# NOTE: here I use game.global_plan.edge_index instead of edge_label_index => no negative sampling
|
||||
[self.plan_embedder0.decode(z, y[0].repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1)]
|
||||
+
|
||||
[self.plan_embedder0.decode(z, f.repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1) for f in y[len(y)%10-1::10]]
|
||||
)
|
||||
label = torch.zeros(game.global_plan.edge_index.shape[1])
|
||||
else:
|
||||
if sel1:
|
||||
pred = self.plan_embedder1.decode(z, y.mean(0, keepdim=True).repeat(game.player1_edge_label_index.shape[1], 1), game.player1_edge_label_index).view(-1)
|
||||
label = game.player1_edge_label_own_missing_knowledge.to(self.device)
|
||||
elif sel2:
|
||||
pred = self.plan_embedder2.decode(z, y.mean(0, keepdim=True).repeat(game.player2_edge_label_index.shape[1], 1), game.player2_edge_label_index).view(-1)
|
||||
label = game.player2_edge_label_own_missing_knowledge.to(self.device)
|
||||
else:
|
||||
# NOTE: here I use game.global_plan.edge_index instead of edge_label_index => no negative sampling
|
||||
pred = self.plan_embedder0.decode(z, y.mean(0, keepdim=True).repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1)
|
||||
label = torch.zeros(game.global_plan.edge_index.shape[1])
|
||||
|
||||
return (pred, label)
|
||||
|
||||
def decode_partner_missing_knowledge(self, z, y, game, sel1, sel2, incremental):
|
||||
if incremental:
|
||||
if sel1:
|
||||
pred = torch.stack(
|
||||
[self.plan_embedder1.decode(z, y[0].repeat(game.player1_plan.edge_index.shape[1], 1), game.player1_plan.edge_index).view(-1)]
|
||||
+
|
||||
[self.plan_embedder1.decode(z, f.repeat(game.player1_plan.edge_index.shape[1], 1), game.player1_plan.edge_index).view(-1) for f in y[len(y)%10-1::10]]
|
||||
)
|
||||
label = game.player1_edge_label_other_missing_knowledge.to(self.device)
|
||||
elif sel2:
|
||||
pred = torch.stack(
|
||||
[self.plan_embedder2.decode(z, y[0].repeat(game.player2_plan.edge_index.shape[1], 1), game.player2_plan.edge_index).view(-1)]
|
||||
+
|
||||
[self.plan_embedder2.decode(z, f.repeat(game.player2_plan.edge_index.shape[1], 1), game.player2_plan.edge_index).view(-1) for f in y[len(y)%10-1::10]]
|
||||
)
|
||||
label = game.player2_edge_label_other_missing_knowledge.to(self.device)
|
||||
else:
|
||||
pred = torch.stack(
|
||||
# NOTE: here I use game.global_plan.edge_index instead of edge_label_index => no negative sampling
|
||||
[self.plan_embedder0.decode(z, y[0].repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1)]
|
||||
+
|
||||
[self.plan_embedder0.decode(z, f.repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1) for f in y[len(y)%10-1::10]]
|
||||
)
|
||||
label = torch.zeros(game.global_plan.edge_index.shape[1])
|
||||
else:
|
||||
if sel1:
|
||||
pred = self.plan_embedder1.decode(z, y.mean(0, keepdim=True).repeat(game.player1_plan.edge_index.shape[1], 1), game.player1_plan.edge_index).view(-1)
|
||||
label = game.player1_edge_label_other_missing_knowledge.to(self.device)
|
||||
elif sel2:
|
||||
pred = self.plan_embedder2.decode(z, y.mean(0, keepdim=True).repeat(game.player2_plan.edge_index.shape[1], 1), game.player2_plan.edge_index).view(-1)
|
||||
label = game.player2_edge_label_other_missing_knowledge.to(self.device)
|
||||
else:
|
||||
# NOTE: here I use game.global_plan.edge_index instead of edge_label_index => no negative sampling
|
||||
pred = self.plan_embedder0.decode(z, y.mean(0, keepdim=True).repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1)
|
||||
label = torch.zeros(game.global_plan.edge_index.shape[1])
|
||||
|
||||
return (pred, label)
|
214
src/models/plan_model_oracle.py
Normal file
214
src/models/plan_model_oracle.py
Normal file
|
@ -0,0 +1,214 @@
|
|||
import sys, torch, random
|
||||
from numpy.core.fromnumeric import reshape
|
||||
import torch.nn as nn, numpy as np
|
||||
from torch.nn import functional as F
|
||||
from src.data.game_parser import DEVICE
|
||||
|
||||
def onehot(x,n):
|
||||
retval = np.zeros(n)
|
||||
if x > 0:
|
||||
retval[x-1] = 1
|
||||
return retval
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, seq_model_type=0,device=DEVICE):
|
||||
super(Model, self).__init__()
|
||||
|
||||
self.device = device
|
||||
|
||||
my_rnn = lambda i,o: nn.GRU(i,o)
|
||||
|
||||
plan_emb_in = 81
|
||||
plan_emb_out = 32
|
||||
q_emb = 100
|
||||
|
||||
self.plan_embedder0 = my_rnn(plan_emb_in,plan_emb_out)
|
||||
self.plan_embedder1 = my_rnn(plan_emb_in,plan_emb_out)
|
||||
self.plan_embedder2 = my_rnn(plan_emb_in,plan_emb_out)
|
||||
|
||||
dlist_hidden = 1024
|
||||
frame_emb = 512
|
||||
drnn_in = 5*1024 + 2 + frame_emb + 1024
|
||||
|
||||
my_rnn = lambda i,o: nn.LSTM(i,o)
|
||||
|
||||
if seq_model_type==0:
|
||||
self.dialogue_listener_rnn = nn.GRU(drnn_in,dlist_hidden)
|
||||
self.dialogue_listener = lambda x: \
|
||||
self.dialogue_listener_rnn(x.reshape(-1,1,drnn_in))[0]
|
||||
elif seq_model_type==1:
|
||||
self.dialogue_listener_rnn = nn.LSTM(drnn_in,dlist_hidden)
|
||||
self.dialogue_listener = lambda x: \
|
||||
self.dialogue_listener_rnn(x.reshape(-1,1,drnn_in))[0]
|
||||
elif seq_model_type==2:
|
||||
mask_fun = lambda x: torch.triu(torch.ones(x.shape[0],x.shape[0]),diagonal=1).bool().to(device)
|
||||
sincos_fun = lambda x:torch.transpose(torch.stack([
|
||||
torch.sin(2*np.pi*torch.tensor(list(range(x)))/x),
|
||||
torch.cos(2*np.pi*torch.tensor(list(range(x)))/x)
|
||||
]),0,1).reshape(-1,1,2)
|
||||
self.dialogue_listener_lin1 = nn.Linear(drnn_in,dlist_hidden-2)
|
||||
self.dialogue_listener_attn = nn.MultiheadAttention(dlist_hidden, 8)
|
||||
self.dialogue_listener_wrap = lambda x: self.dialogue_listener_attn(x,x,x,attn_mask=mask_fun(x))
|
||||
self.dialogue_listener = lambda x: self.dialogue_listener_wrap(torch.cat([
|
||||
sincos_fun(x.shape[0]).float().to(self.device),
|
||||
self.dialogue_listener_lin1(x).reshape(-1,1,dlist_hidden-2)
|
||||
], axis=-1))[0]
|
||||
else:
|
||||
print('Sequence model type must be in (0: GRU, 1: LSTM, 2: Transformer), but got ', seq_model_type)
|
||||
exit()
|
||||
|
||||
conv_block = lambda i,o,k,p,s: nn.Sequential(
|
||||
nn.Conv2d( i, o, k, padding=p, stride=s),
|
||||
nn.BatchNorm2d(o),
|
||||
nn.Dropout(0.5),
|
||||
nn.MaxPool2d(2),
|
||||
nn.BatchNorm2d(o),
|
||||
nn.Dropout(0.5),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
conv_block( 3, 8, 3, 1, 1),
|
||||
conv_block( 8, 32, 5, 2, 2),
|
||||
conv_block( 32, frame_emb//4, 5, 2, 2),
|
||||
nn.Conv2d( frame_emb//4, frame_emb, 3),nn.ReLU(),
|
||||
)
|
||||
|
||||
plan_layer = lambda i,o : nn.Sequential(
|
||||
nn.Linear(i,(i+2*o)//3),
|
||||
nn.Dropout(0.5),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.5),
|
||||
nn.Linear((i+2*o)//3,o),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.5),
|
||||
)
|
||||
|
||||
plan_mat_size = 21*21
|
||||
q_in_size = 3*plan_emb_out+dlist_hidden
|
||||
q_in_size = 3*plan_emb_out+dlist_hidden+plan_mat_size
|
||||
q_in_size = dlist_hidden+plan_mat_size
|
||||
|
||||
self.plan_out = plan_layer(q_in_size,plan_mat_size)
|
||||
|
||||
self.proj_tom = nn.Linear(154, 5*1024)
|
||||
|
||||
def forward(self,game,global_plan=False, player_plan=False,evaluation=False, incremental=False, intermediate=0):
|
||||
|
||||
_,d,l,q,f,_,_,_ = zip(*list(game))
|
||||
|
||||
tom_gt = [self.parse_ql(x, game, intermediate) for x in q]
|
||||
tom_gt = torch.tensor(np.stack(tom_gt), device=self.device, dtype=torch.float32)
|
||||
tom_gt = self.proj_tom(tom_gt)
|
||||
|
||||
f = np.array(f, dtype=np.uint8)
|
||||
|
||||
d = np.stack([np.concatenate(([int(x[0][1]==2),int(x[0][1]==1)],x[0][-1])) if not x is None else np.zeros(1026) for x in d])
|
||||
|
||||
try:
|
||||
sel1 = int([x[0][2] for x in q if not x is None][0] == 1)
|
||||
sel2 = 1 - sel1
|
||||
except Exception as e:
|
||||
sel1 = 0
|
||||
sel2 = 0
|
||||
|
||||
if not global_plan and not player_plan:
|
||||
plan_emb = torch.cat([
|
||||
self.plan_embedder0(torch.stack(list(map(torch.tensor,game.global_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||
self.plan_embedder1(torch.stack(list(map(torch.tensor,game.player1_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||
self.plan_embedder2(torch.stack(list(map(torch.tensor,game.player2_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0]
|
||||
])
|
||||
plan_emb = 0*plan_emb
|
||||
elif global_plan:
|
||||
plan_emb = torch.cat([
|
||||
self.plan_embedder0(torch.stack(list(map(torch.tensor,game.global_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||
self.plan_embedder1(torch.stack(list(map(torch.tensor,game.player1_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||
self.plan_embedder2(torch.stack(list(map(torch.tensor,game.player2_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0]
|
||||
])
|
||||
else:
|
||||
plan_emb = torch.cat([
|
||||
0*self.plan_embedder0(torch.stack(list(map(torch.tensor,game.global_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||
sel1*self.plan_embedder1(torch.stack(list(map(torch.tensor,game.player1_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||
sel2*self.plan_embedder2(torch.stack(list(map(torch.tensor,game.player2_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0]
|
||||
])
|
||||
|
||||
u = torch.cat((
|
||||
torch.tensor(d).float().to(self.device),
|
||||
self.conv(torch.tensor(f).permute(0,3,1,2).float().to(self.device)).reshape(-1,512),
|
||||
tom_gt
|
||||
),axis=-1)
|
||||
u = u.float().to(self.device)
|
||||
|
||||
y = self.dialogue_listener(u)
|
||||
y = y.reshape(-1,y.shape[-1])
|
||||
|
||||
if incremental:
|
||||
prediction = torch.stack([
|
||||
self.plan_out(torch.cat((y[0],torch.tensor(game.plan_repr.reshape(-1)).float().to(self.device))))] + [
|
||||
self.plan_out(torch.cat((f,torch.tensor(game.plan_repr.reshape(-1)).float().to(self.device)))) for f in y[len(y)%10-1::10]
|
||||
])
|
||||
prediction = F.softmax(prediction.reshape(-1,21,21),-1).reshape(-1,21*21)
|
||||
else:
|
||||
prediction = self.plan_out(torch.cat((y[-1],torch.tensor(game.plan_repr.reshape(-1)).float().to(self.device))))
|
||||
prediction = F.softmax(prediction.reshape(21,21),-1).reshape(21*21)
|
||||
|
||||
return prediction, y
|
||||
|
||||
def parse_ql(self, q, game, intermediate):
|
||||
tom12_answ = ['YES', 'NO', 'MAYBE']
|
||||
materials_dict = game.materials_dict.copy()
|
||||
materials_dict['NOT_SURE'] = 0
|
||||
if not q is None:
|
||||
q, l = q
|
||||
tom_gt = np.concatenate([onehot(q[2],2), onehot(q[3],2)])
|
||||
#### q1
|
||||
q1_1 = np.concatenate([onehot(q[4][0][0]+1,2), onehot(materials_dict[q[4][0][1]], len(game.materials_dict))])
|
||||
## r1
|
||||
r1_1 = np.eye(len(tom12_answ))[tom12_answ.index(l[0][0])]
|
||||
#### q2
|
||||
q2_1 = np.concatenate([onehot(q[4][1][0]+1,2), onehot(materials_dict[q[4][1][1]], len(game.materials_dict))])
|
||||
## r2
|
||||
r2_1 = np.eye(len(tom12_answ))[tom12_answ.index(l[0][1])]
|
||||
#### q3
|
||||
q3_1 = onehot(q[4][2]+1, 2)
|
||||
## r3
|
||||
r3_1 = onehot(materials_dict[l[0][2]], len(game.materials_dict))
|
||||
#### q1
|
||||
q1_2 = np.concatenate([onehot(q[5][0][0]+1,2), onehot(materials_dict[q[5][0][1]], len(game.materials_dict))])
|
||||
## r1
|
||||
r1_2 = np.eye(len(tom12_answ))[tom12_answ.index(l[1][0])]
|
||||
#### q2
|
||||
q2_2 = np.concatenate([onehot(q[5][1][0]+1,2), onehot(materials_dict[q[5][1][1]], len(game.materials_dict))])
|
||||
## r2
|
||||
r2_2 = np.eye(len(tom12_answ))[tom12_answ.index(l[1][1])]
|
||||
#### q3
|
||||
q3_2 = onehot(q[5][2]+1,2)
|
||||
## r3
|
||||
r3_2 = onehot(materials_dict[l[1][2]], len(game.materials_dict))
|
||||
if intermediate == 0:
|
||||
tom_gt = np.zeros(154)
|
||||
elif intermediate == 1:
|
||||
# tom6
|
||||
tom_gt = np.concatenate([tom_gt, q1_1, r1_1, np.zeros(q2_1.shape[0] + r2_1.shape[0] + q3_1.shape[0] + r3_1.shape[0]), q1_2, r1_2, np.zeros(q2_2.shape[0] + r2_2.shape[0] + q3_2.shape[0] + r3_2.shape[0])])
|
||||
elif intermediate == 2:
|
||||
# tom7
|
||||
tom_gt = np.concatenate([tom_gt, np.zeros(q1_1.shape[0] + r1_1.shape[0]), q2_1, r2_1, np.zeros(q3_1.shape[0] + r3_1.shape[0] + q1_2.shape[0] + r1_2.shape[0]), q2_2, r2_2, np.zeros(q3_2.shape[0] + r3_2.shape[0])])
|
||||
elif intermediate == 3:
|
||||
# tom6 + tom7
|
||||
tom_gt = np.concatenate([tom_gt, q1_1, r1_1, q2_1, r2_1, np.zeros(q3_1.shape[0] + r3_1.shape[0]), q1_2, r1_2, q2_2, r2_2, np.zeros(q3_2.shape[0] + r3_2.shape[0])])
|
||||
elif intermediate == 4:
|
||||
# tom8
|
||||
tom_gt = np.concatenate([tom_gt, np.zeros(q1_1.shape[0] + r1_1.shape[0] + q2_1.shape[0] + r2_1.shape[0]), q3_1, r3_1, np.zeros(q1_2.shape[0] + r1_2.shape[0] + q2_2.shape[0] + r2_2.shape[0]), q3_2, r3_2])
|
||||
elif intermediate == 5:
|
||||
# tom6 + tom8
|
||||
tom_gt = np.concatenate([tom_gt, q1_1, r1_1, np.zeros(q2_1.shape[0] + r2_1.shape[0]), q3_1, r3_1, q1_2, r1_2, np.zeros(q2_2.shape[0] + r2_2.shape[0]), q3_2, r3_2])
|
||||
elif intermediate == 6:
|
||||
# tom7 + tom8
|
||||
tom_gt = np.concatenate([tom_gt, np.zeros(q1_1.shape[0] + r1_1.shape[0]), q2_1, r2_1, q3_1, r3_1, np.zeros(q1_2.shape[0] + r1_2.shape[0]), q2_2, r2_2, q3_2, r3_2])
|
||||
elif intermediate == 7:
|
||||
# tom6 + tom7 + tom8
|
||||
tom_gt = np.concatenate([tom_gt, q1_1, r1_1, q2_1, r2_1, q3_1, r3_1, q1_2, r1_2, q2_2, r2_2, q3_2, r3_2])
|
||||
else:
|
||||
tom_gt = np.zeros(154)
|
||||
if tom_gt.shape[0] != 154: breakpoint()
|
||||
return tom_gt
|
Loading…
Add table
Add a link
Reference in a new issue