from collections import namedtuple import random import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torchinfo import summary import numpy as np from tqdm import tqdm from replay_buffer import PrioritizedReplayBuffer, LinearSchedule, TanHSchedule prioritized_replay_alpha = 0.6 max_timesteps=1000000 prioritized_replay_beta0=0.4 prioritized_replay_eps=1e-6 prioritized_replay_beta_iters = max_timesteps*0.5 beta_schedule = LinearSchedule(prioritized_replay_beta_iters, initial_p=prioritized_replay_beta0, final_p=1.0) Experience = namedtuple('Experience', field_names=['state', 'action', 'reward', 'next_state', 'done']) # Deep Q Network class DQN(nn.Module): def __init__(self, device, input_shape, n_actions, hidden_nodes=512, goal_feature=False): super(DQN, self).__init__() self.device = device self.conv = nn.Sequential( nn.Conv2d(input_shape[0], 16, kernel_size=5, stride=2), nn.BatchNorm2d(16), nn.ReLU(), nn.Conv2d(16, 32, kernel_size=5, stride=2), nn.BatchNorm2d(32), nn.ReLU(), nn.Conv2d(32, 32, kernel_size=5, stride=2), nn.BatchNorm2d(32), nn.ReLU() ) conv_out_size = self._get_conv_out(input_shape) if goal_feature: conv_out_size += 4 # Standard DQN self.fc = nn.Sequential( nn.Linear(conv_out_size, hidden_nodes), nn.LeakyReLU(), nn.Linear(hidden_nodes, n_actions), # nn.Softmax(dim=1) # necessary for Boltzman Sampling ) def _get_conv_out(self, shape): o = self.conv(torch.zeros(1, *shape)) return int(np.prod(o.size())) def forward(self, x, goal=None): x = x.to(self.device) conv_out = self.conv(x).view(x.size()[0], -1) if not goal is None: conv_out = torch.cat((conv_out, goal.squeeze(1)), dim=1) return self.fc(conv_out) # Dueling Deep Q Network class DuelingDQN(nn.Module): def __init__(self, device, input_shape, n_actions, hidden_nodes=512, goal_feature=False): super(DuelingDQN, self).__init__() self.device = device self.conv = nn.Sequential( nn.Conv2d(input_shape[0], 16, kernel_size=5, stride=2), nn.BatchNorm2d(16), nn.ReLU(), nn.Conv2d(16, 32, kernel_size=5, stride=2), nn.BatchNorm2d(32), nn.ReLU(), nn.Conv2d(32, 32, kernel_size=5, stride=2), nn.BatchNorm2d(32), nn.ReLU() ) conv_out_size = self._get_conv_out(input_shape) if goal_feature: conv_out_size += 4 # Dueling DQN --> two heads # advantage of actions A(s,a) self.fc_adv = nn.Sequential( nn.Linear(conv_out_size, hidden_nodes), nn.LeakyReLU(), nn.Linear(hidden_nodes, n_actions) ) # value of states V(s) self.fc_val = nn.Sequential( nn.Linear(conv_out_size, hidden_nodes), nn.LeakyReLU(), nn.Linear(hidden_nodes, 1) ) def _get_conv_out(self, shape): o = self.conv(torch.zeros(1, *shape)) return int(np.prod(o.size())) def forward(self, x, goal=None): x = x.to(self.device) fx = x.float() / 256 conv_out = self.conv(fx).view(fx.size()[0], -1) if not goal is None: conv_out = torch.cat((conv_out, goal.squeeze(1)), dim=1) val = self.fc_val(conv_out) adv = self.fc_adv(conv_out) # Q(s,a) = V(s) + A(s,a) - 1/N sum(A(s,k)) --> expected value of advantage = 0 return val + (adv - adv.mean(dim=1, keepdim=True)) class MaskedHuberLoss(nn.Module): def __init__(self, delta=1.0): super(MaskedHuberLoss, self).__init__() self.criterion = nn.HuberLoss(reduction='none', delta=delta) def forward(self, inputs, targets, mask): loss = self.criterion(inputs, targets) loss *= mask return loss.sum(dim=-1) # Agent that performs, remembers and learns actions class Agent(): def __init__(self, device, goal, input_shape, n_actions, random_play_steps, args, dueling=True, goal_feature=False): self.device = device self.n_actions = n_actions self.goal_feature = goal_feature if dueling: self.policy_net = DuelingDQN(device, input_shape, n_actions, args.hidden_nodes, goal_feature).to(self.device) self.target_net = DuelingDQN(device, input_shape, n_actions, args.hidden_nodes, goal_feature).to(self.device) else: self.policy_net = DQN(device, input_shape, n_actions, args.hidden_nodes, goal_feature).to(self.device) self.target_net = DQN(device, input_shape, n_actions, args.hidden_nodes, goal_feature).to(self.device) #self.optimizer = optim.RMSprop(self.policy_net.parameters(), lr=args.learning_rate) self.optimizer = optim.Adam(self.policy_net.parameters(), lr=args.learning_rate) self.memory = PrioritizedReplayBuffer(args.experience_buffer, alpha=0.6) self.exploration = LinearSchedule(schedule_timesteps=args.epsilon_steps, initial_p=args.epsilon_start, final_p=args.epsilon_end) #self.exploration = TanHSchedule(schedule_timesteps=(args.epsilon_steps//10), initial_p=args.epsilon_start, final_p=args.epsilon_end) self.loss_fn = MaskedHuberLoss(delta=1.0) self.steps_done = -1 self.batch_size = args.batch_size self.gamma = args.gamma self.goal = goal self.enable_double_dqn = True self.hard_update = 1000 self.training_done = False self.verbose = args.verbose self.random_play_steps = random_play_steps self.args = args def remember(self, state, action, reward, next_state, done, goal): self.memory.add(state, action, reward, next_state, done, goal) def load_model(self, path): #summary(self.policy_net, (4, 84, 84)) self.policy_net.load_state_dict(torch.load(path)) self.policy_net.eval() self.random_play_steps = 100 self.training_done = True self.exploration = LinearSchedule(schedule_timesteps=1000, initial_p=0.1, final_p=0.005) print('Loaded DDQN policy network weights from ', path) def get_epsilon(self): return self.exploration.value(self.steps_done - self.random_play_steps) def reset_epsilon_schedule(self): if isinstance(self.exploration, TanHSchedule): self.exploration.restart() else: print('Exploration schedule is linear and can not be reset!') def select_action(self, state, goal=None): #if not self.training_done: self.steps_done += 1 # Select an action according to an epsilon greedy approach if self.steps_done < self.random_play_steps or random.random() < self.get_epsilon(): return torch.tensor([[random.randrange(0, self.n_actions)]], device=self.device, dtype=torch.long) else: with torch.no_grad(): return self.policy_net(state, goal).max(1)[1].view(1, 1) # Select an action via Boltzmann Sampling if self.steps_done > self.random_play_steps: return self.sample(state, temperature=self.get_epsilon()) else: return torch.tensor([[random.randrange(0, self.n_actions)]], device=self.device, dtype=torch.long) """ else: with torch.no_grad(): return self.policy_net(state).max(1)[1].view(1, 1) """ def sample(self, state, temperature=0.1, goal=None): 'Predict action probability vector and use Boltzmann Sampling to get new action' with torch.no_grad(): prob_vec = self.policy_net(state, goal) sample_prob = torch.exp(prob_vec / temperature) sample_prob /= sample_prob.sum() try: sample_idx = sample_prob.multinomial(num_samples=1) # random choice with probabilities except Exception as e: # print(temperature, prob_vec, sample_prob) # RuntimeError: probability tensor contains either `inf`, `nan` or element < 0 import logging logging.basicConfig(filename=f'runs/{self.args.model_name}sample_prob.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s') logging.warning(f'[Agent {self.goal} - step {self.steps_done}] {e}: {sample_prob}') sample_idx = torch.tensor([[random.randrange(0, self.n_actions)]], device=self.device, dtype=torch.long) return sample_idx.view(1,1) def finish_training(self, model_name): torch.save(self.policy_net.state_dict(), f"runs/run{model_name}-agent{self.goal}.pt") self.training_done = True # Free memory by deleting target net (no longer needed) del self.target_net def optimize_model(self): if len(self.memory) < self.batch_size or self.steps_done < self.random_play_steps or self.training_done: return # Sample from prioritized replay memory sample = self.memory.sample(self.batch_size, beta=beta_schedule.value(self.steps_done)) states, actions, rewards, next_states, dones, goals, importance_v, idx_vector = sample states_v = states.to(self.device) next_states_v = next_states.to(self.device) actions_v = actions.to(self.device) rewards_v = rewards.to(self.device) done_mask = dones.to(self.device) goals_v = goals.to(self.device) if not self.goal_feature: goals_v = None # predicted Q(s, a) q_values = self.policy_net(states_v, goals_v) state_action_values = q_values.gather(1, actions_v.squeeze(-1)) if self.enable_double_dqn: # calculate max_a Q'(s_t+1, argmax_a Q(s_t+1, a)) next_state_actions = self.policy_net(next_states_v, goals_v).max(1)[1] # [1] = argmax next_state_values = self.target_net(next_states_v, goals_v).gather(1, next_state_actions.unsqueeze(-1)).squeeze(-1) else: # calculate max_a Q'(s_t+1, a) next_state_values = self.target_net(next_states_v, goals_v).max(1)[0] # [0] = max # Set discounted reward to zero for all states that were terminal next_state_values[done_mask] = 0.0 # Compute r_t + gamma * max_a Q'(s_t+1, a) expected_state_action_values = next_state_values.detach() * self.gamma + rewards_v targets = np.zeros((self.batch_size, self.n_actions)) dummy_targets = np.zeros((self.batch_size,)) masks = np.zeros((self.batch_size, self.n_actions)) for idx, (target, mask, reward, action) in enumerate(zip(targets, masks, expected_state_action_values, actions_v)): target[action] = reward # update action with estimated accumulated reward dummy_targets[idx] = reward mask[action] = 1. # enable loss for this specific action # Calculate TD [Temporal Difference] error action_batch = actions_v.squeeze().cpu().detach().numpy() td_errors = targets[range(self.batch_size), action_batch] - state_action_values.squeeze().cpu().detach().numpy() # Compute masked loss between predicted state action values and targets targets = torch.tensor(targets, dtype=torch.float32).to(self.device) masks = torch.tensor(masks, dtype=torch.float32).to(self.device) loss = self.loss_fn(q_values, targets , masks) # Update models and memory only when training is not done if not self.training_done: # Update priorities in replay buffer new_priorities = np.abs(td_errors) + prioritized_replay_eps self.memory.update_priorities(idx_vector, new_priorities) self.optimizer.zero_grad() loss.mean().backward() # https://discuss.pytorch.org/t/loss-backward-raises-error-grad-can-be-implicitly-created-only-for-scalar-outputs/12152/2 for param in self.policy_net.parameters(): param.grad.data.clamp_(-1, 1) self.optimizer.step() # Update the target network if self.steps_done % self.hard_update == 0: if self.verbose >= 1: print('\nUpdating target network...\n') self.target_net.load_state_dict(self.policy_net.state_dict()) return loss