ActionDiffusion_WACV2025/utils/eval.py
2024-12-02 15:42:58 +01:00

376 lines
19 KiB
Python

from .accuracy import *
from model.helpers import AverageMeter
def validate_act_noise(val_loader, model, args, act_emb):
model.eval()
losses = AverageMeter()
acc_top1 = AverageMeter()
acc_top5 = AverageMeter()
trajectory_success_rate_meter = AverageMeter()
MIoU1_meter = AverageMeter()
MIoU2_meter = AverageMeter()
A0_acc = AverageMeter()
AT_acc = AverageMeter()
for i_batch, sample_batch in enumerate(val_loader):
# compute output
global_img_tensors = sample_batch[0].cuda().contiguous().float()
video_label = sample_batch[1].cuda()
batch_size_current, T = video_label.size()
task_class = sample_batch[2].view(-1).cuda()
cond = {}
with torch.no_grad():
cond[0] = global_img_tensors[:, 0, :]
cond[T - 1] = global_img_tensors[:, -1, :]
task_onehot = torch.zeros((task_class.size(0), args.class_dim)) # [bs*T, ac_dim]
ind = torch.arange(0, len(task_class))
task_onehot[ind, task_class] = 1.
task_onehot = task_onehot.cuda()
temp = task_onehot.unsqueeze(1)
task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim]
cond['task'] = task_class_
video_label_reshaped = video_label.view(-1)
action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim))
ind = torch.arange(0, len(video_label_reshaped))
action_label_onehot[ind, video_label_reshaped] = 1.
action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda()
x_start = torch.zeros((batch_size_current, T, args.class_dim + args.action_dim + args.observation_dim))
x_start[:, 0, args.class_dim + args.action_dim:] = global_img_tensors[:, 0, :]
x_start[:, -1, args.class_dim + args.action_dim:] = global_img_tensors[:, -1, :]
x_start[:, :, args.class_dim:args.class_dim + args.action_dim] = action_label_onehot
x_start[:, :, :args.class_dim] = task_class_
#output = model(cond, task_class, act_emb, if_jump=False, if_avg_mask=args.infer_avg_mask, cond_type=args.use_cls_mask)
output = model(cond, act_emb, task_class, if_jump=False, if_avg_mask=args.infer_avg_mask)
actions_pred = output.contiguous()
loss = model.module.loss_fn(actions_pred, x_start.cuda())
actions_pred = actions_pred[:, :, args.class_dim:args.class_dim + args.action_dim].contiguous()
actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim]
(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \
accuracy(actions_pred.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon)
losses.update(loss.item(), batch_size_current)
acc_top1.update(acc1.item(), batch_size_current)
acc_top5.update(acc5.item(), batch_size_current)
trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current)
MIoU1_meter.update(MIoU1, batch_size_current)
MIoU2_meter.update(MIoU2, batch_size_current)
A0_acc.update(a0_acc, batch_size_current)
AT_acc.update(aT_acc, batch_size_current)
return torch.tensor(losses.avg), torch.tensor(acc_top1.avg), torch.tensor(acc_top5.avg), \
torch.tensor(trajectory_success_rate_meter.avg), \
torch.tensor(MIoU1_meter.avg), torch.tensor(MIoU2_meter.avg), \
torch.tensor(A0_acc.avg), torch.tensor(AT_acc.avg)
def validate_act(val_loader, model, args, act_emb):
model.eval()
losses = AverageMeter()
acc_top1 = AverageMeter()
acc_top5 = AverageMeter()
trajectory_success_rate_meter = AverageMeter()
MIoU1_meter = AverageMeter()
MIoU2_meter = AverageMeter()
A0_acc = AverageMeter()
AT_acc = AverageMeter()
for i_batch, sample_batch in enumerate(val_loader):
# compute output
global_img_tensors = sample_batch[0].cuda().contiguous().float()
video_label = sample_batch[1].cuda()
batch_size_current, T = video_label.size()
task_class = sample_batch[2].view(-1).cuda()
cond = {}
with torch.no_grad():
cond[0] = global_img_tensors[:, 0, :]
cond[T - 1] = global_img_tensors[:, -1, :]
task_onehot = torch.zeros((task_class.size(0), args.class_dim)) # [bs*T, ac_dim]
ind = torch.arange(0, len(task_class))
task_onehot[ind, task_class] = 1.
task_onehot = task_onehot.cuda()
temp = task_onehot.unsqueeze(1)
task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim]
if args.use_cls_mask:
cond['task'] = task_class_
video_label_reshaped = video_label.view(-1)
action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim))
ind = torch.arange(0, len(video_label_reshaped))
action_label_onehot[ind, video_label_reshaped] = 1.
action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda()
if args.use_cls_mask:
x_start = torch.zeros((batch_size_current, T, args.class_dim + args.action_dim + args.observation_dim))
x_start[:, 0, args.class_dim + args.action_dim:] = global_img_tensors[:, 0, :]
x_start[:, -1, args.class_dim + args.action_dim:] = global_img_tensors[:, -1, :]
x_start[:, :, args.class_dim:args.class_dim + args.action_dim] = action_label_onehot
x_start[:, :, :args.class_dim] = task_class_
else:
x_start = torch.zeros((batch_size_current, T, args.action_dim + args.observation_dim))
x_start[:, 0, args.action_dim:] = global_img_tensors[:, 0, :]
x_start[:, -1, args.action_dim:] = global_img_tensors[:, -1, :]
x_start[:, :, :args.action_dim] = action_label_onehot
#output = model(cond, task_class, act_emb, if_jump=False, if_avg_mask=args.infer_avg_mask, cond_type=args.use_cls_mask)
output = model(cond, act_emb, task_class, if_jump=False, if_avg_mask=args.infer_avg_mask)
#output = model(cond, act_emb, task_class, if_jump=True, if_avg_mask=args.infer_avg_mask)
actions_pred = output.contiguous()
loss = model.module.loss_fn(actions_pred, x_start.cuda())
actions_pred = actions_pred[:, :, args.class_dim:args.class_dim + args.action_dim].contiguous()
actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim]
(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \
accuracy(actions_pred.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon)
losses.update(loss.item(), batch_size_current)
acc_top1.update(acc1.item(), batch_size_current)
acc_top5.update(acc5.item(), batch_size_current)
trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current)
MIoU1_meter.update(MIoU1, batch_size_current)
MIoU2_meter.update(MIoU2, batch_size_current)
A0_acc.update(a0_acc, batch_size_current)
AT_acc.update(aT_acc, batch_size_current)
return torch.tensor(losses.avg), torch.tensor(acc_top1.avg), torch.tensor(acc_top5.avg), \
torch.tensor(trajectory_success_rate_meter.avg), \
torch.tensor(MIoU1_meter.avg), torch.tensor(MIoU2_meter.avg), \
torch.tensor(A0_acc.avg), torch.tensor(AT_acc.avg)
def validate(val_loader, model, args):
model.eval()
losses = AverageMeter()
acc_top1 = AverageMeter()
acc_top5 = AverageMeter()
trajectory_success_rate_meter = AverageMeter()
MIoU1_meter = AverageMeter()
MIoU2_meter = AverageMeter()
A0_acc = AverageMeter()
AT_acc = AverageMeter()
for i_batch, sample_batch in enumerate(val_loader):
# compute output
global_img_tensors = sample_batch[0].cuda().contiguous().float()
video_label = sample_batch[1].cuda()
batch_size_current, T = video_label.size()
task_class = sample_batch[2].view(-1).cuda()
cond = {}
with torch.no_grad():
cond[0] = global_img_tensors[:, 0, :]
cond[T - 1] = global_img_tensors[:, -1, :]
task_onehot = torch.zeros((task_class.size(0), args.class_dim)) # [bs*T, ac_dim]
ind = torch.arange(0, len(task_class))
task_onehot[ind, task_class] = 1.
task_onehot = task_onehot.cuda()
temp = task_onehot.unsqueeze(1)
task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim]
if args.use_cls_mask:
cond['task'] = task_class_
video_label_reshaped = video_label.view(-1)
action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim))
ind = torch.arange(0, len(video_label_reshaped))
action_label_onehot[ind, video_label_reshaped] = 1.
action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda()
if args.use_cls_mask:
x_start = torch.zeros((batch_size_current, T, args.class_dim + args.action_dim + args.observation_dim))
x_start[:, 0, args.class_dim + args.action_dim:] = global_img_tensors[:, 0, :]
x_start[:, -1, args.class_dim + args.action_dim:] = global_img_tensors[:, -1, :]
x_start[:, :, args.class_dim:args.class_dim + args.action_dim] = action_label_onehot
x_start[:, :, :args.class_dim] = task_class_
else:
x_start = torch.zeros((batch_size_current, T, args.action_dim + args.observation_dim))
x_start[:, 0, args.action_dim:] = global_img_tensors[:, 0, :]
x_start[:, -1, args.action_dim:] = global_img_tensors[:, -1, :]
x_start[:, :, :args.action_dim] = action_label_onehot
output = model(cond, task_class, act_emb, if_jump=False, if_avg_mask=args.infer_avg_mask, cond_type=args.use_cls_mask)
actions_pred = output.contiguous()
loss = model.module.loss_fn(actions_pred, x_start.cuda())
if args.use_cls_mask:
actions_pred = actions_pred[:, :, args.class_dim:args.class_dim + args.action_dim].contiguous()
actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim]
else:
actions_pred = actions_pred[:, :, :args.action_dim].contiguous()
actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim]
(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \
accuracy(actions_pred.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon)
losses.update(loss.item(), batch_size_current)
acc_top1.update(acc1.item(), batch_size_current)
acc_top5.update(acc5.item(), batch_size_current)
trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current)
MIoU1_meter.update(MIoU1, batch_size_current)
MIoU2_meter.update(MIoU2, batch_size_current)
A0_acc.update(a0_acc, batch_size_current)
AT_acc.update(aT_acc, batch_size_current)
return torch.tensor(losses.avg), torch.tensor(acc_top1.avg), torch.tensor(acc_top5.avg), \
torch.tensor(trajectory_success_rate_meter.avg), \
torch.tensor(MIoU1_meter.avg), torch.tensor(MIoU2_meter.avg), \
torch.tensor(A0_acc.avg), torch.tensor(AT_acc.avg)
def validate_mlp(val_loader, model, act_model, args, act_emb):
model.eval()
losses = AverageMeter()
acc_top1 = AverageMeter()
acc_top5 = AverageMeter()
trajectory_success_rate_meter = AverageMeter()
MIoU1_meter = AverageMeter()
MIoU2_meter = AverageMeter()
A0_acc = AverageMeter()
AT_acc = AverageMeter()
for i_batch, sample_batch in enumerate(val_loader):
# compute output
global_img_tensors = sample_batch[0].cuda().contiguous().float()
video_label = sample_batch[1].cuda()
batch_size_current, T = video_label.size()
task_class = sample_batch[2].view(-1).cuda()
cond = {}
with torch.no_grad():
cond[0] = global_img_tensors[:, 0, :].view((batch_size_current,32,48))
cond[T - 1] = global_img_tensors[:, -1, :].view((batch_size_current,32,48))
task_onehot = torch.zeros((task_class.size(0), args.class_dim)) # [bs*T, ac_dim]
ind = torch.arange(0, len(task_class))
task_onehot[ind, task_class] = 1.
task_onehot = task_onehot.cuda()
temp = task_onehot.unsqueeze(1)
task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim]
video_label_reshaped = video_label.view(-1)
action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim))
ind = torch.arange(0, len(video_label_reshaped))
action_label_onehot[ind, video_label_reshaped] = 1.
action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda()
x_start = torch.zeros((batch_size_current, T, 32, 64))
x_start[:, 0, :, :48] = global_img_tensors[:, 0, :].view((batch_size_current,32,48))
x_start[:, -1, :, :48] = global_img_tensors[:, -1, :].view((batch_size_current,32,48))
for i in range(video_label.shape[0]):
for j in range(video_label.shape[1]):
x_start[i][j][:, 48:] = act_emb[video_label[i][j]].view(32,16)
output = model(cond, task_class, act_emb, if_jump=False, if_avg_mask=args.infer_avg_mask, cond_type=args.use_cls_mask)
actions_pred = output.contiguous()
loss = model.module.loss_fn(actions_pred, x_start.cuda())
actions_pred = output[:, :, :, 48:].contiguous().view(batch_size_current*args.horizon, 512)
#actions = act_model(actions_pred)
sim_act_pred = similarity_score(actions_pred, act_emb)
#print('actions',actions.shape, actions_pred.shape, video_label_reshaped.shape, video_label.shape)
#print('lebel', torch.max(actions, 1))
#print('lebel', video_label_reshaped)
#(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = accuracy(actions.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon)
(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = accuracy(sim_act_pred, video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon)
losses.update(loss.item(), batch_size_current)
acc_top1.update(acc1.item(), batch_size_current)
acc_top5.update(acc5.item(), batch_size_current)
trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current)
MIoU1_meter.update(MIoU1, batch_size_current)
MIoU2_meter.update(MIoU2, batch_size_current)
A0_acc.update(a0_acc, batch_size_current)
AT_acc.update(aT_acc, batch_size_current)
'''
with torch.no_grad():
cond[0] = global_img_tensors[:, 0, :]
cond[T - 1] = global_img_tensors[:, -1, :]
task_onehot = torch.zeros((task_class.size(0), args.class_dim)) # [bs*T, ac_dim]
ind = torch.arange(0, len(task_class))
task_onehot[ind, task_class] = 1.
task_onehot = task_onehot.cuda()
temp = task_onehot.unsqueeze(1)
task_class_ = temp.repeat(1, T, 1) # [bs, T, args.class_dim]
if args.use_cls_mask:
cond['task'] = task_class_
video_label_reshaped = video_label.view(-1)
action_label_onehot = torch.zeros((video_label_reshaped.size(0), args.action_dim))
ind = torch.arange(0, len(video_label_reshaped))
action_label_onehot[ind, video_label_reshaped] = 1.
action_label_onehot = action_label_onehot.reshape(batch_size_current, T, -1).cuda()
if args.use_cls_mask:
x_start = torch.zeros((batch_size_current, T, args.class_dim + args.action_dim + args.observation_dim))
x_start[:, 0, args.class_dim + args.action_dim:] = global_img_tensors[:, 0, :]
x_start[:, -1, args.class_dim + args.action_dim:] = global_img_tensors[:, -1, :]
x_start[:, :, args.class_dim:args.class_dim + args.action_dim] = action_label_onehot
x_start[:, :, :args.class_dim] = task_class_
else:
x_start = torch.zeros((batch_size_current, T, args.action_dim + args.observation_dim))
x_start[:, 0, args.action_dim:] = global_img_tensors[:, 0, :]
x_start[:, -1, args.action_dim:] = global_img_tensors[:, -1, :]
x_start[:, :, :args.action_dim] = action_label_onehot
output = model(cond, task_class, act_emb, if_jump=False, if_avg_mask=args.infer_avg_mask, cond_type=args.use_cls_mask)
actions_pred = output.contiguous()
loss = model.module.loss_fn(actions_pred, x_start.cuda())
if args.use_cls_mask:
actions_pred = actions_pred[:, :, args.class_dim:args.class_dim + args.action_dim].contiguous()
actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim]
else:
actions_pred = actions_pred[:, :, :args.action_dim].contiguous()
actions_pred = actions_pred.view(-1, args.action_dim) # [bs*T, action_dim]
(acc1, acc5), trajectory_success_rate, MIoU1, MIoU2, a0_acc, aT_acc = \
accuracy(actions_pred.cpu(), video_label_reshaped.cpu(), topk=(1, 5), max_traj_len=args.horizon)
losses.update(loss.item(), batch_size_current)
acc_top1.update(acc1.item(), batch_size_current)
acc_top5.update(acc5.item(), batch_size_current)
trajectory_success_rate_meter.update(trajectory_success_rate.item(), batch_size_current)
MIoU1_meter.update(MIoU1, batch_size_current)
MIoU2_meter.update(MIoU2, batch_size_current)
A0_acc.update(a0_acc, batch_size_current)
AT_acc.update(aT_acc, batch_size_current)
'''
return torch.tensor(losses.avg), torch.tensor(acc_top1.avg), torch.tensor(acc_top5.avg), \
torch.tensor(trajectory_success_rate_meter.avg), \
torch.tensor(MIoU1_meter.avg), torch.tensor(MIoU2_meter.avg), \
torch.tensor(A0_acc.avg), torch.tensor(AT_acc.avg)