mtomnet/boss/test.py

181 lines
8.1 KiB
Python
Raw Normal View History

2025-01-10 15:39:20 +01:00
import argparse
import torch
from tqdm import tqdm
import csv
import os
from torch.utils.data import DataLoader
from dataloader import DataTest
from models.resnet import ResNet, ResNetGRU, ResNetLSTM, ResNetConv1D
from models.tom_implicit import ImplicitToMnet
from models.tom_common_mind import CommonMindToMnet, CommonMindToMnetXL
from models.tom_sl import SLToMnet
from models.single_mindnet import SingleMindNet
from utils import tried_once, tried_twice, tried_thrice, friends, strangers, get_classification_accuracy, pad_collate, get_input_dim
def test(args):
if args.test_frames == 'friends':
test_frame_ids = friends
elif args.test_frames == 'strangers':
test_frame_ids = strangers
elif args.test_frames == 'once':
test_frame_ids = tried_once
elif args.test_frames == 'twice_thrice':
test_frame_ids = tried_twice + tried_thrice
elif args.test_frames is None:
test_frame_ids = None
else: raise NameError
if args.median is not None:
median = (240, False)
else:
median = None
if args.model_type == 'tom_cm' or args.model_type == 'tom_sl' or args.model_type == 'tom_impl' or args.model_type == 'tom_cm_xl' or args.model_type == 'tom_single':
flatten_dim = 2
else:
flatten_dim = 1
# load datasets
test_dataset = DataTest(
args.test_frame_path,
args.label_path,
args.test_pose_path,
args.test_gaze_path,
args.test_bbox_path,
args.ocr_graph_path,
args.presaved,
test_frame_ids,
median,
flatten_dim=flatten_dim
)
test_dataloader = DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=args.num_workers,
collate_fn=pad_collate,
pin_memory=args.pin_memory
)
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
assert args.load_model_path is not None
if args.model_type == 'resnet':
inp_dim = get_input_dim(args.mods)
model = ResNet(inp_dim, device).to(device)
elif args.model_type == 'gru':
inp_dim = get_input_dim(args.mods)
model = ResNetGRU(inp_dim, device).to(device)
elif args.model_type == 'lstm':
inp_dim = get_input_dim(args.mods)
model = ResNetLSTM(inp_dim, device).to(device)
elif args.model_type == 'conv1d':
inp_dim = get_input_dim(args.mods)
model = ResNetConv1D(inp_dim, device).to(device)
elif args.model_type == 'tom_cm':
model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
elif args.model_type == 'tom_impl':
model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
elif args.model_type == 'tom_sl':
model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
elif args.model_type == 'tom_single':
model = SingleMindNet(args.hidden_dim, device, args.use_resnet, args.dropout, args.mods).to(device)
elif args.model_type == 'tom_cm_xl':
model = CommonMindToMnetXL(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
else: raise NotImplementedError
model.load_state_dict(torch.load(args.load_model_path, map_location=device))
model.device = device
model.eval()
if args.save_preds:
# Define the output file path
folder_path = f'predictions/{os.path.dirname(args.load_model_path).split(os.path.sep)[-1]}_{args.test_frames}'
if not os.path.exists(folder_path):
os.makedirs(folder_path)
print(f'Saving predictions in {folder_path}.')
print('Testing...')
num_correct = 0
cnt = 0
with torch.no_grad():
for j, batch in tqdm(enumerate(test_dataloader)):
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = batch
if frames is not None: frames = frames.to(device, non_blocking=True)
if poses is not None: poses = poses.to(device, non_blocking=True)
if gazes is not None: gazes = gazes.to(device, non_blocking=True)
if bboxes is not None: bboxes = bboxes.to(device, non_blocking=True)
if ocr_graphs is not None: ocr_graphs = ocr_graphs.to(device, non_blocking=True)
pred_left_labels, pred_right_labels, repr = model(frames, poses, gazes, bboxes, ocr_graphs)
pred_left_labels = torch.reshape(pred_left_labels, (-1, 27))
pred_right_labels = torch.reshape(pred_right_labels, (-1, 27))
labels = torch.reshape(labels, (-1, 2)).to(device)
batch_acc, batch_num_correct, batch_num_pred = get_classification_accuracy(
pred_left_labels, pred_right_labels, labels, sequence_lengths
)
cnt += batch_num_pred
num_correct += batch_num_correct
if args.save_preds:
torch.save([r.cpu() for r in repr], os.path.join(folder_path, f"{j}.pt"))
data = [(
i,
torch.argmax(pred_left_labels[i]).cpu().numpy(),
torch.argmax(pred_right_labels[i]).cpu().numpy(),
labels[:, 0][i].cpu().numpy(),
labels[:, 1][i].cpu().numpy()) for i in range(len(labels))
]
header = ['frame', 'left_pred', 'right_pred', 'left_label', 'right_label']
with open(os.path.join(folder_path, f'{j}_{batch_acc:.2f}.csv'), mode='w', newline='') as file:
writer = csv.writer(file)
writer.writerow(header) # Write the header row
writer.writerows(data) # Write the data rows
test_acc = num_correct / cnt
print("Test accuracy: {}".format(num_correct / cnt))
with open(args.load_model_path.rsplit('/', 1)[0]+'/test_stats.txt', 'w') as f:
f.write(str(test_acc))
f.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Define the command-line arguments
parser.add_argument('--gpu_id', type=int)
parser.add_argument('--presaved', type=int, default=128)
parser.add_argument('--non_blocking', action='store_true')
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--pin_memory', action='store_true')
parser.add_argument('--model_type', type=str)
parser.add_argument('--aggr', type=str, default='mult', required=False)
parser.add_argument('--use_resnet', action='store_true')
parser.add_argument('--hidden_dim', type=int, default=64)
parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
parser.add_argument('--mods', nargs='+', type=str, default=['rgb', 'pose', 'gaze', 'ocr', 'bbox'])
parser.add_argument('--test_frame_path', type=str, default='/scratch/bortoletto/data/boss/test/frame')
parser.add_argument('--test_pose_path', type=str, default='/scratch/bortoletto/data/boss/test/pose')
parser.add_argument('--test_gaze_path', type=str, default='/scratch/bortoletto/data/boss/test/gaze')
parser.add_argument('--test_bbox_path', type=str, default='/scratch/bortoletto/data/boss/test/new_bbox/labels')
parser.add_argument('--ocr_graph_path', type=str, default='')
parser.add_argument('--label_path', type=str, default='outfile')
parser.add_argument('--save_path', type=str, default='experiments/')
parser.add_argument('--test_frames', type=str, default=None)
parser.add_argument('--median', type=int, default=None)
parser.add_argument('--load_model_path', type=str)
parser.add_argument('--dropout', type=float, default=0.0)
parser.add_argument('--save_preds', action='store_true')
# Parse the command-line arguments
args = parser.parse_args()
if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
if not args.aggr:
parser.error("The choosen --model_type requires --aggr")
if args.model_type == 'tom_sl' and not args.tom_weight:
parser.error("The choosen --model_type requires --tom_weight")
test(args)