395 lines
16 KiB
Python
395 lines
16 KiB
Python
|
import glob
|
||
|
import os
|
||
|
import random
|
||
|
import time
|
||
|
from collections import OrderedDict
|
||
|
import pickle
|
||
|
|
||
|
import torch.nn.parallel
|
||
|
import torch.backends.cudnn as cudnn
|
||
|
import torch.distributed as dist
|
||
|
import torch.optim
|
||
|
import torch.multiprocessing as mp
|
||
|
import torch.utils.data
|
||
|
import torch.utils.data.distributed
|
||
|
from torch.distributed import ReduceOp
|
||
|
import torch
|
||
|
|
||
|
import utils
|
||
|
from dataloader.data_load import PlanningDataset
|
||
|
from model import diffusion_act
|
||
|
from model import temporal_act
|
||
|
#from model import temporal
|
||
|
#from model import unet_atten
|
||
|
from model.helpers import get_lr_schedule_with_warmup
|
||
|
|
||
|
from utils.training_act import Trainer
|
||
|
from utils.eval import validate_act as validate
|
||
|
from logging import log
|
||
|
from utils.args import get_args
|
||
|
import numpy as np
|
||
|
from model.helpers import Logger
|
||
|
|
||
|
def scale_norm(feature):
|
||
|
ratio = 2/(np.max(feature)-np.min(feature))
|
||
|
shift = (np.max(feature)+np.min(feature))/2
|
||
|
return (feature - shift)*ratio
|
||
|
|
||
|
|
||
|
def z_norm(feature):
|
||
|
std = np.std(feature, 0)
|
||
|
std[std==0] = 1
|
||
|
mean = np.mean(feature, 0)
|
||
|
return (feature - np.full(feature.shape, mean))/np.full(feature.shape, std)
|
||
|
|
||
|
def reduce_tensor(tensor):
|
||
|
rt = tensor.clone()
|
||
|
torch.distributed.all_reduce(rt, op=ReduceOp.SUM)
|
||
|
rt /= dist.get_world_size()
|
||
|
return rt
|
||
|
|
||
|
|
||
|
def main():
|
||
|
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
||
|
args = get_args()
|
||
|
|
||
|
os.environ['PYTHONHASHSEED'] = str(args.seed)
|
||
|
|
||
|
if args.verbose:
|
||
|
print(args)
|
||
|
if args.seed is not None:
|
||
|
random.seed(args.seed)
|
||
|
np.random.seed(args.seed)
|
||
|
torch.manual_seed(args.seed)
|
||
|
torch.cuda.manual_seed_all(args.seed)
|
||
|
|
||
|
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
|
||
|
ngpus_per_node = torch.cuda.device_count()
|
||
|
|
||
|
if args.multiprocessing_distributed:
|
||
|
args.world_size = ngpus_per_node * args.world_size
|
||
|
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
|
||
|
else:
|
||
|
main_worker(args.gpu, ngpus_per_node, args)
|
||
|
|
||
|
|
||
|
def main_worker(gpu, ngpus_per_node, args):
|
||
|
args.gpu = gpu
|
||
|
|
||
|
if args.distributed:
|
||
|
if args.multiprocessing_distributed:
|
||
|
args.rank = args.rank * ngpus_per_node + gpu
|
||
|
dist.init_process_group(
|
||
|
backend=args.dist_backend,
|
||
|
init_method=args.dist_url,
|
||
|
world_size=args.world_size,
|
||
|
rank=args.rank,
|
||
|
)
|
||
|
if args.gpu is not None:
|
||
|
torch.cuda.set_device(args.gpu)
|
||
|
args.batch_size = int(args.batch_size / ngpus_per_node)
|
||
|
args.batch_size_val = int(args.batch_size_val / ngpus_per_node)
|
||
|
args.num_thread_reader = int(args.num_thread_reader / ngpus_per_node)
|
||
|
elif args.gpu is not None:
|
||
|
torch.cuda.set_device(args.gpu)
|
||
|
|
||
|
# Data loading code
|
||
|
train_dataset = PlanningDataset(
|
||
|
args.root,
|
||
|
args=args,
|
||
|
is_val=False,
|
||
|
model=None,
|
||
|
)
|
||
|
# Test data loading code
|
||
|
test_dataset = PlanningDataset(
|
||
|
args.root,
|
||
|
args=args,
|
||
|
is_val=True,
|
||
|
model=None,
|
||
|
)
|
||
|
if args.distributed:
|
||
|
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
|
||
|
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)
|
||
|
else:
|
||
|
train_sampler = None
|
||
|
test_sampler = None
|
||
|
|
||
|
train_loader = torch.utils.data.DataLoader(
|
||
|
train_dataset,
|
||
|
batch_size=args.batch_size,
|
||
|
shuffle=(train_sampler is None),
|
||
|
drop_last=True,
|
||
|
num_workers=args.num_thread_reader,
|
||
|
pin_memory=args.pin_memory,
|
||
|
sampler=train_sampler,
|
||
|
)
|
||
|
test_loader = torch.utils.data.DataLoader(
|
||
|
test_dataset,
|
||
|
batch_size=args.batch_size_val,
|
||
|
shuffle=False,
|
||
|
drop_last=False,
|
||
|
num_workers=args.num_thread_reader,
|
||
|
sampler=test_sampler,
|
||
|
)
|
||
|
|
||
|
# read action embeddings
|
||
|
if args.dataset == 'crosstask' or args.dataset == 'NIV':
|
||
|
with open(args.act_emb_path, 'rb') as f:
|
||
|
act_emb = pickle.load(f)
|
||
|
|
||
|
ordered_act = dict(sorted(act_emb.items()))
|
||
|
feature = []
|
||
|
for i in ordered_act.keys():
|
||
|
feature.append(ordered_act[i])
|
||
|
|
||
|
feature = np.array(feature)
|
||
|
feature = scale_norm(feature)
|
||
|
act_emb = torch.tensor(feature)
|
||
|
|
||
|
if args.dataset == 'coin':
|
||
|
with open(args.act_emb_path, 'rb') as f:
|
||
|
act_emb = pickle.load(f)
|
||
|
ordered_act = dict(sorted(act_emb['steps_to_embeddings'].items()))
|
||
|
feature = []
|
||
|
for i in ordered_act.keys():
|
||
|
feature.append(ordered_act[i])
|
||
|
|
||
|
feature = np.array(feature)
|
||
|
feature = scale_norm(feature)
|
||
|
act_emb = torch.tensor(feature)
|
||
|
|
||
|
# create model
|
||
|
if args.dataset=='NIV':
|
||
|
if args.attn=='NoAttention':
|
||
|
temporal_model = temporal_act.TemporalUnetNoAttn(args.action_dim + args.observation_dim + args.class_dim, args.action_dim, dim=256, dim_mults=(1, 2, 4, 8), )
|
||
|
if args.attn=='WithAttention':
|
||
|
temporal_model = temporal_act.TemporalUnet(args.action_dim + args.observation_dim + args.class_dim, args.action_dim, dim=256, dim_mults=(1, 2, 4, 8), )
|
||
|
else:
|
||
|
if args.attn=='NoAttention':
|
||
|
temporal_model = temporal_act.TemporalUnetNoAttn(args.action_dim + args.observation_dim + args.class_dim, args.action_dim, dim=512, dim_mults=(1, 2, 4), )
|
||
|
if args.attn=='WithAttention':
|
||
|
temporal_model = temporal_act.TemporalUnet(args.action_dim + args.observation_dim + args.class_dim, args.action_dim, dim=512, dim_mults=(1, 2, 4), )
|
||
|
|
||
|
diffusion_model = diffusion_act.GaussianDiffusion(
|
||
|
temporal_model, args.horizon, args.observation_dim, args.action_dim, args.class_dim, args.n_diffusion_steps,
|
||
|
loss_type='Weighted_MSE', clip_denoised=True, )
|
||
|
|
||
|
model = Trainer(diffusion_model, train_loader, args.ema_decay, args.lr, args.gradient_accumulate_every, args.step_start_ema, args.update_ema_every, args.log_freq, act_emb)
|
||
|
|
||
|
if args.pretrain_cnn_path:
|
||
|
net_data = torch.load(args.pretrain_cnn_path)
|
||
|
model.model.load_state_dict(net_data)
|
||
|
model.ema_model.load_state_dict(net_data)
|
||
|
if args.distributed:
|
||
|
if args.gpu is not None:
|
||
|
model.model.cuda(args.gpu)
|
||
|
model.ema_model.cuda(args.gpu)
|
||
|
model.model = torch.nn.parallel.DistributedDataParallel(
|
||
|
model.model, device_ids=[args.gpu], find_unused_parameters=True)
|
||
|
model.ema_model = torch.nn.parallel.DistributedDataParallel(
|
||
|
model.ema_model, device_ids=[args.gpu], find_unused_parameters=True)
|
||
|
else:
|
||
|
model.model.cuda()
|
||
|
model.ema_model.cuda()
|
||
|
model.model = torch.nn.parallel.DistributedDataParallel(model.model, find_unused_parameters=True)
|
||
|
model.ema_model = torch.nn.parallel.DistributedDataParallel(model.ema_model,
|
||
|
find_unused_parameters=True)
|
||
|
|
||
|
elif args.gpu is not None:
|
||
|
model.model = model.model.cuda(args.gpu)
|
||
|
model.ema_model = model.ema_model.cuda(args.gpu)
|
||
|
else:
|
||
|
model.model = torch.nn.DataParallel(model.model).cuda()
|
||
|
model.ema_model = torch.nn.DataParallel(model.ema_model).cuda()
|
||
|
|
||
|
scheduler = get_lr_schedule_with_warmup(args, model.optimizer, int(args.n_train_steps * args.epochs))
|
||
|
|
||
|
checkpoint_dir = os.path.join(os.path.dirname(__file__), 'checkpoint', args.checkpoint_dir)
|
||
|
if args.checkpoint_dir != '' and not (os.path.isdir(checkpoint_dir)) and args.rank == 0:
|
||
|
os.mkdir(checkpoint_dir)
|
||
|
|
||
|
if args.resume:
|
||
|
checkpoint_path = get_last_checkpoint(checkpoint_dir)
|
||
|
if checkpoint_path:
|
||
|
log("=> loading checkpoint '{}'".format(checkpoint_path), args)
|
||
|
checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(args.rank))
|
||
|
args.start_epoch = checkpoint["epoch"]
|
||
|
model.model.load_state_dict(checkpoint["model"])
|
||
|
model.ema_model.load_state_dict(checkpoint["ema"])
|
||
|
model.optimizer.load_state_dict(checkpoint["optimizer"])
|
||
|
model.step = checkpoint["step"]
|
||
|
scheduler.load_state_dict(checkpoint["scheduler"])
|
||
|
tb_logdir = checkpoint["tb_logdir"]
|
||
|
if args.rank == 0:
|
||
|
# creat logger
|
||
|
tb_logger = Logger(tb_logdir)
|
||
|
log("=> loaded checkpoint '{}' (epoch {}){}".format(checkpoint_path, checkpoint["epoch"], args.gpu), args)
|
||
|
else:
|
||
|
time_pre = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||
|
logname = args.log_root + '_' + time_pre + '_' + args.dataset
|
||
|
tb_logdir = os.path.join(args.log_root, logname)
|
||
|
if args.rank == 0:
|
||
|
# creat logger
|
||
|
if not (os.path.exists(tb_logdir)):
|
||
|
os.makedirs(tb_logdir)
|
||
|
tb_logger = Logger(tb_logdir)
|
||
|
tb_logger.log_info(args)
|
||
|
log("=> no checkpoint found at '{}'".format(args.resume), args)
|
||
|
|
||
|
if args.cudnn_benchmark:
|
||
|
cudnn.benchmark = True
|
||
|
total_batch_size = args.world_size * args.batch_size
|
||
|
log(
|
||
|
"Starting training loop for rank: {}, total batch size: {}".format(
|
||
|
args.rank, total_batch_size
|
||
|
), args
|
||
|
)
|
||
|
|
||
|
max_eva = 0
|
||
|
max_acc = 0
|
||
|
old_max_epoch = 0
|
||
|
save_max = os.path.join(os.path.dirname(__file__), 'save_max')
|
||
|
|
||
|
for epoch in range(args.start_epoch, args.epochs):
|
||
|
|
||
|
if args.distributed:
|
||
|
train_sampler.set_epoch(epoch)
|
||
|
|
||
|
# train for one epoch
|
||
|
if (epoch + 1) % 10 == 0: # calculate on training set
|
||
|
losses, acc_top1, acc_top5, trajectory_success_rate_meter, MIoU1_meter, MIoU2_meter, \
|
||
|
acc_a0, acc_aT = model.train(args.n_train_steps, True, args, scheduler)
|
||
|
losses_reduced = reduce_tensor(losses.cuda()).item()
|
||
|
acc_top1_reduced = reduce_tensor(acc_top1.cuda()).item()
|
||
|
acc_top5_reduced = reduce_tensor(acc_top5.cuda()).item()
|
||
|
trajectory_success_rate_meter_reduced = reduce_tensor(trajectory_success_rate_meter.cuda()).item()
|
||
|
MIoU1_meter_reduced = reduce_tensor(MIoU1_meter.cuda()).item()
|
||
|
MIoU2_meter_reduced = reduce_tensor(MIoU2_meter.cuda()).item()
|
||
|
acc_a0_reduced = reduce_tensor(acc_a0.cuda()).item()
|
||
|
acc_aT_reduced = reduce_tensor(acc_aT.cuda()).item()
|
||
|
|
||
|
if args.rank == 0:
|
||
|
logs = OrderedDict()
|
||
|
logs['Train/EpochLoss'] = losses_reduced
|
||
|
logs['Train/EpochAcc@1'] = acc_top1_reduced
|
||
|
logs['Train/EpochAcc@5'] = acc_top5_reduced
|
||
|
logs['Train/Traj_Success_Rate'] = trajectory_success_rate_meter_reduced
|
||
|
logs['Train/MIoU1'] = MIoU1_meter_reduced
|
||
|
logs['Train/MIoU2'] = MIoU2_meter_reduced
|
||
|
logs['Train/acc_a0'] = acc_a0_reduced
|
||
|
logs['Train/acc_aT'] = acc_aT_reduced
|
||
|
for key, value in logs.items():
|
||
|
tb_logger.log_scalar(value, key, epoch + 1)
|
||
|
|
||
|
tb_logger.flush()
|
||
|
else:
|
||
|
losses = model.train(args.n_train_steps, False, args, scheduler).cuda()
|
||
|
losses_reduced = reduce_tensor(losses).item()
|
||
|
if args.rank == 0:
|
||
|
print('lrs:')
|
||
|
for p in model.optimizer.param_groups:
|
||
|
print(p['lr'])
|
||
|
print('---------------------------------')
|
||
|
|
||
|
logs = OrderedDict()
|
||
|
logs['Train/EpochLoss'] = losses_reduced
|
||
|
for key, value in logs.items():
|
||
|
tb_logger.log_scalar(value, key, epoch + 1)
|
||
|
|
||
|
tb_logger.flush()
|
||
|
|
||
|
if ((epoch + 1) % 5 == 0) and args.evaluate: # or epoch > 18
|
||
|
losses, acc_top1, acc_top5, \
|
||
|
trajectory_success_rate_meter, MIoU1_meter, MIoU2_meter, \
|
||
|
acc_a0, acc_aT = validate(test_loader, model.ema_model, args, act_emb)
|
||
|
|
||
|
losses_reduced = reduce_tensor(losses.cuda()).item()
|
||
|
acc_top1_reduced = reduce_tensor(acc_top1.cuda()).item()
|
||
|
acc_top5_reduced = reduce_tensor(acc_top5.cuda()).item()
|
||
|
trajectory_success_rate_meter_reduced = reduce_tensor(trajectory_success_rate_meter.cuda()).item()
|
||
|
MIoU1_meter_reduced = reduce_tensor(MIoU1_meter.cuda()).item()
|
||
|
MIoU2_meter_reduced = reduce_tensor(MIoU2_meter.cuda()).item()
|
||
|
acc_a0_reduced = reduce_tensor(acc_a0.cuda()).item()
|
||
|
acc_aT_reduced = reduce_tensor(acc_aT.cuda()).item()
|
||
|
|
||
|
if args.rank == 0:
|
||
|
logs = OrderedDict()
|
||
|
logs['Val/EpochLoss'] = losses_reduced
|
||
|
logs['Val/EpochAcc@1'] = acc_top1_reduced
|
||
|
logs['Val/EpochAcc@5'] = acc_top5_reduced
|
||
|
logs['Val/Traj_Success_Rate'] = trajectory_success_rate_meter_reduced
|
||
|
logs['Val/MIoU1'] = MIoU1_meter_reduced
|
||
|
logs['Val/MIoU2'] = MIoU2_meter_reduced
|
||
|
logs['Val/acc_a0'] = acc_a0_reduced
|
||
|
logs['Val/acc_aT'] = acc_aT_reduced
|
||
|
logs['lr'] = model.optimizer.param_groups[0]['lr']
|
||
|
for key, value in logs.items():
|
||
|
tb_logger.log_scalar(value, key, epoch + 1)
|
||
|
|
||
|
tb_logger.flush()
|
||
|
print(trajectory_success_rate_meter_reduced, max_eva)
|
||
|
if trajectory_success_rate_meter_reduced >= max_eva:
|
||
|
if not (trajectory_success_rate_meter_reduced == max_eva and acc_top1_reduced < max_acc):
|
||
|
save_checkpoint2(
|
||
|
{
|
||
|
"epoch": epoch + 1,
|
||
|
"model": model.model.state_dict(),
|
||
|
"ema": model.ema_model.state_dict(),
|
||
|
"optimizer": model.optimizer.state_dict(),
|
||
|
"step": model.step,
|
||
|
"tb_logdir": tb_logdir,
|
||
|
"scheduler": scheduler.state_dict(),
|
||
|
}, save_max, old_max_epoch, epoch + 1, args.rank
|
||
|
)
|
||
|
max_eva = trajectory_success_rate_meter_reduced
|
||
|
max_acc = acc_top1_reduced
|
||
|
old_max_epoch = epoch + 1
|
||
|
|
||
|
if (epoch + 1) % args.save_freq == 0:
|
||
|
if args.rank == 0:
|
||
|
save_checkpoint(
|
||
|
{
|
||
|
"epoch": epoch + 1,
|
||
|
"model": model.model.state_dict(),
|
||
|
"ema": model.ema_model.state_dict(),
|
||
|
"optimizer": model.optimizer.state_dict(),
|
||
|
"step": model.step,
|
||
|
"tb_logdir": tb_logdir,
|
||
|
"scheduler": scheduler.state_dict(),
|
||
|
}, checkpoint_dir, epoch + 1
|
||
|
)
|
||
|
|
||
|
|
||
|
def log(output, args):
|
||
|
with open(os.path.join(os.path.dirname(__file__), 'log', args.checkpoint_dir + '.txt'), "a") as f:
|
||
|
f.write(output + '\n')
|
||
|
|
||
|
|
||
|
def save_checkpoint(state, checkpoint_dir, epoch, n_ckpt=1):
|
||
|
torch.save(state, os.path.join(checkpoint_dir, "act_epoch{:0>4d}.pth.tar".format(epoch)))
|
||
|
if epoch - n_ckpt >= 0:
|
||
|
oldest_ckpt = os.path.join(checkpoint_dir, "act_epoch{:0>4d}.pth.tar".format(epoch - n_ckpt))
|
||
|
if os.path.isfile(oldest_ckpt):
|
||
|
os.remove(oldest_ckpt)
|
||
|
|
||
|
|
||
|
def save_checkpoint2(state, checkpoint_dir, old_epoch, epoch, rank):
|
||
|
torch.save(state, os.path.join(checkpoint_dir, "act_epoch{:0>4d}_{}.pth.tar".format(epoch, rank)))
|
||
|
if old_epoch > 0:
|
||
|
oldest_ckpt = os.path.join(checkpoint_dir, "act_epoch{:0>4d}_{}.pth.tar".format(old_epoch, rank))
|
||
|
if os.path.isfile(oldest_ckpt):
|
||
|
os.remove(oldest_ckpt)
|
||
|
|
||
|
|
||
|
def get_last_checkpoint(checkpoint_dir):
|
||
|
all_ckpt = glob.glob(os.path.join(checkpoint_dir, 'epoch*.pth.tar'))
|
||
|
if all_ckpt:
|
||
|
all_ckpt = sorted(all_ckpt)
|
||
|
return all_ckpt[-1]
|
||
|
else:
|
||
|
return ''
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|