first commit
This commit is contained in:
commit
8f8cf48929
2819 changed files with 33143 additions and 0 deletions
73
utils/accuracy.py
Normal file
73
utils/accuracy.py
Normal file
|
@ -0,0 +1,73 @@
|
|||
import torch
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,), max_traj_len=0):
|
||||
with torch.no_grad():
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t() # [k, bs*T]
|
||||
#print(pred, target.view(1, -1).expand_as(pred))
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred)) # [k, bs*T]
|
||||
|
||||
correct_a = correct[:1].view(-1, max_traj_len) # [bs, T]
|
||||
correct_a0 = correct_a[:, 0].reshape(-1).float().mean().mul_(100.0)
|
||||
correct_aT = correct_a[:, -1].reshape(-1).float().mean().mul_(100.0)
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
|
||||
correct_1 = correct[:1] # (1, bs*T)
|
||||
|
||||
# Success Rate
|
||||
trajectory_success = torch.all(correct_1.view(correct_1.shape[1] // max_traj_len, -1), dim=1)
|
||||
trajectory_success_rate = trajectory_success.sum() * 100.0 / trajectory_success.shape[0]
|
||||
|
||||
# MIoU
|
||||
_, pred_token = output.topk(1, 1, True, True) # [bs*T, 1]
|
||||
pred_inst = pred_token.view(correct_1.shape[1], -1) # [bs*T, 1]
|
||||
pred_inst_set = set()
|
||||
target_inst = target.view(correct_1.shape[1], -1) # [bs*T, 1]
|
||||
target_inst_set = set()
|
||||
for i in range(pred_inst.shape[0]):
|
||||
# print(pred_inst[i], target_inst[i])
|
||||
pred_inst_set.add(tuple(pred_inst[i].tolist()))
|
||||
target_inst_set.add(tuple(target_inst[i].tolist()))
|
||||
MIoU1 = 100.0 * len(pred_inst_set.intersection(target_inst_set)) / len(pred_inst_set.union(target_inst_set))
|
||||
|
||||
batch_size = batch_size // max_traj_len
|
||||
pred_inst = pred_token.view(batch_size, -1) # [bs, T]
|
||||
pred_inst_set = set()
|
||||
target_inst = target.view(batch_size, -1) # [bs, T]
|
||||
target_inst_set = set()
|
||||
MIoU_sum = 0
|
||||
for i in range(pred_inst.shape[0]):
|
||||
# print(pred_inst[i], target_inst[i])
|
||||
pred_inst_set.update(pred_inst[i].tolist())
|
||||
target_inst_set.update(target_inst[i].tolist())
|
||||
MIoU_current = 100.0 * len(pred_inst_set.intersection(target_inst_set)) / len(
|
||||
pred_inst_set.union(target_inst_set))
|
||||
MIoU_sum += MIoU_current
|
||||
pred_inst_set.clear()
|
||||
target_inst_set.clear()
|
||||
|
||||
MIoU2 = MIoU_sum / batch_size
|
||||
return res, trajectory_success_rate, MIoU1, MIoU2, correct_a0, correct_aT
|
||||
|
||||
|
||||
def similarity_score(pred, act_emb, metric='cos'):
|
||||
# pred shape: [bs*t, 512] 512 is action embedding shape
|
||||
sim_score = torch.zeros((pred.shape[0], act_emb.shape[0])).cuda()
|
||||
if metric == 'cos':
|
||||
cos = torch.nn.CosineSimilarity(dim=0, eps=1e-8)
|
||||
#print(pred[0].shape, sim_score.shape, act_emb[0].shape)
|
||||
for i in range(sim_score.shape[0]):
|
||||
for j in range(act_emb.shape[0]):
|
||||
sim_score[i][j] = cos(pred[i],act_emb[j].cuda())
|
||||
|
||||
sim_score = torch.abs(sim_score)
|
||||
print('sim_score', sim_score, torch.max(sim_score, 1))
|
||||
return sim_score.cpu()
|
191
utils/args.py
Normal file
191
utils/args.py
Normal file
|
@ -0,0 +1,191 @@
|
|||
import argparse
|
||||
|
||||
def get_args(description='whl'):
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
parser.add_argument('--act_emb_path',
|
||||
type=str,
|
||||
default='dataset/coin/steps_info.pickle',
|
||||
help='action embedding path')
|
||||
parser.add_argument('--checkpoint_mlp',
|
||||
type=str,
|
||||
default='',
|
||||
help='checkpoint path for task prediction model')
|
||||
parser.add_argument('--checkpoint_diff',
|
||||
type=str,
|
||||
default='',
|
||||
help='checkpoint path for diffusion model')
|
||||
parser.add_argument('--mask_type',
|
||||
type=str,
|
||||
default='multi_add', # single_add, multi_add
|
||||
help='action embedding mask type')
|
||||
parser.add_argument('--attn',
|
||||
type=str,
|
||||
default='attention', # single_add, multi_add
|
||||
help='WithAttention: unet with attn. NoAttention: unet without attention.')
|
||||
parser.add_argument('--infer_avg_mask',
|
||||
type=bool,
|
||||
default=False,
|
||||
help='if use average mask for inference')
|
||||
parser.add_argument('--use_cls_mask',
|
||||
type=bool,
|
||||
default=False,
|
||||
help='if use class label in diffusion mask')
|
||||
parser.add_argument('--checkpoint_root',
|
||||
type=str,
|
||||
default='checkpoint',
|
||||
help='checkpoint dir root')
|
||||
parser.add_argument('--log_root',
|
||||
type=str,
|
||||
default='log',
|
||||
help='log dir root')
|
||||
parser.add_argument('--checkpoint_dir',
|
||||
type=str,
|
||||
default='',
|
||||
help='checkpoint model folder')
|
||||
parser.add_argument('--optimizer',
|
||||
type=str,
|
||||
default='adam',
|
||||
help='opt algorithm')
|
||||
parser.add_argument('--num_thread_reader',
|
||||
type=int,
|
||||
default=40,
|
||||
help='')
|
||||
parser.add_argument('--batch_size',
|
||||
type=int,
|
||||
default=256, # 256
|
||||
help='batch size')
|
||||
parser.add_argument('--batch_size_val',
|
||||
type=int,
|
||||
default=1024, # 1024
|
||||
help='batch size eval')
|
||||
parser.add_argument('--pretrain_cnn_path',
|
||||
type=str,
|
||||
default='',
|
||||
help='')
|
||||
parser.add_argument('--momemtum',
|
||||
type=float,
|
||||
default=0.9,
|
||||
help='SGD momemtum')
|
||||
parser.add_argument('--log_freq',
|
||||
type=int,
|
||||
default=500,
|
||||
help='how many steps do we log once')
|
||||
parser.add_argument('--save_freq',
|
||||
type=int,
|
||||
default=1,
|
||||
help='how many epochs do we save once')
|
||||
parser.add_argument('--gradient_accumulate_every',
|
||||
type=int,
|
||||
default=1,
|
||||
help='accumulation_steps')
|
||||
parser.add_argument('--ema_decay',
|
||||
type=float,
|
||||
default=0.995,
|
||||
help='')
|
||||
parser.add_argument('--step_start_ema',
|
||||
type=int,
|
||||
default=400,
|
||||
help='')
|
||||
parser.add_argument('--update_ema_every',
|
||||
type=int,
|
||||
default=10,
|
||||
help='')
|
||||
parser.add_argument('--crop_only',
|
||||
type=int,
|
||||
default=1,
|
||||
help='random seed')
|
||||
parser.add_argument('--centercrop',
|
||||
type=int,
|
||||
default=0,
|
||||
help='random seed')
|
||||
parser.add_argument('--random_flip',
|
||||
type=int,
|
||||
default=1,
|
||||
help='random seed')
|
||||
parser.add_argument('--verbose',
|
||||
type=int,
|
||||
default=1,
|
||||
help='')
|
||||
parser.add_argument('--fps',
|
||||
type=int,
|
||||
default=1,
|
||||
help='')
|
||||
parser.add_argument('--cudnn_benchmark',
|
||||
type=int,
|
||||
default=0,
|
||||
help='')
|
||||
parser.add_argument('--horizon',
|
||||
type=int,
|
||||
default=3,
|
||||
help='')
|
||||
parser.add_argument('--dataset',
|
||||
type=str,
|
||||
default='coin',
|
||||
help='dataset')
|
||||
parser.add_argument('--action_dim',
|
||||
type=int,
|
||||
default=778,
|
||||
help='')
|
||||
parser.add_argument('--observation_dim',
|
||||
type=int,
|
||||
default=1536,
|
||||
help='')
|
||||
parser.add_argument('--class_dim',
|
||||
type=int,
|
||||
default=180,
|
||||
help='')
|
||||
parser.add_argument('--n_diffusion_steps',
|
||||
type=int,
|
||||
default=200,
|
||||
help='')
|
||||
parser.add_argument('--n_train_steps',
|
||||
type=int,
|
||||
default=200,
|
||||
help='training_steps_per_epoch')
|
||||
parser.add_argument('--root',
|
||||
type=str,
|
||||
default='',
|
||||
help='root path of dataset crosstask')
|
||||
parser.add_argument('--json_path_train',
|
||||
type=str,
|
||||
default='dataset/coin/train_split_T4.json',
|
||||
help='path of the generated json file for train')
|
||||
parser.add_argument('--json_path_val',
|
||||
type=str,
|
||||
default='dataset/coin/coin_mlp_T4.json',
|
||||
help='path of the generated json file for val')
|
||||
parser.add_argument('--epochs', default=800, type=int, metavar='N',
|
||||
help='number of total epochs to run')
|
||||
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
||||
help='manual epoch number (useful on restarts)')
|
||||
parser.add_argument('--lr', '--learning-rate', default=1e-5, type=float,
|
||||
metavar='LR', help='initial learning rate', dest='lr')
|
||||
parser.add_argument('--resume', dest='resume', action='store_true',
|
||||
help='resume training from last checkpoint')
|
||||
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
|
||||
help='evaluate model on validation set')
|
||||
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
||||
help='use pre-trained model')
|
||||
parser.add_argument('--pin_memory', dest='pin_memory', action='store_true',
|
||||
help='use pin_memory')
|
||||
parser.add_argument('--world-size', default=1, type=int,
|
||||
help='number of nodes for distributed training')
|
||||
parser.add_argument('--rank', default=0, type=int,
|
||||
help='node rank for distributed training')
|
||||
parser.add_argument('--dist-file', default='dist-file', type=str,
|
||||
help='url used to set up distributed training')
|
||||
parser.add_argument('--dist-url', default='tcp://localhost:20000', type=str,
|
||||
help='url used to set up distributed training')
|
||||
parser.add_argument('--dist-backend', default='nccl', type=str,
|
||||
help='distributed backend')
|
||||
parser.add_argument('--seed', default=217, type=int,
|
||||
help='seed for initializing training. ')
|
||||
parser.add_argument('--gpu', default=None, type=int,
|
||||
help='GPU id to use.')
|
||||
parser.add_argument('--multiprocessing-distributed', action='store_true',
|
||||
help='Use multi-processing distributed training to launch '
|
||||
'N processes per node, which has N GPUs. This is the '
|
||||
'fastest way to use PyTorch for either single node or '
|
||||
'multi node data parallel training')
|
||||
args = parser.parse_args()
|
||||
return args
|
376
utils/eval.py
Normal file
376
utils/eval.py
Normal file
|
@ -0,0 +1,376 @@
|
|||
from .accuracy import *
|
||||
from model.helpers import AverageMeter
|
||||
|
||||
|
||||
def validate_act_noise(val_loader, model, args, act_emb):
|
||||
model.eval()
|
||||
losses = AverageMeter()
|
||||
acc_top1 = AverageMeter()
|
||||
acc_top5 = AverageMeter()
|
||||
trajectory_success_rate_meter = AverageMeter()
|
||||
MIoU1_meter = AverageMeter()
|
||||
MIoU2_meter = AverageMeter()
|
||||
|
||||
A0_acc = AverageMeter()
|
||||
AT_acc = AverageMeter()
|
||||
|
||||
for i_batch, sample_batch in enumerate(val_loader):
|
||||
# compute output
|
||||
global_img_tensors = sample_batch[0].cuda().contiguous().float()
|
||||
video_label = sample_batch[1].cuda()
|
||||
batch_size_current, T = video_label.size()
|
||||
task_class = sample_batch[2].view(-1).cuda()
|
||||
cond = {}
|
||||
|
||||
with torch.no_grad():
|
||||
cond[0] = global_img_tensors[:, 0, :]
|
||||
cond[T - 1] = global_img_tensors[:, -1, :]
|
||||
|
||||
task_onehot = torch.zeros((task_class.size(0), args.class_dim)) # [bs*T, ac_dim]
|
||||
ind = torch.arange(0, len(task_class))
|
||||
task_onehot[ind, task_class] = 1.
|
||||
task_onehot = task_onehot.cuda()
|
||||
temp = task_onehot.unsqueeze(1)
|
||||
task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim]
|
||||
|
||||
cond['task'] = task_class_
|
||||
|
||||
video_label_reshaped = video_label.view(-1)
|
||||
|
||||
action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim))
|
||||
ind = torch.arange(0, len(video_label_reshaped))
|
||||
action_label_onehot[ind, video_label_reshaped] = 1.
|
||||
action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda()
|
||||
|
||||
|
||||
x_start = torch.zeros((batch_size_current, T, args.class_dim + args.action_dim + args.observation_dim))
|
||||
x_start[:, 0, args.class_dim + args.action_dim:] = global_img_tensors[:, 0, :]
|
||||
x_start[:, -1, args.class_dim + args.action_dim:] = global_img_tensors[:, -1, :]
|
||||
|
||||
x_start[:, :, args.class_dim:args.class_dim + args.action_dim] = action_label_onehot
|
||||
x_start[:, :, :args.class_dim] = task_class_
|
||||
|
||||
|
||||
#output = model(cond, task_class, act_emb, if_jump=False, if_avg_mask=args.infer_avg_mask, cond_type=args.use_cls_mask)
|
||||
output = model(cond, act_emb, task_class, if_jump=False, if_avg_mask=args.infer_avg_mask)
|
||||
actions_pred = output.contiguous()
|
||||
loss = model.module.loss_fn(actions_pred, x_start.cuda())
|
||||
|
||||
actions_pred = actions_pred[:, :, args.class_dim:args.class_dim + args.action_dim].contiguous()
|
||||
actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim]
|
||||
|
||||
(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \
|
||||
accuracy(actions_pred.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon)
|
||||
|
||||
losses.update(loss.item(), batch_size_current)
|
||||
acc_top1.update(acc1.item(), batch_size_current)
|
||||
acc_top5.update(acc5.item(), batch_size_current)
|
||||
trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current)
|
||||
MIoU1_meter.update(MIoU1, batch_size_current)
|
||||
MIoU2_meter.update(MIoU2, batch_size_current)
|
||||
A0_acc.update(a0_acc, batch_size_current)
|
||||
AT_acc.update(aT_acc, batch_size_current)
|
||||
|
||||
return torch.tensor(losses.avg), torch.tensor(acc_top1.avg), torch.tensor(acc_top5.avg), \
|
||||
torch.tensor(trajectory_success_rate_meter.avg), \
|
||||
torch.tensor(MIoU1_meter.avg), torch.tensor(MIoU2_meter.avg), \
|
||||
torch.tensor(A0_acc.avg), torch.tensor(AT_acc.avg)
|
||||
|
||||
def validate_act(val_loader, model, args, act_emb):
|
||||
model.eval()
|
||||
losses = AverageMeter()
|
||||
acc_top1 = AverageMeter()
|
||||
acc_top5 = AverageMeter()
|
||||
trajectory_success_rate_meter = AverageMeter()
|
||||
MIoU1_meter = AverageMeter()
|
||||
MIoU2_meter = AverageMeter()
|
||||
|
||||
A0_acc = AverageMeter()
|
||||
AT_acc = AverageMeter()
|
||||
|
||||
for i_batch, sample_batch in enumerate(val_loader):
|
||||
# compute output
|
||||
global_img_tensors = sample_batch[0].cuda().contiguous().float()
|
||||
video_label = sample_batch[1].cuda()
|
||||
batch_size_current, T = video_label.size()
|
||||
task_class = sample_batch[2].view(-1).cuda()
|
||||
cond = {}
|
||||
|
||||
with torch.no_grad():
|
||||
cond[0] = global_img_tensors[:, 0, :]
|
||||
cond[T - 1] = global_img_tensors[:, -1, :]
|
||||
|
||||
task_onehot = torch.zeros((task_class.size(0), args.class_dim)) # [bs*T, ac_dim]
|
||||
ind = torch.arange(0, len(task_class))
|
||||
task_onehot[ind, task_class] = 1.
|
||||
task_onehot = task_onehot.cuda()
|
||||
temp = task_onehot.unsqueeze(1)
|
||||
task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim]
|
||||
if args.use_cls_mask:
|
||||
cond['task'] = task_class_
|
||||
|
||||
video_label_reshaped = video_label.view(-1)
|
||||
|
||||
action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim))
|
||||
ind = torch.arange(0, len(video_label_reshaped))
|
||||
action_label_onehot[ind, video_label_reshaped] = 1.
|
||||
action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda()
|
||||
|
||||
if args.use_cls_mask:
|
||||
x_start = torch.zeros((batch_size_current, T, args.class_dim + args.action_dim + args.observation_dim))
|
||||
x_start[:, 0, args.class_dim + args.action_dim:] = global_img_tensors[:, 0, :]
|
||||
x_start[:, -1, args.class_dim + args.action_dim:] = global_img_tensors[:, -1, :]
|
||||
|
||||
x_start[:, :, args.class_dim:args.class_dim + args.action_dim] = action_label_onehot
|
||||
x_start[:, :, :args.class_dim] = task_class_
|
||||
else:
|
||||
x_start = torch.zeros((batch_size_current, T, args.action_dim + args.observation_dim))
|
||||
x_start[:, 0, args.action_dim:] = global_img_tensors[:, 0, :]
|
||||
x_start[:, -1, args.action_dim:] = global_img_tensors[:, -1, :]
|
||||
|
||||
x_start[:, :, :args.action_dim] = action_label_onehot
|
||||
|
||||
#output = model(cond, task_class, act_emb, if_jump=False, if_avg_mask=args.infer_avg_mask, cond_type=args.use_cls_mask)
|
||||
output = model(cond, act_emb, task_class, if_jump=False, if_avg_mask=args.infer_avg_mask)
|
||||
#output = model(cond, act_emb, task_class, if_jump=True, if_avg_mask=args.infer_avg_mask)
|
||||
actions_pred = output.contiguous()
|
||||
loss = model.module.loss_fn(actions_pred, x_start.cuda())
|
||||
|
||||
actions_pred = actions_pred[:, :, args.class_dim:args.class_dim + args.action_dim].contiguous()
|
||||
actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim]
|
||||
|
||||
(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \
|
||||
accuracy(actions_pred.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon)
|
||||
|
||||
losses.update(loss.item(), batch_size_current)
|
||||
acc_top1.update(acc1.item(), batch_size_current)
|
||||
acc_top5.update(acc5.item(), batch_size_current)
|
||||
trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current)
|
||||
MIoU1_meter.update(MIoU1, batch_size_current)
|
||||
MIoU2_meter.update(MIoU2, batch_size_current)
|
||||
A0_acc.update(a0_acc, batch_size_current)
|
||||
AT_acc.update(aT_acc, batch_size_current)
|
||||
|
||||
return torch.tensor(losses.avg), torch.tensor(acc_top1.avg), torch.tensor(acc_top5.avg), \
|
||||
torch.tensor(trajectory_success_rate_meter.avg), \
|
||||
torch.tensor(MIoU1_meter.avg), torch.tensor(MIoU2_meter.avg), \
|
||||
torch.tensor(A0_acc.avg), torch.tensor(AT_acc.avg)
|
||||
|
||||
|
||||
def validate(val_loader, model, args):
|
||||
model.eval()
|
||||
losses = AverageMeter()
|
||||
acc_top1 = AverageMeter()
|
||||
acc_top5 = AverageMeter()
|
||||
trajectory_success_rate_meter = AverageMeter()
|
||||
MIoU1_meter = AverageMeter()
|
||||
MIoU2_meter = AverageMeter()
|
||||
|
||||
A0_acc = AverageMeter()
|
||||
AT_acc = AverageMeter()
|
||||
|
||||
for i_batch, sample_batch in enumerate(val_loader):
|
||||
# compute output
|
||||
global_img_tensors = sample_batch[0].cuda().contiguous().float()
|
||||
video_label = sample_batch[1].cuda()
|
||||
batch_size_current, T = video_label.size()
|
||||
task_class = sample_batch[2].view(-1).cuda()
|
||||
cond = {}
|
||||
|
||||
with torch.no_grad():
|
||||
cond[0] = global_img_tensors[:, 0, :]
|
||||
cond[T - 1] = global_img_tensors[:, -1, :]
|
||||
|
||||
task_onehot = torch.zeros((task_class.size(0), args.class_dim)) # [bs*T, ac_dim]
|
||||
ind = torch.arange(0, len(task_class))
|
||||
task_onehot[ind, task_class] = 1.
|
||||
task_onehot = task_onehot.cuda()
|
||||
temp = task_onehot.unsqueeze(1)
|
||||
task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim]
|
||||
if args.use_cls_mask:
|
||||
cond['task'] = task_class_
|
||||
|
||||
video_label_reshaped = video_label.view(-1)
|
||||
|
||||
action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim))
|
||||
ind = torch.arange(0, len(video_label_reshaped))
|
||||
action_label_onehot[ind, video_label_reshaped] = 1.
|
||||
action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda()
|
||||
|
||||
if args.use_cls_mask:
|
||||
x_start = torch.zeros((batch_size_current, T, args.class_dim + args.action_dim + args.observation_dim))
|
||||
x_start[:, 0, args.class_dim + args.action_dim:] = global_img_tensors[:, 0, :]
|
||||
x_start[:, -1, args.class_dim + args.action_dim:] = global_img_tensors[:, -1, :]
|
||||
|
||||
x_start[:, :, args.class_dim:args.class_dim + args.action_dim] = action_label_onehot
|
||||
x_start[:, :, :args.class_dim] = task_class_
|
||||
else:
|
||||
x_start = torch.zeros((batch_size_current, T, args.action_dim + args.observation_dim))
|
||||
x_start[:, 0, args.action_dim:] = global_img_tensors[:, 0, :]
|
||||
x_start[:, -1, args.action_dim:] = global_img_tensors[:, -1, :]
|
||||
|
||||
x_start[:, :, :args.action_dim] = action_label_onehot
|
||||
|
||||
output = model(cond, task_class, act_emb, if_jump=False, if_avg_mask=args.infer_avg_mask, cond_type=args.use_cls_mask)
|
||||
|
||||
actions_pred = output.contiguous()
|
||||
loss = model.module.loss_fn(actions_pred, x_start.cuda())
|
||||
|
||||
if args.use_cls_mask:
|
||||
actions_pred = actions_pred[:, :, args.class_dim:args.class_dim + args.action_dim].contiguous()
|
||||
actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim]
|
||||
else:
|
||||
actions_pred = actions_pred[:, :, :args.action_dim].contiguous()
|
||||
actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim]
|
||||
|
||||
(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \
|
||||
accuracy(actions_pred.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon)
|
||||
|
||||
losses.update(loss.item(), batch_size_current)
|
||||
acc_top1.update(acc1.item(), batch_size_current)
|
||||
acc_top5.update(acc5.item(), batch_size_current)
|
||||
trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current)
|
||||
MIoU1_meter.update(MIoU1, batch_size_current)
|
||||
MIoU2_meter.update(MIoU2, batch_size_current)
|
||||
A0_acc.update(a0_acc, batch_size_current)
|
||||
AT_acc.update(aT_acc, batch_size_current)
|
||||
|
||||
return torch.tensor(losses.avg), torch.tensor(acc_top1.avg), torch.tensor(acc_top5.avg), \
|
||||
torch.tensor(trajectory_success_rate_meter.avg), \
|
||||
torch.tensor(MIoU1_meter.avg), torch.tensor(MIoU2_meter.avg), \
|
||||
torch.tensor(A0_acc.avg), torch.tensor(AT_acc.avg)
|
||||
|
||||
def validate_mlp(val_loader, model, act_model, args, act_emb):
|
||||
model.eval()
|
||||
losses = AverageMeter()
|
||||
acc_top1 = AverageMeter()
|
||||
acc_top5 = AverageMeter()
|
||||
trajectory_success_rate_meter = AverageMeter()
|
||||
MIoU1_meter = AverageMeter()
|
||||
MIoU2_meter = AverageMeter()
|
||||
|
||||
A0_acc = AverageMeter()
|
||||
AT_acc = AverageMeter()
|
||||
|
||||
for i_batch, sample_batch in enumerate(val_loader):
|
||||
# compute output
|
||||
global_img_tensors = sample_batch[0].cuda().contiguous().float()
|
||||
video_label = sample_batch[1].cuda()
|
||||
batch_size_current, T = video_label.size()
|
||||
task_class = sample_batch[2].view(-1).cuda()
|
||||
cond = {}
|
||||
|
||||
with torch.no_grad():
|
||||
cond[0] = global_img_tensors[:, 0, :].view((batch_size_current,32,48))
|
||||
cond[T - 1] = global_img_tensors[:, -1, :].view((batch_size_current,32,48))
|
||||
|
||||
task_onehot = torch.zeros((task_class.size(0), args.class_dim)) # [bs*T, ac_dim]
|
||||
ind = torch.arange(0, len(task_class))
|
||||
task_onehot[ind, task_class] = 1.
|
||||
task_onehot = task_onehot.cuda()
|
||||
temp = task_onehot.unsqueeze(1)
|
||||
task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim]
|
||||
|
||||
video_label_reshaped = video_label.view(-1)
|
||||
|
||||
action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim))
|
||||
ind = torch.arange(0, len(video_label_reshaped))
|
||||
action_label_onehot[ind, video_label_reshaped] = 1.
|
||||
action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda()
|
||||
|
||||
x_start = torch.zeros((batch_size_current, T, 32, 64))
|
||||
x_start[:, 0, :, :48] = global_img_tensors[:, 0, :].view((batch_size_current,32,48))
|
||||
x_start[:, -1, :, :48] = global_img_tensors[:, -1, :].view((batch_size_current,32,48))
|
||||
|
||||
for i in range(video_label.shape[0]):
|
||||
for j in range(video_label.shape[1]):
|
||||
x_start[i][j][:, 48:] = act_emb[video_label[i][j]].view(32,16)
|
||||
|
||||
output = model(cond, task_class, act_emb, if_jump=False, if_avg_mask=args.infer_avg_mask, cond_type=args.use_cls_mask)
|
||||
actions_pred = output.contiguous()
|
||||
loss = model.module.loss_fn(actions_pred, x_start.cuda())
|
||||
|
||||
actions_pred = output[:, :, :, 48:].contiguous().view(batch_size_current*args.horizon, 512)
|
||||
#actions = act_model(actions_pred)
|
||||
sim_act_pred = similarity_score(actions_pred, act_emb)
|
||||
#print('actions',actions.shape, actions_pred.shape, video_label_reshaped.shape, video_label.shape)
|
||||
#print('lebel', torch.max(actions, 1))
|
||||
#print('lebel', video_label_reshaped)
|
||||
|
||||
#(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = accuracy(actions.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon)
|
||||
(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = accuracy(sim_act_pred, video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon)
|
||||
|
||||
|
||||
losses.update(loss.item(), batch_size_current)
|
||||
acc_top1.update(acc1.item(), batch_size_current)
|
||||
acc_top5.update(acc5.item(), batch_size_current)
|
||||
trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current)
|
||||
MIoU1_meter.update(MIoU1, batch_size_current)
|
||||
MIoU2_meter.update(MIoU2, batch_size_current)
|
||||
A0_acc.update(a0_acc, batch_size_current)
|
||||
AT_acc.update(aT_acc, batch_size_current)
|
||||
|
||||
'''
|
||||
with torch.no_grad():
|
||||
cond[0] = global_img_tensors[:, 0, :]
|
||||
cond[T - 1] = global_img_tensors[:, -1, :]
|
||||
|
||||
task_onehot = torch.zeros((task_class.size(0), args.class_dim)) # [bs*T, ac_dim]
|
||||
ind = torch.arange(0, len(task_class))
|
||||
task_onehot[ind, task_class] = 1.
|
||||
task_onehot = task_onehot.cuda()
|
||||
temp = task_onehot.unsqueeze(1)
|
||||
task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim]
|
||||
if args.use_cls_mask:
|
||||
cond['task'] = task_class_
|
||||
|
||||
video_label_reshaped = video_label.view(-1)
|
||||
|
||||
action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim))
|
||||
ind = torch.arange(0, len(video_label_reshaped))
|
||||
action_label_onehot[ind, video_label_reshaped] = 1.
|
||||
action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda()
|
||||
|
||||
if args.use_cls_mask:
|
||||
x_start = torch.zeros((batch_size_current, T, args.class_dim + args.action_dim + args.observation_dim))
|
||||
x_start[:, 0, args.class_dim + args.action_dim:] = global_img_tensors[:, 0, :]
|
||||
x_start[:, -1, args.class_dim + args.action_dim:] = global_img_tensors[:, -1, :]
|
||||
|
||||
x_start[:, :, args.class_dim:args.class_dim + args.action_dim] = action_label_onehot
|
||||
x_start[:, :, :args.class_dim] = task_class_
|
||||
else:
|
||||
x_start = torch.zeros((batch_size_current, T, args.action_dim + args.observation_dim))
|
||||
x_start[:, 0, args.action_dim:] = global_img_tensors[:, 0, :]
|
||||
x_start[:, -1, args.action_dim:] = global_img_tensors[:, -1, :]
|
||||
|
||||
x_start[:, :, :args.action_dim] = action_label_onehot
|
||||
|
||||
output = model(cond, task_class, act_emb, if_jump=False, if_avg_mask=args.infer_avg_mask, cond_type=args.use_cls_mask)
|
||||
actions_pred = output.contiguous()
|
||||
loss = model.module.loss_fn(actions_pred, x_start.cuda())
|
||||
|
||||
if args.use_cls_mask:
|
||||
actions_pred = actions_pred[:, :, args.class_dim:args.class_dim + args.action_dim].contiguous()
|
||||
actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim]
|
||||
else:
|
||||
actions_pred = actions_pred[:, :, :args.action_dim].contiguous()
|
||||
actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim]
|
||||
|
||||
(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \
|
||||
accuracy(actions_pred.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon)
|
||||
|
||||
losses.update(loss.item(), batch_size_current)
|
||||
acc_top1.update(acc1.item(), batch_size_current)
|
||||
acc_top5.update(acc5.item(), batch_size_current)
|
||||
trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current)
|
||||
MIoU1_meter.update(MIoU1, batch_size_current)
|
||||
MIoU2_meter.update(MIoU2, batch_size_current)
|
||||
A0_acc.update(a0_acc, batch_size_current)
|
||||
AT_acc.update(aT_acc, batch_size_current)
|
||||
'''
|
||||
|
||||
return torch.tensor(losses.avg), torch.tensor(acc_top1.avg), torch.tensor(acc_top5.avg), \
|
||||
torch.tensor(trajectory_success_rate_meter.avg), \
|
||||
torch.tensor(MIoU1_meter.avg), torch.tensor(MIoU2_meter.avg), \
|
||||
torch.tensor(A0_acc.avg), torch.tensor(AT_acc.avg)
|
||||
|
143
utils/training.py
Normal file
143
utils/training.py
Normal file
|
@ -0,0 +1,143 @@
|
|||
import copy
|
||||
from model.helpers import AverageMeter
|
||||
from .accuracy import *
|
||||
import numpy as np
|
||||
|
||||
|
||||
def cycle(dl):
|
||||
while True:
|
||||
for data in dl:
|
||||
yield data
|
||||
|
||||
|
||||
class EMA():
|
||||
"""
|
||||
empirical moving average
|
||||
"""
|
||||
|
||||
def __init__(self, beta):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
|
||||
def update_model_average(self, ma_model, current_model):
|
||||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
||||
old_weight, up_weight = ma_params.data, current_params.data
|
||||
ma_params.data = self.update_average(old_weight, up_weight)
|
||||
|
||||
def update_average(self, old, new):
|
||||
if old is None:
|
||||
return new
|
||||
return old * self.beta + (1 - self.beta) * new
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
def __init__(
|
||||
self,
|
||||
diffusion_model,
|
||||
datasetloader,
|
||||
ema_decay=0.995,
|
||||
train_lr=1e-5,
|
||||
gradient_accumulate_every=1,
|
||||
step_start_ema=400,
|
||||
update_ema_every=10,
|
||||
log_freq=100,
|
||||
):
|
||||
super().__init__()
|
||||
self.model = diffusion_model
|
||||
self.ema = EMA(ema_decay)
|
||||
self.ema_model = copy.deepcopy(self.model)
|
||||
self.update_ema_every = update_ema_every
|
||||
|
||||
self.step_start_ema = step_start_ema
|
||||
self.log_freq = log_freq
|
||||
self.gradient_accumulate_every = gradient_accumulate_every
|
||||
|
||||
self.dataloader = cycle(datasetloader)
|
||||
self.optimizer = torch.optim.AdamW(diffusion_model.parameters(), lr=train_lr, weight_decay=0.0)
|
||||
# self.optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, diffusion_model.parameters()), lr=train_lr, weight_decay=0.0)
|
||||
|
||||
self.reset_parameters()
|
||||
self.step = 0
|
||||
|
||||
def reset_parameters(self):
|
||||
self.ema_model.load_state_dict(self.model.state_dict())
|
||||
|
||||
def step_ema(self):
|
||||
if self.step < self.step_start_ema:
|
||||
self.reset_parameters()
|
||||
return
|
||||
self.ema.update_model_average(self.ema_model, self.model)
|
||||
|
||||
# -----------------------------------------------------------------------------#
|
||||
# ------------------------------------ api ------------------------------------#
|
||||
# -----------------------------------------------------------------------------#
|
||||
|
||||
def train(self, n_train_steps, if_calculate_acc, args, scheduler):
|
||||
self.model.train()
|
||||
self.ema_model.train()
|
||||
losses = AverageMeter()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
for step in range(n_train_steps):
|
||||
for i in range(self.gradient_accumulate_every):
|
||||
batch = next(self.dataloader)
|
||||
bs, T = batch[1].shape # [bs, (T+1), ob_dim]
|
||||
global_img_tensors = batch[0].cuda().contiguous().float()
|
||||
img_tensors = torch.zeros((bs, T, args.class_dim + args.action_dim + args.observation_dim))
|
||||
img_tensors[:, 0, args.class_dim+args.action_dim:] = global_img_tensors[:, 0, :]
|
||||
img_tensors[:, -1, args.class_dim+args.action_dim:] = global_img_tensors[:, -1, :]
|
||||
img_tensors = img_tensors.cuda()
|
||||
np.save('conformal/X_train.npy', global_img_tensors[:, 0, :].cpu().detach().numpy())
|
||||
np.save('conformal/X_end_train.npy', global_img_tensors[:, -1, :].cpu().detach().numpy())
|
||||
|
||||
video_label = batch[1].view(-1).cuda() # [bs*T]
|
||||
np.save('conformal/label_train.npy', video_label.cpu().detach().numpy())
|
||||
task_class = batch[2].view(-1).cuda() # [bs]
|
||||
|
||||
action_label_onehot = torch.zeros((video_label.size(0), self.model.module.action_dim))
|
||||
# [bs*T, ac_dim]
|
||||
ind = torch.arange(0, len(video_label))
|
||||
action_label_onehot[ind, video_label] = 1.
|
||||
action_label_onehot = action_label_onehot.reshape(bs, T, -1).cuda()
|
||||
img_tensors[:, :, args.class_dim:args.class_dim+args.action_dim] = action_label_onehot
|
||||
|
||||
task_onehot = torch.zeros((task_class.size(0), args.class_dim))
|
||||
# [bs*T, ac_dim]
|
||||
ind = torch.arange(0, len(task_class))
|
||||
task_onehot[ind, task_class] = 1.
|
||||
task_onehot = task_onehot.cuda()
|
||||
temp = task_onehot.unsqueeze(1)
|
||||
task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim]
|
||||
img_tensors[:, :, :args.class_dim] = task_class_
|
||||
|
||||
cond = {0: global_img_tensors[:, 0, :].float(), T - 1: global_img_tensors[:, -1, :].float(),
|
||||
'task': task_class_}
|
||||
x = img_tensors.float()
|
||||
loss = self.model.module.loss(x, cond)
|
||||
loss = loss / self.gradient_accumulate_every
|
||||
loss.backward()
|
||||
losses.update(loss.item(), bs)
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
scheduler.step()
|
||||
|
||||
if self.step % self.update_ema_every == 0:
|
||||
self.step_ema()
|
||||
self.step += 1
|
||||
|
||||
if if_calculate_acc:
|
||||
with torch.no_grad():
|
||||
output = self.ema_model(cond)
|
||||
actions_pred = output[:, :, args.class_dim:args.class_dim+self.model.module.action_dim]\
|
||||
.contiguous().view(-1, self.model.module.action_dim) # [bs*T, action_dim]
|
||||
|
||||
(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \
|
||||
accuracy(actions_pred.cpu(), video_label.cpu(), topk=(1, 5),
|
||||
max_traj_len=self.model.module.horizon)
|
||||
|
||||
return torch.tensor(losses.avg), acc1, acc5, torch.tensor(trajectory_success_rate), \
|
||||
torch.tensor(MIoU1), torch.tensor(MIoU2), a0_acc, aT_acc
|
||||
|
||||
else:
|
||||
return torch.tensor(losses.avg)
|
170
utils/training_act.py
Normal file
170
utils/training_act.py
Normal file
|
@ -0,0 +1,170 @@
|
|||
import copy
|
||||
from model.helpers import AverageMeter
|
||||
from .accuracy import *
|
||||
import numpy as np
|
||||
|
||||
def cycle(dl):
|
||||
while True:
|
||||
for data in dl:
|
||||
yield data
|
||||
|
||||
class EMA():
|
||||
def __init__(self, beta):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
|
||||
def update_model_average(self, ma_model, current_model):
|
||||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
||||
old_weight, up_weight = ma_params.data, current_params.data
|
||||
ma_params.data = self.update_average(old_weight, up_weight)
|
||||
|
||||
def update_average(self, old, new):
|
||||
if old is None:
|
||||
return new
|
||||
return old * self.beta + (1 - self.beta) * new
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
def __init__(
|
||||
self,
|
||||
diffusion_model,
|
||||
datasetloader,
|
||||
ema_decay=0.995,
|
||||
train_lr=1e-5,
|
||||
gradient_accumulate_every=1,
|
||||
step_start_ema=400,
|
||||
update_ema_every=10,
|
||||
log_freq=100,
|
||||
act_emd=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.model = diffusion_model
|
||||
self.ema = EMA(ema_decay)
|
||||
self.ema_model = copy.deepcopy(self.model)
|
||||
self.update_ema_every = update_ema_every
|
||||
|
||||
self.step_start_ema = step_start_ema
|
||||
self.log_freq = log_freq
|
||||
self.gradient_accumulate_every = gradient_accumulate_every
|
||||
|
||||
self.act_emd = act_emd
|
||||
|
||||
self.dataloader = cycle(datasetloader)
|
||||
self.optimizer = torch.optim.AdamW(diffusion_model.parameters(), lr=train_lr, weight_decay=0.0)
|
||||
|
||||
self.reset_parameters()
|
||||
self.step = 0
|
||||
|
||||
def reset_parameters(self):
|
||||
self.ema_model.load_state_dict(self.model.state_dict())
|
||||
|
||||
def step_ema(self):
|
||||
if self.step < self.step_start_ema:
|
||||
self.reset_parameters()
|
||||
return
|
||||
self.ema.update_model_average(self.ema_model, self.model)
|
||||
|
||||
def get_noise_mask(self, action_label, args, img_tensors):
|
||||
output_act_emb = torch.randn_like(img_tensors).cuda() #
|
||||
self.act_emd = self.act_emd.cuda()
|
||||
np.random.seed(args.seed)
|
||||
if args.mask_type == 'single_add':
|
||||
for i in range(action_label.shape[0]):
|
||||
for j in range(action_label.shape[1]):
|
||||
if args.dataset=='crosstask' or args.dataset=='NIV':
|
||||
rnd_idx = np.random.randint(0, args.class_dim+args.action_dim+args.observation_dim-512) # for niv and crosstask
|
||||
if args.dataset=='coin':
|
||||
rnd_idx = np.random.randint(args.class_dim, args.class_dim+args.action_dim-512)
|
||||
output_act_emb[i][j][rnd_idx:rnd_idx+512] = self.act_emd[action_label[i][j]]
|
||||
return output_act_emb.cuda()
|
||||
|
||||
if args.mask_type == 'multi_add':
|
||||
for i in range(action_label.shape[0]):
|
||||
for j in range(action_label.shape[1]):
|
||||
if args.dataset=='crosstask' or args.dataset=='NIV':
|
||||
rnd_idx = np.random.randint(0, 512, args.action_dim)
|
||||
if j==0:
|
||||
output_act_emb[i][j][args.class_dim:args.class_dim+args.action_dim] = self.act_emd[action_label[i][j],rnd_idx]
|
||||
else:
|
||||
output_act_emb[i][j][args.class_dim:args.class_dim+args.action_dim] = output_act_emb[i][j-1][args.class_dim:args.class_dim+args.action_dim] + self.act_emd[action_label[i][j],rnd_idx]
|
||||
if args.dataset=='coin':
|
||||
rnd_idx = np.random.randint(args.class_dim, args.class_dim+args.action_dim-512)
|
||||
if j==0:
|
||||
output_act_emb[i][j][rnd_idx:rnd_idx+512] = self.act_emd[action_label[i][j]]
|
||||
else:
|
||||
output_act_emb[i][j][rnd_idx:rnd_idx+512] = output_act_emb[i][j-1][rnd_idx:rnd_idx+512] + self.act_emd[action_label[i][j]]
|
||||
|
||||
return output_act_emb.cuda()
|
||||
|
||||
def train(self, n_train_steps, if_calculate_acc, args, scheduler):
|
||||
self.model.train()
|
||||
self.ema_model.train()
|
||||
losses = AverageMeter()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
for step in range(n_train_steps):
|
||||
for i in range(self.gradient_accumulate_every):
|
||||
batch = next(self.dataloader)
|
||||
bs, T = batch[1].shape # [bs, (T+1), ob_dim]
|
||||
global_img_tensors = batch[0].cuda().contiguous().float()
|
||||
img_tensors = torch.zeros((bs, T, args.class_dim + args.action_dim + args.observation_dim))
|
||||
img_tensors[:, 0, args.class_dim+args.action_dim:] = global_img_tensors[:, 0, :]
|
||||
img_tensors[:, -1, args.class_dim+args.action_dim:] = global_img_tensors[:, -1, :]
|
||||
img_tensors = img_tensors.cuda()
|
||||
|
||||
video_label = batch[1].view(-1).cuda() # [bs*T]
|
||||
|
||||
task_class = batch[2].view(-1).cuda() # [bs]
|
||||
|
||||
action_label = batch[1]
|
||||
|
||||
act_emb_noise = self.get_noise_mask(action_label, args, img_tensors)
|
||||
|
||||
action_label_onehot = torch.zeros((video_label.size(0), self.model.module.action_dim))
|
||||
# [bs*T, ac_dim]
|
||||
ind = torch.arange(0, len(video_label))
|
||||
action_label_onehot[ind, video_label] = 1.
|
||||
action_label_onehot = action_label_onehot.reshape(bs, T, -1).cuda()
|
||||
img_tensors[:, :, args.class_dim:args.class_dim+args.action_dim] = action_label_onehot
|
||||
|
||||
task_onehot = torch.zeros((task_class.size(0), args.class_dim))
|
||||
# [bs*T, ac_dim]
|
||||
ind = torch.arange(0, len(task_class))
|
||||
task_onehot[ind, task_class] = 1.
|
||||
task_onehot = task_onehot.cuda()
|
||||
temp = task_onehot.unsqueeze(1)
|
||||
task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim]
|
||||
img_tensors[:, :, :args.class_dim] = task_class_
|
||||
|
||||
cond = {0: global_img_tensors[:, 0, :].float(), T - 1: global_img_tensors[:, -1, :].float(),
|
||||
'task': task_class_}
|
||||
x = img_tensors.float()
|
||||
|
||||
loss = self.model.module.loss(x, cond, act_emb_noise, task_class)
|
||||
loss = loss / self.gradient_accumulate_every
|
||||
loss.backward()
|
||||
losses.update(loss.item(), bs)
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
scheduler.step()
|
||||
|
||||
if self.step % self.update_ema_every == 0:
|
||||
self.step_ema()
|
||||
self.step += 1
|
||||
|
||||
if if_calculate_acc:
|
||||
with torch.no_grad():
|
||||
output = self.ema_model(cond, self.act_emd, task_class, if_jump=False, if_avg_mask=args.infer_avg_mask)
|
||||
actions_pred = output[:, :, args.class_dim:args.class_dim+self.model.module.action_dim]\
|
||||
.contiguous().view(-1, self.model.module.action_dim) # [bs*T, action_dim]
|
||||
|
||||
(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \
|
||||
accuracy(actions_pred.cpu(), video_label.cpu(), topk=(1, 5),
|
||||
max_traj_len=self.model.module.horizon)
|
||||
|
||||
return torch.tensor(losses.avg), acc1, acc5, torch.tensor(trajectory_success_rate), \
|
||||
torch.tensor(MIoU1), torch.tensor(MIoU2), a0_acc, aT_acc
|
||||
|
||||
else:
|
||||
return torch.tensor(losses.avg)
|
Loading…
Add table
Add a link
Reference in a new issue