376 lines
19 KiB
Python
376 lines
19 KiB
Python
from .accuracy import *
|
|
from model.helpers import AverageMeter
|
|
|
|
|
|
def validate_act_noise(val_loader, model, args, act_emb):
|
|
model.eval()
|
|
losses = AverageMeter()
|
|
acc_top1 = AverageMeter()
|
|
acc_top5 = AverageMeter()
|
|
trajectory_success_rate_meter = AverageMeter()
|
|
MIoU1_meter = AverageMeter()
|
|
MIoU2_meter = AverageMeter()
|
|
|
|
A0_acc = AverageMeter()
|
|
AT_acc = AverageMeter()
|
|
|
|
for i_batch, sample_batch in enumerate(val_loader):
|
|
# compute output
|
|
global_img_tensors = sample_batch[0].cuda().contiguous().float()
|
|
video_label = sample_batch[1].cuda()
|
|
batch_size_current, T = video_label.size()
|
|
task_class = sample_batch[2].view(-1).cuda()
|
|
cond = {}
|
|
|
|
with torch.no_grad():
|
|
cond[0] = global_img_tensors[:, 0, :]
|
|
cond[T - 1] = global_img_tensors[:, -1, :]
|
|
|
|
task_onehot = torch.zeros((task_class.size(0), args.class_dim)) # [bs*T, ac_dim]
|
|
ind = torch.arange(0, len(task_class))
|
|
task_onehot[ind, task_class] = 1.
|
|
task_onehot = task_onehot.cuda()
|
|
temp = task_onehot.unsqueeze(1)
|
|
task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim]
|
|
|
|
cond['task'] = task_class_
|
|
|
|
video_label_reshaped = video_label.view(-1)
|
|
|
|
action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim))
|
|
ind = torch.arange(0, len(video_label_reshaped))
|
|
action_label_onehot[ind, video_label_reshaped] = 1.
|
|
action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda()
|
|
|
|
|
|
x_start = torch.zeros((batch_size_current, T, args.class_dim + args.action_dim + args.observation_dim))
|
|
x_start[:, 0, args.class_dim + args.action_dim:] = global_img_tensors[:, 0, :]
|
|
x_start[:, -1, args.class_dim + args.action_dim:] = global_img_tensors[:, -1, :]
|
|
|
|
x_start[:, :, args.class_dim:args.class_dim + args.action_dim] = action_label_onehot
|
|
x_start[:, :, :args.class_dim] = task_class_
|
|
|
|
|
|
#output = model(cond, task_class, act_emb, if_jump=False, if_avg_mask=args.infer_avg_mask, cond_type=args.use_cls_mask)
|
|
output = model(cond, act_emb, task_class, if_jump=False, if_avg_mask=args.infer_avg_mask)
|
|
actions_pred = output.contiguous()
|
|
loss = model.module.loss_fn(actions_pred, x_start.cuda())
|
|
|
|
actions_pred = actions_pred[:, :, args.class_dim:args.class_dim + args.action_dim].contiguous()
|
|
actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim]
|
|
|
|
(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \
|
|
accuracy(actions_pred.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon)
|
|
|
|
losses.update(loss.item(), batch_size_current)
|
|
acc_top1.update(acc1.item(), batch_size_current)
|
|
acc_top5.update(acc5.item(), batch_size_current)
|
|
trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current)
|
|
MIoU1_meter.update(MIoU1, batch_size_current)
|
|
MIoU2_meter.update(MIoU2, batch_size_current)
|
|
A0_acc.update(a0_acc, batch_size_current)
|
|
AT_acc.update(aT_acc, batch_size_current)
|
|
|
|
return torch.tensor(losses.avg), torch.tensor(acc_top1.avg), torch.tensor(acc_top5.avg), \
|
|
torch.tensor(trajectory_success_rate_meter.avg), \
|
|
torch.tensor(MIoU1_meter.avg), torch.tensor(MIoU2_meter.avg), \
|
|
torch.tensor(A0_acc.avg), torch.tensor(AT_acc.avg)
|
|
|
|
def validate_act(val_loader, model, args, act_emb):
|
|
model.eval()
|
|
losses = AverageMeter()
|
|
acc_top1 = AverageMeter()
|
|
acc_top5 = AverageMeter()
|
|
trajectory_success_rate_meter = AverageMeter()
|
|
MIoU1_meter = AverageMeter()
|
|
MIoU2_meter = AverageMeter()
|
|
|
|
A0_acc = AverageMeter()
|
|
AT_acc = AverageMeter()
|
|
|
|
for i_batch, sample_batch in enumerate(val_loader):
|
|
# compute output
|
|
global_img_tensors = sample_batch[0].cuda().contiguous().float()
|
|
video_label = sample_batch[1].cuda()
|
|
batch_size_current, T = video_label.size()
|
|
task_class = sample_batch[2].view(-1).cuda()
|
|
cond = {}
|
|
|
|
with torch.no_grad():
|
|
cond[0] = global_img_tensors[:, 0, :]
|
|
cond[T - 1] = global_img_tensors[:, -1, :]
|
|
|
|
task_onehot = torch.zeros((task_class.size(0), args.class_dim)) # [bs*T, ac_dim]
|
|
ind = torch.arange(0, len(task_class))
|
|
task_onehot[ind, task_class] = 1.
|
|
task_onehot = task_onehot.cuda()
|
|
temp = task_onehot.unsqueeze(1)
|
|
task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim]
|
|
if args.use_cls_mask:
|
|
cond['task'] = task_class_
|
|
|
|
video_label_reshaped = video_label.view(-1)
|
|
|
|
action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim))
|
|
ind = torch.arange(0, len(video_label_reshaped))
|
|
action_label_onehot[ind, video_label_reshaped] = 1.
|
|
action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda()
|
|
|
|
if args.use_cls_mask:
|
|
x_start = torch.zeros((batch_size_current, T, args.class_dim + args.action_dim + args.observation_dim))
|
|
x_start[:, 0, args.class_dim + args.action_dim:] = global_img_tensors[:, 0, :]
|
|
x_start[:, -1, args.class_dim + args.action_dim:] = global_img_tensors[:, -1, :]
|
|
|
|
x_start[:, :, args.class_dim:args.class_dim + args.action_dim] = action_label_onehot
|
|
x_start[:, :, :args.class_dim] = task_class_
|
|
else:
|
|
x_start = torch.zeros((batch_size_current, T, args.action_dim + args.observation_dim))
|
|
x_start[:, 0, args.action_dim:] = global_img_tensors[:, 0, :]
|
|
x_start[:, -1, args.action_dim:] = global_img_tensors[:, -1, :]
|
|
|
|
x_start[:, :, :args.action_dim] = action_label_onehot
|
|
|
|
#output = model(cond, task_class, act_emb, if_jump=False, if_avg_mask=args.infer_avg_mask, cond_type=args.use_cls_mask)
|
|
output = model(cond, act_emb, task_class, if_jump=False, if_avg_mask=args.infer_avg_mask)
|
|
#output = model(cond, act_emb, task_class, if_jump=True, if_avg_mask=args.infer_avg_mask)
|
|
actions_pred = output.contiguous()
|
|
loss = model.module.loss_fn(actions_pred, x_start.cuda())
|
|
|
|
actions_pred = actions_pred[:, :, args.class_dim:args.class_dim + args.action_dim].contiguous()
|
|
actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim]
|
|
|
|
(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \
|
|
accuracy(actions_pred.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon)
|
|
|
|
losses.update(loss.item(), batch_size_current)
|
|
acc_top1.update(acc1.item(), batch_size_current)
|
|
acc_top5.update(acc5.item(), batch_size_current)
|
|
trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current)
|
|
MIoU1_meter.update(MIoU1, batch_size_current)
|
|
MIoU2_meter.update(MIoU2, batch_size_current)
|
|
A0_acc.update(a0_acc, batch_size_current)
|
|
AT_acc.update(aT_acc, batch_size_current)
|
|
|
|
return torch.tensor(losses.avg), torch.tensor(acc_top1.avg), torch.tensor(acc_top5.avg), \
|
|
torch.tensor(trajectory_success_rate_meter.avg), \
|
|
torch.tensor(MIoU1_meter.avg), torch.tensor(MIoU2_meter.avg), \
|
|
torch.tensor(A0_acc.avg), torch.tensor(AT_acc.avg)
|
|
|
|
|
|
def validate(val_loader, model, args):
|
|
model.eval()
|
|
losses = AverageMeter()
|
|
acc_top1 = AverageMeter()
|
|
acc_top5 = AverageMeter()
|
|
trajectory_success_rate_meter = AverageMeter()
|
|
MIoU1_meter = AverageMeter()
|
|
MIoU2_meter = AverageMeter()
|
|
|
|
A0_acc = AverageMeter()
|
|
AT_acc = AverageMeter()
|
|
|
|
for i_batch, sample_batch in enumerate(val_loader):
|
|
# compute output
|
|
global_img_tensors = sample_batch[0].cuda().contiguous().float()
|
|
video_label = sample_batch[1].cuda()
|
|
batch_size_current, T = video_label.size()
|
|
task_class = sample_batch[2].view(-1).cuda()
|
|
cond = {}
|
|
|
|
with torch.no_grad():
|
|
cond[0] = global_img_tensors[:, 0, :]
|
|
cond[T - 1] = global_img_tensors[:, -1, :]
|
|
|
|
task_onehot = torch.zeros((task_class.size(0), args.class_dim)) # [bs*T, ac_dim]
|
|
ind = torch.arange(0, len(task_class))
|
|
task_onehot[ind, task_class] = 1.
|
|
task_onehot = task_onehot.cuda()
|
|
temp = task_onehot.unsqueeze(1)
|
|
task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim]
|
|
if args.use_cls_mask:
|
|
cond['task'] = task_class_
|
|
|
|
video_label_reshaped = video_label.view(-1)
|
|
|
|
action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim))
|
|
ind = torch.arange(0, len(video_label_reshaped))
|
|
action_label_onehot[ind, video_label_reshaped] = 1.
|
|
action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda()
|
|
|
|
if args.use_cls_mask:
|
|
x_start = torch.zeros((batch_size_current, T, args.class_dim + args.action_dim + args.observation_dim))
|
|
x_start[:, 0, args.class_dim + args.action_dim:] = global_img_tensors[:, 0, :]
|
|
x_start[:, -1, args.class_dim + args.action_dim:] = global_img_tensors[:, -1, :]
|
|
|
|
x_start[:, :, args.class_dim:args.class_dim + args.action_dim] = action_label_onehot
|
|
x_start[:, :, :args.class_dim] = task_class_
|
|
else:
|
|
x_start = torch.zeros((batch_size_current, T, args.action_dim + args.observation_dim))
|
|
x_start[:, 0, args.action_dim:] = global_img_tensors[:, 0, :]
|
|
x_start[:, -1, args.action_dim:] = global_img_tensors[:, -1, :]
|
|
|
|
x_start[:, :, :args.action_dim] = action_label_onehot
|
|
|
|
output = model(cond, task_class, act_emb, if_jump=False, if_avg_mask=args.infer_avg_mask, cond_type=args.use_cls_mask)
|
|
|
|
actions_pred = output.contiguous()
|
|
loss = model.module.loss_fn(actions_pred, x_start.cuda())
|
|
|
|
if args.use_cls_mask:
|
|
actions_pred = actions_pred[:, :, args.class_dim:args.class_dim + args.action_dim].contiguous()
|
|
actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim]
|
|
else:
|
|
actions_pred = actions_pred[:, :, :args.action_dim].contiguous()
|
|
actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim]
|
|
|
|
(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \
|
|
accuracy(actions_pred.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon)
|
|
|
|
losses.update(loss.item(), batch_size_current)
|
|
acc_top1.update(acc1.item(), batch_size_current)
|
|
acc_top5.update(acc5.item(), batch_size_current)
|
|
trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current)
|
|
MIoU1_meter.update(MIoU1, batch_size_current)
|
|
MIoU2_meter.update(MIoU2, batch_size_current)
|
|
A0_acc.update(a0_acc, batch_size_current)
|
|
AT_acc.update(aT_acc, batch_size_current)
|
|
|
|
return torch.tensor(losses.avg), torch.tensor(acc_top1.avg), torch.tensor(acc_top5.avg), \
|
|
torch.tensor(trajectory_success_rate_meter.avg), \
|
|
torch.tensor(MIoU1_meter.avg), torch.tensor(MIoU2_meter.avg), \
|
|
torch.tensor(A0_acc.avg), torch.tensor(AT_acc.avg)
|
|
|
|
def validate_mlp(val_loader, model, act_model, args, act_emb):
|
|
model.eval()
|
|
losses = AverageMeter()
|
|
acc_top1 = AverageMeter()
|
|
acc_top5 = AverageMeter()
|
|
trajectory_success_rate_meter = AverageMeter()
|
|
MIoU1_meter = AverageMeter()
|
|
MIoU2_meter = AverageMeter()
|
|
|
|
A0_acc = AverageMeter()
|
|
AT_acc = AverageMeter()
|
|
|
|
for i_batch, sample_batch in enumerate(val_loader):
|
|
# compute output
|
|
global_img_tensors = sample_batch[0].cuda().contiguous().float()
|
|
video_label = sample_batch[1].cuda()
|
|
batch_size_current, T = video_label.size()
|
|
task_class = sample_batch[2].view(-1).cuda()
|
|
cond = {}
|
|
|
|
with torch.no_grad():
|
|
cond[0] = global_img_tensors[:, 0, :].view((batch_size_current,32,48))
|
|
cond[T - 1] = global_img_tensors[:, -1, :].view((batch_size_current,32,48))
|
|
|
|
task_onehot = torch.zeros((task_class.size(0), args.class_dim)) # [bs*T, ac_dim]
|
|
ind = torch.arange(0, len(task_class))
|
|
task_onehot[ind, task_class] = 1.
|
|
task_onehot = task_onehot.cuda()
|
|
temp = task_onehot.unsqueeze(1)
|
|
task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim]
|
|
|
|
video_label_reshaped = video_label.view(-1)
|
|
|
|
action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim))
|
|
ind = torch.arange(0, len(video_label_reshaped))
|
|
action_label_onehot[ind, video_label_reshaped] = 1.
|
|
action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda()
|
|
|
|
x_start = torch.zeros((batch_size_current, T, 32, 64))
|
|
x_start[:, 0, :, :48] = global_img_tensors[:, 0, :].view((batch_size_current,32,48))
|
|
x_start[:, -1, :, :48] = global_img_tensors[:, -1, :].view((batch_size_current,32,48))
|
|
|
|
for i in range(video_label.shape[0]):
|
|
for j in range(video_label.shape[1]):
|
|
x_start[i][j][:, 48:] = act_emb[video_label[i][j]].view(32,16)
|
|
|
|
output = model(cond, task_class, act_emb, if_jump=False, if_avg_mask=args.infer_avg_mask, cond_type=args.use_cls_mask)
|
|
actions_pred = output.contiguous()
|
|
loss = model.module.loss_fn(actions_pred, x_start.cuda())
|
|
|
|
actions_pred = output[:, :, :, 48:].contiguous().view(batch_size_current*args.horizon, 512)
|
|
#actions = act_model(actions_pred)
|
|
sim_act_pred = similarity_score(actions_pred, act_emb)
|
|
#print('actions',actions.shape, actions_pred.shape, video_label_reshaped.shape, video_label.shape)
|
|
#print('lebel', torch.max(actions, 1))
|
|
#print('lebel', video_label_reshaped)
|
|
|
|
#(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = accuracy(actions.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon)
|
|
(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = accuracy(sim_act_pred, video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon)
|
|
|
|
|
|
losses.update(loss.item(), batch_size_current)
|
|
acc_top1.update(acc1.item(), batch_size_current)
|
|
acc_top5.update(acc5.item(), batch_size_current)
|
|
trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current)
|
|
MIoU1_meter.update(MIoU1, batch_size_current)
|
|
MIoU2_meter.update(MIoU2, batch_size_current)
|
|
A0_acc.update(a0_acc, batch_size_current)
|
|
AT_acc.update(aT_acc, batch_size_current)
|
|
|
|
'''
|
|
with torch.no_grad():
|
|
cond[0] = global_img_tensors[:, 0, :]
|
|
cond[T - 1] = global_img_tensors[:, -1, :]
|
|
|
|
task_onehot = torch.zeros((task_class.size(0), args.class_dim)) # [bs*T, ac_dim]
|
|
ind = torch.arange(0, len(task_class))
|
|
task_onehot[ind, task_class] = 1.
|
|
task_onehot = task_onehot.cuda()
|
|
temp = task_onehot.unsqueeze(1)
|
|
task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim]
|
|
if args.use_cls_mask:
|
|
cond['task'] = task_class_
|
|
|
|
video_label_reshaped = video_label.view(-1)
|
|
|
|
action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim))
|
|
ind = torch.arange(0, len(video_label_reshaped))
|
|
action_label_onehot[ind, video_label_reshaped] = 1.
|
|
action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda()
|
|
|
|
if args.use_cls_mask:
|
|
x_start = torch.zeros((batch_size_current, T, args.class_dim + args.action_dim + args.observation_dim))
|
|
x_start[:, 0, args.class_dim + args.action_dim:] = global_img_tensors[:, 0, :]
|
|
x_start[:, -1, args.class_dim + args.action_dim:] = global_img_tensors[:, -1, :]
|
|
|
|
x_start[:, :, args.class_dim:args.class_dim + args.action_dim] = action_label_onehot
|
|
x_start[:, :, :args.class_dim] = task_class_
|
|
else:
|
|
x_start = torch.zeros((batch_size_current, T, args.action_dim + args.observation_dim))
|
|
x_start[:, 0, args.action_dim:] = global_img_tensors[:, 0, :]
|
|
x_start[:, -1, args.action_dim:] = global_img_tensors[:, -1, :]
|
|
|
|
x_start[:, :, :args.action_dim] = action_label_onehot
|
|
|
|
output = model(cond, task_class, act_emb, if_jump=False, if_avg_mask=args.infer_avg_mask, cond_type=args.use_cls_mask)
|
|
actions_pred = output.contiguous()
|
|
loss = model.module.loss_fn(actions_pred, x_start.cuda())
|
|
|
|
if args.use_cls_mask:
|
|
actions_pred = actions_pred[:, :, args.class_dim:args.class_dim + args.action_dim].contiguous()
|
|
actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim]
|
|
else:
|
|
actions_pred = actions_pred[:, :, :args.action_dim].contiguous()
|
|
actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim]
|
|
|
|
(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \
|
|
accuracy(actions_pred.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon)
|
|
|
|
losses.update(loss.item(), batch_size_current)
|
|
acc_top1.update(acc1.item(), batch_size_current)
|
|
acc_top5.update(acc5.item(), batch_size_current)
|
|
trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current)
|
|
MIoU1_meter.update(MIoU1, batch_size_current)
|
|
MIoU2_meter.update(MIoU2, batch_size_current)
|
|
A0_acc.update(a0_acc, batch_size_current)
|
|
AT_acc.update(aT_acc, batch_size_current)
|
|
'''
|
|
|
|
return torch.tensor(losses.avg), torch.tensor(acc_top1.avg), torch.tensor(acc_top5.avg), \
|
|
torch.tensor(trajectory_success_rate_meter.avg), \
|
|
torch.tensor(MIoU1_meter.avg), torch.tensor(MIoU2_meter.avg), \
|
|
torch.tensor(A0_acc.avg), torch.tensor(AT_acc.avg)
|
|
|