Code release
This commit is contained in:
commit
09fb25e339
29 changed files with 7162 additions and 0 deletions
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
290
utils/data_utils.py
Normal file
290
utils/data_utils.py
Normal file
|
@ -0,0 +1,290 @@
|
|||
import torch
|
||||
from torch.autograd import Variable
|
||||
import random
|
||||
import pickle
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
def load_pickle_lines(filename):
|
||||
data = []
|
||||
with open(filename, 'rb') as f:
|
||||
while True:
|
||||
try:
|
||||
data.append(pickle.load(f))
|
||||
except EOFError:
|
||||
break
|
||||
return data
|
||||
|
||||
|
||||
def flatten(l):
|
||||
return [item for sublist in l for item in sublist]
|
||||
|
||||
|
||||
def build_len_mask_batch(
|
||||
# [batch_size], []
|
||||
len_batch, max_len=None
|
||||
):
|
||||
if max_len is None:
|
||||
max_len = len_batch.max().item()
|
||||
# try:
|
||||
batch_size, = len_batch.shape
|
||||
# [batch_size, max_len]
|
||||
idxes_batch = torch.arange(max_len, device=len_batch.device).view(1, -1).repeat(batch_size, 1)
|
||||
# [batch_size, max_len] = [batch_size, max_len] < [batch_size, 1]
|
||||
return idxes_batch < len_batch.view(-1, 1)
|
||||
|
||||
|
||||
def sequence_mask(sequence_length, max_len=None):
|
||||
if max_len is None:
|
||||
max_len = sequence_length.data.max()
|
||||
batch_size = sequence_length.size(0)
|
||||
seq_range = torch.arange(0, max_len).long()
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||
seq_range_expand = Variable(seq_range_expand)
|
||||
if sequence_length.is_cuda:
|
||||
seq_range_expand = seq_range_expand.to(sequence_length.device)
|
||||
seq_length_expand = (sequence_length.unsqueeze(1)
|
||||
.expand_as(seq_range_expand))
|
||||
return seq_range_expand < seq_length_expand
|
||||
|
||||
def batch_iter(dataloader, params):
|
||||
for epochId in range(params['num_epochs']):
|
||||
for idx, batch in enumerate(dataloader):
|
||||
yield epochId, idx, batch
|
||||
|
||||
def list2tensorpad(inp_list, max_seq_len):
|
||||
inp_tensor = torch.LongTensor([inp_list])
|
||||
inp_tensor_zeros = torch.zeros(1, max_seq_len, dtype=torch.long)
|
||||
inp_tensor_zeros[0,:inp_tensor.shape[1]] = inp_tensor # after preprocess, inp_tensor.shape[1] must < max_seq_len
|
||||
inp_tensor = inp_tensor_zeros
|
||||
return inp_tensor
|
||||
|
||||
|
||||
def encode_input(utterances, start_segment, CLS, SEP, MASK, max_seq_len=256,max_sep_len=25,mask_prob=0.2):
|
||||
|
||||
cur_segment = start_segment
|
||||
token_id_list = []
|
||||
segment_id_list = []
|
||||
sep_token_indices = []
|
||||
masked_token_list = []
|
||||
|
||||
token_id_list.append(CLS)
|
||||
segment_id_list.append(cur_segment)
|
||||
masked_token_list.append(0)
|
||||
|
||||
cur_sep_token_index = 0
|
||||
|
||||
for cur_utterance in utterances:
|
||||
# add the masked token and keep track
|
||||
cur_masked_index = [1 if random.random() < mask_prob else 0 for _ in range(len(cur_utterance))]
|
||||
masked_token_list.extend(cur_masked_index)
|
||||
token_id_list.extend(cur_utterance)
|
||||
segment_id_list.extend([cur_segment]*len(cur_utterance))
|
||||
|
||||
token_id_list.append(SEP)
|
||||
segment_id_list.append(cur_segment)
|
||||
masked_token_list.append(0)
|
||||
cur_sep_token_index = cur_sep_token_index + len(cur_utterance) + 1
|
||||
sep_token_indices.append(cur_sep_token_index)
|
||||
cur_segment = cur_segment ^ 1 # cur segment osciallates between 0 and 1
|
||||
start_question, end_question = sep_token_indices[-3] + 1, sep_token_indices[-2]
|
||||
assert end_question - start_question == len(utterances[-2])
|
||||
|
||||
assert len(segment_id_list) == len(token_id_list) == len(masked_token_list) == sep_token_indices[-1] + 1
|
||||
# convert to tensors and pad to maximum seq length
|
||||
tokens = list2tensorpad(token_id_list,max_seq_len) # [1, max_len]
|
||||
masked_tokens = list2tensorpad(masked_token_list,max_seq_len)
|
||||
masked_tokens[0,masked_tokens[0,:]==0] = -1
|
||||
mask = masked_tokens[0,:]==1
|
||||
masked_tokens[0,mask] = tokens[0,mask]
|
||||
tokens[0,mask] = MASK
|
||||
|
||||
segment_id_list = list2tensorpad(segment_id_list,max_seq_len)
|
||||
return tokens, segment_id_list, list2tensorpad(sep_token_indices,max_sep_len), masked_tokens, start_question, end_question
|
||||
|
||||
def encode_input_with_mask(utterances, start_segment, CLS, SEP, MASK, max_seq_len=256,max_sep_len=25,mask_prob=0.2, get_q_limits=True):
|
||||
|
||||
cur_segment = start_segment
|
||||
token_id_list = []
|
||||
segment_id_list = []
|
||||
sep_token_indices = []
|
||||
masked_token_list = []
|
||||
input_mask_list = []
|
||||
|
||||
token_id_list.append(CLS)
|
||||
segment_id_list.append(cur_segment)
|
||||
masked_token_list.append(0)
|
||||
input_mask_list.append(1)
|
||||
|
||||
cur_sep_token_index = 0
|
||||
|
||||
for cur_utterance in utterances:
|
||||
# add the masked token and keep track
|
||||
cur_masked_index = [1 if random.random() < mask_prob else 0 for _ in range(len(cur_utterance))]
|
||||
masked_token_list.extend(cur_masked_index)
|
||||
token_id_list.extend(cur_utterance)
|
||||
segment_id_list.extend([cur_segment]*len(cur_utterance))
|
||||
input_mask_list.extend([1]*len(cur_utterance))
|
||||
|
||||
token_id_list.append(SEP)
|
||||
segment_id_list.append(cur_segment)
|
||||
masked_token_list.append(0)
|
||||
input_mask_list.append(1)
|
||||
cur_sep_token_index = cur_sep_token_index + len(cur_utterance) + 1
|
||||
sep_token_indices.append(cur_sep_token_index)
|
||||
cur_segment = cur_segment ^ 1 # cur segment osciallates between 0 and 1
|
||||
|
||||
if get_q_limits:
|
||||
start_question, end_question = sep_token_indices[-3] + 1, sep_token_indices[-2]
|
||||
assert end_question - start_question == len(utterances[-2])
|
||||
else:
|
||||
start_question, end_question = -1, -1
|
||||
assert len(segment_id_list) == len(token_id_list) == len(masked_token_list) ==len(input_mask_list) == sep_token_indices[-1] + 1
|
||||
# convert to tensors and pad to maximum seq length
|
||||
tokens = list2tensorpad(token_id_list, max_seq_len)
|
||||
masked_tokens = list2tensorpad(masked_token_list, max_seq_len)
|
||||
input_mask = list2tensorpad(input_mask_list,max_seq_len)
|
||||
masked_tokens[0,masked_tokens[0,:]==0] = -1
|
||||
mask = masked_tokens[0,:]==1
|
||||
masked_tokens[0,mask] = tokens[0,mask]
|
||||
tokens[0,mask] = MASK
|
||||
|
||||
segment_id_list = list2tensorpad(segment_id_list,max_seq_len)
|
||||
return tokens, segment_id_list, list2tensorpad(sep_token_indices,max_sep_len),masked_tokens, input_mask, start_question, end_question
|
||||
|
||||
|
||||
def encode_image_input(features, num_boxes, boxes, image_target, max_regions=37, mask_prob=0.15):
|
||||
output_label = []
|
||||
num_boxes = min(int(num_boxes), max_regions)
|
||||
|
||||
mix_boxes_pad = np.zeros((max_regions, boxes.shape[-1]))
|
||||
mix_features_pad = np.zeros((max_regions, features.shape[-1]))
|
||||
mix_image_target = np.zeros((max_regions, image_target.shape[-1]))
|
||||
|
||||
mix_boxes_pad[:num_boxes] = boxes[:num_boxes]
|
||||
mix_features_pad[:num_boxes] = features[:num_boxes]
|
||||
mix_image_target[:num_boxes] = image_target[:num_boxes]
|
||||
|
||||
boxes = mix_boxes_pad
|
||||
features = mix_features_pad
|
||||
image_target = mix_image_target
|
||||
mask_indexes = []
|
||||
for i in range(num_boxes):
|
||||
prob = random.random()
|
||||
# mask token with 15% probability
|
||||
if prob < mask_prob:
|
||||
prob /= mask_prob
|
||||
|
||||
# 80% randomly change token to mask token
|
||||
if prob < 0.9:
|
||||
features[i] = 0
|
||||
output_label.append(1)
|
||||
mask_indexes.append(i)
|
||||
else:
|
||||
# no masking token (will be ignored by loss function later)
|
||||
output_label.append(-1)
|
||||
|
||||
image_mask = [1] * (int(num_boxes))
|
||||
while len(image_mask) < max_regions:
|
||||
image_mask.append(0)
|
||||
output_label.append(-1)
|
||||
|
||||
# ensure we have atleast one region being predicted
|
||||
output_label[random.randint(1,len(output_label)-1)] = 1
|
||||
image_label = torch.LongTensor(output_label)
|
||||
image_label[0] = 0 # make sure the <IMG> token doesn't contribute to the masked loss
|
||||
image_mask = torch.tensor(image_mask).float()
|
||||
|
||||
features = torch.tensor(features).float()
|
||||
spatials = torch.tensor(boxes).float()
|
||||
image_target = torch.tensor(image_target).float()
|
||||
|
||||
return features, spatials, image_mask, image_target, image_label
|
||||
|
||||
|
||||
def question_edge_masking(question_edge_indices, question_edge_attributes, mask, question_limits, mask_prob=0.4, max_len=10):
|
||||
mask = mask.squeeze().tolist()
|
||||
question_limits = question_limits.tolist()
|
||||
question_start, question_end = question_limits
|
||||
# Get the masking of the question
|
||||
mask_question = mask[question_start:question_end]
|
||||
masked_idx = np.argwhere(np.array(mask_question) > -1).squeeze().tolist()
|
||||
if isinstance(masked_idx, (int)): # only one question token is masked
|
||||
masked_idx = [masked_idx]
|
||||
|
||||
# get rid of all edge indices and attributes that corresond to masked tokens
|
||||
edge_attr_gt = []
|
||||
edge_idx_gt_gnn = []
|
||||
edge_idx_gt_bert = []
|
||||
for i, (question_edge_idx, question_edge_attr) in enumerate(zip(question_edge_indices, question_edge_attributes)):
|
||||
if not(question_edge_idx[0] in masked_idx or question_edge_idx[1] in masked_idx):
|
||||
# Masking
|
||||
if random.random() < mask_prob:
|
||||
edge_attr_gt.append(np.argwhere(question_edge_attr).item())
|
||||
edge_idx_gt_gnn.append(question_edge_idx)
|
||||
edge_idx_gt_bert.append([question_edge_idx[0] + question_start, question_edge_idx[1] + question_start])
|
||||
question_edge_attr = np.zeros_like(question_edge_attr)
|
||||
question_edge_attr[-1] = 1.0 # The [EDGE_MASK] special token is the last one hot vector encoding
|
||||
question_edge_attributes[i] = question_edge_attr
|
||||
else:
|
||||
continue
|
||||
# Force masking if the necessary:
|
||||
if len(edge_attr_gt) == 0:
|
||||
for i, (question_edge_idx, question_edge_attr) in enumerate(zip(question_edge_indices, question_edge_attributes)):
|
||||
if not(question_edge_idx[0] in masked_idx or question_edge_idx[1] in masked_idx):
|
||||
# Masking
|
||||
edge_attr_gt.append(np.argwhere(question_edge_attr).item())
|
||||
edge_idx_gt_gnn.append(question_edge_idx)
|
||||
edge_idx_gt_bert.append([question_edge_idx[0] + question_start, question_edge_idx[1] + question_start])
|
||||
question_edge_attr = np.zeros_like(question_edge_attr)
|
||||
question_edge_attr[-1] = 1.0 # The [EDGE_MASK] special token is the last one hot vector encoding
|
||||
question_edge_attributes[i] = question_edge_attr
|
||||
break
|
||||
|
||||
# For the rare case, where the conditions for masking were not met
|
||||
if len(edge_attr_gt) == 0:
|
||||
edge_attr_gt.append(-1)
|
||||
edge_idx_gt_gnn.append([0, question_end - question_start])
|
||||
edge_idx_gt_bert.append(question_limits)
|
||||
|
||||
# Pad to max_len
|
||||
while len(edge_attr_gt) < max_len:
|
||||
edge_attr_gt.append(-1)
|
||||
edge_idx_gt_gnn.append(edge_idx_gt_gnn[-1])
|
||||
edge_idx_gt_bert.append(edge_idx_gt_bert[-1])
|
||||
|
||||
# Truncate if longer than max_len
|
||||
if len(edge_attr_gt) > max_len:
|
||||
edge_idx_gt_gnn = edge_idx_gt_gnn[:max_len]
|
||||
edge_idx_gt_bert = edge_idx_gt_bert[:max_len]
|
||||
edge_attr_gt = edge_attr_gt[:max_len]
|
||||
edge_idx_gt_gnn = np.array(edge_idx_gt_gnn)
|
||||
edge_idx_gt_bert = np.array(edge_idx_gt_bert)
|
||||
|
||||
first_edge_node_gt_gnn = list(edge_idx_gt_gnn[:, 0])
|
||||
second_edge_node_gt_gnn = list(edge_idx_gt_gnn[:, 1])
|
||||
|
||||
first_edge_node_gt_bert = list(edge_idx_gt_bert[:, 0])
|
||||
second_edge_node_gt_bert = list(edge_idx_gt_bert[:, 1])
|
||||
|
||||
return question_edge_attributes, edge_attr_gt, first_edge_node_gt_gnn, second_edge_node_gt_gnn, first_edge_node_gt_bert, second_edge_node_gt_bert
|
||||
|
||||
|
||||
def to_data_list(feats, batch_idx):
|
||||
feat_list = []
|
||||
device = feats.device
|
||||
left = 0
|
||||
right = 0
|
||||
batch_size = batch_idx.max().item() + 1
|
||||
for batch in range(batch_size):
|
||||
if batch == batch_size - 1:
|
||||
right = batch_idx.size(0)
|
||||
else:
|
||||
right = torch.argwhere(batch_idx == batch + 1)[0].item()
|
||||
idx = torch.arange(left, right).unsqueeze(-1).repeat(1, feats.size(1)).to(device)
|
||||
feat_list.append(torch.gather(feats, 0, idx))
|
||||
left = right
|
||||
|
||||
return feat_list
|
||||
|
192
utils/image_features_reader.py
Normal file
192
utils/image_features_reader.py
Normal file
|
@ -0,0 +1,192 @@
|
|||
from typing import List
|
||||
import csv
|
||||
import h5py
|
||||
import numpy as np
|
||||
import copy
|
||||
import pickle
|
||||
import lmdb # install lmdb by "pip install lmdb"
|
||||
import base64
|
||||
import pdb
|
||||
import os
|
||||
|
||||
|
||||
class ImageFeaturesH5Reader(object):
|
||||
"""
|
||||
A reader for H5 files containing pre-extracted image features. A typical
|
||||
H5 file is expected to have a column named "image_id", and another column
|
||||
named "features".
|
||||
|
||||
Example of an H5 file:
|
||||
```
|
||||
faster_rcnn_bottomup_features.h5
|
||||
|--- "image_id" [shape: (num_images, )]
|
||||
|--- "features" [shape: (num_images, num_proposals, feature_size)]
|
||||
+--- .attrs ("split", "train")
|
||||
```
|
||||
Parameters
|
||||
----------
|
||||
features_h5path : str
|
||||
Path to an H5 file containing COCO train / val image features.
|
||||
in_memory : bool
|
||||
Whether to load the whole H5 file in memory. Beware, these files are
|
||||
sometimes tens of GBs in size. Set this to true if you have sufficient
|
||||
RAM - trade-off between speed and memory.
|
||||
"""
|
||||
def __init__(self, features_path: str, scene_graph_path: str, in_memory: bool = False):
|
||||
self.features_path = features_path
|
||||
self.scene_graph_path = scene_graph_path
|
||||
self._in_memory = in_memory
|
||||
|
||||
self.env = lmdb.open(self.features_path, max_readers=1, readonly=True,
|
||||
lock=False, readahead=False, meminit=False)
|
||||
|
||||
with self.env.begin(write=False) as txn:
|
||||
self._image_ids = pickle.loads(txn.get('keys'.encode()))
|
||||
|
||||
self.features = [None] * len(self._image_ids)
|
||||
self.num_boxes = [None] * len(self._image_ids)
|
||||
self.boxes = [None] * len(self._image_ids)
|
||||
self.boxes_ori = [None] * len(self._image_ids)
|
||||
self.cls_prob = [None] * len(self._image_ids)
|
||||
self.edge_indexes = [None] * len(self._image_ids)
|
||||
self.edge_attributes = [None] * len(self._image_ids)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._image_ids)
|
||||
|
||||
def __getitem__(self, image_id):
|
||||
|
||||
image_id = str(image_id).encode()
|
||||
index = self._image_ids.index(image_id)
|
||||
if self._in_memory:
|
||||
# Load features during first epoch, all not loaded together as it
|
||||
# has a slow start.
|
||||
if self.features[index] is not None:
|
||||
features = self.features[index]
|
||||
num_boxes = self.num_boxes[index]
|
||||
image_location = self.boxes[index]
|
||||
image_location_ori = self.boxes_ori[index]
|
||||
cls_prob = self.cls_prob[index]
|
||||
edge_indexes = self.edge_indexes[index]
|
||||
edge_attributes = self.edge_attributes[index]
|
||||
else:
|
||||
with self.env.begin(write=False) as txn:
|
||||
item = pickle.loads(txn.get(image_id))
|
||||
image_id = item['image_id']
|
||||
image_h = int(item['image_h'])
|
||||
image_w = int(item['image_w'])
|
||||
num_boxes = int(item['num_boxes'])
|
||||
features = np.frombuffer(base64.b64decode(item["features"]), dtype=np.float32).reshape(num_boxes, 2048)
|
||||
boxes = np.frombuffer(base64.b64decode(item['boxes']), dtype=np.float32).reshape(num_boxes, 4)
|
||||
|
||||
cls_prob = np.frombuffer(base64.b64decode(item['cls_prob']), dtype=np.float32).reshape(num_boxes, 1601)
|
||||
# add an extra row at the top for the <IMG> tokens
|
||||
g_cls_prob = np.zeros(1601, dtype=np.float32)
|
||||
g_cls_prob[0] = 1
|
||||
cls_prob = np.concatenate([np.expand_dims(g_cls_prob,axis=0), cls_prob], axis=0)
|
||||
|
||||
self.cls_prob[index] = cls_prob
|
||||
|
||||
g_feat = np.sum(features, axis=0) / num_boxes
|
||||
num_boxes = num_boxes + 1
|
||||
|
||||
features = np.concatenate([np.expand_dims(g_feat, axis=0), features], axis=0)
|
||||
self.features[index] = features
|
||||
|
||||
image_location = np.zeros((boxes.shape[0], 5), dtype=np.float32)
|
||||
image_location[:,:4] = boxes
|
||||
image_location[:,4] = (image_location[:,3] - image_location[:,1]) * (image_location[:,2] - image_location[:,0]) / (float(image_w) * float(image_h))
|
||||
|
||||
image_location_ori = copy.deepcopy(image_location)
|
||||
|
||||
image_location[:,0] = image_location[:,0] / float(image_w)
|
||||
image_location[:,1] = image_location[:,1] / float(image_h)
|
||||
image_location[:,2] = image_location[:,2] / float(image_w)
|
||||
image_location[:,3] = image_location[:,3] / float(image_h)
|
||||
|
||||
g_location = np.array([0,0,1,1,1])
|
||||
image_location = np.concatenate([np.expand_dims(g_location, axis=0), image_location], axis=0)
|
||||
self.boxes[index] = image_location
|
||||
|
||||
g_location_ori = np.array([0, 0, image_w, image_h, image_w*image_h])
|
||||
image_location_ori = np.concatenate([np.expand_dims(g_location_ori, axis=0), image_location_ori], axis=0)
|
||||
self.boxes_ori[index] = image_location_ori
|
||||
self.num_boxes[index] = num_boxes
|
||||
|
||||
# load the scene graph data
|
||||
pth = os.path.join(self.scene_graph_path, f'{image_id}.pkl')
|
||||
with open(pth, 'rb') as f:
|
||||
graph_data = pickle.load(f)
|
||||
edge_indexes = []
|
||||
edge_attributes = []
|
||||
for e_idx, e_attr in graph_data:
|
||||
edge_indexes.append(e_idx)
|
||||
# get one-hot-encoding of the edges
|
||||
e_attr_one_hot = np.zeros((12,), dtype=np.float32) # 12 = 11 rels + hub-node rel
|
||||
e_attr_one_hot[e_attr] = 1.0
|
||||
edge_attributes.append(e_attr_one_hot)
|
||||
edge_indexes = np.array(edge_indexes, dtype=np.float64).transpose(1, 0)
|
||||
edge_attributes = np.stack(edge_attributes, axis=0)
|
||||
|
||||
self.edge_indexes[index] = edge_indexes
|
||||
self.edge_attributes[index] = edge_attributes
|
||||
|
||||
else:
|
||||
# Read chunk from file everytime if not loaded in memory.
|
||||
with self.env.begin(write=False) as txn:
|
||||
item = pickle.loads(txn.get(image_id))
|
||||
image_id = item['image_id']
|
||||
image_h = int(item['image_h'])
|
||||
image_w = int(item['image_w'])
|
||||
num_boxes = int(item['num_boxes'])
|
||||
cls_prob = np.frombuffer(base64.b64decode(item['cls_prob']), dtype=np.float32).reshape(num_boxes, 1601)
|
||||
# add an extra row at the top for the <IMG> tokens
|
||||
g_cls_prob = np.zeros(1601, dtype=np.float32)
|
||||
g_cls_prob[0] = 1
|
||||
cls_prob = np.concatenate([np.expand_dims(g_cls_prob,axis=0), cls_prob], axis=0)
|
||||
|
||||
features = np.frombuffer(base64.b64decode(item["features"]), dtype=np.float32).reshape(num_boxes, 2048)
|
||||
boxes = np.frombuffer(base64.b64decode(item['boxes']), dtype=np.float32).reshape(num_boxes, 4)
|
||||
g_feat = np.sum(features, axis=0) / num_boxes
|
||||
num_boxes = num_boxes + 1
|
||||
features = np.concatenate([np.expand_dims(g_feat, axis=0), features], axis=0)
|
||||
|
||||
image_location = np.zeros((boxes.shape[0], 5), dtype=np.float32)
|
||||
image_location[:,:4] = boxes
|
||||
image_location[:,4] = (image_location[:,3] - image_location[:,1]) * (image_location[:,2] - image_location[:,0]) / (float(image_w) * float(image_h))
|
||||
|
||||
image_location_ori = copy.deepcopy(image_location)
|
||||
image_location[:,0] = image_location[:,0] / float(image_w)
|
||||
image_location[:,1] = image_location[:,1] / float(image_h)
|
||||
image_location[:,2] = image_location[:,2] / float(image_w)
|
||||
image_location[:,3] = image_location[:,3] / float(image_h)
|
||||
|
||||
g_location = np.array([0,0,1,1,1])
|
||||
image_location = np.concatenate([np.expand_dims(g_location, axis=0), image_location], axis=0)
|
||||
|
||||
g_location_ori = np.array([0,0,image_w,image_h,image_w*image_h])
|
||||
image_location_ori = np.concatenate([np.expand_dims(g_location_ori, axis=0), image_location_ori], axis=0)
|
||||
|
||||
# load the scene graph data
|
||||
pth = os.path.join(self.scene_graph_path, f'{image_id}.pkl')
|
||||
with open(pth, 'rb') as f:
|
||||
graph_data = pickle.load(f)
|
||||
edge_indexes = []
|
||||
edge_attributes = []
|
||||
for e_idx, e_attr in graph_data:
|
||||
edge_indexes.append(e_idx)
|
||||
# get one-hot-encoding of the edges
|
||||
e_attr_one_hot = np.zeros((12,), dtype=np.float32) # 12 = 11 rels + hub-node rel
|
||||
e_attr_one_hot[e_attr] = 1.0
|
||||
edge_attributes.append(e_attr_one_hot)
|
||||
edge_indexes = np.array(edge_indexes, dtype=np.float64).transpose(1, 0)
|
||||
edge_attributes = np.stack(edge_attributes, axis=0)
|
||||
|
||||
return features, num_boxes, image_location, image_location_ori, cls_prob, edge_indexes, edge_attributes
|
||||
|
||||
|
||||
def keys(self) -> List[int]:
|
||||
return self._image_ids
|
||||
|
||||
def set_keys(self, new_ids: List[str]):
|
||||
self._image_ids = list(map(lambda _id: _id.encode('ascii') ,new_ids))
|
176
utils/init_utils.py
Normal file
176
utils/init_utils.py
Normal file
|
@ -0,0 +1,176 @@
|
|||
import os
|
||||
import os.path as osp
|
||||
import random
|
||||
import datetime
|
||||
import itertools
|
||||
import glob
|
||||
import subprocess
|
||||
import pyhocon
|
||||
import glob
|
||||
import re
|
||||
import numpy as np
|
||||
import glog as log
|
||||
import json
|
||||
import torch
|
||||
|
||||
import sys
|
||||
sys.path.append('../')
|
||||
|
||||
from models import vdgr
|
||||
from dataloader.dataloader_visdial import VisdialDataset
|
||||
|
||||
from dataloader.dataloader_visdial_dense import VisdialDenseDataset
|
||||
|
||||
|
||||
def load_runner(config):
|
||||
if config['train_on_dense']:
|
||||
return vdgr.DenseRunner(config)
|
||||
else:
|
||||
return vdgr.SparseRunner(config)
|
||||
|
||||
def load_dataset(config):
|
||||
dataset_eval = None
|
||||
|
||||
if config['train_on_dense']:
|
||||
dataset = VisdialDenseDataset(config)
|
||||
if config['skip_mrr_eval']:
|
||||
temp = config['num_options_dense']
|
||||
config['num_options_dense'] = config['num_options']
|
||||
dataset_eval = VisdialDenseDataset(config)
|
||||
config['num_options_dense'] = temp
|
||||
else:
|
||||
dataset_eval = VisdialDataset(config)
|
||||
else:
|
||||
dataset = VisdialDataset(config)
|
||||
if config['skip_mrr_eval']:
|
||||
dataset_eval = VisdialDenseDataset(config)
|
||||
|
||||
if config['use_trainval']:
|
||||
dataset.split = 'trainval'
|
||||
else:
|
||||
dataset.split = 'train'
|
||||
|
||||
if dataset_eval is not None:
|
||||
dataset_eval.split = 'val'
|
||||
|
||||
return dataset, dataset_eval
|
||||
|
||||
|
||||
def initialize_from_env(model, mode, eval_dir, model_type, tag=''):
|
||||
if "GPU" in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ['GPU']
|
||||
if mode in ['train', 'debug']:
|
||||
config = pyhocon.ConfigFactory.parse_file(f"config/{model_type}.conf")[model]
|
||||
else:
|
||||
path_config = osp.join(eval_dir, 'code', f"config/{model_type}.conf")
|
||||
config = pyhocon.ConfigFactory.parse_file(path_config)[model]
|
||||
config['log_dir'] = eval_dir
|
||||
config['model_config'] = osp.join(eval_dir, 'code/config/bert_base_6layer_6conect.json')
|
||||
if config['dp_type'] == 'apex':
|
||||
config['dp_type'] = 'ddp'
|
||||
|
||||
if config['dp_type'] == 'dp':
|
||||
config['stack_gr_data'] = True
|
||||
|
||||
config['model_type'] = model_type
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
config['num_gpus'] = len(os.environ["CUDA_VISIBLE_DEVICES"].split(','))
|
||||
# multi-gpu setting
|
||||
if config['num_gpus'] > 1:
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = '5678'
|
||||
|
||||
if mode == 'debug':
|
||||
model += '_debug'
|
||||
|
||||
if tag:
|
||||
model += '-' + tag
|
||||
if mode in ['train', 'debug']:
|
||||
config['log_dir'] = os.path.join(config["log_dir"], model)
|
||||
if not os.path.exists(config["log_dir"]):
|
||||
os.makedirs(config["log_dir"])
|
||||
config['visdial_output_dir'] = osp.join(config['log_dir'], config['visdial_output_dir'])
|
||||
|
||||
config['timestamp'] = datetime.datetime.now().strftime('%m%d-%H%M%S')
|
||||
|
||||
# add the bert config
|
||||
config['bert_config'] = json.load(open(config['model_config'], 'r'))
|
||||
if mode in ['predict', 'eval']:
|
||||
if (not config['loads_start_path']) and (not config['loads_best_ckpt']):
|
||||
config['loads_best_ckpt'] = True
|
||||
print(f'Setting loads_best_ckpt=True under predict or eval mode')
|
||||
if config['num_options_dense'] < 100:
|
||||
config['num_options_dense'] = 100
|
||||
print('Setting num_options_dense=100 under predict or eval mode')
|
||||
if config['visdial_version'] == 0.9:
|
||||
config['skip_ndcg_eval'] = True
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def set_log_file(fname, file_only=False):
|
||||
# if fname already exists, find all log file under log dir,
|
||||
# and name the current log file with a new number
|
||||
if osp.exists(fname):
|
||||
prefix, suffix = osp.splitext(fname)
|
||||
log_files = glob.glob(prefix + '*' + suffix)
|
||||
count = 0
|
||||
for log_file in log_files:
|
||||
num = re.search(r'(\d+)', log_file)
|
||||
if num is not None:
|
||||
num = int(num.group(0))
|
||||
count = max(num, count)
|
||||
fname = fname.replace(suffix, str(count + 1) + suffix)
|
||||
# set log file
|
||||
# simple tricks for duplicating logging destination in the logging module such as:
|
||||
# logging.getLogger().addHandler(logging.FileHandler(filename))
|
||||
# does NOT work well here, because python Traceback message (not via logging module) is not sent to the file,
|
||||
# the following solution (copied from : https://stackoverflow.com/questions/616645) is a little bit
|
||||
# complicated but simulates exactly the "tee" command in linux shell, and it redirects everything
|
||||
if file_only:
|
||||
# we only output messages to file, and stdout/stderr receives nothing.
|
||||
# this feature is designed for executing the script via ssh:
|
||||
# since ssh has a windowing kind of flow control, i.e., if the controller does not read data from a
|
||||
# ssh channel and its buffer fills up, the execution machine will not be able to write anything into the
|
||||
# channel and the process will be set to sleeping (S) status until someone reads all data from the channel.
|
||||
# this is not desired since we do not want to read stdout/stderr from the controller machine.
|
||||
# so, here we use a simple solution: disable output to stdout/stderr and only output messages to log file.
|
||||
log.logger.handlers[0].stream = log.handler.stream = sys.stdout = sys.stderr = f = open(fname, 'w', buffering=1)
|
||||
else:
|
||||
# we output messages to both file and stdout/stderr
|
||||
tee = subprocess.Popen(['tee', fname], stdin=subprocess.PIPE)
|
||||
os.dup2(tee.stdin.fileno(), sys.stdout.fileno())
|
||||
os.dup2(tee.stdin.fileno(), sys.stderr.fileno())
|
||||
|
||||
|
||||
def copy_file_to_log(log_dir):
|
||||
dirs_to_cp = ['.', 'config', 'dataloader', 'models', 'utils']
|
||||
files_to_cp = ['*.py', '*.json', '*.sh', '*.conf']
|
||||
for dir_name in dirs_to_cp:
|
||||
dir_name = osp.join(log_dir, 'code', dir_name)
|
||||
if not osp.exists(dir_name):
|
||||
os.makedirs(dir_name)
|
||||
for dir_name, file_name in itertools.product(dirs_to_cp, files_to_cp):
|
||||
filename = osp.join(dir_name, file_name)
|
||||
if len(glob.glob(filename)) > 0:
|
||||
os.system(f'cp {filename} {osp.join(log_dir, "code", dir_name)}')
|
||||
log.info(f'Files copied to {osp.join(log_dir, "code")}')
|
||||
|
||||
|
||||
def set_random_seed(random_seed):
|
||||
torch.manual_seed(random_seed)
|
||||
torch.cuda.manual_seed(random_seed)
|
||||
random.seed(random_seed)
|
||||
np.random.seed(random_seed)
|
||||
|
||||
|
||||
def set_training_steps(config, num_samples):
|
||||
if config['parallel'] and config['dp_type'] == 'dp':
|
||||
config['num_iter_per_epoch'] = int(np.ceil(num_samples / config['batch_size']))
|
||||
else:
|
||||
config['num_iter_per_epoch'] = int(np.ceil(num_samples / (config['batch_size'] * config['num_gpus'])))
|
||||
if 'train_steps' not in config:
|
||||
config['train_steps'] = config['num_iter_per_epoch'] * config['num_epochs']
|
||||
if 'warmup_steps' not in config:
|
||||
config['warmup_steps'] = int(config['train_steps'] * config['warmup_ratio'])
|
||||
return config
|
456
utils/model_utils.py
Normal file
456
utils/model_utils.py
Normal file
|
@ -0,0 +1,456 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
def truncated_normal_(tensor, mean=0, std=1):
|
||||
size = tensor.shape
|
||||
tmp = tensor.new_empty(size + (4,)).normal_()
|
||||
valid = (tmp < 2) & (tmp > -2)
|
||||
ind = valid.max(-1, keepdim=True)[1]
|
||||
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
|
||||
tensor.data.mul_(std).add_(mean)
|
||||
|
||||
|
||||
def init_params(module, initializer='normal'):
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
if initializer == 'kaiming_normal':
|
||||
nn.init.kaiming_normal_(module.weight.data)
|
||||
elif initializer == 'normal':
|
||||
nn.init.normal_(module.weight.data, std=0.02)
|
||||
elif initializer == 'truncated_normal':
|
||||
truncated_normal_(module.weight.data, std=0.02)
|
||||
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias.data)
|
||||
|
||||
# log.info('initialized Linear')
|
||||
|
||||
elif isinstance(module, nn.Embedding):
|
||||
if initializer == 'kaiming_normal':
|
||||
nn.init.kaiming_normal_(module.weight.data)
|
||||
elif initializer == 'normal':
|
||||
nn.init.normal_(module.weight.data, std=0.02)
|
||||
elif initializer == 'truncated_normal':
|
||||
truncated_normal_(module.weight.data, std=0.02)
|
||||
|
||||
elif isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv1d):
|
||||
nn.init.kaiming_normal_(module.weight, mode='fan_out')
|
||||
# log.info('initialized Conv')
|
||||
|
||||
elif isinstance(module, nn.RNNBase) or isinstance(module, nn.LSTMCell) or isinstance(module, nn.GRUCell):
|
||||
for name, param in module.named_parameters():
|
||||
if 'weight' in name:
|
||||
nn.init.orthogonal_(param.data)
|
||||
elif 'bias' in name:
|
||||
nn.init.normal_(param.data)
|
||||
|
||||
# log.info('initialized LSTM')
|
||||
|
||||
elif isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d):
|
||||
module.weight.data.normal_(1.0, 0.02)
|
||||
# log.info('initialized BatchNorm')
|
||||
|
||||
|
||||
def TensorboardWriter(save_path):
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
return SummaryWriter(save_path, comment="Unmt")
|
||||
|
||||
|
||||
DEFAULT_EPS = 1e-8
|
||||
PADDED_Y_VALUE = -1
|
||||
|
||||
|
||||
def listMLE(y_pred, y_true, eps=DEFAULT_EPS, padded_value_indicator=PADDED_Y_VALUE):
|
||||
"""
|
||||
ListMLE loss introduced in "Listwise Approach to Learning to Rank - Theory and Algorithm".
|
||||
:param y_pred: predictions from the model, shape [batch_size, slate_length]
|
||||
:param y_true: ground truth labels, shape [batch_size, slate_length]
|
||||
:param eps: epsilon value, used for numerical stability
|
||||
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
|
||||
:return: loss value, a torch.Tensor
|
||||
"""
|
||||
# shuffle for randomised tie resolution
|
||||
random_indices = torch.randperm(y_pred.shape[-1])
|
||||
y_pred_shuffled = y_pred[:, random_indices]
|
||||
y_true_shuffled = y_true[:, random_indices]
|
||||
|
||||
y_true_sorted, indices = y_true_shuffled.sort(descending=True, dim=-1)
|
||||
|
||||
mask = y_true_sorted == padded_value_indicator
|
||||
|
||||
preds_sorted_by_true = torch.gather(y_pred_shuffled, dim=1, index=indices)
|
||||
preds_sorted_by_true[mask] = float("-inf")
|
||||
|
||||
max_pred_values, _ = preds_sorted_by_true.max(dim=1, keepdim=True)
|
||||
|
||||
preds_sorted_by_true_minus_max = preds_sorted_by_true - max_pred_values
|
||||
|
||||
cumsums = torch.cumsum(preds_sorted_by_true_minus_max.exp().flip(dims=[1]), dim=1).flip(dims=[1])
|
||||
|
||||
observation_loss = torch.log(cumsums + eps) - preds_sorted_by_true_minus_max
|
||||
|
||||
observation_loss[mask] = 0.0
|
||||
|
||||
return torch.mean(torch.sum(observation_loss, dim=1))
|
||||
|
||||
|
||||
def approxNDCGLoss(y_pred, y_true, eps=DEFAULT_EPS, padded_value_indicator=PADDED_Y_VALUE, alpha=1.):
|
||||
"""
|
||||
Loss based on approximate NDCG introduced in "A General Approximation Framework for Direct Optimization of
|
||||
Information Retrieval Measures". Please note that this method does not implement any kind of truncation.
|
||||
:param y_pred: predictions from the model, shape [batch_size, slate_length]
|
||||
:param y_true: ground truth labels, shape [batch_size, slate_length]
|
||||
:param eps: epsilon value, used for numerical stability
|
||||
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
|
||||
:param alpha: score difference weight used in the sigmoid function
|
||||
:return: loss value, a torch.Tensor
|
||||
"""
|
||||
device = y_pred.device
|
||||
y_pred = y_pred.clone()
|
||||
y_true = y_true.clone()
|
||||
|
||||
padded_mask = y_true == padded_value_indicator
|
||||
y_pred[padded_mask] = float("-inf")
|
||||
y_true[padded_mask] = float("-inf")
|
||||
|
||||
# Here we sort the true and predicted relevancy scores.
|
||||
y_pred_sorted, indices_pred = y_pred.sort(descending=True, dim=-1)
|
||||
y_true_sorted, _ = y_true.sort(descending=True, dim=-1)
|
||||
|
||||
# After sorting, we can mask out the pairs of indices (i, j) containing index of a padded element.
|
||||
true_sorted_by_preds = torch.gather(y_true, dim=1, index=indices_pred)
|
||||
true_diffs = true_sorted_by_preds[:, :, None] - true_sorted_by_preds[:, None, :]
|
||||
padded_pairs_mask = torch.isfinite(true_diffs)
|
||||
padded_pairs_mask.diagonal(dim1=-2, dim2=-1).zero_()
|
||||
|
||||
# Here we clamp the -infs to get correct gains and ideal DCGs (maxDCGs)
|
||||
true_sorted_by_preds.clamp_(min=0.)
|
||||
y_true_sorted.clamp_(min=0.)
|
||||
|
||||
# Here we find the gains, discounts and ideal DCGs per slate.
|
||||
pos_idxs = torch.arange(1, y_pred.shape[1] + 1).to(device)
|
||||
D = torch.log2(1. + pos_idxs.float())[None, :]
|
||||
maxDCGs = torch.sum((torch.pow(2, y_true_sorted) - 1) / D, dim=-1).clamp(min=eps)
|
||||
G = (torch.pow(2, true_sorted_by_preds) - 1) / maxDCGs[:, None]
|
||||
|
||||
# Here we approximate the ranking positions according to Eqs 19-20 and later approximate NDCG (Eq 21)
|
||||
scores_diffs = (y_pred_sorted[:, :, None] - y_pred_sorted[:, None, :])
|
||||
scores_diffs[~padded_pairs_mask] = 0.
|
||||
approx_pos = 1. + torch.sum(padded_pairs_mask.float() * (torch.sigmoid(-alpha * scores_diffs).clamp(min=eps)),
|
||||
dim=-1)
|
||||
approx_D = torch.log2(1. + approx_pos)
|
||||
approx_NDCG = torch.sum((G / approx_D), dim=-1)
|
||||
|
||||
return -torch.mean(approx_NDCG)
|
||||
# return -torch.mean(approx_NDCG)
|
||||
|
||||
|
||||
def listNet(y_pred, y_true, eps=DEFAULT_EPS, padded_value_indicator=PADDED_Y_VALUE):
|
||||
"""
|
||||
ListNet loss introduced in "Learning to Rank: From Pairwise Approach to Listwise Approach".
|
||||
:param y_pred: predictions from the model, shape [batch_size, slate_length]
|
||||
:param y_true: ground truth labels, shape [batch_size, slate_length]
|
||||
:param eps: epsilon value, used for numerical stability
|
||||
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
|
||||
:return: loss value, a torch.Tensor
|
||||
"""
|
||||
y_pred = y_pred.clone()
|
||||
y_true = y_true.clone()
|
||||
|
||||
mask = y_true == padded_value_indicator
|
||||
y_pred[mask] = float('-inf')
|
||||
y_true[mask] = float('-inf')
|
||||
|
||||
preds_smax = F.softmax(y_pred, dim=1)
|
||||
true_smax = F.softmax(y_true, dim=1)
|
||||
|
||||
preds_smax = preds_smax + eps
|
||||
preds_log = torch.log(preds_smax)
|
||||
|
||||
return torch.mean(-torch.sum(true_smax * preds_log, dim=1))
|
||||
|
||||
|
||||
def deterministic_neural_sort(s, tau, mask):
|
||||
"""
|
||||
Deterministic neural sort.
|
||||
Code taken from "Stochastic Optimization of Sorting Networks via Continuous Relaxations", ICLR 2019.
|
||||
Minor modifications applied to the original code (masking).
|
||||
:param s: values to sort, shape [batch_size, slate_length]
|
||||
:param tau: temperature for the final softmax function
|
||||
:param mask: mask indicating padded elements
|
||||
:return: approximate permutation matrices of shape [batch_size, slate_length, slate_length]
|
||||
"""
|
||||
dev = s.device
|
||||
|
||||
n = s.size()[1]
|
||||
one = torch.ones((n, 1), dtype=torch.float32, device=dev)
|
||||
s = s.masked_fill(mask[:, :, None], -1e8)
|
||||
A_s = torch.abs(s - s.permute(0, 2, 1))
|
||||
A_s = A_s.masked_fill(mask[:, :, None] | mask[:, None, :], 0.0)
|
||||
|
||||
B = torch.matmul(A_s, torch.matmul(one, torch.transpose(one, 0, 1)))
|
||||
|
||||
temp = [n - m + 1 - 2 * (torch.arange(n - m, device=dev) + 1) for m in mask.squeeze(-1).sum(dim=1)]
|
||||
temp = [t.type(torch.float32) for t in temp]
|
||||
temp = [torch.cat((t, torch.zeros(n - len(t), device=dev))) for t in temp]
|
||||
scaling = torch.stack(temp).type(torch.float32).to(dev) # type: ignore
|
||||
|
||||
s = s.masked_fill(mask[:, :, None], 0.0)
|
||||
C = torch.matmul(s, scaling.unsqueeze(-2))
|
||||
|
||||
P_max = (C - B).permute(0, 2, 1)
|
||||
P_max = P_max.masked_fill(mask[:, :, None] | mask[:, None, :], -np.inf)
|
||||
P_max = P_max.masked_fill(mask[:, :, None] & mask[:, None, :], 1.0)
|
||||
sm = torch.nn.Softmax(-1)
|
||||
P_hat = sm(P_max / tau)
|
||||
return P_hat
|
||||
|
||||
def sample_gumbel(samples_shape, device, eps=1e-10) -> torch.Tensor:
|
||||
"""
|
||||
Sampling from Gumbel distribution.
|
||||
Code taken from "Stochastic Optimization of Sorting Networks via Continuous Relaxations", ICLR 2019.
|
||||
Minor modifications applied to the original code (masking).
|
||||
:param samples_shape: shape of the output samples tensor
|
||||
:param device: device of the output samples tensor
|
||||
:param eps: epsilon for the logarithm function
|
||||
:return: Gumbel samples tensor of shape samples_shape
|
||||
"""
|
||||
U = torch.rand(samples_shape, device=device)
|
||||
return -torch.log(-torch.log(U + eps) + eps)
|
||||
|
||||
|
||||
def apply_mask_and_get_true_sorted_by_preds(y_pred, y_true, padding_indicator=PADDED_Y_VALUE):
|
||||
mask = y_true == padding_indicator
|
||||
|
||||
y_pred[mask] = float('-inf')
|
||||
y_true[mask] = 0.0
|
||||
|
||||
_, indices = y_pred.sort(descending=True, dim=-1)
|
||||
return torch.gather(y_true, dim=1, index=indices)
|
||||
|
||||
|
||||
def dcg(y_pred, y_true, ats=None, gain_function=lambda x: torch.pow(2, x) - 1, padding_indicator=PADDED_Y_VALUE):
|
||||
"""
|
||||
Discounted Cumulative Gain at k.
|
||||
Compute DCG at ranks given by ats or at the maximum rank if ats is None.
|
||||
:param y_pred: predictions from the model, shape [batch_size, slate_length]
|
||||
:param y_true: ground truth labels, shape [batch_size, slate_length]
|
||||
:param ats: optional list of ranks for DCG evaluation, if None, maximum rank is used
|
||||
:param gain_function: callable, gain function for the ground truth labels, e.g. torch.pow(2, x) - 1
|
||||
:param padding_indicator: an indicator of the y_true index containing a padded item, e.g. -1
|
||||
:return: DCG values for each slate and evaluation position, shape [batch_size, len(ats)]
|
||||
"""
|
||||
y_true = y_true.clone()
|
||||
y_pred = y_pred.clone()
|
||||
|
||||
actual_length = y_true.shape[1]
|
||||
|
||||
if ats is None:
|
||||
ats = [actual_length]
|
||||
ats = [min(at, actual_length) for at in ats]
|
||||
|
||||
true_sorted_by_preds = apply_mask_and_get_true_sorted_by_preds(y_pred, y_true, padding_indicator)
|
||||
|
||||
discounts = (torch.tensor(1) / torch.log2(torch.arange(true_sorted_by_preds.shape[1], dtype=torch.float) + 2.0)).to(
|
||||
device=true_sorted_by_preds.device)
|
||||
|
||||
gains = gain_function(true_sorted_by_preds)
|
||||
|
||||
discounted_gains = (gains * discounts)[:, :np.max(ats)]
|
||||
|
||||
cum_dcg = torch.cumsum(discounted_gains, dim=1)
|
||||
|
||||
ats_tensor = torch.tensor(ats, dtype=torch.long) - torch.tensor(1)
|
||||
|
||||
dcg = cum_dcg[:, ats_tensor]
|
||||
|
||||
return dcg
|
||||
|
||||
|
||||
def sinkhorn_scaling(mat, mask=None, tol=1e-6, max_iter=50):
|
||||
"""
|
||||
Sinkhorn scaling procedure.
|
||||
:param mat: a tensor of square matrices of shape N x M x M, where N is batch size
|
||||
:param mask: a tensor of masks of shape N x M
|
||||
:param tol: Sinkhorn scaling tolerance
|
||||
:param max_iter: maximum number of iterations of the Sinkhorn scaling
|
||||
:return: a tensor of (approximately) doubly stochastic matrices
|
||||
"""
|
||||
if mask is not None:
|
||||
mat = mat.masked_fill(mask[:, None, :] | mask[:, :, None], 0.0)
|
||||
mat = mat.masked_fill(mask[:, None, :] & mask[:, :, None], 1.0)
|
||||
|
||||
for _ in range(max_iter):
|
||||
mat = mat / mat.sum(dim=1, keepdim=True).clamp(min=DEFAULT_EPS)
|
||||
mat = mat / mat.sum(dim=2, keepdim=True).clamp(min=DEFAULT_EPS)
|
||||
|
||||
if torch.max(torch.abs(mat.sum(dim=2) - 1.)) < tol and torch.max(torch.abs(mat.sum(dim=1) - 1.)) < tol:
|
||||
break
|
||||
|
||||
if mask is not None:
|
||||
mat = mat.masked_fill(mask[:, None, :] | mask[:, :, None], 0.0)
|
||||
|
||||
return mat
|
||||
|
||||
|
||||
|
||||
def stochastic_neural_sort(s, n_samples, tau, mask, beta=1.0, log_scores=True, eps=1e-10):
|
||||
"""
|
||||
Stochastic neural sort. Please note that memory complexity grows by factor n_samples.
|
||||
Code taken from "Stochastic Optimization of Sorting Networks via Continuous Relaxations", ICLR 2019.
|
||||
Minor modifications applied to the original code (masking).
|
||||
:param s: values to sort, shape [batch_size, slate_length]
|
||||
:param n_samples: number of samples (approximations) for each permutation matrix
|
||||
:param tau: temperature for the final softmax function
|
||||
:param mask: mask indicating padded elements
|
||||
:param beta: scale parameter for the Gumbel distribution
|
||||
:param log_scores: whether to apply the logarithm function to scores prior to Gumbel perturbation
|
||||
:param eps: epsilon for the logarithm function
|
||||
:return: approximate permutation matrices of shape [n_samples, batch_size, slate_length, slate_length]
|
||||
"""
|
||||
dev = s.device
|
||||
|
||||
batch_size = s.size()[0]
|
||||
n = s.size()[1]
|
||||
s_positive = s + torch.abs(s.min())
|
||||
samples = beta * sample_gumbel([n_samples, batch_size, n, 1], device=dev)
|
||||
if log_scores:
|
||||
s_positive = torch.log(s_positive + eps)
|
||||
|
||||
s_perturb = (s_positive + samples).view(n_samples * batch_size, n, 1)
|
||||
mask_repeated = mask.repeat_interleave(n_samples, dim=0)
|
||||
|
||||
P_hat = deterministic_neural_sort(s_perturb, tau, mask_repeated)
|
||||
P_hat = P_hat.view(n_samples, batch_size, n, n)
|
||||
return P_hat
|
||||
|
||||
|
||||
def neuralNDCG(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE, temperature=1., powered_relevancies=True, k=None,
|
||||
stochastic=False, n_samples=32, beta=0.1, log_scores=True):
|
||||
"""
|
||||
NeuralNDCG loss introduced in "NeuralNDCG: Direct Optimisation of a Ranking Metric via Differentiable
|
||||
Relaxation of Sorting" - https://arxiv.org/abs/2102.07831. Based on the NeuralSort algorithm.
|
||||
:param y_pred: predictions from the model, shape [batch_size, slate_length]
|
||||
:param y_true: ground truth labels, shape [batch_size, slate_length]
|
||||
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
|
||||
:param temperature: temperature for the NeuralSort algorithm
|
||||
:param powered_relevancies: whether to apply 2^x - 1 gain function, x otherwise
|
||||
:param k: rank at which the loss is truncated
|
||||
:param stochastic: whether to calculate the stochastic variant
|
||||
:param n_samples: how many stochastic samples are taken, used if stochastic == True
|
||||
:param beta: beta parameter for NeuralSort algorithm, used if stochastic == True
|
||||
:param log_scores: log_scores parameter for NeuralSort algorithm, used if stochastic == True
|
||||
:return: loss value, a torch.Tensor
|
||||
"""
|
||||
dev = y_pred.device
|
||||
|
||||
if k is None:
|
||||
k = y_true.shape[1]
|
||||
|
||||
mask = (y_true == padded_value_indicator)
|
||||
# Choose the deterministic/stochastic variant
|
||||
if stochastic:
|
||||
P_hat = stochastic_neural_sort(y_pred.unsqueeze(-1), n_samples=n_samples, tau=temperature, mask=mask,
|
||||
beta=beta, log_scores=log_scores)
|
||||
else:
|
||||
P_hat = deterministic_neural_sort(y_pred.unsqueeze(-1), tau=temperature, mask=mask).unsqueeze(0)
|
||||
|
||||
# Perform sinkhorn scaling to obtain doubly stochastic permutation matrices
|
||||
P_hat = sinkhorn_scaling(P_hat.view(P_hat.shape[0] * P_hat.shape[1], P_hat.shape[2], P_hat.shape[3]),
|
||||
mask.repeat_interleave(P_hat.shape[0], dim=0), tol=1e-6, max_iter=50)
|
||||
P_hat = P_hat.view(int(P_hat.shape[0] / y_pred.shape[0]), y_pred.shape[0], P_hat.shape[1], P_hat.shape[2])
|
||||
|
||||
# Mask P_hat and apply to true labels, ie approximately sort them
|
||||
P_hat = P_hat.masked_fill(mask[None, :, :, None] | mask[None, :, None, :], 0.)
|
||||
y_true_masked = y_true.masked_fill(mask, 0.).unsqueeze(-1).unsqueeze(0)
|
||||
if powered_relevancies:
|
||||
y_true_masked = torch.pow(2., y_true_masked) - 1.
|
||||
|
||||
ground_truth = torch.matmul(P_hat, y_true_masked).squeeze(-1)
|
||||
discounts = (torch.tensor(1.) / torch.log2(torch.arange(y_true.shape[-1], dtype=torch.float) + 2.)).to(dev)
|
||||
discounted_gains = ground_truth * discounts
|
||||
|
||||
if powered_relevancies:
|
||||
idcg = dcg(y_true, y_true, ats=[k]).permute(1, 0)
|
||||
else:
|
||||
idcg = dcg(y_true, y_true, ats=[k], gain_function=lambda x: x).permute(1, 0)
|
||||
|
||||
discounted_gains = discounted_gains[:, :, :k]
|
||||
ndcg = discounted_gains.sum(dim=-1) / (idcg + DEFAULT_EPS)
|
||||
idcg_mask = idcg == 0.
|
||||
ndcg = ndcg.masked_fill(idcg_mask.repeat(ndcg.shape[0], 1), 0.)
|
||||
|
||||
assert (ndcg < 0.).sum() >= 0, "every ndcg should be non-negative"
|
||||
if idcg_mask.all():
|
||||
return torch.tensor(0.)
|
||||
|
||||
mean_ndcg = ndcg.sum() / ((~idcg_mask).sum() * ndcg.shape[0]) # type: ignore
|
||||
return -1. * mean_ndcg # -1 cause we want to maximize NDCG
|
||||
|
||||
|
||||
def neuralNDCG_transposed(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE, temperature=1.,
|
||||
powered_relevancies=True, k=None, stochastic=False, n_samples=32, beta=0.1, log_scores=True,
|
||||
max_iter=50, tol=1e-6):
|
||||
"""
|
||||
NeuralNDCG Transposed loss introduced in "NeuralNDCG: Direct Optimisation of a Ranking Metric via Differentiable
|
||||
Relaxation of Sorting" - https://arxiv.org/abs/2102.07831. Based on the NeuralSort algorithm.
|
||||
:param y_pred: predictions from the model, shape [batch_size, slate_length]
|
||||
:param y_true: ground truth labels, shape [batch_size, slate_length]
|
||||
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
|
||||
:param temperature: temperature for the NeuralSort algorithm
|
||||
:param powered_relevancies: whether to apply 2^x - 1 gain function, x otherwise
|
||||
:param k: rank at which the loss is truncated
|
||||
:param stochastic: whether to calculate the stochastic variant
|
||||
:param n_samples: how many stochastic samples are taken, used if stochastic == True
|
||||
:param beta: beta parameter for NeuralSort algorithm, used if stochastic == True
|
||||
:param log_scores: log_scores parameter for NeuralSort algorithm, used if stochastic == True
|
||||
:param max_iter: maximum iteration count for Sinkhorn scaling
|
||||
:param tol: tolerance for Sinkhorn scaling
|
||||
:return: loss value, a torch.Tensor
|
||||
"""
|
||||
dev = y_pred.device
|
||||
|
||||
if k is None:
|
||||
k = y_true.shape[1]
|
||||
|
||||
mask = (y_true == padded_value_indicator)
|
||||
|
||||
if stochastic:
|
||||
P_hat = stochastic_neural_sort(y_pred.unsqueeze(-1), n_samples=n_samples, tau=temperature, mask=mask,
|
||||
beta=beta, log_scores=log_scores)
|
||||
else:
|
||||
P_hat = deterministic_neural_sort(y_pred.unsqueeze(-1), tau=temperature, mask=mask).unsqueeze(0)
|
||||
|
||||
# Perform sinkhorn scaling to obtain doubly stochastic permutation matrices
|
||||
P_hat_masked = sinkhorn_scaling(P_hat.view(P_hat.shape[0] * y_pred.shape[0], y_pred.shape[1], y_pred.shape[1]),
|
||||
mask.repeat_interleave(P_hat.shape[0], dim=0), tol=tol, max_iter=max_iter)
|
||||
P_hat_masked = P_hat_masked.view(P_hat.shape[0], y_pred.shape[0], y_pred.shape[1], y_pred.shape[1])
|
||||
discounts = (torch.tensor(1) / torch.log2(torch.arange(y_true.shape[-1], dtype=torch.float) + 2.)).to(dev)
|
||||
|
||||
# This takes care of the @k metric truncation - if something is @>k, it is useless and gets 0.0 discount
|
||||
discounts[k:] = 0.
|
||||
discounts = discounts[None, None, :, None]
|
||||
|
||||
# Here the discounts become expected discounts
|
||||
discounts = torch.matmul(P_hat_masked.permute(0, 1, 3, 2), discounts).squeeze(-1)
|
||||
if powered_relevancies:
|
||||
gains = torch.pow(2., y_true) - 1
|
||||
discounted_gains = gains.unsqueeze(0) * discounts
|
||||
idcg = dcg(y_true, y_true, ats=[k]).squeeze()
|
||||
else:
|
||||
gains = y_true
|
||||
discounted_gains = gains.unsqueeze(0) * discounts
|
||||
idcg = dcg(y_true, y_true, ats=[k]).squeeze()
|
||||
|
||||
ndcg = discounted_gains.sum(dim=2) / (idcg + DEFAULT_EPS)
|
||||
idcg_mask = idcg == 0.
|
||||
ndcg = ndcg.masked_fill(idcg_mask, 0.)
|
||||
|
||||
assert (ndcg < 0.).sum() >= 0, "every ndcg should be non-negative"
|
||||
if idcg_mask.all():
|
||||
return torch.tensor(0.)
|
||||
|
||||
mean_ndcg = ndcg.sum() / ((~idcg_mask).sum() * ndcg.shape[0]) # type: ignore
|
||||
return -1. * mean_ndcg # -1 cause we want to maximize NDCG
|
41
utils/modules.py
Normal file
41
utils/modules.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
from collections import Counter, defaultdict
|
||||
import logging
|
||||
from typing import Union, List, Dict, Any
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class Reshaper(nn.Module):
|
||||
def __init__(self, *output_shape):
|
||||
super().__init__()
|
||||
|
||||
self.output_shape = output_shape
|
||||
|
||||
def forward(self, input: torch.Tensor):
|
||||
return input.view(*self.output_shape)
|
||||
|
||||
|
||||
class Normalizer(nn.Module):
|
||||
def __init__(self, target_norm=1.):
|
||||
super().__init__()
|
||||
self.target_norm = target_norm
|
||||
|
||||
def forward(self, input: torch.Tensor):
|
||||
return input * self.target_norm / input.norm(p=2, dim=1, keepdim=True)
|
||||
|
||||
|
||||
class Squeezer(nn.Module):
|
||||
def __init__(self, dim=-1):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, input):
|
||||
return torch.squeeze(input, dim=self.dim)
|
389
utils/optim_utils.py
Normal file
389
utils/optim_utils.py
Normal file
|
@ -0,0 +1,389 @@
|
|||
import logging
|
||||
import math
|
||||
import numpy as np
|
||||
import random
|
||||
import functools
|
||||
import glog as log
|
||||
|
||||
import torch
|
||||
from torch import nn, optim
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler, ConstantLR
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from pytorch_transformers.optimization import AdamW
|
||||
|
||||
|
||||
class WarmupLinearScheduleNonZero(_LRScheduler):
|
||||
""" Linear warmup and then linear decay.
|
||||
Linearly increases learning rate from 0 to max_lr over `warmup_steps` training steps.
|
||||
Linearly decreases learning rate linearly to min_lr over remaining `t_total - warmup_steps` steps.
|
||||
"""
|
||||
def __init__(self, optimizer, warmup_steps, t_total, min_lr=1e-5, last_epoch=-1):
|
||||
self.warmup_steps = warmup_steps
|
||||
self.t_total = t_total
|
||||
self.min_lr = min_lr
|
||||
super(WarmupLinearScheduleNonZero, self).__init__(optimizer, last_epoch=last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
step = self.last_epoch
|
||||
if step < self.warmup_steps:
|
||||
lr_factor = float(step) / float(max(1, self.warmup_steps))
|
||||
else:
|
||||
lr_factor = max(0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
|
||||
|
||||
return [base_lr * lr_factor if (base_lr * lr_factor) > self.min_lr else self.min_lr for base_lr in self.base_lrs]
|
||||
|
||||
|
||||
def init_optim(model, config):
|
||||
optimizer_grouped_parameters = []
|
||||
|
||||
gnn_params = []
|
||||
|
||||
encoder_params_with_decay = []
|
||||
encoder_params_without_decay = []
|
||||
|
||||
exclude_from_weight_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||
|
||||
for module_name, module in model.named_children():
|
||||
for param_name, param in module.named_parameters():
|
||||
if param.requires_grad:
|
||||
if "gnn" in param_name:
|
||||
gnn_params.append(param)
|
||||
elif module_name == 'encoder':
|
||||
if any(ex in param_name for ex in exclude_from_weight_decay):
|
||||
encoder_params_without_decay.append(param)
|
||||
else:
|
||||
encoder_params_with_decay.append(param)
|
||||
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
'params': gnn_params,
|
||||
'weight_decay': config.gnn_weight_decay,
|
||||
'lr': config['learning_rate_gnn'] if config.use_diff_lr_gnn else config['learning_rate_bert']
|
||||
}
|
||||
]
|
||||
|
||||
optimizer_grouped_parameters.extend(
|
||||
[
|
||||
{
|
||||
'params': encoder_params_without_decay,
|
||||
'weight_decay': 0,
|
||||
'lr': config['learning_rate_bert']
|
||||
},
|
||||
{
|
||||
'params': encoder_params_with_decay,
|
||||
'weight_decay': 0.01,
|
||||
'lr': config['learning_rate_bert']
|
||||
}
|
||||
]
|
||||
)
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=config['learning_rate_gnn'])
|
||||
scheduler = WarmupLinearScheduleNonZero(
|
||||
optimizer,
|
||||
warmup_steps=config['warmup_steps'],
|
||||
t_total=config['train_steps'],
|
||||
min_lr=config['min_lr']
|
||||
)
|
||||
|
||||
return optimizer, scheduler
|
||||
|
||||
|
||||
def build_torch_optimizer(model, config):
|
||||
"""Builds the PyTorch optimizer.
|
||||
|
||||
We use the default parameters for Adam that are suggested by
|
||||
the original paper https://arxiv.org/pdf/1412.6980.pdf
|
||||
These values are also used by other established implementations,
|
||||
e.g. https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer
|
||||
https://keras.io/optimizers/
|
||||
Recently there are slightly different values used in the paper
|
||||
"Attention is all you need"
|
||||
https://arxiv.org/pdf/1706.03762.pdf, particularly the value beta2=0.98
|
||||
was used there however, beta2=0.999 is still arguably the more
|
||||
established value, so we use that here as well
|
||||
|
||||
Args:
|
||||
model: The model to optimize.
|
||||
config: The dictionary of options.
|
||||
|
||||
Returns:
|
||||
A ``torch.optim.Optimizer`` instance.
|
||||
"""
|
||||
params = [p for p in model.parameters() if p.requires_grad]
|
||||
betas = [0.9, 0.999]
|
||||
exclude_from_weight_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||
|
||||
params = {'bert': [], 'task': []}
|
||||
for module_name, module in model.named_children():
|
||||
if module_name == 'encoder':
|
||||
param_type = 'bert'
|
||||
else:
|
||||
param_type = 'task'
|
||||
for param_name, param in module.named_parameters():
|
||||
if param.requires_grad:
|
||||
if any(ex in param_name for ex in exclude_from_weight_decay):
|
||||
params[param_type] += [
|
||||
{
|
||||
"params": [param],
|
||||
"weight_decay": 0
|
||||
}
|
||||
]
|
||||
else:
|
||||
params[param_type] += [
|
||||
{
|
||||
"params": [param],
|
||||
"weight_decay": 0.01
|
||||
}
|
||||
]
|
||||
if config['task_optimizer'] == 'adamw':
|
||||
log.info('Using AdamW as task optimizer')
|
||||
task_optimizer = AdamWeightDecay(params['task'],
|
||||
lr=config["learning_rate_task"],
|
||||
betas=betas,
|
||||
eps=1e-6)
|
||||
elif config['task_optimizer'] == 'adam':
|
||||
log.info('Using Adam as task optimizer')
|
||||
task_optimizer = optim.Adam(params['task'],
|
||||
lr=config["learning_rate_task"],
|
||||
betas=betas,
|
||||
eps=1e-6)
|
||||
if len(params['bert']) > 0:
|
||||
bert_optimizer = AdamWeightDecay(params['bert'],
|
||||
lr=config["learning_rate_bert"],
|
||||
betas=betas,
|
||||
eps=1e-6)
|
||||
optimizer = MultipleOptimizer([bert_optimizer, task_optimizer])
|
||||
else:
|
||||
optimizer = task_optimizer
|
||||
|
||||
return optimizer
|
||||
|
||||
|
||||
def make_learning_rate_decay_fn(decay_method, train_steps, **kwargs):
|
||||
"""Returns the learning decay function from options."""
|
||||
if decay_method == "linear":
|
||||
return functools.partial(
|
||||
linear_decay,
|
||||
global_steps=train_steps,
|
||||
**kwargs)
|
||||
elif decay_method == "exp":
|
||||
return functools.partial(
|
||||
exp_decay,
|
||||
global_steps=train_steps,
|
||||
**kwargs)
|
||||
else:
|
||||
raise ValueError(f'{decay_method} not found')
|
||||
|
||||
|
||||
def linear_decay(step, global_steps, warmup_steps=100, initial_learning_rate=1, end_learning_rate=0, **kargs):
|
||||
if step < warmup_steps:
|
||||
return initial_learning_rate * step / warmup_steps
|
||||
else:
|
||||
return (initial_learning_rate - end_learning_rate) * \
|
||||
(1 - (step - warmup_steps) / (global_steps - warmup_steps)) + \
|
||||
end_learning_rate
|
||||
|
||||
def exp_decay(step, global_steps, decay_exp=1, warmup_steps=100, initial_learning_rate=1, end_learning_rate=0, **kargs):
|
||||
if step < warmup_steps:
|
||||
return initial_learning_rate * step / warmup_steps
|
||||
else:
|
||||
return (initial_learning_rate - end_learning_rate) * \
|
||||
((1 - (step - warmup_steps) / (global_steps - warmup_steps)) ** decay_exp) + \
|
||||
end_learning_rate
|
||||
|
||||
|
||||
class MultipleOptimizer(object):
|
||||
""" Implement multiple optimizers needed for sparse adam """
|
||||
|
||||
def __init__(self, op):
|
||||
""" ? """
|
||||
self.optimizers = op
|
||||
|
||||
@property
|
||||
def param_groups(self):
|
||||
param_groups = []
|
||||
for optimizer in self.optimizers:
|
||||
param_groups.extend(optimizer.param_groups)
|
||||
return param_groups
|
||||
|
||||
def zero_grad(self):
|
||||
""" ? """
|
||||
for op in self.optimizers:
|
||||
op.zero_grad()
|
||||
|
||||
def step(self):
|
||||
""" ? """
|
||||
for op in self.optimizers:
|
||||
op.step()
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
""" ? """
|
||||
return {k: v for op in self.optimizers for k, v in op.state.items()}
|
||||
|
||||
def state_dict(self):
|
||||
""" ? """
|
||||
return [op.state_dict() for op in self.optimizers]
|
||||
|
||||
def load_state_dict(self, state_dicts):
|
||||
""" ? """
|
||||
assert len(state_dicts) == len(self.optimizers)
|
||||
for i in range(len(state_dicts)):
|
||||
self.optimizers[i].load_state_dict(state_dicts[i])
|
||||
|
||||
|
||||
class OptimizerBase(object):
|
||||
"""
|
||||
Controller class for optimization. Mostly a thin
|
||||
wrapper for `optim`, but also useful for implementing
|
||||
rate scheduling beyond what is currently available.
|
||||
Also implements necessary methods for training RNNs such
|
||||
as grad manipulations.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
learning_rate,
|
||||
learning_rate_decay_fn=None,
|
||||
max_grad_norm=None):
|
||||
"""Initializes the controller.
|
||||
|
||||
Args:
|
||||
optimizer: A ``torch.optim.Optimizer`` instance.
|
||||
learning_rate: The initial learning rate.
|
||||
learning_rate_decay_fn: An optional callable taking the current step
|
||||
as argument and return a learning rate scaling factor.
|
||||
max_grad_norm: Clip gradients to this global norm.
|
||||
"""
|
||||
self._optimizer = optimizer
|
||||
self._learning_rate = learning_rate
|
||||
self._learning_rate_decay_fn = learning_rate_decay_fn
|
||||
self._max_grad_norm = max_grad_norm or 0
|
||||
self._training_step = 1
|
||||
self._decay_step = 1
|
||||
|
||||
@classmethod
|
||||
def from_opt(cls, model, config, checkpoint=None):
|
||||
"""Builds the optimizer from options.
|
||||
|
||||
Args:
|
||||
cls: The ``Optimizer`` class to instantiate.
|
||||
model: The model to optimize.
|
||||
opt: The dict of user options.
|
||||
checkpoint: An optional checkpoint to load states from.
|
||||
|
||||
Returns:
|
||||
An ``Optimizer`` instance.
|
||||
"""
|
||||
optim_opt = config
|
||||
optim_state_dict = None
|
||||
|
||||
if config["loads_ckpt"] and checkpoint is not None:
|
||||
optim = checkpoint['optim']
|
||||
ckpt_opt = checkpoint['opt']
|
||||
ckpt_state_dict = {}
|
||||
if isinstance(optim, Optimizer): # Backward compatibility.
|
||||
ckpt_state_dict['training_step'] = optim._step + 1
|
||||
ckpt_state_dict['decay_step'] = optim._step + 1
|
||||
ckpt_state_dict['optimizer'] = optim.optimizer.state_dict()
|
||||
else:
|
||||
ckpt_state_dict = optim
|
||||
|
||||
if config["reset_optim"] == 'none':
|
||||
# Load everything from the checkpoint.
|
||||
optim_opt = ckpt_opt
|
||||
optim_state_dict = ckpt_state_dict
|
||||
elif config["reset_optim"] == 'all':
|
||||
# Build everything from scratch.
|
||||
pass
|
||||
elif config["reset_optim"] == 'states':
|
||||
# Reset optimizer, keep options.
|
||||
optim_opt = ckpt_opt
|
||||
optim_state_dict = ckpt_state_dict
|
||||
del optim_state_dict['optimizer']
|
||||
elif config["reset_optim"] == 'keep_states':
|
||||
# Reset options, keep optimizer.
|
||||
optim_state_dict = ckpt_state_dict
|
||||
|
||||
learning_rates = [
|
||||
optim_opt["learning_rate_bert"],
|
||||
optim_opt["learning_rate_gnn"]
|
||||
]
|
||||
decay_fn = [
|
||||
make_learning_rate_decay_fn(optim_opt['decay_method_bert'],
|
||||
optim_opt['train_steps'],
|
||||
warmup_steps=optim_opt['warmup_steps'],
|
||||
decay_exp=optim_opt['decay_exp']),
|
||||
make_learning_rate_decay_fn(optim_opt['decay_method_gnn'],
|
||||
optim_opt['train_steps'],
|
||||
warmup_steps=optim_opt['warmup_steps'],
|
||||
decay_exp=optim_opt['decay_exp']),
|
||||
]
|
||||
optimizer = cls(
|
||||
build_torch_optimizer(model, optim_opt),
|
||||
learning_rates,
|
||||
learning_rate_decay_fn=decay_fn,
|
||||
max_grad_norm=optim_opt["max_grad_norm"])
|
||||
if optim_state_dict:
|
||||
optimizer.load_state_dict(optim_state_dict)
|
||||
return optimizer
|
||||
|
||||
@property
|
||||
def training_step(self):
|
||||
"""The current training step."""
|
||||
return self._training_step
|
||||
|
||||
def learning_rate(self):
|
||||
"""Returns the current learning rate."""
|
||||
if self._learning_rate_decay_fn is None:
|
||||
return self._learning_rate
|
||||
return [decay_fn(self._decay_step) * learning_rate \
|
||||
for decay_fn, learning_rate in \
|
||||
zip(self._learning_rate_decay_fn, self._learning_rate)]
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
'training_step': self._training_step,
|
||||
'decay_step': self._decay_step,
|
||||
'optimizer': self._optimizer.state_dict()
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self._training_step = state_dict['training_step']
|
||||
# State can be partially restored.
|
||||
if 'decay_step' in state_dict:
|
||||
self._decay_step = state_dict['decay_step']
|
||||
if 'optimizer' in state_dict:
|
||||
self._optimizer.load_state_dict(state_dict['optimizer'])
|
||||
|
||||
def zero_grad(self):
|
||||
"""Zero the gradients of optimized parameters."""
|
||||
self._optimizer.zero_grad()
|
||||
|
||||
def backward(self, loss):
|
||||
"""Wrapper for backward pass. Some optimizer requires ownership of the
|
||||
backward pass."""
|
||||
loss.backward()
|
||||
|
||||
def step(self):
|
||||
"""Update the model parameters based on current gradients.
|
||||
|
||||
Optionally, will employ gradient modification or update learning
|
||||
rate.
|
||||
"""
|
||||
learning_rate = self.learning_rate()
|
||||
|
||||
if isinstance(self._optimizer, MultipleOptimizer):
|
||||
optimizers = self._optimizer.optimizers
|
||||
else:
|
||||
optimizers = [self._optimizer]
|
||||
for lr, op in zip(learning_rate, optimizers):
|
||||
for group in op.param_groups:
|
||||
group['lr'] = lr
|
||||
if self._max_grad_norm > 0:
|
||||
clip_grad_norm_(group['params'], self._max_grad_norm)
|
||||
self._optimizer.step()
|
||||
self._decay_step += 1
|
||||
self._training_step += 1
|
||||
|
322
utils/visdial_metrics.py
Normal file
322
utils/visdial_metrics.py
Normal file
|
@ -0,0 +1,322 @@
|
|||
"""
|
||||
A Metric observes output of certain model, for example, in form of logits or
|
||||
scores, and accumulates a particular metric with reference to some provided
|
||||
targets. In context of VisDial, we use Recall (@ 1, 5, 10), Mean Rank, Mean
|
||||
Reciprocal Rank (MRR) and Normalized Discounted Cumulative Gain (NDCG).
|
||||
|
||||
Each ``Metric`` must atleast implement three methods:
|
||||
- ``observe``, update accumulated metric with currently observed outputs
|
||||
and targets.
|
||||
- ``retrieve`` to return the accumulated metric., an optionally reset
|
||||
internally accumulated metric (this is commonly done between two epochs
|
||||
after validation).
|
||||
- ``reset`` to explicitly reset the internally accumulated metric.
|
||||
|
||||
Caveat, if you wish to implement your own class of Metric, make sure you call
|
||||
``detach`` on output tensors (like logits), else it will cause memory leaks.
|
||||
"""
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import numpy as np
|
||||
|
||||
def scores_to_ranks(scores: torch.Tensor):
|
||||
"""Convert model output scores into ranks."""
|
||||
batch_size, num_rounds, num_options = scores.size()
|
||||
scores = scores.view(-1, num_options)
|
||||
|
||||
# sort in descending order - largest score gets highest rank
|
||||
sorted_ranks, ranked_idx = scores.sort(1, descending=True)
|
||||
|
||||
# i-th position in ranked_idx specifies which score shall take this
|
||||
# position but we want i-th position to have rank of score at that
|
||||
# position, do this conversion
|
||||
ranks = ranked_idx.clone().fill_(0)
|
||||
for i in range(ranked_idx.size(0)):
|
||||
for j in range(num_options):
|
||||
ranks[i][ranked_idx[i][j]] = j
|
||||
# convert from 0-99 ranks to 1-100 ranks
|
||||
ranks += 1
|
||||
ranks = ranks.view(batch_size, num_rounds, num_options)
|
||||
return ranks
|
||||
|
||||
class SparseGTMetrics(object):
|
||||
"""
|
||||
A class to accumulate all metrics with sparse ground truth annotations.
|
||||
These include Recall (@ 1, 5, 10), Mean Rank and Mean Reciprocal Rank.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._rank_list = []
|
||||
self._rank_list_rnd = []
|
||||
self.num_rounds = None
|
||||
|
||||
def observe(
|
||||
self, predicted_scores: torch.Tensor, target_ranks: torch.Tensor
|
||||
):
|
||||
predicted_scores = predicted_scores.detach()
|
||||
|
||||
# shape: (batch_size, num_rounds, num_options)
|
||||
predicted_ranks = scores_to_ranks(predicted_scores)
|
||||
batch_size, num_rounds, num_options = predicted_ranks.size()
|
||||
self.num_rounds = num_rounds
|
||||
# collapse batch dimension
|
||||
predicted_ranks = predicted_ranks.view(
|
||||
batch_size * num_rounds, num_options
|
||||
)
|
||||
|
||||
# shape: (batch_size * num_rounds, )
|
||||
target_ranks = target_ranks.view(batch_size * num_rounds).long()
|
||||
|
||||
# shape: (batch_size * num_rounds, )
|
||||
predicted_gt_ranks = predicted_ranks[
|
||||
torch.arange(batch_size * num_rounds), target_ranks
|
||||
]
|
||||
self._rank_list.extend(list(predicted_gt_ranks.cpu().numpy()))
|
||||
|
||||
predicted_gt_ranks_rnd = predicted_gt_ranks.view(batch_size, num_rounds)
|
||||
# predicted gt ranks
|
||||
self._rank_list_rnd.append(predicted_gt_ranks_rnd.cpu().numpy())
|
||||
|
||||
def retrieve(self, reset: bool = True):
|
||||
num_examples = len(self._rank_list)
|
||||
if num_examples > 0:
|
||||
# convert to numpy array for easy calculation.
|
||||
__rank_list = torch.tensor(self._rank_list).float()
|
||||
metrics = {
|
||||
"r@1": torch.mean((__rank_list <= 1).float()).item(),
|
||||
"r@5": torch.mean((__rank_list <= 5).float()).item(),
|
||||
"r@10": torch.mean((__rank_list <= 10).float()).item(),
|
||||
"mean": torch.mean(__rank_list).item(),
|
||||
"mrr": torch.mean(__rank_list.reciprocal()).item()
|
||||
}
|
||||
# add round metrics
|
||||
_rank_list_rnd = np.concatenate(self._rank_list_rnd)
|
||||
_rank_list_rnd = _rank_list_rnd.astype(float)
|
||||
r_1_rnd = np.mean(_rank_list_rnd <= 1, axis=0)
|
||||
r_5_rnd = np.mean(_rank_list_rnd <= 5, axis=0)
|
||||
r_10_rnd = np.mean(_rank_list_rnd <= 10, axis=0)
|
||||
mean_rnd = np.mean(_rank_list_rnd, axis=0)
|
||||
mrr_rnd = np.mean(np.reciprocal(_rank_list_rnd), axis=0)
|
||||
|
||||
for rnd in range(1, self.num_rounds + 1):
|
||||
metrics["r_1" + "_round_" + str(rnd)] = r_1_rnd[rnd-1]
|
||||
metrics["r_5" + "_round_" + str(rnd)] = r_5_rnd[rnd-1]
|
||||
metrics["r_10" + "_round_" + str(rnd)] = r_10_rnd[rnd-1]
|
||||
metrics["mean" + "_round_" + str(rnd)] = mean_rnd[rnd-1]
|
||||
metrics["mrr" + "_round_" + str(rnd)] = mrr_rnd[rnd-1]
|
||||
else:
|
||||
metrics = {}
|
||||
|
||||
if reset:
|
||||
self.reset()
|
||||
return metrics
|
||||
|
||||
def reset(self):
|
||||
self._rank_list = []
|
||||
self._rank_list_rnd = []
|
||||
|
||||
class NDCG(object):
|
||||
def __init__(self):
|
||||
self._ndcg_numerator = 0.0
|
||||
self._ndcg_denominator = 0.0
|
||||
|
||||
def observe(
|
||||
self, predicted_scores: torch.Tensor, target_relevance: torch.Tensor
|
||||
):
|
||||
"""
|
||||
Observe model output scores and target ground truth relevance and
|
||||
accumulate NDCG metric.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
predicted_scores: torch.Tensor
|
||||
A tensor of shape (batch_size, num_options), because dense
|
||||
annotations are available for 1 randomly picked round out of 10.
|
||||
target_relevance: torch.Tensor
|
||||
A tensor of shape same as predicted scores, indicating ground truth
|
||||
relevance of each answer option for a particular round.
|
||||
"""
|
||||
predicted_scores = predicted_scores.detach()
|
||||
|
||||
# shape: (batch_size, 1, num_options)
|
||||
predicted_scores = predicted_scores.unsqueeze(1)
|
||||
predicted_ranks = scores_to_ranks(predicted_scores)
|
||||
|
||||
# shape: (batch_size, num_options)
|
||||
predicted_ranks = predicted_ranks.squeeze(1)
|
||||
batch_size, num_options = predicted_ranks.size()
|
||||
|
||||
k = torch.sum(target_relevance != 0, dim=-1)
|
||||
|
||||
# shape: (batch_size, num_options)
|
||||
_, rankings = torch.sort(predicted_ranks, dim=-1)
|
||||
# Sort relevance in descending order so highest relevance gets top rnk.
|
||||
_, best_rankings = torch.sort(
|
||||
target_relevance, dim=-1, descending=True
|
||||
)
|
||||
|
||||
# shape: (batch_size, )
|
||||
batch_ndcg = []
|
||||
for batch_index in range(batch_size):
|
||||
num_relevant = k[batch_index]
|
||||
dcg = self._dcg(
|
||||
rankings[batch_index][:num_relevant],
|
||||
target_relevance[batch_index],
|
||||
)
|
||||
best_dcg = self._dcg(
|
||||
best_rankings[batch_index][:num_relevant],
|
||||
target_relevance[batch_index],
|
||||
)
|
||||
batch_ndcg.append(dcg / best_dcg)
|
||||
|
||||
self._ndcg_denominator += batch_size
|
||||
self._ndcg_numerator += sum(batch_ndcg)
|
||||
|
||||
def _dcg(self, rankings: torch.Tensor, relevance: torch.Tensor):
|
||||
sorted_relevance = relevance[rankings].cpu().float()
|
||||
discounts = torch.log2(torch.arange(len(rankings)).float() + 2)
|
||||
return torch.sum(sorted_relevance / discounts, dim=-1)
|
||||
|
||||
def retrieve(self, reset: bool = True):
|
||||
if self._ndcg_denominator > 0:
|
||||
metrics = {
|
||||
"ndcg": float(self._ndcg_numerator / self._ndcg_denominator)
|
||||
}
|
||||
else:
|
||||
metrics = {}
|
||||
|
||||
if reset:
|
||||
self.reset()
|
||||
return metrics
|
||||
|
||||
def reset(self):
|
||||
self._ndcg_numerator = 0.0
|
||||
self._ndcg_denominator = 0.0
|
||||
|
||||
class SparseGTMetricsParallel(object):
|
||||
"""
|
||||
A class to accumulate all metrics with sparse ground truth annotations.
|
||||
These include Recall (@ 1, 5, 10), Mean Rank and Mean Reciprocal Rank.
|
||||
"""
|
||||
|
||||
def __init__(self, gpu_rank):
|
||||
self.rank_1 = 0
|
||||
self.rank_5 = 0
|
||||
self.rank_10 = 0
|
||||
self.ranks = 0
|
||||
self.reciprocal = 0
|
||||
self.count = 0
|
||||
self.gpu_rank = gpu_rank
|
||||
self.img_ids = []
|
||||
|
||||
def observe(
|
||||
self, img_id: list, predicted_scores: torch.Tensor, target_ranks: torch.Tensor
|
||||
):
|
||||
if img_id in self.img_ids:
|
||||
return
|
||||
else:
|
||||
self.img_ids.append(img_id)
|
||||
|
||||
predicted_scores = predicted_scores.detach()
|
||||
|
||||
# shape: (batch_size, num_rounds, num_options)
|
||||
predicted_ranks = scores_to_ranks(predicted_scores)
|
||||
batch_size, num_rounds, num_options = predicted_ranks.size()
|
||||
self.num_rounds = num_rounds
|
||||
# collapse batch dimension
|
||||
predicted_ranks = predicted_ranks.view(
|
||||
batch_size * num_rounds, num_options
|
||||
)
|
||||
|
||||
# shape: (batch_size * num_rounds, )
|
||||
target_ranks = target_ranks.view(batch_size * num_rounds).long()
|
||||
|
||||
# shape: (batch_size * num_rounds, )
|
||||
predicted_gt_ranks = predicted_ranks[
|
||||
torch.arange(batch_size * num_rounds), target_ranks
|
||||
]
|
||||
|
||||
self.rank_1 += (predicted_gt_ranks <= 1).sum().item()
|
||||
self.rank_5 += (predicted_gt_ranks <= 5).sum().item()
|
||||
self.rank_10 += (predicted_gt_ranks <= 10).sum().item()
|
||||
self.ranks += predicted_gt_ranks.sum().item()
|
||||
self.reciprocal += predicted_gt_ranks.float().reciprocal().sum().item()
|
||||
self.count += batch_size * num_rounds
|
||||
|
||||
def retrieve(self):
|
||||
if self.count > 0:
|
||||
# retrieve data from all gpu
|
||||
# define tensor on GPU, count and total is the result at each GPU
|
||||
t = torch.tensor([self.rank_1, self.rank_5, self.rank_10, self.ranks, self.reciprocal, self.count], dtype=torch.float32, device=f'cuda:{self.gpu_rank}')
|
||||
dist.barrier() # synchronizes all processes
|
||||
dist.all_reduce(t, op=torch.distributed.ReduceOp.SUM,) # Reduces the tensor data across all machines in such a way that all get the final result.
|
||||
t = t.tolist()
|
||||
self.rank_1, self.rank_5, self.rank_10, self.ranks, self.reciprocal, self.count = t
|
||||
|
||||
# convert to numpy array for easy calculation.
|
||||
metrics = {
|
||||
"r@1": self.rank_1 / self.count,
|
||||
"r@5": self.rank_5 / self.count,
|
||||
"r@10": self.rank_10 / self.count,
|
||||
"mean": self.ranks / self.count,
|
||||
"mrr": self.reciprocal / self.count,
|
||||
"tot_rnds": self.count,
|
||||
}
|
||||
|
||||
else:
|
||||
metrics = {}
|
||||
|
||||
return metrics
|
||||
|
||||
def get_count(self):
|
||||
return int(self.count)
|
||||
|
||||
class NDCGParallel(NDCG):
|
||||
def __init__(self, gpu_rank):
|
||||
super(NDCGParallel, self).__init__()
|
||||
self.gpu_rank = gpu_rank
|
||||
self.img_ids = []
|
||||
self.count = 0
|
||||
|
||||
def observe(
|
||||
self, img_id: int, predicted_scores: torch.Tensor, target_relevance: torch.Tensor
|
||||
):
|
||||
"""
|
||||
Observe model output scores and target ground truth relevance and
|
||||
accumulate NDCG metric.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
predicted_scores: torch.Tensor
|
||||
A tensor of shape (batch_size, num_options), because dense
|
||||
annotations are available for 1 randomly picked round out of 10.
|
||||
target_relevance: torch.Tensor
|
||||
A tensor of shape same as predicted scores, indicating ground truth
|
||||
relevance of each answer option for a particular round.
|
||||
"""
|
||||
if img_id in self.img_ids:
|
||||
return
|
||||
else:
|
||||
self.img_ids.append(img_id)
|
||||
self.count += 1
|
||||
|
||||
super(NDCGParallel, self).observe(predicted_scores, target_relevance)
|
||||
|
||||
|
||||
def retrieve(self):
|
||||
if self._ndcg_denominator > 0:
|
||||
# define tensor on GPU, count and total is the result at each GPU
|
||||
t = torch.tensor([self._ndcg_numerator, self._ndcg_denominator, self.count], dtype=torch.float32, device=f'cuda:{self.gpu_rank}')
|
||||
dist.barrier() # synchronizes all processes
|
||||
dist.all_reduce(t, op=torch.distributed.ReduceOp.SUM,) # Reduces the tensor data across all machines in such a way that all get the final result.
|
||||
t = t.tolist()
|
||||
self._ndcg_numerator, self._ndcg_denominator, self.count = t
|
||||
metrics = {
|
||||
"ndcg": float(self._ndcg_numerator / self._ndcg_denominator)
|
||||
}
|
||||
else:
|
||||
metrics = {}
|
||||
return metrics
|
||||
|
||||
def get_count(self):
|
||||
return int(self.count)
|
Loading…
Add table
Add a link
Reference in a new issue