mtomnet/tbd/train.py

474 lines
22 KiB
Python
Raw Normal View History

2025-01-10 15:39:20 +01:00
import torch
import os
import argparse
import numpy as np
import random
import datetime
import wandb
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from tbd_dataloader import TBDDataset, collate_fn
from models.common_mind import CommonMindToMnet
from models.sl import SLToMnet
from models.implicit import ImplicitToMnet
from utils.helpers import count_parameters, compute_f1_scores
def main(args):
train_dataset = TBDDataset(
path=args.data_path,
mode="train",
use_preprocessed_img=True
)
train_dataloader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=args.pin_memory,
collate_fn=collate_fn
)
val_dataset = TBDDataset(
path=args.data_path,
mode="val",
use_preprocessed_img=True
)
val_dataloader = DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=args.pin_memory,
collate_fn=collate_fn
)
train_data = [data[1] for data in train_dataset.data]
val_data = [data[1] for data in val_dataset.data]
if args.logger:
wandb.config.update({"train_data": train_data})
wandb.config.update({"val_data": val_data})
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
# model
if args.model_type == 'tom_cm':
model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
elif args.model_type == 'tom_sl':
model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
elif args.model_type == 'tom_impl':
model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
else: raise NotImplementedError
if args.resume_from_checkpoint is not None:
model.load_state_dict(torch.load(args.resume_from_checkpoint, map_location=device))
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# scheduler
if args.scheduler == None:
scheduler = None
else:
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=3e-5)
# loss function
if args.model_type == 'tom_sl':
ce_loss_m1 = nn.NLLLoss()
ce_loss_m2 = nn.NLLLoss()
ce_loss_m12 = nn.NLLLoss()
ce_loss_m21 = nn.NLLLoss()
ce_loss_mc = nn.NLLLoss()
else:
ce_loss_m1 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
ce_loss_m2 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
ce_loss_m12 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
ce_loss_m21 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
ce_loss_mc = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
stats = {
'train': {'loss_m1': [], 'loss_m2': [], 'loss_m12': [], 'loss_m21': [], 'loss_mc': [], 'm1_f1': [], 'm2_f1': [], 'm12_f1': [], 'm21_f1': [], 'mc_f1': []},
'val': {'loss_m1': [], 'loss_m2': [], 'loss_m12': [], 'loss_m21': [], 'loss_mc': [], 'm1_f1': [], 'm2_f1': [], 'm12_f1': [], 'm21_f1': [], 'mc_f1': []}
}
max_val_f1 = 0
max_val_classification_epoch = None
counter = 0
print(f'Number of parameters: {count_parameters(model)}')
for i in range(args.num_epoch):
# training
print('Training for epoch {}/{}...'.format(i+1, args.num_epoch))
epoch_train_loss_m1 = 0.0
epoch_train_loss_m2 = 0.0
epoch_train_loss_m12 = 0.0
epoch_train_loss_m21 = 0.0
epoch_train_loss_mc = 0.0
m1_train_batch_pred_list = []
m2_train_batch_pred_list = []
m12_train_batch_pred_list = []
m21_train_batch_pred_list = []
mc_train_batch_pred_list = []
m1_train_batch_label_list = []
m2_train_batch_label_list = []
m12_train_batch_label_list = []
m21_train_batch_label_list = []
mc_train_batch_label_list = []
model.train()
for j, batch in tqdm(enumerate(train_dataloader)):
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep = batch
if img_3rd_pov is not None: img_3rd_pov = img_3rd_pov.to(device, non_blocking=args.non_blocking)
if img_tracker is not None: img_tracker = img_tracker.to(device, non_blocking=args.non_blocking)
if img_battery is not None: img_battery = img_battery.to(device, non_blocking=args.non_blocking)
if pose1 is not None: pose1 = pose1.to(device, non_blocking=args.non_blocking)
if pose2 is not None: pose2 = pose2.to(device, non_blocking=args.non_blocking)
if bbox is not None: bbox = bbox.to(device, non_blocking=args.non_blocking)
if gaze is not None: gaze = gaze.to(device, non_blocking=args.non_blocking)
m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, _ = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
m1_pred = m1_pred.reshape(-1, 4)
m2_pred = m2_pred.reshape(-1, 4)
m12_pred = m12_pred.reshape(-1, 4)
m21_pred = m21_pred.reshape(-1, 4)
mc_pred = mc_pred.reshape(-1, 4)
m1_label = labels[:, 0].reshape(-1).to(device)
m2_label = labels[:, 1].reshape(-1).to(device)
m12_label = labels[:, 2].reshape(-1).to(device)
m21_label = labels[:, 3].reshape(-1).to(device)
mc_label = labels[:, 4].reshape(-1).to(device)
loss_m1 = ce_loss_m1(m1_pred, m1_label)
loss_m2 = ce_loss_m2(m2_pred, m2_label)
loss_m12 = ce_loss_m12(m12_pred, m12_label)
loss_m21 = ce_loss_m21(m21_pred, m21_label)
loss_mc = ce_loss_mc(mc_pred, mc_label)
loss = loss_m1 + loss_m2 + loss_m12 + loss_m21 + loss_mc
epoch_train_loss_m1 += loss_m1.data.item()
epoch_train_loss_m2 += loss_m2.data.item()
epoch_train_loss_m12 += loss_m12.data.item()
epoch_train_loss_m21 += loss_m21.data.item()
epoch_train_loss_mc += loss_mc.data.item()
m1_train_batch_pred_list.append(m1_pred)
m2_train_batch_pred_list.append(m2_pred)
m12_train_batch_pred_list.append(m12_pred)
m21_train_batch_pred_list.append(m21_pred)
mc_train_batch_pred_list.append(mc_pred)
m1_train_batch_label_list.append(m1_label)
m2_train_batch_label_list.append(m2_label)
m12_train_batch_label_list.append(m12_label)
m21_train_batch_label_list.append(m21_label)
mc_train_batch_label_list.append(mc_label)
optimizer.zero_grad()
if args.clip_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
loss.backward()
optimizer.step()
if args.logger: wandb.log({
'batch_train_loss': loss.data.item(),
'lr': optimizer.param_groups[-1]['lr']
})
print("Epoch {}/{} batch {}/{} training done with loss={}".format(
i+1, args.num_epoch, j+1, len(train_dataloader), loss.data.item())
)
if scheduler: scheduler.step()
train_m1_f1_score, train_m2_f1_score, train_m12_f1_score, train_m21_f1_score, train_mc_f1_score = compute_f1_scores(
torch.cat(m1_train_batch_pred_list),
torch.cat(m1_train_batch_label_list),
torch.cat(m2_train_batch_pred_list),
torch.cat(m2_train_batch_label_list),
torch.cat(m12_train_batch_pred_list),
torch.cat(m12_train_batch_label_list),
torch.cat(m21_train_batch_pred_list),
torch.cat(m21_train_batch_label_list),
torch.cat(mc_train_batch_pred_list),
torch.cat(mc_train_batch_label_list)
)
print("Epoch {}/{} OVERALL train m1_loss={}, m2_loss={}, m12_loss={}, m21_loss={}, mc_loss={}, m1_f1={}, m2_f1={}, m12_f1={}, m21_f1={}.\n".format(
i+1,
args.num_epoch,
epoch_train_loss_m1/len(train_dataloader),
epoch_train_loss_m2/len(train_dataloader),
epoch_train_loss_m12/len(train_dataloader),
epoch_train_loss_m21/len(train_dataloader),
epoch_train_loss_mc/len(train_dataloader),
train_m1_f1_score, train_m2_f1_score, train_m12_f1_score, train_m21_f1_score, train_mc_f1_score
)
)
stats['train']['loss_m1'].append(epoch_train_loss_m1/len(train_dataloader))
stats['train']['loss_m2'].append(epoch_train_loss_m2/len(train_dataloader))
stats['train']['loss_m12'].append(epoch_train_loss_m12/len(train_dataloader))
stats['train']['loss_m21'].append(epoch_train_loss_m21/len(train_dataloader))
stats['train']['loss_mc'].append(epoch_train_loss_mc/len(train_dataloader))
stats['train']['m1_f1'].append(train_m1_f1_score)
stats['train']['m2_f1'].append(train_m2_f1_score)
stats['train']['m12_f1'].append(train_m12_f1_score)
stats['train']['m21_f1'].append(train_m21_f1_score)
stats['train']['mc_f1'].append(train_mc_f1_score)
if args.logger: wandb.log(
{
'train_m1_loss': epoch_train_loss_m1/len(train_dataloader),
'train_m2_loss': epoch_train_loss_m2/len(train_dataloader),
'train_m12_loss': epoch_train_loss_m12/len(train_dataloader),
'train_m21_loss': epoch_train_loss_m21/len(train_dataloader),
'train_mc_loss': epoch_train_loss_mc/len(train_dataloader),
'train_loss': epoch_train_loss_m1/len(train_dataloader) + \
epoch_train_loss_m2/len(train_dataloader) + \
epoch_train_loss_m12/len(train_dataloader) + \
epoch_train_loss_m21/len(train_dataloader) + \
epoch_train_loss_mc/len(train_dataloader),
'train_m1_f1_score': train_m1_f1_score,
'train_m2_f1_score': train_m2_f1_score,
'train_m12_f1_score': train_m12_f1_score,
'train_m21_f1_score': train_m21_f1_score,
'train_mc_f1_score': train_mc_f1_score
}
)
# validation
print('Validation for epoch {}/{}...'.format(i+1, args.num_epoch))
epoch_val_loss_m1 = 0.0
epoch_val_loss_m2 = 0.0
epoch_val_loss_m12 = 0.0
epoch_val_loss_m21 = 0.0
epoch_val_loss_mc = 0.0
m1_val_batch_pred_list = []
m2_val_batch_pred_list = []
m12_val_batch_pred_list = []
m21_val_batch_pred_list = []
mc_val_batch_pred_list = []
m1_val_batch_label_list = []
m2_val_batch_label_list = []
m12_val_batch_label_list = []
m21_val_batch_label_list = []
mc_val_batch_label_list = []
model.eval()
with torch.no_grad():
for j, batch in tqdm(enumerate(val_dataloader)):
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep = batch
if img_3rd_pov is not None: img_3rd_pov = img_3rd_pov.to(device, non_blocking=args.non_blocking)
if img_tracker is not None: img_tracker = img_tracker.to(device, non_blocking=args.non_blocking)
if img_battery is not None: img_battery = img_battery.to(device, non_blocking=args.non_blocking)
if pose1 is not None: pose1 = pose1.to(device, non_blocking=args.non_blocking)
if pose2 is not None: pose2 = pose2.to(device, non_blocking=args.non_blocking)
if bbox is not None: bbox = bbox.to(device, non_blocking=args.non_blocking)
if gaze is not None: gaze = gaze.to(device, non_blocking=args.non_blocking)
m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, _ = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
m1_pred = m1_pred.reshape(-1, 4)
m2_pred = m2_pred.reshape(-1, 4)
m12_pred = m12_pred.reshape(-1, 4)
m21_pred = m21_pred.reshape(-1, 4)
mc_pred = mc_pred.reshape(-1, 4)
m1_label = labels[:, 0].reshape(-1).to(device)
m2_label = labels[:, 1].reshape(-1).to(device)
m12_label = labels[:, 2].reshape(-1).to(device)
m21_label = labels[:, 3].reshape(-1).to(device)
mc_label = labels[:, 4].reshape(-1).to(device)
loss_m1 = ce_loss_m1(m1_pred, m1_label)
loss_m2 = ce_loss_m2(m2_pred, m2_label)
loss_m12 = ce_loss_m12(m12_pred, m12_label)
loss_m21 = ce_loss_m21(m21_pred, m21_label)
loss_mc = ce_loss_mc(mc_pred, mc_label)
loss = loss_m1 + loss_m2 + loss_m12 + loss_m21 + loss_mc
epoch_val_loss_m1 += loss_m1.data.item()
epoch_val_loss_m2 += loss_m2.data.item()
epoch_val_loss_m12 += loss_m12.data.item()
epoch_val_loss_m21 += loss_m21.data.item()
epoch_val_loss_mc += loss_mc.data.item()
m1_val_batch_pred_list.append(m1_pred)
m2_val_batch_pred_list.append(m2_pred)
m12_val_batch_pred_list.append(m12_pred)
m21_val_batch_pred_list.append(m21_pred)
mc_val_batch_pred_list.append(mc_pred)
m1_val_batch_label_list.append(m1_label)
m2_val_batch_label_list.append(m2_label)
m12_val_batch_label_list.append(m12_label)
m21_val_batch_label_list.append(m21_label)
mc_val_batch_label_list.append(mc_label)
if args.logger: wandb.log({'batch_val_loss': loss.data.item()})
print("Epoch {}/{} batch {}/{} validation done with loss={}".format(
i+1, args.num_epoch, j+1, len(val_dataloader), loss.data.item())
)
val_m1_f1_score, val_m2_f1_score, val_m12_f1_score, val_m21_f1_score, val_mc_f1_score = compute_f1_scores(
torch.cat(m1_val_batch_pred_list),
torch.cat(m1_val_batch_label_list),
torch.cat(m2_val_batch_pred_list),
torch.cat(m2_val_batch_label_list),
torch.cat(m12_val_batch_pred_list),
torch.cat(m12_val_batch_label_list),
torch.cat(m21_val_batch_pred_list),
torch.cat(m21_val_batch_label_list),
torch.cat(mc_val_batch_pred_list),
torch.cat(mc_val_batch_label_list)
)
print("Epoch {}/{} OVERALL validation m1_loss={}, m2_loss={}, m12_loss={}, m21_loss={}, mc_loss={}, m1_f1={}, m2_f1={}, m12_f1={}, m21_f1={}, mc_f1={}.\n".format(
i+1,
args.num_epoch,
epoch_val_loss_m1/len(val_dataloader),
epoch_val_loss_m2/len(val_dataloader),
epoch_val_loss_m12/len(val_dataloader),
epoch_val_loss_m21/len(val_dataloader),
epoch_val_loss_mc/len(val_dataloader),
val_m1_f1_score, val_m2_f1_score, val_m12_f1_score, val_m21_f1_score, val_mc_f1_score
)
)
stats['val']['loss_m1'].append(epoch_val_loss_m1/len(val_dataloader))
stats['val']['loss_m2'].append(epoch_val_loss_m2/len(val_dataloader))
stats['val']['loss_m12'].append(epoch_val_loss_m12/len(val_dataloader))
stats['val']['loss_m21'].append(epoch_val_loss_m21/len(val_dataloader))
stats['val']['loss_mc'].append(epoch_val_loss_mc/len(val_dataloader))
stats['val']['m1_f1'].append(val_m1_f1_score)
stats['val']['m2_f1'].append(val_m2_f1_score)
stats['val']['m12_f1'].append(val_m12_f1_score)
stats['val']['m21_f1'].append(val_m21_f1_score)
stats['val']['mc_f1'].append(val_mc_f1_score)
if args.logger: wandb.log(
{
'val_m1_loss': epoch_val_loss_m1/len(val_dataloader),
'val_m2_loss': epoch_val_loss_m2/len(val_dataloader),
'val_m12_loss': epoch_val_loss_m12/len(val_dataloader),
'val_m21_loss': epoch_val_loss_m21/len(val_dataloader),
'val_mc_loss': epoch_val_loss_mc/len(val_dataloader),
'val_loss': epoch_val_loss_m1/len(val_dataloader) + \
epoch_val_loss_m2/len(val_dataloader) + \
epoch_val_loss_m12/len(val_dataloader) + \
epoch_val_loss_m21/len(val_dataloader) + \
epoch_val_loss_mc/len(val_dataloader),
'val_m1_f1_score': val_m1_f1_score,
'val_m2_f1_score': val_m2_f1_score,
'val_m12_f1_score': val_m12_f1_score,
'val_m21_f1_score': val_m21_f1_score,
'val_mc_f1_score': val_mc_f1_score
}
)
# check for best stat/model using validation accuracy
if stats['val']['m1_f1'][-1] + stats['val']['m2_f1'][-1] + stats['val']['m12_f1'][-1] + stats['val']['m21_f1'][-1] + stats['val']['mc_f1'][-1] >= max_val_f1:
max_val_f1 = stats['val']['m1_f1'][-1] + stats['val']['m2_f1'][-1] + stats['val']['m12_f1'][-1] + stats['val']['m21_f1'][-1] + stats['val']['mc_f1'][-1]
max_val_classification_epoch = i+1
torch.save(model.state_dict(), os.path.join(experiment_save_path, 'model'))
counter = 0
else:
counter += 1
print(f'EarlyStopping counter: {counter} out of {args.patience}.')
if counter >= args.patience:
break
with open(os.path.join(experiment_save_path, 'log.txt'), 'w') as f:
f.write('{}\n'.format(CFG))
f.write('{}\n'.format(train_data))
f.write('{}\n'.format(val_data))
f.write('{}\n'.format(stats))
f.write('Max val classification acc: epoch {}, {}\n'.format(max_val_classification_epoch, max_val_f1))
f.close()
print(f'Results saved in {experiment_save_path}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Define the command-line arguments
parser.add_argument('--gpu_id', type=int)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--logger', action='store_true')
parser.add_argument('--presaved', type=int, default=128)
parser.add_argument('--clip_grad_norm', type=float, default=0.5)
parser.add_argument('--use_mixup', action='store_true')
parser.add_argument('--mixup_alpha', type=float, default=0.3, required=False)
parser.add_argument('--non_blocking', action='store_true')
parser.add_argument('--patience', type=int, default=99)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--pin_memory', action='store_true')
parser.add_argument('--num_epoch', type=int, default=300)
parser.add_argument('--lr', type=float, default=4e-4)
parser.add_argument('--scheduler', type=str, default=None)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--weight_decay', type=float, default=0.005)
parser.add_argument('--label_smoothing', type=float, default=0.1)
parser.add_argument('--model_type', type=str)
parser.add_argument('--aggr', type=str, default='concat', required=False)
parser.add_argument('--use_resnet', action='store_true')
parser.add_argument('--hidden_dim', type=int, default=64)
parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
parser.add_argument('--mods', nargs='+', type=str, default=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox'])
parser.add_argument('--data_path', type=str, default='/scratch/bortoletto/data/tbd')
parser.add_argument('--save_path', type=str, default='experiments/')
parser.add_argument('--resume_from_checkpoint', type=str, default=None)
# Parse the command-line arguments
args = parser.parse_args()
if args.use_mixup and not args.mixup_alpha:
parser.error("--use_mixup requires --mixup_alpha")
if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
if not args.aggr:
parser.error("The choosen --model_type requires --aggr")
if args.model_type == 'tom_sl' and not args.tom_weight:
parser.error("The choosen --model_type requires --tom_weight")
# get experiment ID
experiment_id = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_train'
if not os.path.exists(args.save_path):
os.makedirs(args.save_path, exist_ok=True)
experiment_save_path = os.path.join(args.save_path, experiment_id)
os.makedirs(experiment_save_path, exist_ok=True)
CFG = {
'use_ocr_custom_loss': 0,
'presaved': args.presaved,
'batch_size': args.batch_size,
'num_epoch': args.num_epoch,
'lr': args.lr,
'scheduler': args.scheduler,
'weight_decay': args.weight_decay,
'model_type': args.model_type,
'use_resnet': args.use_resnet,
'hidden_dim': args.hidden_dim,
'tom_weight': args.tom_weight,
'dropout': args.dropout,
'label_smoothing': args.label_smoothing,
'clip_grad_norm': args.clip_grad_norm,
'use_mixup': args.use_mixup,
'mixup_alpha': args.mixup_alpha,
'non_blocking_tensors': args.non_blocking,
'patience': args.patience,
'pin_memory': args.pin_memory,
'resume_from_checkpoint': args.resume_from_checkpoint,
'aggr': args.aggr,
'mods': args.mods,
'save_path': experiment_save_path ,
'seed': args.seed
}
print(CFG)
print(f'Saving results in {experiment_save_path}')
# set seed values
if args.logger:
wandb.init(project="tbd", config=CFG)
os.environ['PYTHONHASHSEED'] = str(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
main(args)