mtomnet/tbd/test.py

196 lines
8.5 KiB
Python
Raw Normal View History

2025-01-10 15:39:20 +01:00
import torch
import csv
import argparse
from tqdm import tqdm
from torch.utils.data import DataLoader
import random
import os
import numpy as np
from tbd_dataloader import TBDDataset, collate_fn_test
from models.common_mind import CommonMindToMnet
from models.sl import SLToMnet
from models.implicit import ImplicitToMnet
from utils.helpers import compute_f1_scores
def test(args):
test_dataset = TBDDataset(
path=args.data_path,
mode="test",
use_preprocessed_img=True
)
test_dataloader = DataLoader(
test_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=args.pin_memory,
collate_fn=collate_fn_test
)
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
# model
if args.model_type == 'tom_cm':
model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
elif args.model_type == 'tom_sl':
model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
elif args.model_type == 'tom_impl':
model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
else: raise NotImplementedError
model.load_state_dict(torch.load(args.load_model_path, map_location=device))
model.device = device
model.eval()
if args.save_preds:
# Define the output file path
folder_path = f'predictions/{os.path.dirname(args.load_model_path).split(os.path.sep)[-1]}'
if not os.path.exists(folder_path):
os.makedirs(folder_path)
print(f'Saving predictions in {folder_path}.')
print('Testing...')
m1_pred_list = []
m2_pred_list = []
m12_pred_list = []
m21_pred_list = []
mc_pred_list = []
m1_label_list = []
m2_label_list = []
m12_label_list = []
m21_label_list = []
mc_label_list = []
with torch.no_grad():
for j, batch in tqdm(enumerate(test_dataloader)):
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep, false_belief = batch
if img_3rd_pov is not None: img_3rd_pov = img_3rd_pov.to(device, non_blocking=args.non_blocking)
if img_tracker is not None: img_tracker = img_tracker.to(device, non_blocking=args.non_blocking)
if img_battery is not None: img_battery = img_battery.to(device, non_blocking=args.non_blocking)
if pose1 is not None: pose1 = pose1.to(device, non_blocking=args.non_blocking)
if pose2 is not None: pose2 = pose2.to(device, non_blocking=args.non_blocking)
if bbox is not None: bbox = bbox.to(device, non_blocking=args.non_blocking)
if gaze is not None: gaze = gaze.to(device, non_blocking=args.non_blocking)
m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, repr = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
m1_pred = m1_pred.reshape(-1, 4)
m2_pred = m2_pred.reshape(-1, 4)
m12_pred = m12_pred.reshape(-1, 4)
m21_pred = m21_pred.reshape(-1, 4)
mc_pred = mc_pred.reshape(-1, 4)
m1_label = labels[:, 0].reshape(-1).to(device)
m2_label = labels[:, 1].reshape(-1).to(device)
m12_label = labels[:, 2].reshape(-1).to(device)
m21_label = labels[:, 3].reshape(-1).to(device)
mc_label = labels[:, 4].reshape(-1).to(device)
m1_pred_list.append(m1_pred)
m2_pred_list.append(m2_pred)
m12_pred_list.append(m12_pred)
m21_pred_list.append(m21_pred)
mc_pred_list.append(mc_pred)
m1_label_list.append(m1_label)
m2_label_list.append(m2_label)
m12_label_list.append(m12_label)
m21_label_list.append(m21_label)
mc_label_list.append(mc_label)
if args.save_preds:
torch.save([r.cpu() for r in repr], os.path.join(folder_path, f"{j}.pt"))
data = [(
i,
torch.argmax(m1_pred[i]).cpu().numpy(),
torch.argmax(m2_pred[i]).cpu().numpy(),
torch.argmax(m12_pred[i]).cpu().numpy(),
torch.argmax(m21_pred[i]).cpu().numpy(),
torch.argmax(mc_pred[i]).cpu().numpy(),
m1_label[i].cpu().numpy(),
m2_label[i].cpu().numpy(),
m12_label[i].cpu().numpy(),
m21_label[i].cpu().numpy(),
mc_label[i].cpu().numpy(),
false_belief[i]) for i in range(len(labels))
]
header = ['frame', 'm1_pred', 'm2_pred', 'm12_pred', 'm21_pred', 'mc_pred', 'm1_label', 'm2_label', 'm12_label', 'm21_label', 'mc_label', 'false_belief']
with open(os.path.join(folder_path, f'{j}.csv'), mode='w', newline='') as file:
writer = csv.writer(file)
writer.writerow(header) # Write the header row
writer.writerows(data) # Write the data rows
#np.savetxt('m1_label_bs1.txt', torch.cat(m1_label_list).cpu().numpy())
test_m1_f1, test_m2_f1, test_m12_f1, test_m21_f1, test_mc_f1 = compute_f1_scores(
torch.cat(m1_pred_list),
torch.cat(m1_label_list),
torch.cat(m2_pred_list),
torch.cat(m2_label_list),
torch.cat(m12_pred_list),
torch.cat(m12_label_list),
torch.cat(m21_pred_list),
torch.cat(m21_label_list),
torch.cat(mc_pred_list),
torch.cat(mc_label_list)
)
print("Test m1 F1: {}".format(test_m1_f1))
print("Test m2 F1: {}".format(test_m2_f1))
print("Test m12 F1: {}".format(test_m12_f1))
print("Test m21 F1: {}".format(test_m21_f1))
print("Test mc F1: {}".format(test_mc_f1))
with open(args.load_model_path.rsplit('/', 1)[0]+'/test_stats.txt', 'w') as f:
f.write(f"Test data:\n {[data[1] for data in test_dataset.data]}")
f.write(f"m1 f1: {test_m1_f1}")
f.write(f"m2 f1: {test_m2_f1}")
f.write(f"m12 f1: {test_m12_f1}")
f.write(f"m21 f1: {test_m21_f1}")
f.write(f"mc f1: {test_mc_f1}")
f.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Define the command-line arguments
parser.add_argument('--gpu_id', type=int)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--presaved', type=int, default=128)
parser.add_argument('--non_blocking', action='store_true')
parser.add_argument('--num_workers', type=int, default=16)
parser.add_argument('--pin_memory', action='store_true')
parser.add_argument('--model_type', type=str)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--aggr', type=str, default='concat', required=False)
parser.add_argument('--use_resnet', action='store_true')
parser.add_argument('--hidden_dim', type=int, default=64)
parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
parser.add_argument('--mods', nargs='+', type=str, default=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox'])
parser.add_argument('--data_path', type=str, default='/scratch/bortoletto/data/tbd')
parser.add_argument('--save_path', type=str, default='experiments/')
parser.add_argument('--test_frames', type=str, default=None)
parser.add_argument('--median', type=int, default=None)
parser.add_argument('--load_model_path', type=str)
parser.add_argument('--dropout', type=float, default=0.0)
parser.add_argument('--save_preds', action='store_true')
# Parse the command-line arguments
args = parser.parse_args()
if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
if not args.aggr:
parser.error("The choosen --model_type requires --aggr")
if args.model_type == 'tom_sl' and not args.tom_weight:
parser.error("The choosen --model_type requires --tom_weight")
os.environ['PYTHONHASHSEED'] = str(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
print('###########################################################################')
print('TESTING: MAKE SURE YOU ARE USING THE SAME RANDOM SEED USED DURING TRAINING!')
print('###########################################################################')
test(args)