196 lines
8.5 KiB
Python
196 lines
8.5 KiB
Python
|
import torch
|
||
|
import csv
|
||
|
import argparse
|
||
|
from tqdm import tqdm
|
||
|
from torch.utils.data import DataLoader
|
||
|
import random
|
||
|
import os
|
||
|
import numpy as np
|
||
|
|
||
|
from tbd_dataloader import TBDDataset, collate_fn_test
|
||
|
from models.common_mind import CommonMindToMnet
|
||
|
from models.sl import SLToMnet
|
||
|
from models.implicit import ImplicitToMnet
|
||
|
from utils.helpers import compute_f1_scores
|
||
|
|
||
|
|
||
|
def test(args):
|
||
|
|
||
|
test_dataset = TBDDataset(
|
||
|
path=args.data_path,
|
||
|
mode="test",
|
||
|
use_preprocessed_img=True
|
||
|
)
|
||
|
test_dataloader = DataLoader(
|
||
|
test_dataset,
|
||
|
batch_size=args.batch_size,
|
||
|
shuffle=False,
|
||
|
num_workers=args.num_workers,
|
||
|
pin_memory=args.pin_memory,
|
||
|
collate_fn=collate_fn_test
|
||
|
)
|
||
|
|
||
|
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
|
||
|
|
||
|
# model
|
||
|
if args.model_type == 'tom_cm':
|
||
|
model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
|
||
|
elif args.model_type == 'tom_sl':
|
||
|
model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
|
||
|
elif args.model_type == 'tom_impl':
|
||
|
model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
|
||
|
else: raise NotImplementedError
|
||
|
|
||
|
model.load_state_dict(torch.load(args.load_model_path, map_location=device))
|
||
|
model.device = device
|
||
|
|
||
|
model.eval()
|
||
|
|
||
|
if args.save_preds:
|
||
|
# Define the output file path
|
||
|
folder_path = f'predictions/{os.path.dirname(args.load_model_path).split(os.path.sep)[-1]}'
|
||
|
if not os.path.exists(folder_path):
|
||
|
os.makedirs(folder_path)
|
||
|
print(f'Saving predictions in {folder_path}.')
|
||
|
|
||
|
print('Testing...')
|
||
|
m1_pred_list = []
|
||
|
m2_pred_list = []
|
||
|
m12_pred_list = []
|
||
|
m21_pred_list = []
|
||
|
mc_pred_list = []
|
||
|
m1_label_list = []
|
||
|
m2_label_list = []
|
||
|
m12_label_list = []
|
||
|
m21_label_list = []
|
||
|
mc_label_list = []
|
||
|
with torch.no_grad():
|
||
|
for j, batch in tqdm(enumerate(test_dataloader)):
|
||
|
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep, false_belief = batch
|
||
|
if img_3rd_pov is not None: img_3rd_pov = img_3rd_pov.to(device, non_blocking=args.non_blocking)
|
||
|
if img_tracker is not None: img_tracker = img_tracker.to(device, non_blocking=args.non_blocking)
|
||
|
if img_battery is not None: img_battery = img_battery.to(device, non_blocking=args.non_blocking)
|
||
|
if pose1 is not None: pose1 = pose1.to(device, non_blocking=args.non_blocking)
|
||
|
if pose2 is not None: pose2 = pose2.to(device, non_blocking=args.non_blocking)
|
||
|
if bbox is not None: bbox = bbox.to(device, non_blocking=args.non_blocking)
|
||
|
if gaze is not None: gaze = gaze.to(device, non_blocking=args.non_blocking)
|
||
|
m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, repr = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
|
||
|
m1_pred = m1_pred.reshape(-1, 4)
|
||
|
m2_pred = m2_pred.reshape(-1, 4)
|
||
|
m12_pred = m12_pred.reshape(-1, 4)
|
||
|
m21_pred = m21_pred.reshape(-1, 4)
|
||
|
mc_pred = mc_pred.reshape(-1, 4)
|
||
|
m1_label = labels[:, 0].reshape(-1).to(device)
|
||
|
m2_label = labels[:, 1].reshape(-1).to(device)
|
||
|
m12_label = labels[:, 2].reshape(-1).to(device)
|
||
|
m21_label = labels[:, 3].reshape(-1).to(device)
|
||
|
mc_label = labels[:, 4].reshape(-1).to(device)
|
||
|
|
||
|
m1_pred_list.append(m1_pred)
|
||
|
m2_pred_list.append(m2_pred)
|
||
|
m12_pred_list.append(m12_pred)
|
||
|
m21_pred_list.append(m21_pred)
|
||
|
mc_pred_list.append(mc_pred)
|
||
|
m1_label_list.append(m1_label)
|
||
|
m2_label_list.append(m2_label)
|
||
|
m12_label_list.append(m12_label)
|
||
|
m21_label_list.append(m21_label)
|
||
|
mc_label_list.append(mc_label)
|
||
|
|
||
|
if args.save_preds:
|
||
|
torch.save([r.cpu() for r in repr], os.path.join(folder_path, f"{j}.pt"))
|
||
|
data = [(
|
||
|
i,
|
||
|
torch.argmax(m1_pred[i]).cpu().numpy(),
|
||
|
torch.argmax(m2_pred[i]).cpu().numpy(),
|
||
|
torch.argmax(m12_pred[i]).cpu().numpy(),
|
||
|
torch.argmax(m21_pred[i]).cpu().numpy(),
|
||
|
torch.argmax(mc_pred[i]).cpu().numpy(),
|
||
|
m1_label[i].cpu().numpy(),
|
||
|
m2_label[i].cpu().numpy(),
|
||
|
m12_label[i].cpu().numpy(),
|
||
|
m21_label[i].cpu().numpy(),
|
||
|
mc_label[i].cpu().numpy(),
|
||
|
false_belief[i]) for i in range(len(labels))
|
||
|
]
|
||
|
header = ['frame', 'm1_pred', 'm2_pred', 'm12_pred', 'm21_pred', 'mc_pred', 'm1_label', 'm2_label', 'm12_label', 'm21_label', 'mc_label', 'false_belief']
|
||
|
with open(os.path.join(folder_path, f'{j}.csv'), mode='w', newline='') as file:
|
||
|
writer = csv.writer(file)
|
||
|
writer.writerow(header) # Write the header row
|
||
|
writer.writerows(data) # Write the data rows
|
||
|
|
||
|
#np.savetxt('m1_label_bs1.txt', torch.cat(m1_label_list).cpu().numpy())
|
||
|
test_m1_f1, test_m2_f1, test_m12_f1, test_m21_f1, test_mc_f1 = compute_f1_scores(
|
||
|
torch.cat(m1_pred_list),
|
||
|
torch.cat(m1_label_list),
|
||
|
torch.cat(m2_pred_list),
|
||
|
torch.cat(m2_label_list),
|
||
|
torch.cat(m12_pred_list),
|
||
|
torch.cat(m12_label_list),
|
||
|
torch.cat(m21_pred_list),
|
||
|
torch.cat(m21_label_list),
|
||
|
torch.cat(mc_pred_list),
|
||
|
torch.cat(mc_label_list)
|
||
|
)
|
||
|
|
||
|
print("Test m1 F1: {}".format(test_m1_f1))
|
||
|
print("Test m2 F1: {}".format(test_m2_f1))
|
||
|
print("Test m12 F1: {}".format(test_m12_f1))
|
||
|
print("Test m21 F1: {}".format(test_m21_f1))
|
||
|
print("Test mc F1: {}".format(test_mc_f1))
|
||
|
|
||
|
with open(args.load_model_path.rsplit('/', 1)[0]+'/test_stats.txt', 'w') as f:
|
||
|
f.write(f"Test data:\n {[data[1] for data in test_dataset.data]}")
|
||
|
f.write(f"m1 f1: {test_m1_f1}")
|
||
|
f.write(f"m2 f1: {test_m2_f1}")
|
||
|
f.write(f"m12 f1: {test_m12_f1}")
|
||
|
f.write(f"m21 f1: {test_m21_f1}")
|
||
|
f.write(f"mc f1: {test_mc_f1}")
|
||
|
f.close()
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
|
||
|
parser = argparse.ArgumentParser()
|
||
|
|
||
|
# Define the command-line arguments
|
||
|
parser.add_argument('--gpu_id', type=int)
|
||
|
parser.add_argument('--seed', type=int, default=1)
|
||
|
parser.add_argument('--presaved', type=int, default=128)
|
||
|
parser.add_argument('--non_blocking', action='store_true')
|
||
|
parser.add_argument('--num_workers', type=int, default=16)
|
||
|
parser.add_argument('--pin_memory', action='store_true')
|
||
|
parser.add_argument('--model_type', type=str)
|
||
|
parser.add_argument('--batch_size', type=int, default=64)
|
||
|
parser.add_argument('--aggr', type=str, default='concat', required=False)
|
||
|
parser.add_argument('--use_resnet', action='store_true')
|
||
|
parser.add_argument('--hidden_dim', type=int, default=64)
|
||
|
parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
|
||
|
parser.add_argument('--mods', nargs='+', type=str, default=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox'])
|
||
|
parser.add_argument('--data_path', type=str, default='/scratch/bortoletto/data/tbd')
|
||
|
parser.add_argument('--save_path', type=str, default='experiments/')
|
||
|
parser.add_argument('--test_frames', type=str, default=None)
|
||
|
parser.add_argument('--median', type=int, default=None)
|
||
|
parser.add_argument('--load_model_path', type=str)
|
||
|
parser.add_argument('--dropout', type=float, default=0.0)
|
||
|
parser.add_argument('--save_preds', action='store_true')
|
||
|
|
||
|
# Parse the command-line arguments
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
|
||
|
if not args.aggr:
|
||
|
parser.error("The choosen --model_type requires --aggr")
|
||
|
if args.model_type == 'tom_sl' and not args.tom_weight:
|
||
|
parser.error("The choosen --model_type requires --tom_weight")
|
||
|
|
||
|
os.environ['PYTHONHASHSEED'] = str(args.seed)
|
||
|
torch.manual_seed(args.seed)
|
||
|
np.random.seed(args.seed)
|
||
|
random.seed(args.seed)
|
||
|
|
||
|
print('###########################################################################')
|
||
|
print('TESTING: MAKE SURE YOU ARE USING THE SAME RANDOM SEED USED DURING TRAINING!')
|
||
|
print('###########################################################################')
|
||
|
|
||
|
test(args)
|