Int-HRL/agent/metacontroller.py

158 lines
5.2 KiB
Python

from collections import deque
import torch
import torch.nn as nn
import torch.optim as optim
from torchmetrics import Accuracy
import numpy as np
BATCH_SIZE = 32
TRAIN_HIST_SIZE = 10000
P_DROPOUT = 0.5
class MetaNN(nn.Module):
def __init__(self, device, input_shape=(4,84,84), n_goals=4, hidden_nodes=512):
super(MetaNN, self).__init__()
self.device = device
### Setup model architecture ###
self.conv = nn.Sequential(
nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Dropout(p=P_DROPOUT),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Dropout(p=P_DROPOUT),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU(),
nn.Dropout(p=P_DROPOUT),
)
conv_out_size = self._get_conv_out(input_shape)
self.fc = nn.Sequential(
nn.Linear(conv_out_size, hidden_nodes), #TODO: copy initialization from meta_net_il.py
nn.ReLU(),
nn.Dropout(p=P_DROPOUT),
nn.Linear(hidden_nodes, n_goals),
nn.Softmax(dim=1)
)
def _get_conv_out(self, shape):
o = self.conv(torch.zeros(1, *shape))
return int(np.prod(o.size()))
def forward(self, x):
x = x.to(self.device)
conv_out = self.conv(x).view(x.size()[0], -1)
return self.fc(conv_out)
class MetaController():
def __init__(self, device, args, input_shape=(4, 84, 84), n_goals=4, hidden_nodes=512) -> None:
self.device = device
self.model = MetaNN(device, input_shape, n_goals, hidden_nodes).to(self.device)
print('Saving init state of MetaNN')
torch.save(self.model.state_dict(), f"runs/meta_init-{args.model_name}.pt")
self.optimizer = optim.RMSprop(self.model.parameters(), lr=0.00025, alpha=0.95, eps=1e-08, weight_decay=0.0)
self.loss_fn = nn.MSELoss()
self.accuracy_fn = Accuracy(num_classes=n_goals, average='macro').to(device)
self.replay_hist = deque([None], maxlen=TRAIN_HIST_SIZE)
self.ind = 0
self.count = 0
self.meta_steps = 0
self.input_shape = input_shape
self.n_goals = n_goals
self.verbose = args.verbose
self.name = args.model_name
def reset(self) -> None:
'Load initial state dictionary from file and reset optimizer'
self.model.load_state_dict(torch.load(f"runs/meta_init-{self.name}.pt"))
self.optimizer = optim.RMSprop(self.model.parameters(), lr=0.00025, alpha=0.95, eps=1e-08, weight_decay=0.0)
def check_training_clock(self) -> bool:
'Only train every <BATCH_SIZE> meta controller steps, i.e. <BATCH_SIZE> new samples in replay buffer.'
if BATCH_SIZE:
return (self.meta_steps % BATCH_SIZE == 0)
else:
return (self.meta_steps % 20 == 0)
def collect(self, processed, expert_a) -> None:
'Collect sample consisting of state (4, 84, 84) and one-hot vector of goal'
if processed is not None:
self.replay_hist.appendleft(tuple([processed, expert_a]))
self.meta_steps += 1
def train(self):
# if not reached TRAIN_HIST_SIZE yet, then get the number of samples
num_samples = min(self.meta_steps, TRAIN_HIST_SIZE)
inputs = torch.stack([self.replay_hist[i][0] for i in range(num_samples)], 0).to(self.device)
labels = torch.stack([torch.tensor(self.replay_hist[i][1], dtype=torch.float32) for i in range(num_samples)], 0).to(self.device)
if self.verbose >= 2.0:
print(f'\nMetacontroller collected {self.meta_steps} samples')
if len(labels) == TRAIN_HIST_SIZE:
print(f'Reached TRAIN_HIST_SIZE = {TRAIN_HIST_SIZE}')
print('Dataset Distribution:')
for goal in labels.unique():
goal = goal.cpu().detach().numpy()
print(f'\nNumber of samples for goal {goal}: {sum(labels).squeeze()[goal]}' )
print(f'--> {sum(labels).squeeze()[goal] / len(labels):.2%}')
print()
if BATCH_SIZE and num_samples >= BATCH_SIZE:
# train one epoch --> noisy convergence more likely to find broader minimum
accumulated_loss = []
accumulated_acc = []
for index in range(0, len(labels) // BATCH_SIZE):
b_inputs, b_labels = inputs[index * BATCH_SIZE: (index + 1) * BATCH_SIZE], labels[index * BATCH_SIZE: (index + 1) * BATCH_SIZE]
# zero the parameter gradients
self.optimizer.zero_grad()
outputs = self.model(b_inputs)
loss = self.loss_fn(outputs, b_labels.squeeze(1))
loss.backward()
self.optimizer.step()
accumulated_loss.append(loss)
accumulated_acc.append(self.accuracy_fn(outputs, b_labels.squeeze(1).type(torch.uint8)))
loss = torch.stack(accumulated_loss).mean()
accuracy = torch.stack(accumulated_acc).mean()
else:
# run once over all samples --> smooth convergence to a deep local minimum
self.optimizer.zero_grad()
outputs = self.model(inputs)
loss = self.loss_fn(outputs, labels.squeeze(1))
accuracy = self.accuracy_fn(outputs, labels.squeeze(1).type(torch.uint8))
loss.backward()
self.optimizer.step()
self.count = 0 # reset the count clock
return loss, accuracy
def predict(self, state, batch_size=1) -> np.ndarray:
'Predict probability distribution of goals with metacontroller model and return as ndarray for sampling'
return self.model.forward(state.unsqueeze(0)).squeeze(0).detach().cpu().numpy()
def sample(self, prob_vec, temperature=0.1) -> int:
prob_pred = np.log(prob_vec) / temperature
dist = np.exp(prob_pred)/np.sum(np.exp(prob_pred))
choices = range(len(prob_pred))
return np.random.choice(choices, p=dist)