Int-HRL/agent/ddqn_agent.py

308 lines
11 KiB
Python

from collections import namedtuple
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchinfo import summary
import numpy as np
from tqdm import tqdm
from replay_buffer import PrioritizedReplayBuffer, LinearSchedule, TanHSchedule
prioritized_replay_alpha = 0.6
max_timesteps=1000000
prioritized_replay_beta0=0.4
prioritized_replay_eps=1e-6
prioritized_replay_beta_iters = max_timesteps*0.5
beta_schedule = LinearSchedule(prioritized_replay_beta_iters,
initial_p=prioritized_replay_beta0,
final_p=1.0)
Experience = namedtuple('Experience', field_names=['state', 'action', 'reward', 'next_state', 'done'])
# Deep Q Network
class DQN(nn.Module):
def __init__(self, device, input_shape, n_actions, hidden_nodes=512, goal_feature=False):
super(DQN, self).__init__()
self.device = device
self.conv = nn.Sequential(
nn.Conv2d(input_shape[0], 16, kernel_size=5, stride=2),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=5, stride=2),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 32, kernel_size=5, stride=2),
nn.BatchNorm2d(32),
nn.ReLU()
)
conv_out_size = self._get_conv_out(input_shape)
if goal_feature:
conv_out_size += 4
# Standard DQN
self.fc = nn.Sequential(
nn.Linear(conv_out_size, hidden_nodes),
nn.LeakyReLU(),
nn.Linear(hidden_nodes, n_actions),
# nn.Softmax(dim=1) # necessary for Boltzman Sampling
)
def _get_conv_out(self, shape):
o = self.conv(torch.zeros(1, *shape))
return int(np.prod(o.size()))
def forward(self, x, goal=None):
x = x.to(self.device)
conv_out = self.conv(x).view(x.size()[0], -1)
if not goal is None:
conv_out = torch.cat((conv_out, goal.squeeze(1)), dim=1)
return self.fc(conv_out)
# Dueling Deep Q Network
class DuelingDQN(nn.Module):
def __init__(self, device, input_shape, n_actions, hidden_nodes=512, goal_feature=False):
super(DuelingDQN, self).__init__()
self.device = device
self.conv = nn.Sequential(
nn.Conv2d(input_shape[0], 16, kernel_size=5, stride=2),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=5, stride=2),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 32, kernel_size=5, stride=2),
nn.BatchNorm2d(32),
nn.ReLU()
)
conv_out_size = self._get_conv_out(input_shape)
if goal_feature:
conv_out_size += 4
# Dueling DQN --> two heads
# advantage of actions A(s,a)
self.fc_adv = nn.Sequential(
nn.Linear(conv_out_size, hidden_nodes),
nn.LeakyReLU(),
nn.Linear(hidden_nodes, n_actions)
)
# value of states V(s)
self.fc_val = nn.Sequential(
nn.Linear(conv_out_size, hidden_nodes),
nn.LeakyReLU(),
nn.Linear(hidden_nodes, 1)
)
def _get_conv_out(self, shape):
o = self.conv(torch.zeros(1, *shape))
return int(np.prod(o.size()))
def forward(self, x, goal=None):
x = x.to(self.device)
fx = x.float() / 256
conv_out = self.conv(fx).view(fx.size()[0], -1)
if not goal is None:
conv_out = torch.cat((conv_out, goal.squeeze(1)), dim=1)
val = self.fc_val(conv_out)
adv = self.fc_adv(conv_out)
# Q(s,a) = V(s) + A(s,a) - 1/N sum(A(s,k)) --> expected value of advantage = 0
return val + (adv - adv.mean(dim=1, keepdim=True))
class MaskedHuberLoss(nn.Module):
def __init__(self, delta=1.0):
super(MaskedHuberLoss, self).__init__()
self.criterion = nn.HuberLoss(reduction='none', delta=delta)
def forward(self, inputs, targets, mask):
loss = self.criterion(inputs, targets)
loss *= mask
return loss.sum(dim=-1)
# Agent that performs, remembers and learns actions
class Agent():
def __init__(self, device, goal, input_shape, n_actions, random_play_steps, args, dueling=True, goal_feature=False):
self.device = device
self.n_actions = n_actions
self.goal_feature = goal_feature
if dueling:
self.policy_net = DuelingDQN(device, input_shape, n_actions, args.hidden_nodes, goal_feature).to(self.device)
self.target_net = DuelingDQN(device, input_shape, n_actions, args.hidden_nodes, goal_feature).to(self.device)
else:
self.policy_net = DQN(device, input_shape, n_actions, args.hidden_nodes, goal_feature).to(self.device)
self.target_net = DQN(device, input_shape, n_actions, args.hidden_nodes, goal_feature).to(self.device)
#self.optimizer = optim.RMSprop(self.policy_net.parameters(), lr=args.learning_rate)
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=args.learning_rate)
self.memory = PrioritizedReplayBuffer(args.experience_buffer, alpha=0.6)
self.exploration = LinearSchedule(schedule_timesteps=args.epsilon_steps, initial_p=args.epsilon_start, final_p=args.epsilon_end)
#self.exploration = TanHSchedule(schedule_timesteps=(args.epsilon_steps//10), initial_p=args.epsilon_start, final_p=args.epsilon_end)
self.loss_fn = MaskedHuberLoss(delta=1.0)
self.steps_done = -1
self.batch_size = args.batch_size
self.gamma = args.gamma
self.goal = goal
self.enable_double_dqn = True
self.hard_update = 1000
self.training_done = False
self.verbose = args.verbose
self.random_play_steps = random_play_steps
self.args = args
def remember(self, state, action, reward, next_state, done, goal):
self.memory.add(state, action, reward, next_state, done, goal)
def load_model(self, path):
#summary(self.policy_net, (4, 84, 84))
self.policy_net.load_state_dict(torch.load(path))
self.policy_net.eval()
self.random_play_steps = 100
self.training_done = True
self.exploration = LinearSchedule(schedule_timesteps=1000, initial_p=0.1, final_p=0.005)
print('Loaded DDQN policy network weights from ', path)
def get_epsilon(self):
return self.exploration.value(self.steps_done - self.random_play_steps)
def reset_epsilon_schedule(self):
if isinstance(self.exploration, TanHSchedule):
self.exploration.restart()
else:
print('Exploration schedule is linear and can not be reset!')
def select_action(self, state, goal=None):
#if not self.training_done:
self.steps_done += 1
# Select an action according to an epsilon greedy approach
if self.steps_done < self.random_play_steps or random.random() < self.get_epsilon():
return torch.tensor([[random.randrange(0, self.n_actions)]], device=self.device, dtype=torch.long)
else:
with torch.no_grad():
return self.policy_net(state, goal).max(1)[1].view(1, 1)
# Select an action via Boltzmann Sampling
if self.steps_done > self.random_play_steps:
return self.sample(state, temperature=self.get_epsilon())
else:
return torch.tensor([[random.randrange(0, self.n_actions)]], device=self.device, dtype=torch.long)
"""
else:
with torch.no_grad():
return self.policy_net(state).max(1)[1].view(1, 1)
"""
def sample(self, state, temperature=0.1, goal=None):
'Predict action probability vector and use Boltzmann Sampling to get new action'
with torch.no_grad():
prob_vec = self.policy_net(state, goal)
sample_prob = torch.exp(prob_vec / temperature)
sample_prob /= sample_prob.sum()
try:
sample_idx = sample_prob.multinomial(num_samples=1) # random choice with probabilities
except Exception as e:
# print(temperature, prob_vec, sample_prob)
# RuntimeError: probability tensor contains either `inf`, `nan` or element < 0
import logging
logging.basicConfig(filename=f'runs/{self.args.model_name}sample_prob.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s')
logging.warning(f'[Agent {self.goal} - step {self.steps_done}] {e}: {sample_prob}')
sample_idx = torch.tensor([[random.randrange(0, self.n_actions)]], device=self.device, dtype=torch.long)
return sample_idx.view(1,1)
def finish_training(self, model_name):
torch.save(self.policy_net.state_dict(), f"runs/run{model_name}-agent{self.goal}.pt")
self.training_done = True
# Free memory by deleting target net (no longer needed)
del self.target_net
def optimize_model(self):
if len(self.memory) < self.batch_size or self.steps_done < self.random_play_steps or self.training_done:
return
# Sample from prioritized replay memory
sample = self.memory.sample(self.batch_size, beta=beta_schedule.value(self.steps_done))
states, actions, rewards, next_states, dones, goals, importance_v, idx_vector = sample
states_v = states.to(self.device)
next_states_v = next_states.to(self.device)
actions_v = actions.to(self.device)
rewards_v = rewards.to(self.device)
done_mask = dones.to(self.device)
goals_v = goals.to(self.device)
if not self.goal_feature:
goals_v = None
# predicted Q(s, a)
q_values = self.policy_net(states_v, goals_v)
state_action_values = q_values.gather(1, actions_v.squeeze(-1))
if self.enable_double_dqn:
# calculate max_a Q'(s_t+1, argmax_a Q(s_t+1, a))
next_state_actions = self.policy_net(next_states_v, goals_v).max(1)[1] # [1] = argmax
next_state_values = self.target_net(next_states_v, goals_v).gather(1, next_state_actions.unsqueeze(-1)).squeeze(-1)
else:
# calculate max_a Q'(s_t+1, a)
next_state_values = self.target_net(next_states_v, goals_v).max(1)[0] # [0] = max
# Set discounted reward to zero for all states that were terminal
next_state_values[done_mask] = 0.0
# Compute r_t + gamma * max_a Q'(s_t+1, a)
expected_state_action_values = next_state_values.detach() * self.gamma + rewards_v
targets = np.zeros((self.batch_size, self.n_actions))
dummy_targets = np.zeros((self.batch_size,))
masks = np.zeros((self.batch_size, self.n_actions))
for idx, (target, mask, reward, action) in enumerate(zip(targets, masks, expected_state_action_values, actions_v)):
target[action] = reward # update action with estimated accumulated reward
dummy_targets[idx] = reward
mask[action] = 1. # enable loss for this specific action
# Calculate TD [Temporal Difference] error
action_batch = actions_v.squeeze().cpu().detach().numpy()
td_errors = targets[range(self.batch_size), action_batch] - state_action_values.squeeze().cpu().detach().numpy()
# Compute masked loss between predicted state action values and targets
targets = torch.tensor(targets, dtype=torch.float32).to(self.device)
masks = torch.tensor(masks, dtype=torch.float32).to(self.device)
loss = self.loss_fn(q_values, targets , masks)
# Update models and memory only when training is not done
if not self.training_done:
# Update priorities in replay buffer
new_priorities = np.abs(td_errors) + prioritized_replay_eps
self.memory.update_priorities(idx_vector, new_priorities)
self.optimizer.zero_grad()
loss.mean().backward() # https://discuss.pytorch.org/t/loss-backward-raises-error-grad-can-be-implicitly-created-only-for-scalar-outputs/12152/2
for param in self.policy_net.parameters():
param.grad.data.clamp_(-1, 1)
self.optimizer.step()
# Update the target network
if self.steps_done % self.hard_update == 0:
if self.verbose >= 1:
print('\nUpdating target network...\n')
self.target_net.load_state_dict(self.policy_net.state_dict())
return loss