add Int-HRL agent scripts
This commit is contained in:
parent
d410acde26
commit
e6f9ace2e3
7 changed files with 1803 additions and 1 deletions
158
agent/metacontroller.py
Normal file
158
agent/metacontroller.py
Normal file
|
@ -0,0 +1,158 @@
|
|||
from collections import deque
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torchmetrics import Accuracy
|
||||
import numpy as np
|
||||
|
||||
|
||||
BATCH_SIZE = 32
|
||||
TRAIN_HIST_SIZE = 10000
|
||||
P_DROPOUT = 0.5
|
||||
|
||||
class MetaNN(nn.Module):
|
||||
|
||||
def __init__(self, device, input_shape=(4,84,84), n_goals=4, hidden_nodes=512):
|
||||
super(MetaNN, self).__init__()
|
||||
|
||||
self.device = device
|
||||
|
||||
### Setup model architecture ###
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(p=P_DROPOUT),
|
||||
|
||||
nn.Conv2d(32, 64, kernel_size=4, stride=2),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(p=P_DROPOUT),
|
||||
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(p=P_DROPOUT),
|
||||
)
|
||||
|
||||
conv_out_size = self._get_conv_out(input_shape)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(conv_out_size, hidden_nodes), #TODO: copy initialization from meta_net_il.py
|
||||
nn.ReLU(),
|
||||
nn.Dropout(p=P_DROPOUT),
|
||||
nn.Linear(hidden_nodes, n_goals),
|
||||
nn.Softmax(dim=1)
|
||||
)
|
||||
|
||||
def _get_conv_out(self, shape):
|
||||
o = self.conv(torch.zeros(1, *shape))
|
||||
return int(np.prod(o.size()))
|
||||
|
||||
def forward(self, x):
|
||||
x = x.to(self.device)
|
||||
conv_out = self.conv(x).view(x.size()[0], -1)
|
||||
return self.fc(conv_out)
|
||||
|
||||
|
||||
|
||||
class MetaController():
|
||||
def __init__(self, device, args, input_shape=(4, 84, 84), n_goals=4, hidden_nodes=512) -> None:
|
||||
self.device = device
|
||||
self.model = MetaNN(device, input_shape, n_goals, hidden_nodes).to(self.device)
|
||||
print('Saving init state of MetaNN')
|
||||
torch.save(self.model.state_dict(), f"runs/meta_init-{args.model_name}.pt")
|
||||
|
||||
self.optimizer = optim.RMSprop(self.model.parameters(), lr=0.00025, alpha=0.95, eps=1e-08, weight_decay=0.0)
|
||||
self.loss_fn = nn.MSELoss()
|
||||
self.accuracy_fn = Accuracy(num_classes=n_goals, average='macro').to(device)
|
||||
|
||||
self.replay_hist = deque([None], maxlen=TRAIN_HIST_SIZE)
|
||||
|
||||
self.ind = 0
|
||||
self.count = 0
|
||||
self.meta_steps = 0
|
||||
|
||||
self.input_shape = input_shape
|
||||
self.n_goals = n_goals
|
||||
self.verbose = args.verbose
|
||||
self.name = args.model_name
|
||||
|
||||
def reset(self) -> None:
|
||||
'Load initial state dictionary from file and reset optimizer'
|
||||
self.model.load_state_dict(torch.load(f"runs/meta_init-{self.name}.pt"))
|
||||
self.optimizer = optim.RMSprop(self.model.parameters(), lr=0.00025, alpha=0.95, eps=1e-08, weight_decay=0.0)
|
||||
|
||||
def check_training_clock(self) -> bool:
|
||||
'Only train every <BATCH_SIZE> meta controller steps, i.e. <BATCH_SIZE> new samples in replay buffer.'
|
||||
if BATCH_SIZE:
|
||||
return (self.meta_steps % BATCH_SIZE == 0)
|
||||
else:
|
||||
return (self.meta_steps % 20 == 0)
|
||||
|
||||
def collect(self, processed, expert_a) -> None:
|
||||
'Collect sample consisting of state (4, 84, 84) and one-hot vector of goal'
|
||||
if processed is not None:
|
||||
self.replay_hist.appendleft(tuple([processed, expert_a]))
|
||||
self.meta_steps += 1
|
||||
|
||||
def train(self):
|
||||
# if not reached TRAIN_HIST_SIZE yet, then get the number of samples
|
||||
num_samples = min(self.meta_steps, TRAIN_HIST_SIZE)
|
||||
|
||||
inputs = torch.stack([self.replay_hist[i][0] for i in range(num_samples)], 0).to(self.device)
|
||||
labels = torch.stack([torch.tensor(self.replay_hist[i][1], dtype=torch.float32) for i in range(num_samples)], 0).to(self.device)
|
||||
|
||||
if self.verbose >= 2.0:
|
||||
print(f'\nMetacontroller collected {self.meta_steps} samples')
|
||||
if len(labels) == TRAIN_HIST_SIZE:
|
||||
print(f'Reached TRAIN_HIST_SIZE = {TRAIN_HIST_SIZE}')
|
||||
print('Dataset Distribution:')
|
||||
for goal in labels.unique():
|
||||
goal = goal.cpu().detach().numpy()
|
||||
print(f'\nNumber of samples for goal {goal}: {sum(labels).squeeze()[goal]}' )
|
||||
print(f'--> {sum(labels).squeeze()[goal] / len(labels):.2%}')
|
||||
print()
|
||||
|
||||
if BATCH_SIZE and num_samples >= BATCH_SIZE:
|
||||
# train one epoch --> noisy convergence more likely to find broader minimum
|
||||
accumulated_loss = []
|
||||
accumulated_acc = []
|
||||
|
||||
for index in range(0, len(labels) // BATCH_SIZE):
|
||||
b_inputs, b_labels = inputs[index * BATCH_SIZE: (index + 1) * BATCH_SIZE], labels[index * BATCH_SIZE: (index + 1) * BATCH_SIZE]
|
||||
|
||||
# zero the parameter gradients
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
outputs = self.model(b_inputs)
|
||||
loss = self.loss_fn(outputs, b_labels.squeeze(1))
|
||||
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
accumulated_loss.append(loss)
|
||||
accumulated_acc.append(self.accuracy_fn(outputs, b_labels.squeeze(1).type(torch.uint8)))
|
||||
loss = torch.stack(accumulated_loss).mean()
|
||||
accuracy = torch.stack(accumulated_acc).mean()
|
||||
else:
|
||||
# run once over all samples --> smooth convergence to a deep local minimum
|
||||
self.optimizer.zero_grad()
|
||||
outputs = self.model(inputs)
|
||||
loss = self.loss_fn(outputs, labels.squeeze(1))
|
||||
accuracy = self.accuracy_fn(outputs, labels.squeeze(1).type(torch.uint8))
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
self.count = 0 # reset the count clock
|
||||
return loss, accuracy
|
||||
|
||||
def predict(self, state, batch_size=1) -> np.ndarray:
|
||||
'Predict probability distribution of goals with metacontroller model and return as ndarray for sampling'
|
||||
return self.model.forward(state.unsqueeze(0)).squeeze(0).detach().cpu().numpy()
|
||||
|
||||
def sample(self, prob_vec, temperature=0.1) -> int:
|
||||
prob_pred = np.log(prob_vec) / temperature
|
||||
dist = np.exp(prob_pred)/np.sum(np.exp(prob_pred))
|
||||
choices = range(len(prob_pred))
|
||||
return np.random.choice(choices, p=dist)
|
||||
|
||||
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue