import os import random import numpy as np import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.distributed as dist import torch.multiprocessing as mp import torch.utils.data import torch.utils.data.distributed from torch.distributed import ReduceOp from dataloader.data_load import PlanningDataset from model import diffusion_no_mask as diffusion from model import temporal_act import utils from utils.args import get_args from utils.training import Trainer from model.helpers import AverageMeter def accuracy2(output, target, topk=(1,), max_traj_len=0): with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() comparison = torch.cat((pred.view(-1, max_traj_len),target.view(-1, max_traj_len)), axis=1).cpu().numpy() correct = pred.eq(target.view(1, -1).expand_as(pred)) correct_a = correct[:1].view(-1, max_traj_len) correct_a0 = correct_a[:, 0].reshape(-1).float().mean().mul_(100.0) correct_aT = correct_a[:, -1].reshape(-1).float().mean().mul_(100.0) res = [] for k in topk: correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) correct_1 = correct[:1] # Success Rate trajectory_success = torch.all(correct_1.view(correct_1.shape[1] // max_traj_len, -1), dim=1) trajectory_success_rate = trajectory_success.sum() * 100.0 / trajectory_success.shape[0] # MIoU _, pred_token = output.topk(1, 1, True, True) pred_inst = pred_token.view(correct_1.shape[1], -1) pred_inst_set = set() target_inst = target.view(correct_1.shape[1], -1) target_inst_set = set() for i in range(pred_inst.shape[0]): pred_inst_set.add(tuple(pred_inst[i].tolist())) target_inst_set.add(tuple(target_inst[i].tolist())) MIoU1 = 100.0 * len(pred_inst_set.intersection(target_inst_set)) / len(pred_inst_set.union(target_inst_set)) batch_size = batch_size // max_traj_len pred_inst = pred_token.view(batch_size, -1) # [bs, T] pred_inst_set = set() target_inst = target.view(batch_size, -1) # [bs, T] target_inst_set = set() MIoU_sum = 0 for i in range(pred_inst.shape[0]): pred_inst_set.update(pred_inst[i].tolist()) target_inst_set.update(target_inst[i].tolist()) MIoU_current = 100.0 * len(pred_inst_set.intersection(target_inst_set)) / len( pred_inst_set.union(target_inst_set)) MIoU_sum += MIoU_current pred_inst_set.clear() target_inst_set.clear() MIoU2 = MIoU_sum / batch_size return res[0], trajectory_success_rate, MIoU1, MIoU2, correct_a0, correct_aT, comparison def test(val_loader, model, args): model.eval() acc_top1 = AverageMeter() trajectory_success_rate_meter = AverageMeter() MIoU1_meter = AverageMeter() MIoU2_meter = AverageMeter() A0_acc = AverageMeter() AT_acc = AverageMeter() pred_gt_total = [] for i_batch, sample_batch in enumerate(val_loader): global_img_tensors = sample_batch[0].cuda().contiguous() video_label = sample_batch[1].cuda() batch_size_current, T = video_label.size() task_class = sample_batch[2].view(-1).cuda() cond = {} with torch.no_grad(): cond[0] = global_img_tensors[:, 0, :].float() cond[T - 1] = global_img_tensors[:, -1, :].float() task_onehot = torch.zeros((task_class.size(0), args.class_dim)) ind = torch.arange(0, len(task_class)) task_onehot[ind, task_class] = 1. task_onehot = task_onehot.cuda() temp = task_onehot.unsqueeze(1) task_class_ = temp.repeat(1, T, 1) cond['task'] = task_class_ video_label_reshaped = video_label.view(-1) img_tensors = torch.zeros((batch_size_current, T, args.class_dim + args.action_dim + args.observation_dim)) img_tensors[:, 0, args.class_dim+args.action_dim:] = global_img_tensors[:, 0, :] img_tensors[:, -1, args.class_dim+args.action_dim:] = global_img_tensors[:, -1, :] img_tensors[:, :, :args.class_dim] = task_class_ output = model(cond, if_jump=True) actions_pred = output.contiguous() actions_pred = actions_pred[:, :, args.class_dim:args.class_dim + args.action_dim].contiguous() actions_pred = actions_pred.view(-1, args.action_dim) acc1, trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc, pred_gt = accuracy2(actions_pred.cpu(), video_label_reshaped.cpu(), topk=(1,), max_traj_len=args.horizon) pred_gt_total.append(pred_gt) acc_top1.update(acc1.item(), batch_size_current) trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current) MIoU1_meter.update(MIoU1, batch_size_current) MIoU2_meter.update(MIoU2, batch_size_current) A0_acc.update(a0_acc, batch_size_current) AT_acc.update(aT_acc, batch_size_current) np.savetxt("pred_gt_"+args.dataset+str(args.horizon)+".csv", np.concatenate(pred_gt_total), delimiter=",") return torch.tensor(acc_top1.avg), \ torch.tensor(trajectory_success_rate_meter.avg), \ torch.tensor(MIoU1_meter.avg), torch.tensor(MIoU2_meter.avg), \ torch.tensor(A0_acc.avg), torch.tensor(AT_acc.avg) def reduce_tensor(tensor): rt = tensor.clone() torch.distributed.all_reduce(rt, op=ReduceOp.SUM) rt /= dist.get_world_size() return rt def main(): args = get_args() os.environ['PYTHONHASHSEED'] = str(args.seed) if args.verbose: print(args) if args.seed is not None: random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) args.distributed = args.world_size > 1 or args.multiprocessing_distributed ngpus_per_node = 1 #torch.cuda.device_count() print('ngpus_per_node:', ngpus_per_node) if args.multiprocessing_distributed: args.world_size = ngpus_per_node * args.world_size mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) else: main_worker(args.gpu, ngpus_per_node, args) def main_worker(gpu, ngpus_per_node, args): args.gpu = gpu if args.distributed: if args.multiprocessing_distributed: args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group( backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank, ) if args.gpu is not None: torch.cuda.set_device(args.gpu) args.batch_size = int(args.batch_size / ngpus_per_node) args.batch_size_val = int(args.batch_size_val / ngpus_per_node) args.num_thread_reader = int(args.num_thread_reader / ngpus_per_node) elif args.gpu is not None: torch.cuda.set_device(args.gpu) # Test data loading code test_dataset = PlanningDataset( args.root, args=args, is_val=True, model=None, ) if args.distributed: test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) test_sampler.shuffle = False else: test_sampler = None test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=args.batch_size_val, shuffle=False, drop_last=False, num_workers=args.num_thread_reader, sampler=test_sampler, ) # create model if args.dataset=='NIV': if args.attn=='NoAttention': temporal_model = temporal_act.TemporalUnetNoAttn(args.action_dim + args.observation_dim + args.class_dim, args.action_dim, dim=256, dim_mults=(1, 2, 4, 8), ) if args.attn=='WithAttention': temporal_model = temporal_act.TemporalUnet(args.action_dim + args.observation_dim + args.class_dim, args.action_dim, dim=256, dim_mults=(1, 2, 4, 8), ) else: if args.attn=='NoAttention': temporal_model = temporal_act.TemporalUnetNoAttn(args.action_dim + args.observation_dim + args.class_dim, args.action_dim, dim=512, dim_mults=(1, 2, 4), ) if args.attn=='WithAttention': temporal_model = temporal_act.TemporalUnet(args.action_dim + args.observation_dim + args.class_dim, args.action_dim, dim=512, dim_mults=(1, 2, 4), ) diffusion_model = diffusion.GaussianDiffusion( temporal_model, args.horizon, args.observation_dim, args.action_dim, args.class_dim, args.n_diffusion_steps, loss_type='Weighted_MSE', clip_denoised=True,) model = Trainer(diffusion_model, None, args.ema_decay, args.lr, args.gradient_accumulate_every, args.step_start_ema, args.update_ema_every, args.log_freq) if args.pretrain_cnn_path: net_data = torch.load(args.pretrain_cnn_path) model.model.load_state_dict(net_data) model.ema_model.load_state_dict(net_data) if args.distributed: if args.gpu is not None: model.model.cuda(args.gpu) model.ema_model.cuda(args.gpu) model.model = torch.nn.parallel.DistributedDataParallel(model.model, device_ids=[args.gpu], find_unused_parameters=True) model.ema_model = torch.nn.parallel.DistributedDataParallel(model.ema_model, device_ids=[args.gpu], find_unused_parameters=True) else: model.model.cuda() model.ema_model.cuda() model.model = torch.nn.parallel.DistributedDataParallel(model.model, find_unused_parameters=True) model.ema_model = torch.nn.parallel.DistributedDataParallel(model.ema_model, find_unused_parameters=True) elif args.gpu is not None: model.model = model.model.cuda(args.gpu) model.ema_model = model.ema_model.cuda(args.gpu) else: model.model = torch.nn.DataParallel(model.model).cuda() model.ema_model = torch.nn.DataParallel(model.ema_model).cuda() if args.resume: checkpoint_path = args.checkpoint_diff if checkpoint_path: print("=> loading checkpoint '{}'".format(checkpoint_path), args) checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(args.rank)) args.start_epoch = checkpoint["epoch"] model.model.load_state_dict(checkpoint["model"], strict=True) model.ema_model.load_state_dict(checkpoint["ema"], strict=True) model.step = checkpoint["step"] else: assert 0 if args.cudnn_benchmark: cudnn.benchmark = True acc_top1_reduced_sum = [] trajectory_success_rate_meter_reduced_sum = [] MIoU1_meter_reduced_sum = [] MIoU2_meter_reduced_sum = [] acc_a0_reduced_sum = [] acc_aT_reduced_sum = [] test_times = 10 for epoch in range(0, test_times): tmp = epoch random.seed(tmp) np.random.seed(tmp) torch.manual_seed(tmp) torch.cuda.manual_seed_all(tmp) acc_top1, trajectory_success_rate_meter, MIoU1_meter, MIoU2_meter, acc_a0, acc_aT = test(test_loader, model.ema_model, args) acc_top1_reduced = reduce_tensor(acc_top1.cuda()).item() trajectory_success_rate_meter_reduced = reduce_tensor(trajectory_success_rate_meter.cuda()).item() MIoU1_meter_reduced = reduce_tensor(MIoU1_meter.cuda()).item() MIoU2_meter_reduced = reduce_tensor(MIoU2_meter.cuda()).item() acc_a0_reduced = reduce_tensor(acc_a0.cuda()).item() acc_aT_reduced = reduce_tensor(acc_aT.cuda()).item() acc_top1_reduced_sum.append(acc_top1_reduced) trajectory_success_rate_meter_reduced_sum.append(trajectory_success_rate_meter_reduced) MIoU1_meter_reduced_sum.append(MIoU1_meter_reduced) MIoU2_meter_reduced_sum.append(MIoU2_meter_reduced) acc_a0_reduced_sum.append(acc_a0_reduced) acc_aT_reduced_sum.append(acc_aT_reduced) if args.rank == 0: max_v = max(trajectory_success_rate_meter_reduced_sum) max_ind = trajectory_success_rate_meter_reduced_sum.index(max_v) print('Val/EpochAcc@1', acc_top1_reduced_sum[max_ind]) print('Val/Traj_Success_Rate', max(trajectory_success_rate_meter_reduced_sum)) print('Val/MIoU2', MIoU2_meter_reduced_sum[max_ind]) if __name__ == "__main__": main()