323 lines
11 KiB
Python
323 lines
11 KiB
Python
"""
|
|
PyTorch DNC implementation from
|
|
-->
|
|
https://github.com/ixaxaar/pytorch-dnc
|
|
<--
|
|
"""
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
import torch.nn as nn
|
|
import torch as T
|
|
from torch.autograd import Variable as var
|
|
import numpy as np
|
|
|
|
from torch.nn.utils.rnn import pad_packed_sequence as pad
|
|
from torch.nn.utils.rnn import pack_padded_sequence as pack
|
|
from torch.nn.utils.rnn import PackedSequence
|
|
|
|
from .util import *
|
|
from .memory import *
|
|
|
|
from torch.nn.init import orthogonal_, xavier_uniform_
|
|
|
|
|
|
class DNC(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
input_size,
|
|
hidden_size,
|
|
rnn_type='lstm',
|
|
num_layers=1,
|
|
num_hidden_layers=2,
|
|
bias=True,
|
|
batch_first=True,
|
|
dropout=0,
|
|
bidirectional=False,
|
|
nr_cells=5,
|
|
read_heads=2,
|
|
cell_size=10,
|
|
nonlinearity='tanh',
|
|
gpu_id=-1,
|
|
independent_linears=False,
|
|
share_memory=True,
|
|
debug=False,
|
|
clip=20
|
|
):
|
|
super(DNC, self).__init__()
|
|
# todo: separate weights and RNNs for the interface and output vectors
|
|
|
|
self.input_size = input_size
|
|
self.hidden_size = hidden_size
|
|
self.rnn_type = rnn_type
|
|
self.num_layers = num_layers
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.bias = bias
|
|
self.batch_first = batch_first
|
|
self.dropout = dropout
|
|
self.bidirectional = bidirectional
|
|
self.nr_cells = nr_cells
|
|
self.read_heads = read_heads
|
|
self.cell_size = cell_size
|
|
self.nonlinearity = nonlinearity
|
|
self.gpu_id = gpu_id
|
|
self.independent_linears = independent_linears
|
|
self.share_memory = share_memory
|
|
self.debug = debug
|
|
self.clip = clip
|
|
|
|
self.w = self.cell_size
|
|
self.r = self.read_heads
|
|
|
|
self.read_vectors_size = self.r * self.w
|
|
self.output_size = self.hidden_size
|
|
|
|
self.nn_input_size = self.input_size + self.read_vectors_size
|
|
self.nn_output_size = self.output_size + self.read_vectors_size
|
|
|
|
self.rnns = []
|
|
self.memories = []
|
|
|
|
for layer in range(self.num_layers):
|
|
if self.rnn_type.lower() == 'rnn':
|
|
self.rnns.append(nn.RNN((self.nn_input_size if layer == 0 else self.nn_output_size), self.output_size,
|
|
bias=self.bias, nonlinearity=self.nonlinearity, batch_first=True, dropout=self.dropout, num_layers=self.num_hidden_layers))
|
|
elif self.rnn_type.lower() == 'gru':
|
|
self.rnns.append(nn.GRU((self.nn_input_size if layer == 0 else self.nn_output_size),
|
|
self.output_size, bias=self.bias, batch_first=True, dropout=self.dropout, num_layers=self.num_hidden_layers))
|
|
if self.rnn_type.lower() == 'lstm':
|
|
self.rnns.append(nn.LSTM((self.nn_input_size if layer == 0 else self.nn_output_size),
|
|
self.output_size, bias=self.bias, batch_first=True, dropout=self.dropout, num_layers=self.num_hidden_layers))
|
|
setattr(self, self.rnn_type.lower() + '_layer_' + str(layer), self.rnns[layer])
|
|
|
|
# memories for each layer
|
|
if not self.share_memory:
|
|
self.memories.append(
|
|
Memory(
|
|
input_size=self.output_size,
|
|
mem_size=self.nr_cells,
|
|
cell_size=self.w,
|
|
read_heads=self.r,
|
|
gpu_id=self.gpu_id,
|
|
independent_linears=self.independent_linears
|
|
)
|
|
)
|
|
setattr(self, 'rnn_layer_memory_' + str(layer), self.memories[layer])
|
|
|
|
# only one memory shared by all layers
|
|
if self.share_memory:
|
|
self.memories.append(
|
|
Memory(
|
|
input_size=self.output_size,
|
|
mem_size=self.nr_cells,
|
|
cell_size=self.w,
|
|
read_heads=self.r,
|
|
gpu_id=self.gpu_id,
|
|
independent_linears=self.independent_linears
|
|
)
|
|
)
|
|
setattr(self, 'rnn_layer_memory_shared', self.memories[0])
|
|
|
|
# final output layer
|
|
self.output = nn.Linear(self.nn_output_size, self.output_size)
|
|
orthogonal_(self.output.weight)
|
|
|
|
if self.gpu_id != -1:
|
|
[x.cuda(self.gpu_id) for x in self.rnns]
|
|
[x.cuda(self.gpu_id) for x in self.memories]
|
|
self.output.cuda()
|
|
|
|
def _init_hidden(self, hx, batch_size, reset_experience):
|
|
# create empty hidden states if not provided
|
|
if hx is None:
|
|
hx = (None, None, None)
|
|
(chx, mhx, last_read) = hx
|
|
|
|
# initialize hidden state of the controller RNN
|
|
if chx is None:
|
|
h = cuda(T.zeros(self.num_hidden_layers, batch_size, self.output_size), gpu_id=self.gpu_id)
|
|
xavier_uniform_(h)
|
|
|
|
chx = [ (h, h) if self.rnn_type.lower() == 'lstm' else h for x in range(self.num_layers)]
|
|
|
|
# Last read vectors
|
|
if last_read is None:
|
|
last_read = cuda(T.zeros(batch_size, self.w * self.r), gpu_id=self.gpu_id)
|
|
|
|
# memory states
|
|
if mhx is None:
|
|
if self.share_memory:
|
|
mhx = self.memories[0].reset(batch_size, erase=reset_experience)
|
|
else:
|
|
mhx = [m.reset(batch_size, erase=reset_experience) for m in self.memories]
|
|
else:
|
|
if self.share_memory:
|
|
mhx = self.memories[0].reset(batch_size, mhx, erase=reset_experience)
|
|
else:
|
|
mhx = [m.reset(batch_size, h, erase=reset_experience) for m, h in zip(self.memories, mhx)]
|
|
|
|
return chx, mhx, last_read
|
|
|
|
def _debug(self, mhx, debug_obj):
|
|
if not debug_obj:
|
|
debug_obj = {
|
|
'memory': [],
|
|
'link_matrix': [],
|
|
'precedence': [],
|
|
'read_weights': [],
|
|
'write_weights': [],
|
|
'usage_vector': [],
|
|
}
|
|
|
|
debug_obj['memory'].append(mhx['memory'][0].data.cpu().numpy())
|
|
debug_obj['link_matrix'].append(mhx['link_matrix'][0][0].data.cpu().numpy())
|
|
debug_obj['precedence'].append(mhx['precedence'][0].data.cpu().numpy())
|
|
debug_obj['read_weights'].append(mhx['read_weights'][0].data.cpu().numpy())
|
|
debug_obj['write_weights'].append(mhx['write_weights'][0].data.cpu().numpy())
|
|
debug_obj['usage_vector'].append(mhx['usage_vector'][0].unsqueeze(0).data.cpu().numpy())
|
|
return debug_obj
|
|
|
|
def _layer_forward(self, input, layer, hx=(None, None), pass_through_memory=True):
|
|
(chx, mhx) = hx
|
|
|
|
# pass through the controller layer
|
|
input, chx = self.rnns[layer](input.unsqueeze(1), chx)
|
|
input = input.squeeze(1)
|
|
|
|
# clip the controller output
|
|
if self.clip != 0:
|
|
output = T.clamp(input, -self.clip, self.clip)
|
|
else:
|
|
output = input
|
|
|
|
# the interface vector
|
|
ξ = output
|
|
|
|
# pass through memory
|
|
if pass_through_memory:
|
|
if self.share_memory:
|
|
read_vecs, mhx = self.memories[0](ξ, mhx)
|
|
else:
|
|
read_vecs, mhx = self.memories[layer](ξ, mhx)
|
|
# the read vectors
|
|
read_vectors = read_vecs.view(-1, self.w * self.r)
|
|
else:
|
|
read_vectors = None
|
|
|
|
return output, (chx, mhx, read_vectors)
|
|
|
|
def forward(self, input, hx=(None, None, None), reset_experience=False, pass_through_memory=True):
|
|
# handle packed data
|
|
is_packed = type(input) is PackedSequence
|
|
if is_packed:
|
|
input, lengths = pad(input)
|
|
max_length = lengths[0]
|
|
else:
|
|
max_length = input.size(1) if self.batch_first else input.size(0)
|
|
lengths = [input.size(1)] * max_length if self.batch_first else [input.size(0)] * max_length
|
|
|
|
batch_size = input.size(0) if self.batch_first else input.size(1)
|
|
|
|
if not self.batch_first:
|
|
input = input.transpose(0, 1)
|
|
# make the data time-first
|
|
|
|
controller_hidden, mem_hidden, last_read = self._init_hidden(hx, batch_size, reset_experience)
|
|
|
|
# concat input with last read (or padding) vectors
|
|
inputs = [T.cat([input[:, x, :], last_read], 1) for x in range(max_length)]
|
|
|
|
# batched forward pass per element / word / etc
|
|
if self.debug:
|
|
viz = None
|
|
|
|
outs = [None] * max_length
|
|
read_vectors = None
|
|
rv = [None] * max_length
|
|
# pass through time
|
|
for time in range(max_length):
|
|
# pass thorugh layers
|
|
for layer in range(self.num_layers):
|
|
# this layer's hidden states
|
|
chx = controller_hidden[layer]
|
|
m = mem_hidden if self.share_memory else mem_hidden[layer]
|
|
# pass through controller
|
|
outs[time], (chx, m, read_vectors) = \
|
|
self._layer_forward(inputs[time], layer, (chx, m), pass_through_memory)
|
|
|
|
# debug memory
|
|
if self.debug:
|
|
viz = self._debug(m, viz)
|
|
|
|
# store the memory back (per layer or shared)
|
|
if self.share_memory:
|
|
mem_hidden = m
|
|
else:
|
|
mem_hidden[layer] = m
|
|
controller_hidden[layer] = chx
|
|
|
|
if read_vectors is not None:
|
|
# the controller output + read vectors go into next layer
|
|
outs[time] = T.cat([outs[time], read_vectors], 1)
|
|
if layer == self.num_layers - 1:
|
|
rv[time] = read_vectors.reshape(batch_size, self.r, self.w)
|
|
else:
|
|
outs[time] = T.cat([outs[time], last_read], 1)
|
|
inputs[time] = outs[time]
|
|
|
|
if self.debug:
|
|
viz = {k: np.array(v) for k, v in viz.items()}
|
|
viz = {k: v.reshape(v.shape[0], v.shape[1] * v.shape[2]) for k, v in viz.items()}
|
|
|
|
# pass through final output layer
|
|
inputs = [self.output(i) for i in inputs]
|
|
outputs = T.stack(inputs, 1 if self.batch_first else 0)
|
|
|
|
if is_packed:
|
|
outputs = pack(output, lengths)
|
|
|
|
if self.debug:
|
|
return outputs, (controller_hidden, mem_hidden, read_vectors), rv, viz
|
|
else:
|
|
return outputs, (controller_hidden, mem_hidden, read_vectors), rv
|
|
|
|
def __repr__(self):
|
|
s = "\n----------------------------------------\n"
|
|
s += '{name}({input_size}, {hidden_size}'
|
|
if self.rnn_type != 'lstm':
|
|
s += ', rnn_type={rnn_type}'
|
|
if self.num_layers != 1:
|
|
s += ', num_layers={num_layers}'
|
|
if self.num_hidden_layers != 2:
|
|
s += ', num_hidden_layers={num_hidden_layers}'
|
|
if self.bias != True:
|
|
s += ', bias={bias}'
|
|
if self.batch_first != True:
|
|
s += ', batch_first={batch_first}'
|
|
if self.dropout != 0:
|
|
s += ', dropout={dropout}'
|
|
if self.bidirectional != False:
|
|
s += ', bidirectional={bidirectional}'
|
|
if self.nr_cells != 5:
|
|
s += ', nr_cells={nr_cells}'
|
|
if self.read_heads != 2:
|
|
s += ', read_heads={read_heads}'
|
|
if self.cell_size != 10:
|
|
s += ', cell_size={cell_size}'
|
|
if self.nonlinearity != 'tanh':
|
|
s += ', nonlinearity={nonlinearity}'
|
|
if self.gpu_id != -1:
|
|
s += ', gpu_id={gpu_id}'
|
|
if self.independent_linears != False:
|
|
s += ', independent_linears={independent_linears}'
|
|
if self.share_memory != True:
|
|
s += ', share_memory={share_memory}'
|
|
if self.debug != False:
|
|
s += ', debug={debug}'
|
|
if self.clip != 20:
|
|
s += ', clip={clip}'
|
|
|
|
s += ")\n" + super(DNC, self).__repr__() + \
|
|
"\n----------------------------------------\n"
|
|
return s.format(name=self.__class__.__name__, **self.__dict__)
|