first commit

This commit is contained in:
Lei Shi 2024-12-02 15:42:58 +01:00
commit 8f8cf48929
2819 changed files with 33143 additions and 0 deletions

73
utils/accuracy.py Normal file
View file

@ -0,0 +1,73 @@
import torch
def accuracy(output, target, topk=(1,), max_traj_len=0):
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t() # [k, bs*T]
#print(pred, target.view(1, -1).expand_as(pred))
correct = pred.eq(target.view(1, -1).expand_as(pred)) # [k, bs*T]
correct_a = correct[:1].view(-1, max_traj_len) # [bs, T]
correct_a0 = correct_a[:, 0].reshape(-1).float().mean().mul_(100.0)
correct_aT = correct_a[:, -1].reshape(-1).float().mean().mul_(100.0)
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
correct_1 = correct[:1] # (1, bs*T)
# Success Rate
trajectory_success = torch.all(correct_1.view(correct_1.shape[1] // max_traj_len, -1), dim=1)
trajectory_success_rate = trajectory_success.sum() * 100.0 / trajectory_success.shape[0]
# MIoU
_, pred_token = output.topk(1, 1, True, True) # [bs*T, 1]
pred_inst = pred_token.view(correct_1.shape[1], -1) # [bs*T, 1]
pred_inst_set = set()
target_inst = target.view(correct_1.shape[1], -1) # [bs*T, 1]
target_inst_set = set()
for i in range(pred_inst.shape[0]):
# print(pred_inst[i], target_inst[i])
pred_inst_set.add(tuple(pred_inst[i].tolist()))
target_inst_set.add(tuple(target_inst[i].tolist()))
MIoU1 = 100.0 * len(pred_inst_set.intersection(target_inst_set)) / len(pred_inst_set.union(target_inst_set))
batch_size = batch_size // max_traj_len
pred_inst = pred_token.view(batch_size, -1) # [bs, T]
pred_inst_set = set()
target_inst = target.view(batch_size, -1) # [bs, T]
target_inst_set = set()
MIoU_sum = 0
for i in range(pred_inst.shape[0]):
# print(pred_inst[i], target_inst[i])
pred_inst_set.update(pred_inst[i].tolist())
target_inst_set.update(target_inst[i].tolist())
MIoU_current = 100.0 * len(pred_inst_set.intersection(target_inst_set)) / len(
pred_inst_set.union(target_inst_set))
MIoU_sum += MIoU_current
pred_inst_set.clear()
target_inst_set.clear()
MIoU2 = MIoU_sum / batch_size
return res, trajectory_success_rate, MIoU1, MIoU2, correct_a0, correct_aT
def similarity_score(pred, act_emb, metric='cos'):
# pred shape: [bs*t, 512] 512 is action embedding shape
sim_score = torch.zeros((pred.shape[0], act_emb.shape[0])).cuda()
if metric == 'cos':
cos = torch.nn.CosineSimilarity(dim=0, eps=1e-8)
#print(pred[0].shape, sim_score.shape, act_emb[0].shape)
for i in range(sim_score.shape[0]):
for j in range(act_emb.shape[0]):
sim_score[i][j] = cos(pred[i],act_emb[j].cuda())
sim_score = torch.abs(sim_score)
print('sim_score', sim_score, torch.max(sim_score, 1))
return sim_score.cpu()

191
utils/args.py Normal file
View file

@ -0,0 +1,191 @@
import argparse
def get_args(description='whl'):
parser = argparse.ArgumentParser(description=description)
parser.add_argument('--act_emb_path',
type=str,
default='dataset/coin/steps_info.pickle',
help='action embedding path')
parser.add_argument('--checkpoint_mlp',
type=str,
default='',
help='checkpoint path for task prediction model')
parser.add_argument('--checkpoint_diff',
type=str,
default='',
help='checkpoint path for diffusion model')
parser.add_argument('--mask_type',
type=str,
default='multi_add', # single_add, multi_add
help='action embedding mask type')
parser.add_argument('--attn',
type=str,
default='attention', # single_add, multi_add
help='WithAttention: unet with attn. NoAttention: unet without attention.')
parser.add_argument('--infer_avg_mask',
type=bool,
default=False,
help='if use average mask for inference')
parser.add_argument('--use_cls_mask',
type=bool,
default=False,
help='if use class label in diffusion mask')
parser.add_argument('--checkpoint_root',
type=str,
default='checkpoint',
help='checkpoint dir root')
parser.add_argument('--log_root',
type=str,
default='log',
help='log dir root')
parser.add_argument('--checkpoint_dir',
type=str,
default='',
help='checkpoint model folder')
parser.add_argument('--optimizer',
type=str,
default='adam',
help='opt algorithm')
parser.add_argument('--num_thread_reader',
type=int,
default=40,
help='')
parser.add_argument('--batch_size',
type=int,
default=256, # 256
help='batch size')
parser.add_argument('--batch_size_val',
type=int,
default=1024, # 1024
help='batch size eval')
parser.add_argument('--pretrain_cnn_path',
type=str,
default='',
help='')
parser.add_argument('--momemtum',
type=float,
default=0.9,
help='SGD momemtum')
parser.add_argument('--log_freq',
type=int,
default=500,
help='how many steps do we log once')
parser.add_argument('--save_freq',
type=int,
default=1,
help='how many epochs do we save once')
parser.add_argument('--gradient_accumulate_every',
type=int,
default=1,
help='accumulation_steps')
parser.add_argument('--ema_decay',
type=float,
default=0.995,
help='')
parser.add_argument('--step_start_ema',
type=int,
default=400,
help='')
parser.add_argument('--update_ema_every',
type=int,
default=10,
help='')
parser.add_argument('--crop_only',
type=int,
default=1,
help='random seed')
parser.add_argument('--centercrop',
type=int,
default=0,
help='random seed')
parser.add_argument('--random_flip',
type=int,
default=1,
help='random seed')
parser.add_argument('--verbose',
type=int,
default=1,
help='')
parser.add_argument('--fps',
type=int,
default=1,
help='')
parser.add_argument('--cudnn_benchmark',
type=int,
default=0,
help='')
parser.add_argument('--horizon',
type=int,
default=3,
help='')
parser.add_argument('--dataset',
type=str,
default='coin',
help='dataset')
parser.add_argument('--action_dim',
type=int,
default=778,
help='')
parser.add_argument('--observation_dim',
type=int,
default=1536,
help='')
parser.add_argument('--class_dim',
type=int,
default=180,
help='')
parser.add_argument('--n_diffusion_steps',
type=int,
default=200,
help='')
parser.add_argument('--n_train_steps',
type=int,
default=200,
help='training_steps_per_epoch')
parser.add_argument('--root',
type=str,
default='',
help='root path of dataset crosstask')
parser.add_argument('--json_path_train',
type=str,
default='dataset/coin/train_split_T4.json',
help='path of the generated json file for train')
parser.add_argument('--json_path_val',
type=str,
default='dataset/coin/coin_mlp_T4.json',
help='path of the generated json file for val')
parser.add_argument('--epochs', default=800, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('--lr', '--learning-rate', default=1e-5, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--resume', dest='resume', action='store_true',
help='resume training from last checkpoint')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument('--pin_memory', dest='pin_memory', action='store_true',
help='use pin_memory')
parser.add_argument('--world-size', default=1, type=int,
help='number of nodes for distributed training')
parser.add_argument('--rank', default=0, type=int,
help='node rank for distributed training')
parser.add_argument('--dist-file', default='dist-file', type=str,
help='url used to set up distributed training')
parser.add_argument('--dist-url', default='tcp://localhost:20000', type=str,
help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str,
help='distributed backend')
parser.add_argument('--seed', default=217, type=int,
help='seed for initializing training. ')
parser.add_argument('--gpu', default=None, type=int,
help='GPU id to use.')
parser.add_argument('--multiprocessing-distributed', action='store_true',
help='Use multi-processing distributed training to launch '
'N processes per node, which has N GPUs. This is the '
'fastest way to use PyTorch for either single node or '
'multi node data parallel training')
args = parser.parse_args()
return args

376
utils/eval.py Normal file
View file

@ -0,0 +1,376 @@
from .accuracy import *
from model.helpers import AverageMeter
def validate_act_noise(val_loader, model, args, act_emb):
model.eval()
losses = AverageMeter()
acc_top1 = AverageMeter()
acc_top5 = AverageMeter()
trajectory_success_rate_meter = AverageMeter()
MIoU1_meter = AverageMeter()
MIoU2_meter = AverageMeter()
A0_acc = AverageMeter()
AT_acc = AverageMeter()
for i_batch, sample_batch in enumerate(val_loader):
# compute output
global_img_tensors = sample_batch[0].cuda().contiguous().float()
video_label = sample_batch[1].cuda()
batch_size_current, T = video_label.size()
task_class = sample_batch[2].view(-1).cuda()
cond = {}
with torch.no_grad():
cond[0] = global_img_tensors[:, 0, :]
cond[T - 1] = global_img_tensors[:, -1, :]
task_onehot = torch.zeros((task_class.size(0), args.class_dim)) # [bs*T, ac_dim]
ind = torch.arange(0, len(task_class))
task_onehot[ind, task_class] = 1.
task_onehot = task_onehot.cuda()
temp = task_onehot.unsqueeze(1)
task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim]
cond['task'] = task_class_
video_label_reshaped = video_label.view(-1)
action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim))
ind = torch.arange(0, len(video_label_reshaped))
action_label_onehot[ind, video_label_reshaped] = 1.
action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda()
x_start = torch.zeros((batch_size_current, T, args.class_dim + args.action_dim + args.observation_dim))
x_start[:, 0, args.class_dim + args.action_dim:] = global_img_tensors[:, 0, :]
x_start[:, -1, args.class_dim + args.action_dim:] = global_img_tensors[:, -1, :]
x_start[:, :, args.class_dim:args.class_dim + args.action_dim] = action_label_onehot
x_start[:, :, :args.class_dim] = task_class_
#output = model(cond, task_class, act_emb, if_jump=False, if_avg_mask=args.infer_avg_mask, cond_type=args.use_cls_mask)
output = model(cond, act_emb, task_class, if_jump=False, if_avg_mask=args.infer_avg_mask)
actions_pred = output.contiguous()
loss = model.module.loss_fn(actions_pred, x_start.cuda())
actions_pred = actions_pred[:, :, args.class_dim:args.class_dim + args.action_dim].contiguous()
actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim]
(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \
accuracy(actions_pred.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon)
losses.update(loss.item(), batch_size_current)
acc_top1.update(acc1.item(), batch_size_current)
acc_top5.update(acc5.item(), batch_size_current)
trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current)
MIoU1_meter.update(MIoU1, batch_size_current)
MIoU2_meter.update(MIoU2, batch_size_current)
A0_acc.update(a0_acc, batch_size_current)
AT_acc.update(aT_acc, batch_size_current)
return torch.tensor(losses.avg), torch.tensor(acc_top1.avg), torch.tensor(acc_top5.avg), \
torch.tensor(trajectory_success_rate_meter.avg), \
torch.tensor(MIoU1_meter.avg), torch.tensor(MIoU2_meter.avg), \
torch.tensor(A0_acc.avg), torch.tensor(AT_acc.avg)
def validate_act(val_loader, model, args, act_emb):
model.eval()
losses = AverageMeter()
acc_top1 = AverageMeter()
acc_top5 = AverageMeter()
trajectory_success_rate_meter = AverageMeter()
MIoU1_meter = AverageMeter()
MIoU2_meter = AverageMeter()
A0_acc = AverageMeter()
AT_acc = AverageMeter()
for i_batch, sample_batch in enumerate(val_loader):
# compute output
global_img_tensors = sample_batch[0].cuda().contiguous().float()
video_label = sample_batch[1].cuda()
batch_size_current, T = video_label.size()
task_class = sample_batch[2].view(-1).cuda()
cond = {}
with torch.no_grad():
cond[0] = global_img_tensors[:, 0, :]
cond[T - 1] = global_img_tensors[:, -1, :]
task_onehot = torch.zeros((task_class.size(0), args.class_dim)) # [bs*T, ac_dim]
ind = torch.arange(0, len(task_class))
task_onehot[ind, task_class] = 1.
task_onehot = task_onehot.cuda()
temp = task_onehot.unsqueeze(1)
task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim]
if args.use_cls_mask:
cond['task'] = task_class_
video_label_reshaped = video_label.view(-1)
action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim))
ind = torch.arange(0, len(video_label_reshaped))
action_label_onehot[ind, video_label_reshaped] = 1.
action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda()
if args.use_cls_mask:
x_start = torch.zeros((batch_size_current, T, args.class_dim + args.action_dim + args.observation_dim))
x_start[:, 0, args.class_dim + args.action_dim:] = global_img_tensors[:, 0, :]
x_start[:, -1, args.class_dim + args.action_dim:] = global_img_tensors[:, -1, :]
x_start[:, :, args.class_dim:args.class_dim + args.action_dim] = action_label_onehot
x_start[:, :, :args.class_dim] = task_class_
else:
x_start = torch.zeros((batch_size_current, T, args.action_dim + args.observation_dim))
x_start[:, 0, args.action_dim:] = global_img_tensors[:, 0, :]
x_start[:, -1, args.action_dim:] = global_img_tensors[:, -1, :]
x_start[:, :, :args.action_dim] = action_label_onehot
#output = model(cond, task_class, act_emb, if_jump=False, if_avg_mask=args.infer_avg_mask, cond_type=args.use_cls_mask)
output = model(cond, act_emb, task_class, if_jump=False, if_avg_mask=args.infer_avg_mask)
#output = model(cond, act_emb, task_class, if_jump=True, if_avg_mask=args.infer_avg_mask)
actions_pred = output.contiguous()
loss = model.module.loss_fn(actions_pred, x_start.cuda())
actions_pred = actions_pred[:, :, args.class_dim:args.class_dim + args.action_dim].contiguous()
actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim]
(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \
accuracy(actions_pred.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon)
losses.update(loss.item(), batch_size_current)
acc_top1.update(acc1.item(), batch_size_current)
acc_top5.update(acc5.item(), batch_size_current)
trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current)
MIoU1_meter.update(MIoU1, batch_size_current)
MIoU2_meter.update(MIoU2, batch_size_current)
A0_acc.update(a0_acc, batch_size_current)
AT_acc.update(aT_acc, batch_size_current)
return torch.tensor(losses.avg), torch.tensor(acc_top1.avg), torch.tensor(acc_top5.avg), \
torch.tensor(trajectory_success_rate_meter.avg), \
torch.tensor(MIoU1_meter.avg), torch.tensor(MIoU2_meter.avg), \
torch.tensor(A0_acc.avg), torch.tensor(AT_acc.avg)
def validate(val_loader, model, args):
model.eval()
losses = AverageMeter()
acc_top1 = AverageMeter()
acc_top5 = AverageMeter()
trajectory_success_rate_meter = AverageMeter()
MIoU1_meter = AverageMeter()
MIoU2_meter = AverageMeter()
A0_acc = AverageMeter()
AT_acc = AverageMeter()
for i_batch, sample_batch in enumerate(val_loader):
# compute output
global_img_tensors = sample_batch[0].cuda().contiguous().float()
video_label = sample_batch[1].cuda()
batch_size_current, T = video_label.size()
task_class = sample_batch[2].view(-1).cuda()
cond = {}
with torch.no_grad():
cond[0] = global_img_tensors[:, 0, :]
cond[T - 1] = global_img_tensors[:, -1, :]
task_onehot = torch.zeros((task_class.size(0), args.class_dim)) # [bs*T, ac_dim]
ind = torch.arange(0, len(task_class))
task_onehot[ind, task_class] = 1.
task_onehot = task_onehot.cuda()
temp = task_onehot.unsqueeze(1)
task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim]
if args.use_cls_mask:
cond['task'] = task_class_
video_label_reshaped = video_label.view(-1)
action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim))
ind = torch.arange(0, len(video_label_reshaped))
action_label_onehot[ind, video_label_reshaped] = 1.
action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda()
if args.use_cls_mask:
x_start = torch.zeros((batch_size_current, T, args.class_dim + args.action_dim + args.observation_dim))
x_start[:, 0, args.class_dim + args.action_dim:] = global_img_tensors[:, 0, :]
x_start[:, -1, args.class_dim + args.action_dim:] = global_img_tensors[:, -1, :]
x_start[:, :, args.class_dim:args.class_dim + args.action_dim] = action_label_onehot
x_start[:, :, :args.class_dim] = task_class_
else:
x_start = torch.zeros((batch_size_current, T, args.action_dim + args.observation_dim))
x_start[:, 0, args.action_dim:] = global_img_tensors[:, 0, :]
x_start[:, -1, args.action_dim:] = global_img_tensors[:, -1, :]
x_start[:, :, :args.action_dim] = action_label_onehot
output = model(cond, task_class, act_emb, if_jump=False, if_avg_mask=args.infer_avg_mask, cond_type=args.use_cls_mask)
actions_pred = output.contiguous()
loss = model.module.loss_fn(actions_pred, x_start.cuda())
if args.use_cls_mask:
actions_pred = actions_pred[:, :, args.class_dim:args.class_dim + args.action_dim].contiguous()
actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim]
else:
actions_pred = actions_pred[:, :, :args.action_dim].contiguous()
actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim]
(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \
accuracy(actions_pred.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon)
losses.update(loss.item(), batch_size_current)
acc_top1.update(acc1.item(), batch_size_current)
acc_top5.update(acc5.item(), batch_size_current)
trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current)
MIoU1_meter.update(MIoU1, batch_size_current)
MIoU2_meter.update(MIoU2, batch_size_current)
A0_acc.update(a0_acc, batch_size_current)
AT_acc.update(aT_acc, batch_size_current)
return torch.tensor(losses.avg), torch.tensor(acc_top1.avg), torch.tensor(acc_top5.avg), \
torch.tensor(trajectory_success_rate_meter.avg), \
torch.tensor(MIoU1_meter.avg), torch.tensor(MIoU2_meter.avg), \
torch.tensor(A0_acc.avg), torch.tensor(AT_acc.avg)
def validate_mlp(val_loader, model, act_model, args, act_emb):
model.eval()
losses = AverageMeter()
acc_top1 = AverageMeter()
acc_top5 = AverageMeter()
trajectory_success_rate_meter = AverageMeter()
MIoU1_meter = AverageMeter()
MIoU2_meter = AverageMeter()
A0_acc = AverageMeter()
AT_acc = AverageMeter()
for i_batch, sample_batch in enumerate(val_loader):
# compute output
global_img_tensors = sample_batch[0].cuda().contiguous().float()
video_label = sample_batch[1].cuda()
batch_size_current, T = video_label.size()
task_class = sample_batch[2].view(-1).cuda()
cond = {}
with torch.no_grad():
cond[0] = global_img_tensors[:, 0, :].view((batch_size_current,32,48))
cond[T - 1] = global_img_tensors[:, -1, :].view((batch_size_current,32,48))
task_onehot = torch.zeros((task_class.size(0), args.class_dim)) # [bs*T, ac_dim]
ind = torch.arange(0, len(task_class))
task_onehot[ind, task_class] = 1.
task_onehot = task_onehot.cuda()
temp = task_onehot.unsqueeze(1)
task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim]
video_label_reshaped = video_label.view(-1)
action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim))
ind = torch.arange(0, len(video_label_reshaped))
action_label_onehot[ind, video_label_reshaped] = 1.
action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda()
x_start = torch.zeros((batch_size_current, T, 32, 64))
x_start[:, 0, :, :48] = global_img_tensors[:, 0, :].view((batch_size_current,32,48))
x_start[:, -1, :, :48] = global_img_tensors[:, -1, :].view((batch_size_current,32,48))
for i in range(video_label.shape[0]):
for j in range(video_label.shape[1]):
x_start[i][j][:, 48:] = act_emb[video_label[i][j]].view(32,16)
output = model(cond, task_class, act_emb, if_jump=False, if_avg_mask=args.infer_avg_mask, cond_type=args.use_cls_mask)
actions_pred = output.contiguous()
loss = model.module.loss_fn(actions_pred, x_start.cuda())
actions_pred = output[:, :, :, 48:].contiguous().view(batch_size_current*args.horizon, 512)
#actions = act_model(actions_pred)
sim_act_pred = similarity_score(actions_pred, act_emb)
#print('actions',actions.shape, actions_pred.shape, video_label_reshaped.shape, video_label.shape)
#print('lebel', torch.max(actions, 1))
#print('lebel', video_label_reshaped)
#(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = accuracy(actions.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon)
(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = accuracy(sim_act_pred, video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon)
losses.update(loss.item(), batch_size_current)
acc_top1.update(acc1.item(), batch_size_current)
acc_top5.update(acc5.item(), batch_size_current)
trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current)
MIoU1_meter.update(MIoU1, batch_size_current)
MIoU2_meter.update(MIoU2, batch_size_current)
A0_acc.update(a0_acc, batch_size_current)
AT_acc.update(aT_acc, batch_size_current)
'''
with torch.no_grad():
cond[0] = global_img_tensors[:, 0, :]
cond[T - 1] = global_img_tensors[:, -1, :]
task_onehot = torch.zeros((task_class.size(0), args.class_dim)) # [bs*T, ac_dim]
ind = torch.arange(0, len(task_class))
task_onehot[ind, task_class] = 1.
task_onehot = task_onehot.cuda()
temp = task_onehot.unsqueeze(1)
task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim]
if args.use_cls_mask:
cond['task'] = task_class_
video_label_reshaped = video_label.view(-1)
action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim))
ind = torch.arange(0, len(video_label_reshaped))
action_label_onehot[ind, video_label_reshaped] = 1.
action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda()
if args.use_cls_mask:
x_start = torch.zeros((batch_size_current, T, args.class_dim + args.action_dim + args.observation_dim))
x_start[:, 0, args.class_dim + args.action_dim:] = global_img_tensors[:, 0, :]
x_start[:, -1, args.class_dim + args.action_dim:] = global_img_tensors[:, -1, :]
x_start[:, :, args.class_dim:args.class_dim + args.action_dim] = action_label_onehot
x_start[:, :, :args.class_dim] = task_class_
else:
x_start = torch.zeros((batch_size_current, T, args.action_dim + args.observation_dim))
x_start[:, 0, args.action_dim:] = global_img_tensors[:, 0, :]
x_start[:, -1, args.action_dim:] = global_img_tensors[:, -1, :]
x_start[:, :, :args.action_dim] = action_label_onehot
output = model(cond, task_class, act_emb, if_jump=False, if_avg_mask=args.infer_avg_mask, cond_type=args.use_cls_mask)
actions_pred = output.contiguous()
loss = model.module.loss_fn(actions_pred, x_start.cuda())
if args.use_cls_mask:
actions_pred = actions_pred[:, :, args.class_dim:args.class_dim + args.action_dim].contiguous()
actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim]
else:
actions_pred = actions_pred[:, :, :args.action_dim].contiguous()
actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim]
(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \
accuracy(actions_pred.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon)
losses.update(loss.item(), batch_size_current)
acc_top1.update(acc1.item(), batch_size_current)
acc_top5.update(acc5.item(), batch_size_current)
trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current)
MIoU1_meter.update(MIoU1, batch_size_current)
MIoU2_meter.update(MIoU2, batch_size_current)
A0_acc.update(a0_acc, batch_size_current)
AT_acc.update(aT_acc, batch_size_current)
'''
return torch.tensor(losses.avg), torch.tensor(acc_top1.avg), torch.tensor(acc_top5.avg), \
torch.tensor(trajectory_success_rate_meter.avg), \
torch.tensor(MIoU1_meter.avg), torch.tensor(MIoU2_meter.avg), \
torch.tensor(A0_acc.avg), torch.tensor(AT_acc.avg)

143
utils/training.py Normal file
View file

@ -0,0 +1,143 @@
import copy
from model.helpers import AverageMeter
from .accuracy import *
import numpy as np
def cycle(dl):
while True:
for data in dl:
yield data
class EMA():
"""
empirical moving average
"""
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_model_average(self, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = self.update_average(old_weight, up_weight)
def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new
class Trainer(object):
def __init__(
self,
diffusion_model,
datasetloader,
ema_decay=0.995,
train_lr=1e-5,
gradient_accumulate_every=1,
step_start_ema=400,
update_ema_every=10,
log_freq=100,
):
super().__init__()
self.model = diffusion_model
self.ema = EMA(ema_decay)
self.ema_model = copy.deepcopy(self.model)
self.update_ema_every = update_ema_every
self.step_start_ema = step_start_ema
self.log_freq = log_freq
self.gradient_accumulate_every = gradient_accumulate_every
self.dataloader = cycle(datasetloader)
self.optimizer = torch.optim.AdamW(diffusion_model.parameters(), lr=train_lr, weight_decay=0.0)
# self.optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, diffusion_model.parameters()), lr=train_lr, weight_decay=0.0)
self.reset_parameters()
self.step = 0
def reset_parameters(self):
self.ema_model.load_state_dict(self.model.state_dict())
def step_ema(self):
if self.step < self.step_start_ema:
self.reset_parameters()
return
self.ema.update_model_average(self.ema_model, self.model)
# -----------------------------------------------------------------------------#
# ------------------------------------ api ------------------------------------#
# -----------------------------------------------------------------------------#
def train(self, n_train_steps, if_calculate_acc, args, scheduler):
self.model.train()
self.ema_model.train()
losses = AverageMeter()
self.optimizer.zero_grad()
for step in range(n_train_steps):
for i in range(self.gradient_accumulate_every):
batch = next(self.dataloader)
bs, T = batch[1].shape # [bs, (T+1), ob_dim]
global_img_tensors = batch[0].cuda().contiguous().float()
img_tensors = torch.zeros((bs, T, args.class_dim + args.action_dim + args.observation_dim))
img_tensors[:, 0, args.class_dim+args.action_dim:] = global_img_tensors[:, 0, :]
img_tensors[:, -1, args.class_dim+args.action_dim:] = global_img_tensors[:, -1, :]
img_tensors = img_tensors.cuda()
np.save('conformal/X_train.npy', global_img_tensors[:, 0, :].cpu().detach().numpy())
np.save('conformal/X_end_train.npy', global_img_tensors[:, -1, :].cpu().detach().numpy())
video_label = batch[1].view(-1).cuda() # [bs*T]
np.save('conformal/label_train.npy', video_label.cpu().detach().numpy())
task_class = batch[2].view(-1).cuda() # [bs]
action_label_onehot = torch.zeros((video_label.size(0), self.model.module.action_dim))
# [bs*T, ac_dim]
ind = torch.arange(0, len(video_label))
action_label_onehot[ind, video_label] = 1.
action_label_onehot = action_label_onehot.reshape(bs, T, -1).cuda()
img_tensors[:, :, args.class_dim:args.class_dim+args.action_dim] = action_label_onehot
task_onehot = torch.zeros((task_class.size(0), args.class_dim))
# [bs*T, ac_dim]
ind = torch.arange(0, len(task_class))
task_onehot[ind, task_class] = 1.
task_onehot = task_onehot.cuda()
temp = task_onehot.unsqueeze(1)
task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim]
img_tensors[:, :, :args.class_dim] = task_class_
cond = {0: global_img_tensors[:, 0, :].float(), T - 1: global_img_tensors[:, -1, :].float(),
'task': task_class_}
x = img_tensors.float()
loss = self.model.module.loss(x, cond)
loss = loss / self.gradient_accumulate_every
loss.backward()
losses.update(loss.item(), bs)
self.optimizer.step()
self.optimizer.zero_grad()
scheduler.step()
if self.step % self.update_ema_every == 0:
self.step_ema()
self.step += 1
if if_calculate_acc:
with torch.no_grad():
output = self.ema_model(cond)
actions_pred = output[:, :, args.class_dim:args.class_dim+self.model.module.action_dim]\
.contiguous().view(-1, self.model.module.action_dim) # [bs*T, action_dim]
(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \
accuracy(actions_pred.cpu(), video_label.cpu(), topk=(1, 5),
max_traj_len=self.model.module.horizon)
return torch.tensor(losses.avg), acc1, acc5, torch.tensor(trajectory_success_rate), \
torch.tensor(MIoU1), torch.tensor(MIoU2), a0_acc, aT_acc
else:
return torch.tensor(losses.avg)

170
utils/training_act.py Normal file
View file

@ -0,0 +1,170 @@
import copy
from model.helpers import AverageMeter
from .accuracy import *
import numpy as np
def cycle(dl):
while True:
for data in dl:
yield data
class EMA():
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_model_average(self, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = self.update_average(old_weight, up_weight)
def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new
class Trainer(object):
def __init__(
self,
diffusion_model,
datasetloader,
ema_decay=0.995,
train_lr=1e-5,
gradient_accumulate_every=1,
step_start_ema=400,
update_ema_every=10,
log_freq=100,
act_emd=None,
):
super().__init__()
self.model = diffusion_model
self.ema = EMA(ema_decay)
self.ema_model = copy.deepcopy(self.model)
self.update_ema_every = update_ema_every
self.step_start_ema = step_start_ema
self.log_freq = log_freq
self.gradient_accumulate_every = gradient_accumulate_every
self.act_emd = act_emd
self.dataloader = cycle(datasetloader)
self.optimizer = torch.optim.AdamW(diffusion_model.parameters(), lr=train_lr, weight_decay=0.0)
self.reset_parameters()
self.step = 0
def reset_parameters(self):
self.ema_model.load_state_dict(self.model.state_dict())
def step_ema(self):
if self.step < self.step_start_ema:
self.reset_parameters()
return
self.ema.update_model_average(self.ema_model, self.model)
def get_noise_mask(self, action_label, args, img_tensors):
output_act_emb = torch.randn_like(img_tensors).cuda() #
self.act_emd = self.act_emd.cuda()
np.random.seed(args.seed)
if args.mask_type == 'single_add':
for i in range(action_label.shape[0]):
for j in range(action_label.shape[1]):
if args.dataset=='crosstask' or args.dataset=='NIV':
rnd_idx = np.random.randint(0, args.class_dim+args.action_dim+args.observation_dim-512) # for niv and crosstask
if args.dataset=='coin':
rnd_idx = np.random.randint(args.class_dim, args.class_dim+args.action_dim-512)
output_act_emb[i][j][rnd_idx:rnd_idx+512] = self.act_emd[action_label[i][j]]
return output_act_emb.cuda()
if args.mask_type == 'multi_add':
for i in range(action_label.shape[0]):
for j in range(action_label.shape[1]):
if args.dataset=='crosstask' or args.dataset=='NIV':
rnd_idx = np.random.randint(0, 512, args.action_dim)
if j==0:
output_act_emb[i][j][args.class_dim:args.class_dim+args.action_dim] = self.act_emd[action_label[i][j],rnd_idx]
else:
output_act_emb[i][j][args.class_dim:args.class_dim+args.action_dim] = output_act_emb[i][j-1][args.class_dim:args.class_dim+args.action_dim] + self.act_emd[action_label[i][j],rnd_idx]
if args.dataset=='coin':
rnd_idx = np.random.randint(args.class_dim, args.class_dim+args.action_dim-512)
if j==0:
output_act_emb[i][j][rnd_idx:rnd_idx+512] = self.act_emd[action_label[i][j]]
else:
output_act_emb[i][j][rnd_idx:rnd_idx+512] = output_act_emb[i][j-1][rnd_idx:rnd_idx+512] + self.act_emd[action_label[i][j]]
return output_act_emb.cuda()
def train(self, n_train_steps, if_calculate_acc, args, scheduler):
self.model.train()
self.ema_model.train()
losses = AverageMeter()
self.optimizer.zero_grad()
for step in range(n_train_steps):
for i in range(self.gradient_accumulate_every):
batch = next(self.dataloader)
bs, T = batch[1].shape # [bs, (T+1), ob_dim]
global_img_tensors = batch[0].cuda().contiguous().float()
img_tensors = torch.zeros((bs, T, args.class_dim + args.action_dim + args.observation_dim))
img_tensors[:, 0, args.class_dim+args.action_dim:] = global_img_tensors[:, 0, :]
img_tensors[:, -1, args.class_dim+args.action_dim:] = global_img_tensors[:, -1, :]
img_tensors = img_tensors.cuda()
video_label = batch[1].view(-1).cuda() # [bs*T]
task_class = batch[2].view(-1).cuda() # [bs]
action_label = batch[1]
act_emb_noise = self.get_noise_mask(action_label, args, img_tensors)
action_label_onehot = torch.zeros((video_label.size(0), self.model.module.action_dim))
# [bs*T, ac_dim]
ind = torch.arange(0, len(video_label))
action_label_onehot[ind, video_label] = 1.
action_label_onehot = action_label_onehot.reshape(bs, T, -1).cuda()
img_tensors[:, :, args.class_dim:args.class_dim+args.action_dim] = action_label_onehot
task_onehot = torch.zeros((task_class.size(0), args.class_dim))
# [bs*T, ac_dim]
ind = torch.arange(0, len(task_class))
task_onehot[ind, task_class] = 1.
task_onehot = task_onehot.cuda()
temp = task_onehot.unsqueeze(1)
task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim]
img_tensors[:, :, :args.class_dim] = task_class_
cond = {0: global_img_tensors[:, 0, :].float(), T - 1: global_img_tensors[:, -1, :].float(),
'task': task_class_}
x = img_tensors.float()
loss = self.model.module.loss(x, cond, act_emb_noise, task_class)
loss = loss / self.gradient_accumulate_every
loss.backward()
losses.update(loss.item(), bs)
self.optimizer.step()
self.optimizer.zero_grad()
scheduler.step()
if self.step % self.update_ema_every == 0:
self.step_ema()
self.step += 1
if if_calculate_acc:
with torch.no_grad():
output = self.ema_model(cond, self.act_emd, task_class, if_jump=False, if_avg_mask=args.infer_avg_mask)
actions_pred = output[:, :, args.class_dim:args.class_dim+self.model.module.action_dim]\
.contiguous().view(-1, self.model.module.action_dim) # [bs*T, action_dim]
(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \
accuracy(actions_pred.cpu(), video_label.cpu(), topk=(1, 5),
max_traj_len=self.model.module.horizon)
return torch.tensor(losses.avg), acc1, acc5, torch.tensor(trajectory_success_rate), \
torch.tensor(MIoU1), torch.tensor(MIoU2), a0_acc, aT_acc
else:
return torch.tensor(losses.avg)