324 lines
16 KiB
Python
324 lines
16 KiB
Python
import torch
|
|
import os
|
|
import argparse
|
|
import numpy as np
|
|
import random
|
|
import datetime
|
|
import wandb
|
|
from tqdm import tqdm
|
|
from torch.utils.data import DataLoader
|
|
import torch.nn as nn
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
|
|
from dataloader import Data
|
|
from models.resnet import ResNet, ResNetGRU, ResNetLSTM, ResNetConv1D
|
|
from models.tom_implicit import ImplicitToMnet
|
|
from models.tom_common_mind import CommonMindToMnet, CommonMindToMnetXL
|
|
from models.tom_sl import SLToMnet
|
|
from models.tom_tf import TFToMnet
|
|
from models.single_mindnet import SingleMindNet
|
|
from utils import pad_collate, get_classification_accuracy, mixup, get_classification_accuracy_mixup, count_parameters, get_input_dim
|
|
|
|
|
|
def train(args):
|
|
|
|
if args.model_type == 'tom_cm' or args.model_type == 'tom_sl' or args.model_type == 'tom_impl' or args.model_type == 'tom_tf' or args.model_type == 'tom_cm_xl' or args.model_type == 'tom_single':
|
|
flatten_dim = 2
|
|
else:
|
|
flatten_dim = 1
|
|
|
|
train_dataset = Data(
|
|
args.train_frame_path,
|
|
args.label_path,
|
|
args.train_pose_path,
|
|
args.train_gaze_path,
|
|
args.train_bbox_path,
|
|
args.ocr_graph_path,
|
|
presaved=args.presaved,
|
|
flatten_dim=flatten_dim
|
|
)
|
|
train_dataloader = DataLoader(
|
|
train_dataset,
|
|
batch_size=args.batch_size,
|
|
shuffle=True,
|
|
num_workers=args.num_workers,
|
|
collate_fn=pad_collate,
|
|
pin_memory=args.pin_memory
|
|
)
|
|
val_dataset = Data(
|
|
args.val_frame_path,
|
|
args.label_path,
|
|
args.val_pose_path,
|
|
args.val_gaze_path,
|
|
args.val_bbox_path,
|
|
args.ocr_graph_path,
|
|
presaved=args.presaved,
|
|
flatten_dim=flatten_dim
|
|
)
|
|
val_dataloader = DataLoader(
|
|
val_dataset,
|
|
batch_size=args.batch_size,
|
|
shuffle=False,
|
|
num_workers=args.num_workers,
|
|
collate_fn=pad_collate,
|
|
pin_memory=args.pin_memory
|
|
)
|
|
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
|
|
if args.model_type == 'resnet':
|
|
inp_dim = get_input_dim(args.mods)
|
|
model = ResNet(inp_dim, device).to(device)
|
|
elif args.model_type == 'gru':
|
|
inp_dim = get_input_dim(args.mods)
|
|
model = ResNetGRU(inp_dim, device).to(device)
|
|
elif args.model_type == 'lstm':
|
|
inp_dim = get_input_dim(args.mods)
|
|
model = ResNetLSTM(inp_dim, device).to(device)
|
|
elif args.model_type == 'conv1d':
|
|
inp_dim = get_input_dim(args.mods)
|
|
model = ResNetConv1D(inp_dim, device).to(device)
|
|
elif args.model_type == 'tom_cm':
|
|
model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
|
|
elif args.model_type == 'tom_cm_xl':
|
|
model = CommonMindToMnetXL(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
|
|
elif args.model_type == 'tom_impl':
|
|
model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
|
|
elif args.model_type == 'tom_sl':
|
|
model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
|
|
elif args.model_type == 'tom_single':
|
|
model = SingleMindNet(args.hidden_dim, device, args.use_resnet, args.dropout, args.mods).to(device)
|
|
elif args.model_type == 'tom_tf':
|
|
model = TFToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.mods).to(device)
|
|
else: raise NotImplementedError
|
|
if args.resume_from_checkpoint is not None:
|
|
model.load_state_dict(torch.load(args.resume_from_checkpoint, map_location=device))
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
|
if args.scheduler == None:
|
|
scheduler = None
|
|
else:
|
|
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=3e-5)
|
|
if args.model_type == 'tom_sl': cross_entropy_loss = nn.NLLLoss(ignore_index=-1)
|
|
else: cross_entropy_loss = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing, ignore_index=-1).to(device)
|
|
stats = {'train': {'cls_loss': [], 'cls_acc': []}, 'val': {'cls_loss': [], 'cls_acc': []}}
|
|
|
|
max_val_classification_acc = 0
|
|
max_val_classification_epoch = None
|
|
counter = 0
|
|
|
|
print(f'Number of parameters: {count_parameters(model)}')
|
|
|
|
for i in range(args.num_epoch):
|
|
# training
|
|
print('Training for epoch {}/{}...'.format(i+1, args.num_epoch))
|
|
temp_train_classification_loss = []
|
|
epoch_num_correct = 0
|
|
epoch_cnt = 0
|
|
model.train()
|
|
for j, batch in tqdm(enumerate(train_dataloader)):
|
|
if args.use_mixup:
|
|
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = mixup(batch, args.mixup_alpha, 27)
|
|
else:
|
|
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = batch
|
|
if frames is not None: frames = frames.to(device, non_blocking=args.non_blocking)
|
|
if poses is not None: poses = poses.to(device, non_blocking=args.non_blocking)
|
|
if gazes is not None: gazes = gazes.to(device, non_blocking=args.non_blocking)
|
|
if bboxes is not None: bboxes = bboxes.to(device, non_blocking=args.non_blocking)
|
|
if ocr_graphs is not None: ocr_graphs = ocr_graphs.to(device, non_blocking=args.non_blocking)
|
|
pred_left_labels, pred_right_labels, _ = model(frames, poses, gazes, bboxes, ocr_graphs)
|
|
pred_left_labels = torch.reshape(pred_left_labels, (-1, 27))
|
|
pred_right_labels = torch.reshape(pred_right_labels, (-1, 27))
|
|
if args.use_mixup:
|
|
labels = torch.reshape(labels, (-1, 54)).to(device)
|
|
batch_train_acc, batch_num_correct, batch_num_pred = get_classification_accuracy_mixup(
|
|
pred_left_labels, pred_right_labels, labels, sequence_lengths
|
|
)
|
|
loss = cross_entropy_loss(pred_left_labels, labels[:, :27]) + cross_entropy_loss(pred_right_labels, labels[:, 27:])
|
|
else:
|
|
labels = torch.reshape(labels, (-1, 2)).to(device)
|
|
batch_train_acc, batch_num_correct, batch_num_pred = get_classification_accuracy(
|
|
pred_left_labels, pred_right_labels, labels, sequence_lengths
|
|
)
|
|
loss = cross_entropy_loss(pred_left_labels, labels[:, 0]) + cross_entropy_loss(pred_right_labels, labels[:, 1])
|
|
epoch_cnt += batch_num_pred
|
|
epoch_num_correct += batch_num_correct
|
|
temp_train_classification_loss.append(loss.data.item() * batch_num_pred / 2)
|
|
|
|
optimizer.zero_grad()
|
|
if args.clip_grad_norm is not None:
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
if args.logger: wandb.log({'batch_train_acc': batch_train_acc, 'batch_train_loss': loss.data.item(), 'lr': optimizer.param_groups[-1]['lr']})
|
|
print("Epoch {}/{} batch {}/{} training done with cls loss={}, cls accuracy={}.".format(
|
|
i+1, args.num_epoch, j+1, len(train_dataloader), loss.data.item(), batch_train_acc)
|
|
)
|
|
|
|
if scheduler: scheduler.step()
|
|
|
|
print("Epoch {}/{} OVERALL train cls loss={}, cls accuracy={}.\n".format(
|
|
i+1, args.num_epoch, sum(temp_train_classification_loss) * 2 / epoch_cnt, epoch_num_correct / epoch_cnt)
|
|
)
|
|
stats['train']['cls_loss'].append(sum(temp_train_classification_loss) * 2 / epoch_cnt)
|
|
stats['train']['cls_acc'].append(epoch_num_correct / epoch_cnt)
|
|
|
|
# validation
|
|
print('Validation for epoch {}/{}...'.format(i+1, args.num_epoch))
|
|
temp_val_classification_loss = []
|
|
epoch_num_correct = 0
|
|
epoch_cnt = 0
|
|
model.eval()
|
|
with torch.no_grad():
|
|
for j, batch in tqdm(enumerate(val_dataloader)):
|
|
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = batch
|
|
if frames is not None: frames = frames.to(device, non_blocking=args.non_blocking)
|
|
if poses is not None: poses = poses.to(device, non_blocking=args.non_blocking)
|
|
if gazes is not None: gazes = gazes.to(device, non_blocking=args.non_blocking)
|
|
if bboxes is not None: bboxes = bboxes.to(device, non_blocking=args.non_blocking)
|
|
if ocr_graphs is not None: ocr_graphs = ocr_graphs.to(device, non_blocking=args.non_blocking)
|
|
pred_left_labels, pred_right_labels, _ = model(frames, poses, gazes, bboxes, ocr_graphs)
|
|
pred_left_labels = torch.reshape(pred_left_labels, (-1, 27))
|
|
pred_right_labels = torch.reshape(pred_right_labels, (-1, 27))
|
|
labels = torch.reshape(labels, (-1, 2)).to(device)
|
|
batch_val_acc, batch_num_correct, batch_num_pred = get_classification_accuracy(
|
|
pred_left_labels, pred_right_labels, labels, sequence_lengths
|
|
)
|
|
epoch_cnt += batch_num_pred
|
|
epoch_num_correct += batch_num_correct
|
|
loss = cross_entropy_loss(pred_left_labels, labels[:,0]) + cross_entropy_loss(pred_right_labels, labels[:,1])
|
|
temp_val_classification_loss.append(loss.data.item() * batch_num_pred / 2)
|
|
|
|
if args.logger: wandb.log({'batch_val_acc': batch_val_acc, 'batch_val_loss': loss.data.item()})
|
|
print("Epoch {}/{} batch {}/{} validation done with cls loss={}, cls accuracy={}.".format(
|
|
i+1, args.num_epoch, j+1, len(val_dataloader), loss.data.item(), batch_val_acc)
|
|
)
|
|
|
|
print("Epoch {}/{} OVERALL validation cls loss={}, cls accuracy={}.\n".format(
|
|
i+1, args.num_epoch, sum(temp_val_classification_loss) * 2 / epoch_cnt, epoch_num_correct / epoch_cnt)
|
|
)
|
|
|
|
cls_loss = sum(temp_val_classification_loss) * 2 / epoch_cnt
|
|
cls_acc = epoch_num_correct / epoch_cnt
|
|
stats['val']['cls_loss'].append(cls_loss)
|
|
stats['val']['cls_acc'].append(cls_acc)
|
|
if args.logger: wandb.log({'cls_loss': cls_loss, 'cls_acc': cls_acc, 'epoch': i})
|
|
|
|
# check for best stat/model using validation accuracy
|
|
if stats['val']['cls_acc'][-1] >= max_val_classification_acc:
|
|
max_val_classification_acc = stats['val']['cls_acc'][-1]
|
|
max_val_classification_epoch = i+1
|
|
torch.save(model.state_dict(), os.path.join(experiment_save_path, 'model'))
|
|
counter = 0
|
|
else:
|
|
counter += 1
|
|
print(f'EarlyStopping counter: {counter} out of {args.patience}.')
|
|
if counter >= args.patience:
|
|
break
|
|
|
|
with open(os.path.join(experiment_save_path, 'log.txt'), 'w') as f:
|
|
f.write('{}\n'.format(CFG))
|
|
f.write('{}\n'.format(stats))
|
|
f.write('Max val classification acc: epoch {}, {}\n'.format(max_val_classification_epoch, max_val_classification_acc))
|
|
f.close()
|
|
|
|
print(f'Results saved in {experiment_save_path}')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
# Define the command-line arguments
|
|
parser.add_argument('--gpu_id', type=int)
|
|
parser.add_argument('--seed', type=int, default=1)
|
|
parser.add_argument('--logger', action='store_true')
|
|
parser.add_argument('--presaved', type=int, default=128)
|
|
parser.add_argument('--clip_grad_norm', type=float, default=0.5)
|
|
parser.add_argument('--use_mixup', action='store_true')
|
|
parser.add_argument('--mixup_alpha', type=float, default=0.3, required=False)
|
|
parser.add_argument('--non_blocking', action='store_true')
|
|
parser.add_argument('--patience', type=int, default=99)
|
|
parser.add_argument('--batch_size', type=int, default=4)
|
|
parser.add_argument('--num_workers', type=int, default=4)
|
|
parser.add_argument('--pin_memory', action='store_true')
|
|
parser.add_argument('--num_epoch', type=int, default=300)
|
|
parser.add_argument('--lr', type=float, default=4e-4)
|
|
parser.add_argument('--scheduler', type=str, default=None)
|
|
parser.add_argument('--dropout', type=float, default=0.1)
|
|
parser.add_argument('--weight_decay', type=float, default=0.005)
|
|
parser.add_argument('--label_smoothing', type=float, default=0.1)
|
|
parser.add_argument('--model_type', type=str)
|
|
parser.add_argument('--aggr', type=str, default='mult', required=False)
|
|
parser.add_argument('--use_resnet', action='store_true')
|
|
parser.add_argument('--hidden_dim', type=int, default=64)
|
|
parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
|
|
parser.add_argument('--mods', nargs='+', type=str, default=['rgb', 'pose', 'gaze', 'ocr', 'bbox'])
|
|
parser.add_argument('--train_frame_path', type=str, default='/scratch/bortoletto/data/boss/train/frame')
|
|
parser.add_argument('--train_pose_path', type=str, default='/scratch/bortoletto/data/boss/train/pose')
|
|
parser.add_argument('--train_gaze_path', type=str, default='/scratch/bortoletto/data/boss/train/gaze')
|
|
parser.add_argument('--train_bbox_path', type=str, default='/scratch/bortoletto/data/boss/train/new_bbox/labels')
|
|
parser.add_argument('--val_frame_path', type=str, default='/scratch/bortoletto/data/boss/val/frame')
|
|
parser.add_argument('--val_pose_path', type=str, default='/scratch/bortoletto/data/boss/val/pose')
|
|
parser.add_argument('--val_gaze_path', type=str, default='/scratch/bortoletto/data/boss/val/gaze')
|
|
parser.add_argument('--val_bbox_path', type=str, default='/scratch/bortoletto/data/boss/val/new_bbox/labels')
|
|
parser.add_argument('--ocr_graph_path', type=str, default='')
|
|
parser.add_argument('--label_path', type=str, default='outfile')
|
|
parser.add_argument('--save_path', type=str, default='experiments/')
|
|
parser.add_argument('--resume_from_checkpoint', type=str, default=None)
|
|
|
|
# Parse the command-line arguments
|
|
args = parser.parse_args()
|
|
|
|
if args.use_mixup and not args.mixup_alpha:
|
|
parser.error("--use_mixup requires --mixup_alpha")
|
|
if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
|
|
if not args.aggr:
|
|
parser.error("The choosen --model_type requires --aggr")
|
|
if args.model_type == 'tom_sl' and not args.tom_weight:
|
|
parser.error("The choosen --model_type requires --tom_weight")
|
|
|
|
# get experiment ID
|
|
experiment_id = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_train'
|
|
if not os.path.exists(args.save_path):
|
|
os.makedirs(args.save_path, exist_ok=True)
|
|
experiment_save_path = os.path.join(args.save_path, experiment_id)
|
|
os.makedirs(experiment_save_path, exist_ok=True)
|
|
|
|
CFG = {
|
|
'use_ocr_custom_loss': 0,
|
|
'presaved': args.presaved,
|
|
'batch_size': args.batch_size,
|
|
'num_epoch': args.num_epoch,
|
|
'lr': args.lr,
|
|
'scheduler': args.scheduler,
|
|
'weight_decay': args.weight_decay,
|
|
'model_type': args.model_type,
|
|
'use_resnet': args.use_resnet,
|
|
'hidden_dim': args.hidden_dim,
|
|
'tom_weight': args.tom_weight,
|
|
'dropout': args.dropout,
|
|
'label_smoothing': args.label_smoothing,
|
|
'clip_grad_norm': args.clip_grad_norm,
|
|
'use_mixup': args.use_mixup,
|
|
'mixup_alpha': args.mixup_alpha,
|
|
'non_blocking_tensors': args.non_blocking,
|
|
'patience': args.patience,
|
|
'pin_memory': args.pin_memory,
|
|
'resume_from_checkpoint': args.resume_from_checkpoint,
|
|
'aggr': args.aggr,
|
|
'mods': args.mods,
|
|
'save_path': experiment_save_path ,
|
|
'seed': args.seed
|
|
}
|
|
|
|
print(CFG)
|
|
print(f'Saving results in {experiment_save_path}')
|
|
|
|
# set seed values
|
|
if args.logger:
|
|
wandb.init(project="boss", config=CFG)
|
|
os.environ['PYTHONHASHSEED'] = str(args.seed)
|
|
torch.manual_seed(args.seed)
|
|
np.random.seed(args.seed)
|
|
random.seed(args.seed)
|
|
|
|
train(args)
|