from .accuracy import * from model.helpers import AverageMeter def validate_act_noise(val_loader, model, args, act_emb): model.eval() losses = AverageMeter() acc_top1 = AverageMeter() acc_top5 = AverageMeter() trajectory_success_rate_meter = AverageMeter() MIoU1_meter = AverageMeter() MIoU2_meter = AverageMeter() A0_acc = AverageMeter() AT_acc = AverageMeter() for i_batch, sample_batch in enumerate(val_loader): # compute output global_img_tensors = sample_batch[0].cuda().contiguous().float() video_label = sample_batch[1].cuda() batch_size_current, T = video_label.size() task_class = sample_batch[2].view(-1).cuda() cond = {} with torch.no_grad(): cond[0] = global_img_tensors[:, 0, :] cond[T - 1] = global_img_tensors[:, -1, :] task_onehot = torch.zeros((task_class.size(0), args.class_dim)) # [bs*T, ac_dim] ind = torch.arange(0, len(task_class)) task_onehot[ind, task_class] = 1. task_onehot = task_onehot.cuda() temp = task_onehot.unsqueeze(1) task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim] cond['task'] = task_class_ video_label_reshaped = video_label.view(-1) action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim)) ind = torch.arange(0, len(video_label_reshaped)) action_label_onehot[ind, video_label_reshaped] = 1. action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda() x_start = torch.zeros((batch_size_current, T, args.class_dim + args.action_dim + args.observation_dim)) x_start[:, 0, args.class_dim + args.action_dim:] = global_img_tensors[:, 0, :] x_start[:, -1, args.class_dim + args.action_dim:] = global_img_tensors[:, -1, :] x_start[:, :, args.class_dim:args.class_dim + args.action_dim] = action_label_onehot x_start[:, :, :args.class_dim] = task_class_ #output = model(cond, task_class, act_emb, if_jump=False, if_avg_mask=args.infer_avg_mask, cond_type=args.use_cls_mask) output = model(cond, act_emb, task_class, if_jump=False, if_avg_mask=args.infer_avg_mask) actions_pred = output.contiguous() loss = model.module.loss_fn(actions_pred, x_start.cuda()) actions_pred = actions_pred[:, :, args.class_dim:args.class_dim + args.action_dim].contiguous() actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim] (acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \ accuracy(actions_pred.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon) losses.update(loss.item(), batch_size_current) acc_top1.update(acc1.item(), batch_size_current) acc_top5.update(acc5.item(), batch_size_current) trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current) MIoU1_meter.update(MIoU1, batch_size_current) MIoU2_meter.update(MIoU2, batch_size_current) A0_acc.update(a0_acc, batch_size_current) AT_acc.update(aT_acc, batch_size_current) return torch.tensor(losses.avg), torch.tensor(acc_top1.avg), torch.tensor(acc_top5.avg), \ torch.tensor(trajectory_success_rate_meter.avg), \ torch.tensor(MIoU1_meter.avg), torch.tensor(MIoU2_meter.avg), \ torch.tensor(A0_acc.avg), torch.tensor(AT_acc.avg) def validate_act(val_loader, model, args, act_emb): model.eval() losses = AverageMeter() acc_top1 = AverageMeter() acc_top5 = AverageMeter() trajectory_success_rate_meter = AverageMeter() MIoU1_meter = AverageMeter() MIoU2_meter = AverageMeter() A0_acc = AverageMeter() AT_acc = AverageMeter() for i_batch, sample_batch in enumerate(val_loader): # compute output global_img_tensors = sample_batch[0].cuda().contiguous().float() video_label = sample_batch[1].cuda() batch_size_current, T = video_label.size() task_class = sample_batch[2].view(-1).cuda() cond = {} with torch.no_grad(): cond[0] = global_img_tensors[:, 0, :] cond[T - 1] = global_img_tensors[:, -1, :] task_onehot = torch.zeros((task_class.size(0), args.class_dim)) # [bs*T, ac_dim] ind = torch.arange(0, len(task_class)) task_onehot[ind, task_class] = 1. task_onehot = task_onehot.cuda() temp = task_onehot.unsqueeze(1) task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim] if args.use_cls_mask: cond['task'] = task_class_ video_label_reshaped = video_label.view(-1) action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim)) ind = torch.arange(0, len(video_label_reshaped)) action_label_onehot[ind, video_label_reshaped] = 1. action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda() if args.use_cls_mask: x_start = torch.zeros((batch_size_current, T, args.class_dim + args.action_dim + args.observation_dim)) x_start[:, 0, args.class_dim + args.action_dim:] = global_img_tensors[:, 0, :] x_start[:, -1, args.class_dim + args.action_dim:] = global_img_tensors[:, -1, :] x_start[:, :, args.class_dim:args.class_dim + args.action_dim] = action_label_onehot x_start[:, :, :args.class_dim] = task_class_ else: x_start = torch.zeros((batch_size_current, T, args.action_dim + args.observation_dim)) x_start[:, 0, args.action_dim:] = global_img_tensors[:, 0, :] x_start[:, -1, args.action_dim:] = global_img_tensors[:, -1, :] x_start[:, :, :args.action_dim] = action_label_onehot #output = model(cond, task_class, act_emb, if_jump=False, if_avg_mask=args.infer_avg_mask, cond_type=args.use_cls_mask) output = model(cond, act_emb, task_class, if_jump=False, if_avg_mask=args.infer_avg_mask) #output = model(cond, act_emb, task_class, if_jump=True, if_avg_mask=args.infer_avg_mask) actions_pred = output.contiguous() loss = model.module.loss_fn(actions_pred, x_start.cuda()) actions_pred = actions_pred[:, :, args.class_dim:args.class_dim + args.action_dim].contiguous() actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim] (acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \ accuracy(actions_pred.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon) losses.update(loss.item(), batch_size_current) acc_top1.update(acc1.item(), batch_size_current) acc_top5.update(acc5.item(), batch_size_current) trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current) MIoU1_meter.update(MIoU1, batch_size_current) MIoU2_meter.update(MIoU2, batch_size_current) A0_acc.update(a0_acc, batch_size_current) AT_acc.update(aT_acc, batch_size_current) return torch.tensor(losses.avg), torch.tensor(acc_top1.avg), torch.tensor(acc_top5.avg), \ torch.tensor(trajectory_success_rate_meter.avg), \ torch.tensor(MIoU1_meter.avg), torch.tensor(MIoU2_meter.avg), \ torch.tensor(A0_acc.avg), torch.tensor(AT_acc.avg) def validate(val_loader, model, args): model.eval() losses = AverageMeter() acc_top1 = AverageMeter() acc_top5 = AverageMeter() trajectory_success_rate_meter = AverageMeter() MIoU1_meter = AverageMeter() MIoU2_meter = AverageMeter() A0_acc = AverageMeter() AT_acc = AverageMeter() for i_batch, sample_batch in enumerate(val_loader): # compute output global_img_tensors = sample_batch[0].cuda().contiguous().float() video_label = sample_batch[1].cuda() batch_size_current, T = video_label.size() task_class = sample_batch[2].view(-1).cuda() cond = {} with torch.no_grad(): cond[0] = global_img_tensors[:, 0, :] cond[T - 1] = global_img_tensors[:, -1, :] task_onehot = torch.zeros((task_class.size(0), args.class_dim)) # [bs*T, ac_dim] ind = torch.arange(0, len(task_class)) task_onehot[ind, task_class] = 1. task_onehot = task_onehot.cuda() temp = task_onehot.unsqueeze(1) task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim] if args.use_cls_mask: cond['task'] = task_class_ video_label_reshaped = video_label.view(-1) action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim)) ind = torch.arange(0, len(video_label_reshaped)) action_label_onehot[ind, video_label_reshaped] = 1. action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda() if args.use_cls_mask: x_start = torch.zeros((batch_size_current, T, args.class_dim + args.action_dim + args.observation_dim)) x_start[:, 0, args.class_dim + args.action_dim:] = global_img_tensors[:, 0, :] x_start[:, -1, args.class_dim + args.action_dim:] = global_img_tensors[:, -1, :] x_start[:, :, args.class_dim:args.class_dim + args.action_dim] = action_label_onehot x_start[:, :, :args.class_dim] = task_class_ else: x_start = torch.zeros((batch_size_current, T, args.action_dim + args.observation_dim)) x_start[:, 0, args.action_dim:] = global_img_tensors[:, 0, :] x_start[:, -1, args.action_dim:] = global_img_tensors[:, -1, :] x_start[:, :, :args.action_dim] = action_label_onehot output = model(cond, task_class, act_emb, if_jump=False, if_avg_mask=args.infer_avg_mask, cond_type=args.use_cls_mask) actions_pred = output.contiguous() loss = model.module.loss_fn(actions_pred, x_start.cuda()) if args.use_cls_mask: actions_pred = actions_pred[:, :, args.class_dim:args.class_dim + args.action_dim].contiguous() actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim] else: actions_pred = actions_pred[:, :, :args.action_dim].contiguous() actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim] (acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \ accuracy(actions_pred.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon) losses.update(loss.item(), batch_size_current) acc_top1.update(acc1.item(), batch_size_current) acc_top5.update(acc5.item(), batch_size_current) trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current) MIoU1_meter.update(MIoU1, batch_size_current) MIoU2_meter.update(MIoU2, batch_size_current) A0_acc.update(a0_acc, batch_size_current) AT_acc.update(aT_acc, batch_size_current) return torch.tensor(losses.avg), torch.tensor(acc_top1.avg), torch.tensor(acc_top5.avg), \ torch.tensor(trajectory_success_rate_meter.avg), \ torch.tensor(MIoU1_meter.avg), torch.tensor(MIoU2_meter.avg), \ torch.tensor(A0_acc.avg), torch.tensor(AT_acc.avg) def validate_mlp(val_loader, model, act_model, args, act_emb): model.eval() losses = AverageMeter() acc_top1 = AverageMeter() acc_top5 = AverageMeter() trajectory_success_rate_meter = AverageMeter() MIoU1_meter = AverageMeter() MIoU2_meter = AverageMeter() A0_acc = AverageMeter() AT_acc = AverageMeter() for i_batch, sample_batch in enumerate(val_loader): # compute output global_img_tensors = sample_batch[0].cuda().contiguous().float() video_label = sample_batch[1].cuda() batch_size_current, T = video_label.size() task_class = sample_batch[2].view(-1).cuda() cond = {} with torch.no_grad(): cond[0] = global_img_tensors[:, 0, :].view((batch_size_current,32,48)) cond[T - 1] = global_img_tensors[:, -1, :].view((batch_size_current,32,48)) task_onehot = torch.zeros((task_class.size(0), args.class_dim)) # [bs*T, ac_dim] ind = torch.arange(0, len(task_class)) task_onehot[ind, task_class] = 1. task_onehot = task_onehot.cuda() temp = task_onehot.unsqueeze(1) task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim] video_label_reshaped = video_label.view(-1) action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim)) ind = torch.arange(0, len(video_label_reshaped)) action_label_onehot[ind, video_label_reshaped] = 1. action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda() x_start = torch.zeros((batch_size_current, T, 32, 64)) x_start[:, 0, :, :48] = global_img_tensors[:, 0, :].view((batch_size_current,32,48)) x_start[:, -1, :, :48] = global_img_tensors[:, -1, :].view((batch_size_current,32,48)) for i in range(video_label.shape[0]): for j in range(video_label.shape[1]): x_start[i][j][:, 48:] = act_emb[video_label[i][j]].view(32,16) output = model(cond, task_class, act_emb, if_jump=False, if_avg_mask=args.infer_avg_mask, cond_type=args.use_cls_mask) actions_pred = output.contiguous() loss = model.module.loss_fn(actions_pred, x_start.cuda()) actions_pred = output[:, :, :, 48:].contiguous().view(batch_size_current*args.horizon, 512) #actions = act_model(actions_pred) sim_act_pred = similarity_score(actions_pred, act_emb) #print('actions',actions.shape, actions_pred.shape, video_label_reshaped.shape, video_label.shape) #print('lebel', torch.max(actions, 1)) #print('lebel', video_label_reshaped) #(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = accuracy(actions.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon) (acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = accuracy(sim_act_pred, video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon) losses.update(loss.item(), batch_size_current) acc_top1.update(acc1.item(), batch_size_current) acc_top5.update(acc5.item(), batch_size_current) trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current) MIoU1_meter.update(MIoU1, batch_size_current) MIoU2_meter.update(MIoU2, batch_size_current) A0_acc.update(a0_acc, batch_size_current) AT_acc.update(aT_acc, batch_size_current) ''' with torch.no_grad(): cond[0] = global_img_tensors[:, 0, :] cond[T - 1] = global_img_tensors[:, -1, :] task_onehot = torch.zeros((task_class.size(0), args.class_dim)) # [bs*T, ac_dim] ind = torch.arange(0, len(task_class)) task_onehot[ind, task_class] = 1. task_onehot = task_onehot.cuda() temp = task_onehot.unsqueeze(1) task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim] if args.use_cls_mask: cond['task'] = task_class_ video_label_reshaped = video_label.view(-1) action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim)) ind = torch.arange(0, len(video_label_reshaped)) action_label_onehot[ind, video_label_reshaped] = 1. action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda() if args.use_cls_mask: x_start = torch.zeros((batch_size_current, T, args.class_dim + args.action_dim + args.observation_dim)) x_start[:, 0, args.class_dim + args.action_dim:] = global_img_tensors[:, 0, :] x_start[:, -1, args.class_dim + args.action_dim:] = global_img_tensors[:, -1, :] x_start[:, :, args.class_dim:args.class_dim + args.action_dim] = action_label_onehot x_start[:, :, :args.class_dim] = task_class_ else: x_start = torch.zeros((batch_size_current, T, args.action_dim + args.observation_dim)) x_start[:, 0, args.action_dim:] = global_img_tensors[:, 0, :] x_start[:, -1, args.action_dim:] = global_img_tensors[:, -1, :] x_start[:, :, :args.action_dim] = action_label_onehot output = model(cond, task_class, act_emb, if_jump=False, if_avg_mask=args.infer_avg_mask, cond_type=args.use_cls_mask) actions_pred = output.contiguous() loss = model.module.loss_fn(actions_pred, x_start.cuda()) if args.use_cls_mask: actions_pred = actions_pred[:, :, args.class_dim:args.class_dim + args.action_dim].contiguous() actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim] else: actions_pred = actions_pred[:, :, :args.action_dim].contiguous() actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim] (acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \ accuracy(actions_pred.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon) losses.update(loss.item(), batch_size_current) acc_top1.update(acc1.item(), batch_size_current) acc_top5.update(acc5.item(), batch_size_current) trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current) MIoU1_meter.update(MIoU1, batch_size_current) MIoU2_meter.update(MIoU2, batch_size_current) A0_acc.update(a0_acc, batch_size_current) AT_acc.update(aT_acc, batch_size_current) ''' return torch.tensor(losses.avg), torch.tensor(acc_top1.avg), torch.tensor(acc_top5.avg), \ torch.tensor(trajectory_success_rate_meter.avg), \ torch.tensor(MIoU1_meter.avg), torch.tensor(MIoU2_meter.avg), \ torch.tensor(A0_acc.avg), torch.tensor(AT_acc.avg)