import torch import csv import argparse from tqdm import tqdm from torch.utils.data import DataLoader import random import os import numpy as np from tbd_dataloader import TBDDataset, collate_fn_test from models.common_mind import CommonMindToMnet from models.sl import SLToMnet from models.implicit import ImplicitToMnet from utils.helpers import compute_f1_scores def test(args): test_dataset = TBDDataset( path=args.data_path, mode="test", use_preprocessed_img=True ) test_dataloader = DataLoader( test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_memory, collate_fn=collate_fn_test ) device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu') # model if args.model_type == 'tom_cm': model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device) elif args.model_type == 'tom_sl': model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device) elif args.model_type == 'tom_impl': model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device) else: raise NotImplementedError model.load_state_dict(torch.load(args.load_model_path, map_location=device)) model.device = device model.eval() if args.save_preds: # Define the output file path folder_path = f'predictions/{os.path.dirname(args.load_model_path).split(os.path.sep)[-1]}' if not os.path.exists(folder_path): os.makedirs(folder_path) print(f'Saving predictions in {folder_path}.') print('Testing...') m1_pred_list = [] m2_pred_list = [] m12_pred_list = [] m21_pred_list = [] mc_pred_list = [] m1_label_list = [] m2_label_list = [] m12_label_list = [] m21_label_list = [] mc_label_list = [] with torch.no_grad(): for j, batch in tqdm(enumerate(test_dataloader)): img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep, false_belief = batch if img_3rd_pov is not None: img_3rd_pov = img_3rd_pov.to(device, non_blocking=args.non_blocking) if img_tracker is not None: img_tracker = img_tracker.to(device, non_blocking=args.non_blocking) if img_battery is not None: img_battery = img_battery.to(device, non_blocking=args.non_blocking) if pose1 is not None: pose1 = pose1.to(device, non_blocking=args.non_blocking) if pose2 is not None: pose2 = pose2.to(device, non_blocking=args.non_blocking) if bbox is not None: bbox = bbox.to(device, non_blocking=args.non_blocking) if gaze is not None: gaze = gaze.to(device, non_blocking=args.non_blocking) m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, repr = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze) m1_pred = m1_pred.reshape(-1, 4) m2_pred = m2_pred.reshape(-1, 4) m12_pred = m12_pred.reshape(-1, 4) m21_pred = m21_pred.reshape(-1, 4) mc_pred = mc_pred.reshape(-1, 4) m1_label = labels[:, 0].reshape(-1).to(device) m2_label = labels[:, 1].reshape(-1).to(device) m12_label = labels[:, 2].reshape(-1).to(device) m21_label = labels[:, 3].reshape(-1).to(device) mc_label = labels[:, 4].reshape(-1).to(device) m1_pred_list.append(m1_pred) m2_pred_list.append(m2_pred) m12_pred_list.append(m12_pred) m21_pred_list.append(m21_pred) mc_pred_list.append(mc_pred) m1_label_list.append(m1_label) m2_label_list.append(m2_label) m12_label_list.append(m12_label) m21_label_list.append(m21_label) mc_label_list.append(mc_label) if args.save_preds: torch.save([r.cpu() for r in repr], os.path.join(folder_path, f"{j}.pt")) data = [( i, torch.argmax(m1_pred[i]).cpu().numpy(), torch.argmax(m2_pred[i]).cpu().numpy(), torch.argmax(m12_pred[i]).cpu().numpy(), torch.argmax(m21_pred[i]).cpu().numpy(), torch.argmax(mc_pred[i]).cpu().numpy(), m1_label[i].cpu().numpy(), m2_label[i].cpu().numpy(), m12_label[i].cpu().numpy(), m21_label[i].cpu().numpy(), mc_label[i].cpu().numpy(), false_belief[i]) for i in range(len(labels)) ] header = ['frame', 'm1_pred', 'm2_pred', 'm12_pred', 'm21_pred', 'mc_pred', 'm1_label', 'm2_label', 'm12_label', 'm21_label', 'mc_label', 'false_belief'] with open(os.path.join(folder_path, f'{j}.csv'), mode='w', newline='') as file: writer = csv.writer(file) writer.writerow(header) # Write the header row writer.writerows(data) # Write the data rows #np.savetxt('m1_label_bs1.txt', torch.cat(m1_label_list).cpu().numpy()) test_m1_f1, test_m2_f1, test_m12_f1, test_m21_f1, test_mc_f1 = compute_f1_scores( torch.cat(m1_pred_list), torch.cat(m1_label_list), torch.cat(m2_pred_list), torch.cat(m2_label_list), torch.cat(m12_pred_list), torch.cat(m12_label_list), torch.cat(m21_pred_list), torch.cat(m21_label_list), torch.cat(mc_pred_list), torch.cat(mc_label_list) ) print("Test m1 F1: {}".format(test_m1_f1)) print("Test m2 F1: {}".format(test_m2_f1)) print("Test m12 F1: {}".format(test_m12_f1)) print("Test m21 F1: {}".format(test_m21_f1)) print("Test mc F1: {}".format(test_mc_f1)) with open(args.load_model_path.rsplit('/', 1)[0]+'/test_stats.txt', 'w') as f: f.write(f"Test data:\n {[data[1] for data in test_dataset.data]}") f.write(f"m1 f1: {test_m1_f1}") f.write(f"m2 f1: {test_m2_f1}") f.write(f"m12 f1: {test_m12_f1}") f.write(f"m21 f1: {test_m21_f1}") f.write(f"mc f1: {test_mc_f1}") f.close() if __name__ == '__main__': parser = argparse.ArgumentParser() # Define the command-line arguments parser.add_argument('--gpu_id', type=int) parser.add_argument('--seed', type=int, default=1) parser.add_argument('--presaved', type=int, default=128) parser.add_argument('--non_blocking', action='store_true') parser.add_argument('--num_workers', type=int, default=16) parser.add_argument('--pin_memory', action='store_true') parser.add_argument('--model_type', type=str) parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--aggr', type=str, default='concat', required=False) parser.add_argument('--use_resnet', action='store_true') parser.add_argument('--hidden_dim', type=int, default=64) parser.add_argument('--tom_weight', type=float, default=2.0, required=False) parser.add_argument('--mods', nargs='+', type=str, default=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox']) parser.add_argument('--data_path', type=str, default='/scratch/bortoletto/data/tbd') parser.add_argument('--save_path', type=str, default='experiments/') parser.add_argument('--test_frames', type=str, default=None) parser.add_argument('--median', type=int, default=None) parser.add_argument('--load_model_path', type=str) parser.add_argument('--dropout', type=float, default=0.0) parser.add_argument('--save_preds', action='store_true') # Parse the command-line arguments args = parser.parse_args() if args.model_type == 'tom_cm' or args.model_type == 'tom_impl': if not args.aggr: parser.error("The choosen --model_type requires --aggr") if args.model_type == 'tom_sl' and not args.tom_weight: parser.error("The choosen --model_type requires --tom_weight") os.environ['PYTHONHASHSEED'] = str(args.seed) torch.manual_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) print('###########################################################################') print('TESTING: MAKE SURE YOU ARE USING THE SAME RANDOM SEED USED DURING TRAINING!') print('###########################################################################') test(args)