181 lines
8.1 KiB
Python
181 lines
8.1 KiB
Python
|
import argparse
|
||
|
import torch
|
||
|
from tqdm import tqdm
|
||
|
import csv
|
||
|
import os
|
||
|
from torch.utils.data import DataLoader
|
||
|
|
||
|
from dataloader import DataTest
|
||
|
from models.resnet import ResNet, ResNetGRU, ResNetLSTM, ResNetConv1D
|
||
|
from models.tom_implicit import ImplicitToMnet
|
||
|
from models.tom_common_mind import CommonMindToMnet, CommonMindToMnetXL
|
||
|
from models.tom_sl import SLToMnet
|
||
|
from models.single_mindnet import SingleMindNet
|
||
|
from utils import tried_once, tried_twice, tried_thrice, friends, strangers, get_classification_accuracy, pad_collate, get_input_dim
|
||
|
|
||
|
|
||
|
def test(args):
|
||
|
|
||
|
if args.test_frames == 'friends':
|
||
|
test_frame_ids = friends
|
||
|
elif args.test_frames == 'strangers':
|
||
|
test_frame_ids = strangers
|
||
|
elif args.test_frames == 'once':
|
||
|
test_frame_ids = tried_once
|
||
|
elif args.test_frames == 'twice_thrice':
|
||
|
test_frame_ids = tried_twice + tried_thrice
|
||
|
elif args.test_frames is None:
|
||
|
test_frame_ids = None
|
||
|
else: raise NameError
|
||
|
|
||
|
if args.median is not None:
|
||
|
median = (240, False)
|
||
|
else:
|
||
|
median = None
|
||
|
|
||
|
if args.model_type == 'tom_cm' or args.model_type == 'tom_sl' or args.model_type == 'tom_impl' or args.model_type == 'tom_cm_xl' or args.model_type == 'tom_single':
|
||
|
flatten_dim = 2
|
||
|
else:
|
||
|
flatten_dim = 1
|
||
|
|
||
|
# load datasets
|
||
|
test_dataset = DataTest(
|
||
|
args.test_frame_path,
|
||
|
args.label_path,
|
||
|
args.test_pose_path,
|
||
|
args.test_gaze_path,
|
||
|
args.test_bbox_path,
|
||
|
args.ocr_graph_path,
|
||
|
args.presaved,
|
||
|
test_frame_ids,
|
||
|
median,
|
||
|
flatten_dim=flatten_dim
|
||
|
)
|
||
|
test_dataloader = DataLoader(
|
||
|
test_dataset,
|
||
|
batch_size=1,
|
||
|
shuffle=False,
|
||
|
num_workers=args.num_workers,
|
||
|
collate_fn=pad_collate,
|
||
|
pin_memory=args.pin_memory
|
||
|
)
|
||
|
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
|
||
|
assert args.load_model_path is not None
|
||
|
if args.model_type == 'resnet':
|
||
|
inp_dim = get_input_dim(args.mods)
|
||
|
model = ResNet(inp_dim, device).to(device)
|
||
|
elif args.model_type == 'gru':
|
||
|
inp_dim = get_input_dim(args.mods)
|
||
|
model = ResNetGRU(inp_dim, device).to(device)
|
||
|
elif args.model_type == 'lstm':
|
||
|
inp_dim = get_input_dim(args.mods)
|
||
|
model = ResNetLSTM(inp_dim, device).to(device)
|
||
|
elif args.model_type == 'conv1d':
|
||
|
inp_dim = get_input_dim(args.mods)
|
||
|
model = ResNetConv1D(inp_dim, device).to(device)
|
||
|
elif args.model_type == 'tom_cm':
|
||
|
model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
|
||
|
elif args.model_type == 'tom_impl':
|
||
|
model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
|
||
|
elif args.model_type == 'tom_sl':
|
||
|
model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
|
||
|
elif args.model_type == 'tom_single':
|
||
|
model = SingleMindNet(args.hidden_dim, device, args.use_resnet, args.dropout, args.mods).to(device)
|
||
|
elif args.model_type == 'tom_cm_xl':
|
||
|
model = CommonMindToMnetXL(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
|
||
|
else: raise NotImplementedError
|
||
|
model.load_state_dict(torch.load(args.load_model_path, map_location=device))
|
||
|
model.device = device
|
||
|
|
||
|
model.eval()
|
||
|
|
||
|
if args.save_preds:
|
||
|
# Define the output file path
|
||
|
folder_path = f'predictions/{os.path.dirname(args.load_model_path).split(os.path.sep)[-1]}_{args.test_frames}'
|
||
|
if not os.path.exists(folder_path):
|
||
|
os.makedirs(folder_path)
|
||
|
print(f'Saving predictions in {folder_path}.')
|
||
|
|
||
|
print('Testing...')
|
||
|
num_correct = 0
|
||
|
cnt = 0
|
||
|
with torch.no_grad():
|
||
|
for j, batch in tqdm(enumerate(test_dataloader)):
|
||
|
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = batch
|
||
|
if frames is not None: frames = frames.to(device, non_blocking=True)
|
||
|
if poses is not None: poses = poses.to(device, non_blocking=True)
|
||
|
if gazes is not None: gazes = gazes.to(device, non_blocking=True)
|
||
|
if bboxes is not None: bboxes = bboxes.to(device, non_blocking=True)
|
||
|
if ocr_graphs is not None: ocr_graphs = ocr_graphs.to(device, non_blocking=True)
|
||
|
pred_left_labels, pred_right_labels, repr = model(frames, poses, gazes, bboxes, ocr_graphs)
|
||
|
pred_left_labels = torch.reshape(pred_left_labels, (-1, 27))
|
||
|
pred_right_labels = torch.reshape(pred_right_labels, (-1, 27))
|
||
|
labels = torch.reshape(labels, (-1, 2)).to(device)
|
||
|
batch_acc, batch_num_correct, batch_num_pred = get_classification_accuracy(
|
||
|
pred_left_labels, pred_right_labels, labels, sequence_lengths
|
||
|
)
|
||
|
cnt += batch_num_pred
|
||
|
num_correct += batch_num_correct
|
||
|
|
||
|
if args.save_preds:
|
||
|
torch.save([r.cpu() for r in repr], os.path.join(folder_path, f"{j}.pt"))
|
||
|
data = [(
|
||
|
i,
|
||
|
torch.argmax(pred_left_labels[i]).cpu().numpy(),
|
||
|
torch.argmax(pred_right_labels[i]).cpu().numpy(),
|
||
|
labels[:, 0][i].cpu().numpy(),
|
||
|
labels[:, 1][i].cpu().numpy()) for i in range(len(labels))
|
||
|
]
|
||
|
header = ['frame', 'left_pred', 'right_pred', 'left_label', 'right_label']
|
||
|
with open(os.path.join(folder_path, f'{j}_{batch_acc:.2f}.csv'), mode='w', newline='') as file:
|
||
|
writer = csv.writer(file)
|
||
|
writer.writerow(header) # Write the header row
|
||
|
writer.writerows(data) # Write the data rows
|
||
|
|
||
|
test_acc = num_correct / cnt
|
||
|
print("Test accuracy: {}".format(num_correct / cnt))
|
||
|
|
||
|
with open(args.load_model_path.rsplit('/', 1)[0]+'/test_stats.txt', 'w') as f:
|
||
|
f.write(str(test_acc))
|
||
|
f.close()
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
|
||
|
parser = argparse.ArgumentParser()
|
||
|
|
||
|
# Define the command-line arguments
|
||
|
parser.add_argument('--gpu_id', type=int)
|
||
|
parser.add_argument('--presaved', type=int, default=128)
|
||
|
parser.add_argument('--non_blocking', action='store_true')
|
||
|
parser.add_argument('--num_workers', type=int, default=4)
|
||
|
parser.add_argument('--pin_memory', action='store_true')
|
||
|
parser.add_argument('--model_type', type=str)
|
||
|
parser.add_argument('--aggr', type=str, default='mult', required=False)
|
||
|
parser.add_argument('--use_resnet', action='store_true')
|
||
|
parser.add_argument('--hidden_dim', type=int, default=64)
|
||
|
parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
|
||
|
parser.add_argument('--mods', nargs='+', type=str, default=['rgb', 'pose', 'gaze', 'ocr', 'bbox'])
|
||
|
parser.add_argument('--test_frame_path', type=str, default='/scratch/bortoletto/data/boss/test/frame')
|
||
|
parser.add_argument('--test_pose_path', type=str, default='/scratch/bortoletto/data/boss/test/pose')
|
||
|
parser.add_argument('--test_gaze_path', type=str, default='/scratch/bortoletto/data/boss/test/gaze')
|
||
|
parser.add_argument('--test_bbox_path', type=str, default='/scratch/bortoletto/data/boss/test/new_bbox/labels')
|
||
|
parser.add_argument('--ocr_graph_path', type=str, default='')
|
||
|
parser.add_argument('--label_path', type=str, default='outfile')
|
||
|
parser.add_argument('--save_path', type=str, default='experiments/')
|
||
|
parser.add_argument('--test_frames', type=str, default=None)
|
||
|
parser.add_argument('--median', type=int, default=None)
|
||
|
parser.add_argument('--load_model_path', type=str)
|
||
|
parser.add_argument('--dropout', type=float, default=0.0)
|
||
|
parser.add_argument('--save_preds', action='store_true')
|
||
|
|
||
|
# Parse the command-line arguments
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
|
||
|
if not args.aggr:
|
||
|
parser.error("The choosen --model_type requires --aggr")
|
||
|
if args.model_type == 'tom_sl' and not args.tom_weight:
|
||
|
parser.error("The choosen --model_type requires --tom_weight")
|
||
|
|
||
|
test(args)
|