474 lines
No EOL
22 KiB
Python
474 lines
No EOL
22 KiB
Python
import torch
|
|
import os
|
|
import argparse
|
|
import numpy as np
|
|
import random
|
|
import datetime
|
|
import wandb
|
|
from tqdm import tqdm
|
|
from torch.utils.data import DataLoader
|
|
import torch.nn as nn
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
|
|
from tbd_dataloader import TBDDataset, collate_fn
|
|
from models.common_mind import CommonMindToMnet
|
|
from models.sl import SLToMnet
|
|
from models.implicit import ImplicitToMnet
|
|
from utils.helpers import count_parameters, compute_f1_scores
|
|
|
|
|
|
def main(args):
|
|
|
|
train_dataset = TBDDataset(
|
|
path=args.data_path,
|
|
mode="train",
|
|
use_preprocessed_img=True
|
|
)
|
|
train_dataloader = DataLoader(
|
|
train_dataset,
|
|
batch_size=args.batch_size,
|
|
shuffle=True,
|
|
num_workers=args.num_workers,
|
|
pin_memory=args.pin_memory,
|
|
collate_fn=collate_fn
|
|
)
|
|
val_dataset = TBDDataset(
|
|
path=args.data_path,
|
|
mode="val",
|
|
use_preprocessed_img=True
|
|
)
|
|
val_dataloader = DataLoader(
|
|
val_dataset,
|
|
batch_size=args.batch_size,
|
|
shuffle=False,
|
|
num_workers=args.num_workers,
|
|
pin_memory=args.pin_memory,
|
|
collate_fn=collate_fn
|
|
)
|
|
|
|
train_data = [data[1] for data in train_dataset.data]
|
|
val_data = [data[1] for data in val_dataset.data]
|
|
if args.logger:
|
|
wandb.config.update({"train_data": train_data})
|
|
wandb.config.update({"val_data": val_data})
|
|
|
|
device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
|
|
|
|
# model
|
|
if args.model_type == 'tom_cm':
|
|
model = CommonMindToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
|
|
elif args.model_type == 'tom_sl':
|
|
model = SLToMnet(args.hidden_dim, device, args.tom_weight, args.use_resnet, args.dropout, args.mods).to(device)
|
|
elif args.model_type == 'tom_impl':
|
|
model = ImplicitToMnet(args.hidden_dim, device, args.use_resnet, args.dropout, args.aggr, args.mods).to(device)
|
|
else: raise NotImplementedError
|
|
if args.resume_from_checkpoint is not None:
|
|
model.load_state_dict(torch.load(args.resume_from_checkpoint, map_location=device))
|
|
# optimizer
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
|
# scheduler
|
|
if args.scheduler == None:
|
|
scheduler = None
|
|
else:
|
|
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=3e-5)
|
|
# loss function
|
|
if args.model_type == 'tom_sl':
|
|
ce_loss_m1 = nn.NLLLoss()
|
|
ce_loss_m2 = nn.NLLLoss()
|
|
ce_loss_m12 = nn.NLLLoss()
|
|
ce_loss_m21 = nn.NLLLoss()
|
|
ce_loss_mc = nn.NLLLoss()
|
|
else:
|
|
ce_loss_m1 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
|
|
ce_loss_m2 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
|
|
ce_loss_m12 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
|
|
ce_loss_m21 = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
|
|
ce_loss_mc = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
|
|
|
|
stats = {
|
|
'train': {'loss_m1': [], 'loss_m2': [], 'loss_m12': [], 'loss_m21': [], 'loss_mc': [], 'm1_f1': [], 'm2_f1': [], 'm12_f1': [], 'm21_f1': [], 'mc_f1': []},
|
|
'val': {'loss_m1': [], 'loss_m2': [], 'loss_m12': [], 'loss_m21': [], 'loss_mc': [], 'm1_f1': [], 'm2_f1': [], 'm12_f1': [], 'm21_f1': [], 'mc_f1': []}
|
|
}
|
|
max_val_f1 = 0
|
|
max_val_classification_epoch = None
|
|
counter = 0
|
|
|
|
print(f'Number of parameters: {count_parameters(model)}')
|
|
|
|
for i in range(args.num_epoch):
|
|
# training
|
|
print('Training for epoch {}/{}...'.format(i+1, args.num_epoch))
|
|
epoch_train_loss_m1 = 0.0
|
|
epoch_train_loss_m2 = 0.0
|
|
epoch_train_loss_m12 = 0.0
|
|
epoch_train_loss_m21 = 0.0
|
|
epoch_train_loss_mc = 0.0
|
|
m1_train_batch_pred_list = []
|
|
m2_train_batch_pred_list = []
|
|
m12_train_batch_pred_list = []
|
|
m21_train_batch_pred_list = []
|
|
mc_train_batch_pred_list = []
|
|
m1_train_batch_label_list = []
|
|
m2_train_batch_label_list = []
|
|
m12_train_batch_label_list = []
|
|
m21_train_batch_label_list = []
|
|
mc_train_batch_label_list = []
|
|
model.train()
|
|
for j, batch in tqdm(enumerate(train_dataloader)):
|
|
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep = batch
|
|
if img_3rd_pov is not None: img_3rd_pov = img_3rd_pov.to(device, non_blocking=args.non_blocking)
|
|
if img_tracker is not None: img_tracker = img_tracker.to(device, non_blocking=args.non_blocking)
|
|
if img_battery is not None: img_battery = img_battery.to(device, non_blocking=args.non_blocking)
|
|
if pose1 is not None: pose1 = pose1.to(device, non_blocking=args.non_blocking)
|
|
if pose2 is not None: pose2 = pose2.to(device, non_blocking=args.non_blocking)
|
|
if bbox is not None: bbox = bbox.to(device, non_blocking=args.non_blocking)
|
|
if gaze is not None: gaze = gaze.to(device, non_blocking=args.non_blocking)
|
|
m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, _ = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
|
|
m1_pred = m1_pred.reshape(-1, 4)
|
|
m2_pred = m2_pred.reshape(-1, 4)
|
|
m12_pred = m12_pred.reshape(-1, 4)
|
|
m21_pred = m21_pred.reshape(-1, 4)
|
|
mc_pred = mc_pred.reshape(-1, 4)
|
|
m1_label = labels[:, 0].reshape(-1).to(device)
|
|
m2_label = labels[:, 1].reshape(-1).to(device)
|
|
m12_label = labels[:, 2].reshape(-1).to(device)
|
|
m21_label = labels[:, 3].reshape(-1).to(device)
|
|
mc_label = labels[:, 4].reshape(-1).to(device)
|
|
|
|
loss_m1 = ce_loss_m1(m1_pred, m1_label)
|
|
loss_m2 = ce_loss_m2(m2_pred, m2_label)
|
|
loss_m12 = ce_loss_m12(m12_pred, m12_label)
|
|
loss_m21 = ce_loss_m21(m21_pred, m21_label)
|
|
loss_mc = ce_loss_mc(mc_pred, mc_label)
|
|
loss = loss_m1 + loss_m2 + loss_m12 + loss_m21 + loss_mc
|
|
|
|
epoch_train_loss_m1 += loss_m1.data.item()
|
|
epoch_train_loss_m2 += loss_m2.data.item()
|
|
epoch_train_loss_m12 += loss_m12.data.item()
|
|
epoch_train_loss_m21 += loss_m21.data.item()
|
|
epoch_train_loss_mc += loss_mc.data.item()
|
|
|
|
m1_train_batch_pred_list.append(m1_pred)
|
|
m2_train_batch_pred_list.append(m2_pred)
|
|
m12_train_batch_pred_list.append(m12_pred)
|
|
m21_train_batch_pred_list.append(m21_pred)
|
|
mc_train_batch_pred_list.append(mc_pred)
|
|
m1_train_batch_label_list.append(m1_label)
|
|
m2_train_batch_label_list.append(m2_label)
|
|
m12_train_batch_label_list.append(m12_label)
|
|
m21_train_batch_label_list.append(m21_label)
|
|
mc_train_batch_label_list.append(mc_label)
|
|
|
|
optimizer.zero_grad()
|
|
if args.clip_grad_norm is not None:
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
if args.logger: wandb.log({
|
|
'batch_train_loss': loss.data.item(),
|
|
'lr': optimizer.param_groups[-1]['lr']
|
|
})
|
|
|
|
print("Epoch {}/{} batch {}/{} training done with loss={}".format(
|
|
i+1, args.num_epoch, j+1, len(train_dataloader), loss.data.item())
|
|
)
|
|
|
|
if scheduler: scheduler.step()
|
|
|
|
train_m1_f1_score, train_m2_f1_score, train_m12_f1_score, train_m21_f1_score, train_mc_f1_score = compute_f1_scores(
|
|
torch.cat(m1_train_batch_pred_list),
|
|
torch.cat(m1_train_batch_label_list),
|
|
torch.cat(m2_train_batch_pred_list),
|
|
torch.cat(m2_train_batch_label_list),
|
|
torch.cat(m12_train_batch_pred_list),
|
|
torch.cat(m12_train_batch_label_list),
|
|
torch.cat(m21_train_batch_pred_list),
|
|
torch.cat(m21_train_batch_label_list),
|
|
torch.cat(mc_train_batch_pred_list),
|
|
torch.cat(mc_train_batch_label_list)
|
|
)
|
|
|
|
print("Epoch {}/{} OVERALL train m1_loss={}, m2_loss={}, m12_loss={}, m21_loss={}, mc_loss={}, m1_f1={}, m2_f1={}, m12_f1={}, m21_f1={}.\n".format(
|
|
i+1,
|
|
args.num_epoch,
|
|
epoch_train_loss_m1/len(train_dataloader),
|
|
epoch_train_loss_m2/len(train_dataloader),
|
|
epoch_train_loss_m12/len(train_dataloader),
|
|
epoch_train_loss_m21/len(train_dataloader),
|
|
epoch_train_loss_mc/len(train_dataloader),
|
|
train_m1_f1_score, train_m2_f1_score, train_m12_f1_score, train_m21_f1_score, train_mc_f1_score
|
|
)
|
|
)
|
|
stats['train']['loss_m1'].append(epoch_train_loss_m1/len(train_dataloader))
|
|
stats['train']['loss_m2'].append(epoch_train_loss_m2/len(train_dataloader))
|
|
stats['train']['loss_m12'].append(epoch_train_loss_m12/len(train_dataloader))
|
|
stats['train']['loss_m21'].append(epoch_train_loss_m21/len(train_dataloader))
|
|
stats['train']['loss_mc'].append(epoch_train_loss_mc/len(train_dataloader))
|
|
stats['train']['m1_f1'].append(train_m1_f1_score)
|
|
stats['train']['m2_f1'].append(train_m2_f1_score)
|
|
stats['train']['m12_f1'].append(train_m12_f1_score)
|
|
stats['train']['m21_f1'].append(train_m21_f1_score)
|
|
stats['train']['mc_f1'].append(train_mc_f1_score)
|
|
|
|
if args.logger: wandb.log(
|
|
{
|
|
'train_m1_loss': epoch_train_loss_m1/len(train_dataloader),
|
|
'train_m2_loss': epoch_train_loss_m2/len(train_dataloader),
|
|
'train_m12_loss': epoch_train_loss_m12/len(train_dataloader),
|
|
'train_m21_loss': epoch_train_loss_m21/len(train_dataloader),
|
|
'train_mc_loss': epoch_train_loss_mc/len(train_dataloader),
|
|
'train_loss': epoch_train_loss_m1/len(train_dataloader) + \
|
|
epoch_train_loss_m2/len(train_dataloader) + \
|
|
epoch_train_loss_m12/len(train_dataloader) + \
|
|
epoch_train_loss_m21/len(train_dataloader) + \
|
|
epoch_train_loss_mc/len(train_dataloader),
|
|
'train_m1_f1_score': train_m1_f1_score,
|
|
'train_m2_f1_score': train_m2_f1_score,
|
|
'train_m12_f1_score': train_m12_f1_score,
|
|
'train_m21_f1_score': train_m21_f1_score,
|
|
'train_mc_f1_score': train_mc_f1_score
|
|
}
|
|
)
|
|
|
|
# validation
|
|
print('Validation for epoch {}/{}...'.format(i+1, args.num_epoch))
|
|
epoch_val_loss_m1 = 0.0
|
|
epoch_val_loss_m2 = 0.0
|
|
epoch_val_loss_m12 = 0.0
|
|
epoch_val_loss_m21 = 0.0
|
|
epoch_val_loss_mc = 0.0
|
|
m1_val_batch_pred_list = []
|
|
m2_val_batch_pred_list = []
|
|
m12_val_batch_pred_list = []
|
|
m21_val_batch_pred_list = []
|
|
mc_val_batch_pred_list = []
|
|
m1_val_batch_label_list = []
|
|
m2_val_batch_label_list = []
|
|
m12_val_batch_label_list = []
|
|
m21_val_batch_label_list = []
|
|
mc_val_batch_label_list = []
|
|
model.eval()
|
|
with torch.no_grad():
|
|
for j, batch in tqdm(enumerate(val_dataloader)):
|
|
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep = batch
|
|
if img_3rd_pov is not None: img_3rd_pov = img_3rd_pov.to(device, non_blocking=args.non_blocking)
|
|
if img_tracker is not None: img_tracker = img_tracker.to(device, non_blocking=args.non_blocking)
|
|
if img_battery is not None: img_battery = img_battery.to(device, non_blocking=args.non_blocking)
|
|
if pose1 is not None: pose1 = pose1.to(device, non_blocking=args.non_blocking)
|
|
if pose2 is not None: pose2 = pose2.to(device, non_blocking=args.non_blocking)
|
|
if bbox is not None: bbox = bbox.to(device, non_blocking=args.non_blocking)
|
|
if gaze is not None: gaze = gaze.to(device, non_blocking=args.non_blocking)
|
|
m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, _ = model(img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze)
|
|
m1_pred = m1_pred.reshape(-1, 4)
|
|
m2_pred = m2_pred.reshape(-1, 4)
|
|
m12_pred = m12_pred.reshape(-1, 4)
|
|
m21_pred = m21_pred.reshape(-1, 4)
|
|
mc_pred = mc_pred.reshape(-1, 4)
|
|
m1_label = labels[:, 0].reshape(-1).to(device)
|
|
m2_label = labels[:, 1].reshape(-1).to(device)
|
|
m12_label = labels[:, 2].reshape(-1).to(device)
|
|
m21_label = labels[:, 3].reshape(-1).to(device)
|
|
mc_label = labels[:, 4].reshape(-1).to(device)
|
|
|
|
loss_m1 = ce_loss_m1(m1_pred, m1_label)
|
|
loss_m2 = ce_loss_m2(m2_pred, m2_label)
|
|
loss_m12 = ce_loss_m12(m12_pred, m12_label)
|
|
loss_m21 = ce_loss_m21(m21_pred, m21_label)
|
|
loss_mc = ce_loss_mc(mc_pred, mc_label)
|
|
loss = loss_m1 + loss_m2 + loss_m12 + loss_m21 + loss_mc
|
|
|
|
epoch_val_loss_m1 += loss_m1.data.item()
|
|
epoch_val_loss_m2 += loss_m2.data.item()
|
|
epoch_val_loss_m12 += loss_m12.data.item()
|
|
epoch_val_loss_m21 += loss_m21.data.item()
|
|
epoch_val_loss_mc += loss_mc.data.item()
|
|
|
|
m1_val_batch_pred_list.append(m1_pred)
|
|
m2_val_batch_pred_list.append(m2_pred)
|
|
m12_val_batch_pred_list.append(m12_pred)
|
|
m21_val_batch_pred_list.append(m21_pred)
|
|
mc_val_batch_pred_list.append(mc_pred)
|
|
m1_val_batch_label_list.append(m1_label)
|
|
m2_val_batch_label_list.append(m2_label)
|
|
m12_val_batch_label_list.append(m12_label)
|
|
m21_val_batch_label_list.append(m21_label)
|
|
mc_val_batch_label_list.append(mc_label)
|
|
|
|
if args.logger: wandb.log({'batch_val_loss': loss.data.item()})
|
|
print("Epoch {}/{} batch {}/{} validation done with loss={}".format(
|
|
i+1, args.num_epoch, j+1, len(val_dataloader), loss.data.item())
|
|
)
|
|
|
|
val_m1_f1_score, val_m2_f1_score, val_m12_f1_score, val_m21_f1_score, val_mc_f1_score = compute_f1_scores(
|
|
torch.cat(m1_val_batch_pred_list),
|
|
torch.cat(m1_val_batch_label_list),
|
|
torch.cat(m2_val_batch_pred_list),
|
|
torch.cat(m2_val_batch_label_list),
|
|
torch.cat(m12_val_batch_pred_list),
|
|
torch.cat(m12_val_batch_label_list),
|
|
torch.cat(m21_val_batch_pred_list),
|
|
torch.cat(m21_val_batch_label_list),
|
|
torch.cat(mc_val_batch_pred_list),
|
|
torch.cat(mc_val_batch_label_list)
|
|
)
|
|
|
|
print("Epoch {}/{} OVERALL validation m1_loss={}, m2_loss={}, m12_loss={}, m21_loss={}, mc_loss={}, m1_f1={}, m2_f1={}, m12_f1={}, m21_f1={}, mc_f1={}.\n".format(
|
|
i+1,
|
|
args.num_epoch,
|
|
epoch_val_loss_m1/len(val_dataloader),
|
|
epoch_val_loss_m2/len(val_dataloader),
|
|
epoch_val_loss_m12/len(val_dataloader),
|
|
epoch_val_loss_m21/len(val_dataloader),
|
|
epoch_val_loss_mc/len(val_dataloader),
|
|
val_m1_f1_score, val_m2_f1_score, val_m12_f1_score, val_m21_f1_score, val_mc_f1_score
|
|
)
|
|
)
|
|
|
|
stats['val']['loss_m1'].append(epoch_val_loss_m1/len(val_dataloader))
|
|
stats['val']['loss_m2'].append(epoch_val_loss_m2/len(val_dataloader))
|
|
stats['val']['loss_m12'].append(epoch_val_loss_m12/len(val_dataloader))
|
|
stats['val']['loss_m21'].append(epoch_val_loss_m21/len(val_dataloader))
|
|
stats['val']['loss_mc'].append(epoch_val_loss_mc/len(val_dataloader))
|
|
stats['val']['m1_f1'].append(val_m1_f1_score)
|
|
stats['val']['m2_f1'].append(val_m2_f1_score)
|
|
stats['val']['m12_f1'].append(val_m12_f1_score)
|
|
stats['val']['m21_f1'].append(val_m21_f1_score)
|
|
stats['val']['mc_f1'].append(val_mc_f1_score)
|
|
|
|
if args.logger: wandb.log(
|
|
{
|
|
'val_m1_loss': epoch_val_loss_m1/len(val_dataloader),
|
|
'val_m2_loss': epoch_val_loss_m2/len(val_dataloader),
|
|
'val_m12_loss': epoch_val_loss_m12/len(val_dataloader),
|
|
'val_m21_loss': epoch_val_loss_m21/len(val_dataloader),
|
|
'val_mc_loss': epoch_val_loss_mc/len(val_dataloader),
|
|
'val_loss': epoch_val_loss_m1/len(val_dataloader) + \
|
|
epoch_val_loss_m2/len(val_dataloader) + \
|
|
epoch_val_loss_m12/len(val_dataloader) + \
|
|
epoch_val_loss_m21/len(val_dataloader) + \
|
|
epoch_val_loss_mc/len(val_dataloader),
|
|
'val_m1_f1_score': val_m1_f1_score,
|
|
'val_m2_f1_score': val_m2_f1_score,
|
|
'val_m12_f1_score': val_m12_f1_score,
|
|
'val_m21_f1_score': val_m21_f1_score,
|
|
'val_mc_f1_score': val_mc_f1_score
|
|
}
|
|
)
|
|
|
|
# check for best stat/model using validation accuracy
|
|
if stats['val']['m1_f1'][-1] + stats['val']['m2_f1'][-1] + stats['val']['m12_f1'][-1] + stats['val']['m21_f1'][-1] + stats['val']['mc_f1'][-1] >= max_val_f1:
|
|
max_val_f1 = stats['val']['m1_f1'][-1] + stats['val']['m2_f1'][-1] + stats['val']['m12_f1'][-1] + stats['val']['m21_f1'][-1] + stats['val']['mc_f1'][-1]
|
|
max_val_classification_epoch = i+1
|
|
torch.save(model.state_dict(), os.path.join(experiment_save_path, 'model'))
|
|
counter = 0
|
|
else:
|
|
counter += 1
|
|
print(f'EarlyStopping counter: {counter} out of {args.patience}.')
|
|
if counter >= args.patience:
|
|
break
|
|
|
|
with open(os.path.join(experiment_save_path, 'log.txt'), 'w') as f:
|
|
f.write('{}\n'.format(CFG))
|
|
f.write('{}\n'.format(train_data))
|
|
f.write('{}\n'.format(val_data))
|
|
f.write('{}\n'.format(stats))
|
|
f.write('Max val classification acc: epoch {}, {}\n'.format(max_val_classification_epoch, max_val_f1))
|
|
f.close()
|
|
|
|
print(f'Results saved in {experiment_save_path}')
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
# Define the command-line arguments
|
|
parser.add_argument('--gpu_id', type=int)
|
|
parser.add_argument('--seed', type=int, default=1)
|
|
parser.add_argument('--logger', action='store_true')
|
|
parser.add_argument('--presaved', type=int, default=128)
|
|
parser.add_argument('--clip_grad_norm', type=float, default=0.5)
|
|
parser.add_argument('--use_mixup', action='store_true')
|
|
parser.add_argument('--mixup_alpha', type=float, default=0.3, required=False)
|
|
parser.add_argument('--non_blocking', action='store_true')
|
|
parser.add_argument('--patience', type=int, default=99)
|
|
parser.add_argument('--batch_size', type=int, default=4)
|
|
parser.add_argument('--num_workers', type=int, default=8)
|
|
parser.add_argument('--pin_memory', action='store_true')
|
|
parser.add_argument('--num_epoch', type=int, default=300)
|
|
parser.add_argument('--lr', type=float, default=4e-4)
|
|
parser.add_argument('--scheduler', type=str, default=None)
|
|
parser.add_argument('--dropout', type=float, default=0.1)
|
|
parser.add_argument('--weight_decay', type=float, default=0.005)
|
|
parser.add_argument('--label_smoothing', type=float, default=0.1)
|
|
parser.add_argument('--model_type', type=str)
|
|
parser.add_argument('--aggr', type=str, default='concat', required=False)
|
|
parser.add_argument('--use_resnet', action='store_true')
|
|
parser.add_argument('--hidden_dim', type=int, default=64)
|
|
parser.add_argument('--tom_weight', type=float, default=2.0, required=False)
|
|
parser.add_argument('--mods', nargs='+', type=str, default=['rgb_3', 'rgb_1', 'pose', 'gaze', 'bbox'])
|
|
parser.add_argument('--data_path', type=str, default='/scratch/bortoletto/data/tbd')
|
|
parser.add_argument('--save_path', type=str, default='experiments/')
|
|
parser.add_argument('--resume_from_checkpoint', type=str, default=None)
|
|
|
|
# Parse the command-line arguments
|
|
args = parser.parse_args()
|
|
|
|
if args.use_mixup and not args.mixup_alpha:
|
|
parser.error("--use_mixup requires --mixup_alpha")
|
|
if args.model_type == 'tom_cm' or args.model_type == 'tom_impl':
|
|
if not args.aggr:
|
|
parser.error("The choosen --model_type requires --aggr")
|
|
if args.model_type == 'tom_sl' and not args.tom_weight:
|
|
parser.error("The choosen --model_type requires --tom_weight")
|
|
|
|
# get experiment ID
|
|
experiment_id = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_train'
|
|
if not os.path.exists(args.save_path):
|
|
os.makedirs(args.save_path, exist_ok=True)
|
|
experiment_save_path = os.path.join(args.save_path, experiment_id)
|
|
os.makedirs(experiment_save_path, exist_ok=True)
|
|
|
|
CFG = {
|
|
'use_ocr_custom_loss': 0,
|
|
'presaved': args.presaved,
|
|
'batch_size': args.batch_size,
|
|
'num_epoch': args.num_epoch,
|
|
'lr': args.lr,
|
|
'scheduler': args.scheduler,
|
|
'weight_decay': args.weight_decay,
|
|
'model_type': args.model_type,
|
|
'use_resnet': args.use_resnet,
|
|
'hidden_dim': args.hidden_dim,
|
|
'tom_weight': args.tom_weight,
|
|
'dropout': args.dropout,
|
|
'label_smoothing': args.label_smoothing,
|
|
'clip_grad_norm': args.clip_grad_norm,
|
|
'use_mixup': args.use_mixup,
|
|
'mixup_alpha': args.mixup_alpha,
|
|
'non_blocking_tensors': args.non_blocking,
|
|
'patience': args.patience,
|
|
'pin_memory': args.pin_memory,
|
|
'resume_from_checkpoint': args.resume_from_checkpoint,
|
|
'aggr': args.aggr,
|
|
'mods': args.mods,
|
|
'save_path': experiment_save_path ,
|
|
'seed': args.seed
|
|
}
|
|
|
|
print(CFG)
|
|
print(f'Saving results in {experiment_save_path}')
|
|
|
|
# set seed values
|
|
if args.logger:
|
|
wandb.init(project="tbd", config=CFG)
|
|
os.environ['PYTHONHASHSEED'] = str(args.seed)
|
|
torch.manual_seed(args.seed)
|
|
np.random.seed(args.seed)
|
|
random.seed(args.seed)
|
|
|
|
main(args) |