mtomnet/tbd/test.py
2025-01-10 15:39:20 +01:00

196 lines
No EOL
8.5 KiB
Python

import torch
import csv
import argparse
from tqdm import tqdm
from torch.utils.data import DataLoader
import random
import os
import numpy as np
from tbd_dataloader import TBDDataset, collate_fn_test
from models.common_mind import CommonMindToMnet
from models.sl import SLToMnet
from models.implicit import ImplicitToMnet
from utils.helpers import compute_f1_scores
def test(args):
test_dataset = TBDDataset(
path=args.data_path,
mode="test",
use_preprocessed_img=True
)
test_dataloader = DataLoader(
test_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=args.pin_memory,
collate_fn=collate_fn_test
)
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
# model
if args.model_type == 'tom_cm':
model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
elif args.model_type == 'tom_sl':
model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
elif args.model_type == 'tom_impl':
model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
else: raise NotImplementedError
model.load_state_dict(torch.load(args.load_model_path, map_location=device))
model.device = device
model.eval()
if args.save_preds:
# Define the output file path
folder_path = f'predictions/{os.path.dirname(args.load_model_path).split(os.path.sep)[-1]}'
if not os.path.exists(folder_path):
os.makedirs(folder_path)
print(f'Saving predictions in {folder_path}.')
print('Testing...')
m1_pred_list = []
m2_pred_list = []
m12_pred_list = []
m21_pred_list = []
mc_pred_list = []
m1_label_list = []
m2_label_list = []
m12_label_list = []
m21_label_list = []
mc_label_list = []
with torch.no_grad():
for j, batch in tqdm(enumerate(test_dataloader)):
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep, false_belief = batch
if img_3rd_pov is not None: img_3rd_pov = img_3rd_pov.to(device, non_blocking=args.non_blocking)
if img_tracker is not None: img_tracker = img_tracker.to(device, non_blocking=args.non_blocking)
if img_battery is not None: img_battery = img_battery.to(device, non_blocking=args.non_blocking)
if pose1 is not None: pose1 = pose1.to(device, non_blocking=args.non_blocking)
if pose2 is not None: pose2 = pose2.to(device, non_blocking=args.non_blocking)
if bbox is not None: bbox = bbox.to(device, non_blocking=args.non_blocking)
if gaze is not None: gaze = gaze.to(device, non_blocking=args.non_blocking)
m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, repr = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
m1_pred = m1_pred.reshape(-1, 4)
m2_pred = m2_pred.reshape(-1, 4)
m12_pred = m12_pred.reshape(-1, 4)
m21_pred = m21_pred.reshape(-1, 4)
mc_pred = mc_pred.reshape(-1, 4)
m1_label = labels[:, 0].reshape(-1).to(device)
m2_label = labels[:, 1].reshape(-1).to(device)
m12_label = labels[:, 2].reshape(-1).to(device)
m21_label = labels[:, 3].reshape(-1).to(device)
mc_label = labels[:, 4].reshape(-1).to(device)
m1_pred_list.append(m1_pred)
m2_pred_list.append(m2_pred)
m12_pred_list.append(m12_pred)
m21_pred_list.append(m21_pred)
mc_pred_list.append(mc_pred)
m1_label_list.append(m1_label)
m2_label_list.append(m2_label)
m12_label_list.append(m12_label)
m21_label_list.append(m21_label)
mc_label_list.append(mc_label)
if args.save_preds:
torch.save([r.cpu() for r in repr], os.path.join(folder_path, f"{j}.pt"))
data = [(
i,
torch.argmax(m1_pred[i]).cpu().numpy(),
torch.argmax(m2_pred[i]).cpu().numpy(),
torch.argmax(m12_pred[i]).cpu().numpy(),
torch.argmax(m21_pred[i]).cpu().numpy(),
torch.argmax(mc_pred[i]).cpu().numpy(),
m1_label[i].cpu().numpy(),
m2_label[i].cpu().numpy(),
m12_label[i].cpu().numpy(),
m21_label[i].cpu().numpy(),
mc_label[i].cpu().numpy(),
false_belief[i]) for i in range(len(labels))
]
header = ['frame', 'm1_pred', 'm2_pred', 'm12_pred', 'm21_pred', 'mc_pred', 'm1_label', 'm2_label', 'm12_label', 'm21_label', 'mc_label', 'false_belief']
with open(os.path.join(folder_path, f'{j}.csv'), mode='w', newline='') as file:
writer = csv.writer(file)
writer.writerow(header) # Write the header row
writer.writerows(data) # Write the data rows
#np.savetxt('m1_label_bs1.txt', torch.cat(m1_label_list).cpu().numpy())
test_m1_f1, test_m2_f1, test_m12_f1, test_m21_f1, test_mc_f1 = compute_f1_scores(
torch.cat(m1_pred_list),
torch.cat(m1_label_list),
torch.cat(m2_pred_list),
torch.cat(m2_label_list),
torch.cat(m12_pred_list),
torch.cat(m12_label_list),
torch.cat(m21_pred_list),
torch.cat(m21_label_list),
torch.cat(mc_pred_list),
torch.cat(mc_label_list)
)
print("Test m1 F1: {}".format(test_m1_f1))
print("Test m2 F1: {}".format(test_m2_f1))
print("Test m12 F1: {}".format(test_m12_f1))
print("Test m21 F1: {}".format(test_m21_f1))
print("Test mc F1: {}".format(test_mc_f1))
with open(args.load_model_path.rsplit('/', 1)[0]+'/test_stats.txt', 'w') as f:
f.write(f"Test data:\n {[data[1] for data in test_dataset.data]}")
f.write(f"m1 f1: {test_m1_f1}")
f.write(f"m2 f1: {test_m2_f1}")
f.write(f"m12 f1: {test_m12_f1}")
f.write(f"m21 f1: {test_m21_f1}")
f.write(f"mc f1: {test_mc_f1}")
f.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Define the command-line arguments
parser.add_argument('--gpu_id', type=int)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--presaved', type=int, default=128)
parser.add_argument('--non_blocking', action='store_true')
parser.add_argument('--num_workers', type=int, default=16)
parser.add_argument('--pin_memory', action='store_true')
parser.add_argument('--model_type', type=str)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--aggr', type=str, default='concat', required=False)
parser.add_argument('--use_resnet', action='store_true')
parser.add_argument('--hidden_dim', type=int, default=64)
parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
parser.add_argument('--mods', nargs='+', type=str, default=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox'])
parser.add_argument('--data_path', type=str, default='/scratch/bortoletto/data/tbd')
parser.add_argument('--save_path', type=str, default='experiments/')
parser.add_argument('--test_frames', type=str, default=None)
parser.add_argument('--median', type=int, default=None)
parser.add_argument('--load_model_path', type=str)
parser.add_argument('--dropout', type=float, default=0.0)
parser.add_argument('--save_preds', action='store_true')
# Parse the command-line arguments
args = parser.parse_args()
if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
if not args.aggr:
parser.error("The choosen --model_type requires --aggr")
if args.model_type == 'tom_sl' and not args.tom_weight:
parser.error("The choosen --model_type requires --tom_weight")
os.environ['PYTHONHASHSEED'] = str(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
print('###########################################################################')
print('TESTING: MAKE SURE YOU ARE USING THE SAME RANDOM SEED USED DURING TRAINING!')
print('###########################################################################')
test(args)