mtomnet/boss/test.py
2025-01-10 15:39:20 +01:00

181 lines
No EOL
8.1 KiB
Python

import argparse
import torch
from tqdm import tqdm
import csv
import os
from torch.utils.data import DataLoader
from dataloader import DataTest
from models.resnet import ResNet, ResNetGRU, ResNetLSTM, ResNetConv1D
from models.tom_implicit import ImplicitToMnet
from models.tom_common_mind import CommonMindToMnet, CommonMindToMnetXL
from models.tom_sl import SLToMnet
from models.single_mindnet import SingleMindNet
from utils import tried_once, tried_twice, tried_thrice, friends, strangers, get_classification_accuracy, pad_collate, get_input_dim
def test(args):
if args.test_frames == 'friends':
test_frame_ids = friends
elif args.test_frames == 'strangers':
test_frame_ids = strangers
elif args.test_frames == 'once':
test_frame_ids = tried_once
elif args.test_frames == 'twice_thrice':
test_frame_ids = tried_twice + tried_thrice
elif args.test_frames is None:
test_frame_ids = None
else: raise NameError
if args.median is not None:
median = (240, False)
else:
median = None
if args.model_type == 'tom_cm' or args.model_type == 'tom_sl' or args.model_type == 'tom_impl' or args.model_type == 'tom_cm_xl' or args.model_type == 'tom_single':
flatten_dim = 2
else:
flatten_dim = 1
# load datasets
test_dataset = DataTest(
args.test_frame_path,
args.label_path,
args.test_pose_path,
args.test_gaze_path,
args.test_bbox_path,
args.ocr_graph_path,
args.presaved,
test_frame_ids,
median,
flatten_dim=flatten_dim
)
test_dataloader = DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=args.num_workers,
collate_fn=pad_collate,
pin_memory=args.pin_memory
)
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
assert args.load_model_path is not None
if args.model_type == 'resnet':
inp_dim = get_input_dim(args.mods)
model = ResNet(inp_dim, device).to(device)
elif args.model_type == 'gru':
inp_dim = get_input_dim(args.mods)
model = ResNetGRU(inp_dim, device).to(device)
elif args.model_type == 'lstm':
inp_dim = get_input_dim(args.mods)
model = ResNetLSTM(inp_dim, device).to(device)
elif args.model_type == 'conv1d':
inp_dim = get_input_dim(args.mods)
model = ResNetConv1D(inp_dim, device).to(device)
elif args.model_type == 'tom_cm':
model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
elif args.model_type == 'tom_impl':
model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
elif args.model_type == 'tom_sl':
model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
elif args.model_type == 'tom_single':
model = SingleMindNet(args.hidden_dim, device, args.use_resnet, args.dropout, args.mods).to(device)
elif args.model_type == 'tom_cm_xl':
model = CommonMindToMnetXL(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
else: raise NotImplementedError
model.load_state_dict(torch.load(args.load_model_path, map_location=device))
model.device = device
model.eval()
if args.save_preds:
# Define the output file path
folder_path = f'predictions/{os.path.dirname(args.load_model_path).split(os.path.sep)[-1]}_{args.test_frames}'
if not os.path.exists(folder_path):
os.makedirs(folder_path)
print(f'Saving predictions in {folder_path}.')
print('Testing...')
num_correct = 0
cnt = 0
with torch.no_grad():
for j, batch in tqdm(enumerate(test_dataloader)):
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = batch
if frames is not None: frames = frames.to(device, non_blocking=True)
if poses is not None: poses = poses.to(device, non_blocking=True)
if gazes is not None: gazes = gazes.to(device, non_blocking=True)
if bboxes is not None: bboxes = bboxes.to(device, non_blocking=True)
if ocr_graphs is not None: ocr_graphs = ocr_graphs.to(device, non_blocking=True)
pred_left_labels, pred_right_labels, repr = model(frames, poses, gazes, bboxes, ocr_graphs)
pred_left_labels = torch.reshape(pred_left_labels, (-1, 27))
pred_right_labels = torch.reshape(pred_right_labels, (-1, 27))
labels = torch.reshape(labels, (-1, 2)).to(device)
batch_acc, batch_num_correct, batch_num_pred = get_classification_accuracy(
pred_left_labels, pred_right_labels, labels, sequence_lengths
)
cnt += batch_num_pred
num_correct += batch_num_correct
if args.save_preds:
torch.save([r.cpu() for r in repr], os.path.join(folder_path, f"{j}.pt"))
data = [(
i,
torch.argmax(pred_left_labels[i]).cpu().numpy(),
torch.argmax(pred_right_labels[i]).cpu().numpy(),
labels[:, 0][i].cpu().numpy(),
labels[:, 1][i].cpu().numpy()) for i in range(len(labels))
]
header = ['frame', 'left_pred', 'right_pred', 'left_label', 'right_label']
with open(os.path.join(folder_path, f'{j}_{batch_acc:.2f}.csv'), mode='w', newline='') as file:
writer = csv.writer(file)
writer.writerow(header) # Write the header row
writer.writerows(data) # Write the data rows
test_acc = num_correct / cnt
print("Test accuracy: {}".format(num_correct / cnt))
with open(args.load_model_path.rsplit('/', 1)[0]+'/test_stats.txt', 'w') as f:
f.write(str(test_acc))
f.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Define the command-line arguments
parser.add_argument('--gpu_id', type=int)
parser.add_argument('--presaved', type=int, default=128)
parser.add_argument('--non_blocking', action='store_true')
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--pin_memory', action='store_true')
parser.add_argument('--model_type', type=str)
parser.add_argument('--aggr', type=str, default='mult', required=False)
parser.add_argument('--use_resnet', action='store_true')
parser.add_argument('--hidden_dim', type=int, default=64)
parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
parser.add_argument('--mods', nargs='+', type=str, default=['rgb', 'pose', 'gaze', 'ocr', 'bbox'])
parser.add_argument('--test_frame_path', type=str, default='/scratch/bortoletto/data/boss/test/frame')
parser.add_argument('--test_pose_path', type=str, default='/scratch/bortoletto/data/boss/test/pose')
parser.add_argument('--test_gaze_path', type=str, default='/scratch/bortoletto/data/boss/test/gaze')
parser.add_argument('--test_bbox_path', type=str, default='/scratch/bortoletto/data/boss/test/new_bbox/labels')
parser.add_argument('--ocr_graph_path', type=str, default='')
parser.add_argument('--label_path', type=str, default='outfile')
parser.add_argument('--save_path', type=str, default='experiments/')
parser.add_argument('--test_frames', type=str, default=None)
parser.add_argument('--median', type=int, default=None)
parser.add_argument('--load_model_path', type=str)
parser.add_argument('--dropout', type=float, default=0.0)
parser.add_argument('--save_preds', action='store_true')
# Parse the command-line arguments
args = parser.parse_args()
if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
if not args.aggr:
parser.error("The choosen --model_type requires --aggr")
if args.model_type == 'tom_sl' and not args.tom_weight:
parser.error("The choosen --model_type requires --tom_weight")
test(args)