from collections import deque import torch import torch.nn as nn import torch.optim as optim from torchmetrics import Accuracy import numpy as np BATCH_SIZE = 32 TRAIN_HIST_SIZE = 10000 P_DROPOUT = 0.5 class MetaNN(nn.Module): def __init__(self, device, input_shape=(4,84,84), n_goals=4, hidden_nodes=512): super(MetaNN, self).__init__() self.device = device ### Setup model architecture ### self.conv = nn.Sequential( nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4), nn.ReLU(), nn.Dropout(p=P_DROPOUT), nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(), nn.Dropout(p=P_DROPOUT), nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(), nn.Dropout(p=P_DROPOUT), ) conv_out_size = self._get_conv_out(input_shape) self.fc = nn.Sequential( nn.Linear(conv_out_size, hidden_nodes), #TODO: copy initialization from meta_net_il.py nn.ReLU(), nn.Dropout(p=P_DROPOUT), nn.Linear(hidden_nodes, n_goals), nn.Softmax(dim=1) ) def _get_conv_out(self, shape): o = self.conv(torch.zeros(1, *shape)) return int(np.prod(o.size())) def forward(self, x): x = x.to(self.device) conv_out = self.conv(x).view(x.size()[0], -1) return self.fc(conv_out) class MetaController(): def __init__(self, device, args, input_shape=(4, 84, 84), n_goals=4, hidden_nodes=512) -> None: self.device = device self.model = MetaNN(device, input_shape, n_goals, hidden_nodes).to(self.device) print('Saving init state of MetaNN') torch.save(self.model.state_dict(), f"runs/meta_init-{args.model_name}.pt") self.optimizer = optim.RMSprop(self.model.parameters(), lr=0.00025, alpha=0.95, eps=1e-08, weight_decay=0.0) self.loss_fn = nn.MSELoss() self.accuracy_fn = Accuracy(num_classes=n_goals, average='macro').to(device) self.replay_hist = deque([None], maxlen=TRAIN_HIST_SIZE) self.ind = 0 self.count = 0 self.meta_steps = 0 self.input_shape = input_shape self.n_goals = n_goals self.verbose = args.verbose self.name = args.model_name def reset(self) -> None: 'Load initial state dictionary from file and reset optimizer' self.model.load_state_dict(torch.load(f"runs/meta_init-{self.name}.pt")) self.optimizer = optim.RMSprop(self.model.parameters(), lr=0.00025, alpha=0.95, eps=1e-08, weight_decay=0.0) def check_training_clock(self) -> bool: 'Only train every meta controller steps, i.e. new samples in replay buffer.' if BATCH_SIZE: return (self.meta_steps % BATCH_SIZE == 0) else: return (self.meta_steps % 20 == 0) def collect(self, processed, expert_a) -> None: 'Collect sample consisting of state (4, 84, 84) and one-hot vector of goal' if processed is not None: self.replay_hist.appendleft(tuple([processed, expert_a])) self.meta_steps += 1 def train(self): # if not reached TRAIN_HIST_SIZE yet, then get the number of samples num_samples = min(self.meta_steps, TRAIN_HIST_SIZE) inputs = torch.stack([self.replay_hist[i][0] for i in range(num_samples)], 0).to(self.device) labels = torch.stack([torch.tensor(self.replay_hist[i][1], dtype=torch.float32) for i in range(num_samples)], 0).to(self.device) if self.verbose >= 2.0: print(f'\nMetacontroller collected {self.meta_steps} samples') if len(labels) == TRAIN_HIST_SIZE: print(f'Reached TRAIN_HIST_SIZE = {TRAIN_HIST_SIZE}') print('Dataset Distribution:') for goal in labels.unique(): goal = goal.cpu().detach().numpy() print(f'\nNumber of samples for goal {goal}: {sum(labels).squeeze()[goal]}' ) print(f'--> {sum(labels).squeeze()[goal] / len(labels):.2%}') print() if BATCH_SIZE and num_samples >= BATCH_SIZE: # train one epoch --> noisy convergence more likely to find broader minimum accumulated_loss = [] accumulated_acc = [] for index in range(0, len(labels) // BATCH_SIZE): b_inputs, b_labels = inputs[index * BATCH_SIZE: (index + 1) * BATCH_SIZE], labels[index * BATCH_SIZE: (index + 1) * BATCH_SIZE] # zero the parameter gradients self.optimizer.zero_grad() outputs = self.model(b_inputs) loss = self.loss_fn(outputs, b_labels.squeeze(1)) loss.backward() self.optimizer.step() accumulated_loss.append(loss) accumulated_acc.append(self.accuracy_fn(outputs, b_labels.squeeze(1).type(torch.uint8))) loss = torch.stack(accumulated_loss).mean() accuracy = torch.stack(accumulated_acc).mean() else: # run once over all samples --> smooth convergence to a deep local minimum self.optimizer.zero_grad() outputs = self.model(inputs) loss = self.loss_fn(outputs, labels.squeeze(1)) accuracy = self.accuracy_fn(outputs, labels.squeeze(1).type(torch.uint8)) loss.backward() self.optimizer.step() self.count = 0 # reset the count clock return loss, accuracy def predict(self, state, batch_size=1) -> np.ndarray: 'Predict probability distribution of goals with metacontroller model and return as ndarray for sampling' return self.model.forward(state.unsqueeze(0)).squeeze(0).detach().cpu().numpy() def sample(self, prob_vec, temperature=0.1) -> int: prob_pred = np.log(prob_vec) / temperature dist = np.exp(prob_pred)/np.sum(np.exp(prob_pred)) choices = range(len(prob_pred)) return np.random.choice(choices, p=dist)