import argparse import torch from tqdm import tqdm import csv import os from torch.utils.data import DataLoader from dataloader import DataTest from models.resnet import ResNet, ResNetGRU, ResNetLSTM, ResNetConv1D from models.tom_implicit import ImplicitToMnet from models.tom_common_mind import CommonMindToMnet, CommonMindToMnetXL from models.tom_sl import SLToMnet from models.single_mindnet import SingleMindNet from utils import tried_once, tried_twice, tried_thrice, friends, strangers, get_classification_accuracy, pad_collate, get_input_dim def test(args): if args.test_frames == 'friends': test_frame_ids = friends elif args.test_frames == 'strangers': test_frame_ids = strangers elif args.test_frames == 'once': test_frame_ids = tried_once elif args.test_frames == 'twice_thrice': test_frame_ids = tried_twice + tried_thrice elif args.test_frames is None: test_frame_ids = None else: raise NameError if args.median is not None: median = (240, False) else: median = None if args.model_type == 'tom_cm' or args.model_type == 'tom_sl' or args.model_type == 'tom_impl' or args.model_type == 'tom_cm_xl' or args.model_type == 'tom_single': flatten_dim = 2 else: flatten_dim = 1 # load datasets test_dataset = DataTest( args.test_frame_path, args.label_path, args.test_pose_path, args.test_gaze_path, args.test_bbox_path, args.ocr_graph_path, args.presaved, test_frame_ids, median, flatten_dim=flatten_dim ) test_dataloader = DataLoader( test_dataset, batch_size=1, shuffle=False, num_workers=args.num_workers, collate_fn=pad_collate, pin_memory=args.pin_memory ) device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu') assert args.load_model_path is not None if args.model_type == 'resnet': inp_dim = get_input_dim(args.mods) model = ResNet(inp_dim, device).to(device) elif args.model_type == 'gru': inp_dim = get_input_dim(args.mods) model = ResNetGRU(inp_dim, device).to(device) elif args.model_type == 'lstm': inp_dim = get_input_dim(args.mods) model = ResNetLSTM(inp_dim, device).to(device) elif args.model_type == 'conv1d': inp_dim = get_input_dim(args.mods) model = ResNetConv1D(inp_dim, device).to(device) elif args.model_type == 'tom_cm': model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device) elif args.model_type == 'tom_impl': model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device) elif args.model_type == 'tom_sl': model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device) elif args.model_type == 'tom_single': model = SingleMindNet(args.hidden_dim, device, args.use_resnet, args.dropout, args.mods).to(device) elif args.model_type == 'tom_cm_xl': model = CommonMindToMnetXL(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device) else: raise NotImplementedError model.load_state_dict(torch.load(args.load_model_path, map_location=device)) model.device = device model.eval() if args.save_preds: # Define the output file path folder_path = f'predictions/{os.path.dirname(args.load_model_path).split(os.path.sep)[-1]}_{args.test_frames}' if not os.path.exists(folder_path): os.makedirs(folder_path) print(f'Saving predictions in {folder_path}.') print('Testing...') num_correct = 0 cnt = 0 with torch.no_grad(): for j, batch in tqdm(enumerate(test_dataloader)): frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = batch if frames is not None: frames = frames.to(device, non_blocking=True) if poses is not None: poses = poses.to(device, non_blocking=True) if gazes is not None: gazes = gazes.to(device, non_blocking=True) if bboxes is not None: bboxes = bboxes.to(device, non_blocking=True) if ocr_graphs is not None: ocr_graphs = ocr_graphs.to(device, non_blocking=True) pred_left_labels, pred_right_labels, repr = model(frames, poses, gazes, bboxes, ocr_graphs) pred_left_labels = torch.reshape(pred_left_labels, (-1, 27)) pred_right_labels = torch.reshape(pred_right_labels, (-1, 27)) labels = torch.reshape(labels, (-1, 2)).to(device) batch_acc, batch_num_correct, batch_num_pred = get_classification_accuracy( pred_left_labels, pred_right_labels, labels, sequence_lengths ) cnt += batch_num_pred num_correct += batch_num_correct if args.save_preds: torch.save([r.cpu() for r in repr], os.path.join(folder_path, f"{j}.pt")) data = [( i, torch.argmax(pred_left_labels[i]).cpu().numpy(), torch.argmax(pred_right_labels[i]).cpu().numpy(), labels[:, 0][i].cpu().numpy(), labels[:, 1][i].cpu().numpy()) for i in range(len(labels)) ] header = ['frame', 'left_pred', 'right_pred', 'left_label', 'right_label'] with open(os.path.join(folder_path, f'{j}_{batch_acc:.2f}.csv'), mode='w', newline='') as file: writer = csv.writer(file) writer.writerow(header) # Write the header row writer.writerows(data) # Write the data rows test_acc = num_correct / cnt print("Test accuracy: {}".format(num_correct / cnt)) with open(args.load_model_path.rsplit('/', 1)[0]+'/test_stats.txt', 'w') as f: f.write(str(test_acc)) f.close() if __name__ == '__main__': parser = argparse.ArgumentParser() # Define the command-line arguments parser.add_argument('--gpu_id', type=int) parser.add_argument('--presaved', type=int, default=128) parser.add_argument('--non_blocking', action='store_true') parser.add_argument('--num_workers', type=int, default=4) parser.add_argument('--pin_memory', action='store_true') parser.add_argument('--model_type', type=str) parser.add_argument('--aggr', type=str, default='mult', required=False) parser.add_argument('--use_resnet', action='store_true') parser.add_argument('--hidden_dim', type=int, default=64) parser.add_argument('--tom_weight', type=float, default=2.0, required=False) parser.add_argument('--mods', nargs='+', type=str, default=['rgb', 'pose', 'gaze', 'ocr', 'bbox']) parser.add_argument('--test_frame_path', type=str, default='/scratch/bortoletto/data/boss/test/frame') parser.add_argument('--test_pose_path', type=str, default='/scratch/bortoletto/data/boss/test/pose') parser.add_argument('--test_gaze_path', type=str, default='/scratch/bortoletto/data/boss/test/gaze') parser.add_argument('--test_bbox_path', type=str, default='/scratch/bortoletto/data/boss/test/new_bbox/labels') parser.add_argument('--ocr_graph_path', type=str, default='') parser.add_argument('--label_path', type=str, default='outfile') parser.add_argument('--save_path', type=str, default='experiments/') parser.add_argument('--test_frames', type=str, default=None) parser.add_argument('--median', type=int, default=None) parser.add_argument('--load_model_path', type=str) parser.add_argument('--dropout', type=float, default=0.0) parser.add_argument('--save_preds', action='store_true') # Parse the command-line arguments args = parser.parse_args() if args.model_type == 'tom_cm' or args.model_type == 'tom_impl': if not args.aggr: parser.error("The choosen --model_type requires --aggr") if args.model_type == 'tom_sl' and not args.tom_weight: parser.error("The choosen --model_type requires --tom_weight") test(args)