mtomnet/boss/train.py

325 lines
16 KiB
Python
Raw Normal View History

2025-01-10 15:39:20 +01:00
import torch
import os
import argparse
import numpy as np
import random
import datetime
import wandb
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from dataloader import Data
from models.resnet import ResNet, ResNetGRU, ResNetLSTM, ResNetConv1D
from models.tom_implicit import ImplicitToMnet
from models.tom_common_mind import CommonMindToMnet, CommonMindToMnetXL
from models.tom_sl import SLToMnet
from models.tom_tf import TFToMnet
from models.single_mindnet import SingleMindNet
from utils import pad_collate, get_classification_accuracy, mixup, get_classification_accuracy_mixup, count_parameters, get_input_dim
def train(args):
if args.model_type == 'tom_cm' or args.model_type == 'tom_sl' or args.model_type == 'tom_impl' or args.model_type == 'tom_tf' or args.model_type == 'tom_cm_xl' or args.model_type == 'tom_single':
flatten_dim = 2
else:
flatten_dim = 1
train_dataset = Data(
args.train_frame_path,
args.label_path,
args.train_pose_path,
args.train_gaze_path,
args.train_bbox_path,
args.ocr_graph_path,
presaved=args.presaved,
flatten_dim=flatten_dim
)
train_dataloader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
collate_fn=pad_collate,
pin_memory=args.pin_memory
)
val_dataset = Data(
args.val_frame_path,
args.label_path,
args.val_pose_path,
args.val_gaze_path,
args.val_bbox_path,
args.ocr_graph_path,
presaved=args.presaved,
flatten_dim=flatten_dim
)
val_dataloader = DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=pad_collate,
pin_memory=args.pin_memory
)
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
if args.model_type == 'resnet':
inp_dim = get_input_dim(args.mods)
model = ResNet(inp_dim, device).to(device)
elif args.model_type == 'gru':
inp_dim = get_input_dim(args.mods)
model = ResNetGRU(inp_dim, device).to(device)
elif args.model_type == 'lstm':
inp_dim = get_input_dim(args.mods)
model = ResNetLSTM(inp_dim, device).to(device)
elif args.model_type == 'conv1d':
inp_dim = get_input_dim(args.mods)
model = ResNetConv1D(inp_dim, device).to(device)
elif args.model_type == 'tom_cm':
model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
elif args.model_type == 'tom_cm_xl':
model = CommonMindToMnetXL(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
elif args.model_type == 'tom_impl':
model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
elif args.model_type == 'tom_sl':
model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
elif args.model_type == 'tom_single':
model = SingleMindNet(args.hidden_dim, device, args.use_resnet, args.dropout, args.mods).to(device)
elif args.model_type == 'tom_tf':
model = TFToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.mods).to(device)
else: raise NotImplementedError
if args.resume_from_checkpoint is not None:
model.load_state_dict(torch.load(args.resume_from_checkpoint, map_location=device))
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
if args.scheduler == None:
scheduler = None
else:
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=3e-5)
if args.model_type == 'tom_sl': cross_entropy_loss = nn.NLLLoss(ignore_index=-1)
else: cross_entropy_loss = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing, ignore_index=-1).to(device)
stats = {'train': {'cls_loss': [], 'cls_acc': []}, 'val': {'cls_loss': [], 'cls_acc': []}}
max_val_classification_acc = 0
max_val_classification_epoch = None
counter = 0
print(f'Number of parameters: {count_parameters(model)}')
for i in range(args.num_epoch):
# training
print('Training for epoch {}/{}...'.format(i+1, args.num_epoch))
temp_train_classification_loss = []
epoch_num_correct = 0
epoch_cnt = 0
model.train()
for j, batch in tqdm(enumerate(train_dataloader)):
if args.use_mixup:
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = mixup(batch, args.mixup_alpha, 27)
else:
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = batch
if frames is not None: frames = frames.to(device, non_blocking=args.non_blocking)
if poses is not None: poses = poses.to(device, non_blocking=args.non_blocking)
if gazes is not None: gazes = gazes.to(device, non_blocking=args.non_blocking)
if bboxes is not None: bboxes = bboxes.to(device, non_blocking=args.non_blocking)
if ocr_graphs is not None: ocr_graphs = ocr_graphs.to(device, non_blocking=args.non_blocking)
pred_left_labels, pred_right_labels, _ = model(frames, poses, gazes, bboxes, ocr_graphs)
pred_left_labels = torch.reshape(pred_left_labels, (-1, 27))
pred_right_labels = torch.reshape(pred_right_labels, (-1, 27))
if args.use_mixup:
labels = torch.reshape(labels, (-1, 54)).to(device)
batch_train_acc, batch_num_correct, batch_num_pred = get_classification_accuracy_mixup(
pred_left_labels, pred_right_labels, labels, sequence_lengths
)
loss = cross_entropy_loss(pred_left_labels, labels[:, :27]) + cross_entropy_loss(pred_right_labels, labels[:, 27:])
else:
labels = torch.reshape(labels, (-1, 2)).to(device)
batch_train_acc, batch_num_correct, batch_num_pred = get_classification_accuracy(
pred_left_labels, pred_right_labels, labels, sequence_lengths
)
loss = cross_entropy_loss(pred_left_labels, labels[:, 0]) + cross_entropy_loss(pred_right_labels, labels[:, 1])
epoch_cnt += batch_num_pred
epoch_num_correct += batch_num_correct
temp_train_classification_loss.append(loss.data.item() * batch_num_pred / 2)
optimizer.zero_grad()
if args.clip_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
loss.backward()
optimizer.step()
if args.logger: wandb.log({'batch_train_acc': batch_train_acc, 'batch_train_loss': loss.data.item(), 'lr': optimizer.param_groups[-1]['lr']})
print("Epoch {}/{} batch {}/{} training done with cls loss={}, cls accuracy={}.".format(
i+1, args.num_epoch, j+1, len(train_dataloader), loss.data.item(), batch_train_acc)
)
if scheduler: scheduler.step()
print("Epoch {}/{} OVERALL train cls loss={}, cls accuracy={}.\n".format(
i+1, args.num_epoch, sum(temp_train_classification_loss) * 2 / epoch_cnt, epoch_num_correct / epoch_cnt)
)
stats['train']['cls_loss'].append(sum(temp_train_classification_loss) * 2 / epoch_cnt)
stats['train']['cls_acc'].append(epoch_num_correct / epoch_cnt)
# validation
print('Validation for epoch {}/{}...'.format(i+1, args.num_epoch))
temp_val_classification_loss = []
epoch_num_correct = 0
epoch_cnt = 0
model.eval()
with torch.no_grad():
for j, batch in tqdm(enumerate(val_dataloader)):
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = batch
if frames is not None: frames = frames.to(device, non_blocking=args.non_blocking)
if poses is not None: poses = poses.to(device, non_blocking=args.non_blocking)
if gazes is not None: gazes = gazes.to(device, non_blocking=args.non_blocking)
if bboxes is not None: bboxes = bboxes.to(device, non_blocking=args.non_blocking)
if ocr_graphs is not None: ocr_graphs = ocr_graphs.to(device, non_blocking=args.non_blocking)
pred_left_labels, pred_right_labels, _ = model(frames, poses, gazes, bboxes, ocr_graphs)
pred_left_labels = torch.reshape(pred_left_labels, (-1, 27))
pred_right_labels = torch.reshape(pred_right_labels, (-1, 27))
labels = torch.reshape(labels, (-1, 2)).to(device)
batch_val_acc, batch_num_correct, batch_num_pred = get_classification_accuracy(
pred_left_labels, pred_right_labels, labels, sequence_lengths
)
epoch_cnt += batch_num_pred
epoch_num_correct += batch_num_correct
loss = cross_entropy_loss(pred_left_labels, labels[:,0]) + cross_entropy_loss(pred_right_labels, labels[:,1])
temp_val_classification_loss.append(loss.data.item() * batch_num_pred / 2)
if args.logger: wandb.log({'batch_val_acc': batch_val_acc, 'batch_val_loss': loss.data.item()})
print("Epoch {}/{} batch {}/{} validation done with cls loss={}, cls accuracy={}.".format(
i+1, args.num_epoch, j+1, len(val_dataloader), loss.data.item(), batch_val_acc)
)
print("Epoch {}/{} OVERALL validation cls loss={}, cls accuracy={}.\n".format(
i+1, args.num_epoch, sum(temp_val_classification_loss) * 2 / epoch_cnt, epoch_num_correct / epoch_cnt)
)
cls_loss = sum(temp_val_classification_loss) * 2 / epoch_cnt
cls_acc = epoch_num_correct / epoch_cnt
stats['val']['cls_loss'].append(cls_loss)
stats['val']['cls_acc'].append(cls_acc)
if args.logger: wandb.log({'cls_loss': cls_loss, 'cls_acc': cls_acc, 'epoch': i})
# check for best stat/model using validation accuracy
if stats['val']['cls_acc'][-1] >= max_val_classification_acc:
max_val_classification_acc = stats['val']['cls_acc'][-1]
max_val_classification_epoch = i+1
torch.save(model.state_dict(), os.path.join(experiment_save_path, 'model'))
counter = 0
else:
counter += 1
print(f'EarlyStopping counter: {counter} out of {args.patience}.')
if counter >= args.patience:
break
with open(os.path.join(experiment_save_path, 'log.txt'), 'w') as f:
f.write('{}\n'.format(CFG))
f.write('{}\n'.format(stats))
f.write('Max val classification acc: epoch {}, {}\n'.format(max_val_classification_epoch, max_val_classification_acc))
f.close()
print(f'Results saved in {experiment_save_path}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Define the command-line arguments
parser.add_argument('--gpu_id', type=int)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--logger', action='store_true')
parser.add_argument('--presaved', type=int, default=128)
parser.add_argument('--clip_grad_norm', type=float, default=0.5)
parser.add_argument('--use_mixup', action='store_true')
parser.add_argument('--mixup_alpha', type=float, default=0.3, required=False)
parser.add_argument('--non_blocking', action='store_true')
parser.add_argument('--patience', type=int, default=99)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--pin_memory', action='store_true')
parser.add_argument('--num_epoch', type=int, default=300)
parser.add_argument('--lr', type=float, default=4e-4)
parser.add_argument('--scheduler', type=str, default=None)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--weight_decay', type=float, default=0.005)
parser.add_argument('--label_smoothing', type=float, default=0.1)
parser.add_argument('--model_type', type=str)
parser.add_argument('--aggr', type=str, default='mult', required=False)
parser.add_argument('--use_resnet', action='store_true')
parser.add_argument('--hidden_dim', type=int, default=64)
parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
parser.add_argument('--mods', nargs='+', type=str, default=['rgb', 'pose', 'gaze', 'ocr', 'bbox'])
parser.add_argument('--train_frame_path', type=str, default='/scratch/bortoletto/data/boss/train/frame')
parser.add_argument('--train_pose_path', type=str, default='/scratch/bortoletto/data/boss/train/pose')
parser.add_argument('--train_gaze_path', type=str, default='/scratch/bortoletto/data/boss/train/gaze')
parser.add_argument('--train_bbox_path', type=str, default='/scratch/bortoletto/data/boss/train/new_bbox/labels')
parser.add_argument('--val_frame_path', type=str, default='/scratch/bortoletto/data/boss/val/frame')
parser.add_argument('--val_pose_path', type=str, default='/scratch/bortoletto/data/boss/val/pose')
parser.add_argument('--val_gaze_path', type=str, default='/scratch/bortoletto/data/boss/val/gaze')
parser.add_argument('--val_bbox_path', type=str, default='/scratch/bortoletto/data/boss/val/new_bbox/labels')
parser.add_argument('--ocr_graph_path', type=str, default='')
parser.add_argument('--label_path', type=str, default='outfile')
parser.add_argument('--save_path', type=str, default='experiments/')
parser.add_argument('--resume_from_checkpoint', type=str, default=None)
# Parse the command-line arguments
args = parser.parse_args()
if args.use_mixup and not args.mixup_alpha:
parser.error("--use_mixup requires --mixup_alpha")
if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
if not args.aggr:
parser.error("The choosen --model_type requires --aggr")
if args.model_type == 'tom_sl' and not args.tom_weight:
parser.error("The choosen --model_type requires --tom_weight")
# get experiment ID
experiment_id = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_train'
if not os.path.exists(args.save_path):
os.makedirs(args.save_path, exist_ok=True)
experiment_save_path = os.path.join(args.save_path, experiment_id)
os.makedirs(experiment_save_path, exist_ok=True)
CFG = {
'use_ocr_custom_loss': 0,
'presaved': args.presaved,
'batch_size': args.batch_size,
'num_epoch': args.num_epoch,
'lr': args.lr,
'scheduler': args.scheduler,
'weight_decay': args.weight_decay,
'model_type': args.model_type,
'use_resnet': args.use_resnet,
'hidden_dim': args.hidden_dim,
'tom_weight': args.tom_weight,
'dropout': args.dropout,
'label_smoothing': args.label_smoothing,
'clip_grad_norm': args.clip_grad_norm,
'use_mixup': args.use_mixup,
'mixup_alpha': args.mixup_alpha,
'non_blocking_tensors': args.non_blocking,
'patience': args.patience,
'pin_memory': args.pin_memory,
'resume_from_checkpoint': args.resume_from_checkpoint,
'aggr': args.aggr,
'mods': args.mods,
'save_path': experiment_save_path ,
'seed': args.seed
}
print(CFG)
print(f'Saving results in {experiment_save_path}')
# set seed values
if args.logger:
wandb.init(project="boss", config=CFG)
os.environ['PYTHONHASHSEED'] = str(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
train(args)