159 lines
5.2 KiB
Python
159 lines
5.2 KiB
Python
|
from collections import deque
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.optim as optim
|
||
|
from torchmetrics import Accuracy
|
||
|
import numpy as np
|
||
|
|
||
|
|
||
|
BATCH_SIZE = 32
|
||
|
TRAIN_HIST_SIZE = 10000
|
||
|
P_DROPOUT = 0.5
|
||
|
|
||
|
class MetaNN(nn.Module):
|
||
|
|
||
|
def __init__(self, device, input_shape=(4,84,84), n_goals=4, hidden_nodes=512):
|
||
|
super(MetaNN, self).__init__()
|
||
|
|
||
|
self.device = device
|
||
|
|
||
|
### Setup model architecture ###
|
||
|
self.conv = nn.Sequential(
|
||
|
nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
|
||
|
nn.ReLU(),
|
||
|
nn.Dropout(p=P_DROPOUT),
|
||
|
|
||
|
nn.Conv2d(32, 64, kernel_size=4, stride=2),
|
||
|
nn.ReLU(),
|
||
|
nn.Dropout(p=P_DROPOUT),
|
||
|
|
||
|
nn.Conv2d(64, 64, kernel_size=3, stride=1),
|
||
|
nn.ReLU(),
|
||
|
nn.Dropout(p=P_DROPOUT),
|
||
|
)
|
||
|
|
||
|
conv_out_size = self._get_conv_out(input_shape)
|
||
|
self.fc = nn.Sequential(
|
||
|
nn.Linear(conv_out_size, hidden_nodes), #TODO: copy initialization from meta_net_il.py
|
||
|
nn.ReLU(),
|
||
|
nn.Dropout(p=P_DROPOUT),
|
||
|
nn.Linear(hidden_nodes, n_goals),
|
||
|
nn.Softmax(dim=1)
|
||
|
)
|
||
|
|
||
|
def _get_conv_out(self, shape):
|
||
|
o = self.conv(torch.zeros(1, *shape))
|
||
|
return int(np.prod(o.size()))
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = x.to(self.device)
|
||
|
conv_out = self.conv(x).view(x.size()[0], -1)
|
||
|
return self.fc(conv_out)
|
||
|
|
||
|
|
||
|
|
||
|
class MetaController():
|
||
|
def __init__(self, device, args, input_shape=(4, 84, 84), n_goals=4, hidden_nodes=512) -> None:
|
||
|
self.device = device
|
||
|
self.model = MetaNN(device, input_shape, n_goals, hidden_nodes).to(self.device)
|
||
|
print('Saving init state of MetaNN')
|
||
|
torch.save(self.model.state_dict(), f"runs/meta_init-{args.model_name}.pt")
|
||
|
|
||
|
self.optimizer = optim.RMSprop(self.model.parameters(), lr=0.00025, alpha=0.95, eps=1e-08, weight_decay=0.0)
|
||
|
self.loss_fn = nn.MSELoss()
|
||
|
self.accuracy_fn = Accuracy(num_classes=n_goals, average='macro').to(device)
|
||
|
|
||
|
self.replay_hist = deque([None], maxlen=TRAIN_HIST_SIZE)
|
||
|
|
||
|
self.ind = 0
|
||
|
self.count = 0
|
||
|
self.meta_steps = 0
|
||
|
|
||
|
self.input_shape = input_shape
|
||
|
self.n_goals = n_goals
|
||
|
self.verbose = args.verbose
|
||
|
self.name = args.model_name
|
||
|
|
||
|
def reset(self) -> None:
|
||
|
'Load initial state dictionary from file and reset optimizer'
|
||
|
self.model.load_state_dict(torch.load(f"runs/meta_init-{self.name}.pt"))
|
||
|
self.optimizer = optim.RMSprop(self.model.parameters(), lr=0.00025, alpha=0.95, eps=1e-08, weight_decay=0.0)
|
||
|
|
||
|
def check_training_clock(self) -> bool:
|
||
|
'Only train every <BATCH_SIZE> meta controller steps, i.e. <BATCH_SIZE> new samples in replay buffer.'
|
||
|
if BATCH_SIZE:
|
||
|
return (self.meta_steps % BATCH_SIZE == 0)
|
||
|
else:
|
||
|
return (self.meta_steps % 20 == 0)
|
||
|
|
||
|
def collect(self, processed, expert_a) -> None:
|
||
|
'Collect sample consisting of state (4, 84, 84) and one-hot vector of goal'
|
||
|
if processed is not None:
|
||
|
self.replay_hist.appendleft(tuple([processed, expert_a]))
|
||
|
self.meta_steps += 1
|
||
|
|
||
|
def train(self):
|
||
|
# if not reached TRAIN_HIST_SIZE yet, then get the number of samples
|
||
|
num_samples = min(self.meta_steps, TRAIN_HIST_SIZE)
|
||
|
|
||
|
inputs = torch.stack([self.replay_hist[i][0] for i in range(num_samples)], 0).to(self.device)
|
||
|
labels = torch.stack([torch.tensor(self.replay_hist[i][1], dtype=torch.float32) for i in range(num_samples)], 0).to(self.device)
|
||
|
|
||
|
if self.verbose >= 2.0:
|
||
|
print(f'\nMetacontroller collected {self.meta_steps} samples')
|
||
|
if len(labels) == TRAIN_HIST_SIZE:
|
||
|
print(f'Reached TRAIN_HIST_SIZE = {TRAIN_HIST_SIZE}')
|
||
|
print('Dataset Distribution:')
|
||
|
for goal in labels.unique():
|
||
|
goal = goal.cpu().detach().numpy()
|
||
|
print(f'\nNumber of samples for goal {goal}: {sum(labels).squeeze()[goal]}' )
|
||
|
print(f'--> {sum(labels).squeeze()[goal] / len(labels):.2%}')
|
||
|
print()
|
||
|
|
||
|
if BATCH_SIZE and num_samples >= BATCH_SIZE:
|
||
|
# train one epoch --> noisy convergence more likely to find broader minimum
|
||
|
accumulated_loss = []
|
||
|
accumulated_acc = []
|
||
|
|
||
|
for index in range(0, len(labels) // BATCH_SIZE):
|
||
|
b_inputs, b_labels = inputs[index * BATCH_SIZE: (index + 1) * BATCH_SIZE], labels[index * BATCH_SIZE: (index + 1) * BATCH_SIZE]
|
||
|
|
||
|
# zero the parameter gradients
|
||
|
self.optimizer.zero_grad()
|
||
|
|
||
|
outputs = self.model(b_inputs)
|
||
|
loss = self.loss_fn(outputs, b_labels.squeeze(1))
|
||
|
|
||
|
loss.backward()
|
||
|
self.optimizer.step()
|
||
|
|
||
|
accumulated_loss.append(loss)
|
||
|
accumulated_acc.append(self.accuracy_fn(outputs, b_labels.squeeze(1).type(torch.uint8)))
|
||
|
loss = torch.stack(accumulated_loss).mean()
|
||
|
accuracy = torch.stack(accumulated_acc).mean()
|
||
|
else:
|
||
|
# run once over all samples --> smooth convergence to a deep local minimum
|
||
|
self.optimizer.zero_grad()
|
||
|
outputs = self.model(inputs)
|
||
|
loss = self.loss_fn(outputs, labels.squeeze(1))
|
||
|
accuracy = self.accuracy_fn(outputs, labels.squeeze(1).type(torch.uint8))
|
||
|
loss.backward()
|
||
|
self.optimizer.step()
|
||
|
|
||
|
self.count = 0 # reset the count clock
|
||
|
return loss, accuracy
|
||
|
|
||
|
def predict(self, state, batch_size=1) -> np.ndarray:
|
||
|
'Predict probability distribution of goals with metacontroller model and return as ndarray for sampling'
|
||
|
return self.model.forward(state.unsqueeze(0)).squeeze(0).detach().cpu().numpy()
|
||
|
|
||
|
def sample(self, prob_vec, temperature=0.1) -> int:
|
||
|
prob_pred = np.log(prob_vec) / temperature
|
||
|
dist = np.exp(prob_pred)/np.sum(np.exp(prob_pred))
|
||
|
choices = range(len(prob_pred))
|
||
|
return np.random.choice(choices, p=dist)
|
||
|
|
||
|
|
||
|
|