import pickle import numpy as np from torch.utils.data import Dataset, DataLoader import torch import torch.nn as nn import torch.optim as optim import shutil import matplotlib.pyplot as plt import argparse from networks import ActionDemo2Predicate print('torch version: ',torch.__version__) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(DEVICE) torch.manual_seed(256) class train_dataset(Dataset): def __init__(self, x, label): self.x = x self.labels = label def __getitem__(self, index): x = self.x[index] label = self.labels[index] return x, label #, img_idx def __len__(self): return len(self.labels) class test_dataset(Dataset): def __init__(self, x, label): self.x = x self.labels = label def __getitem__(self, index): x = self.x[index] label = self.labels[index] return x, label #, img_idx def __len__(self): return len(self.labels) def train_model(model, train_dataloader, criterion, optimizer, num_epochs, DEVICE, path, resume): running_loss = 0 train_losses = 10 is_best_acc = False is_best_train_loss = False best_train_acc = 0 best_train_loss = 10 start_epoch = 0 accuracy = 0 model.to(DEVICE) model.train() for epoch in range(start_epoch, num_epochs): epoch_losses = [] train_acc = [] epoch_loss = 0 for iter, (x, labels) in enumerate(train_dataloader): x = torch.tensor(x).to(DEVICE) labels = torch.tensor(labels).to(DEVICE) optimizer.zero_grad() logps = model(x) loss = criterion(logps, labels) loss.backward() optimizer.step() epoch_loss += loss.detach().item() argmax_Y = torch.max(logps, 1)[1].view(-1, 1) train_acc.append((labels.float().view(-1, 1) == argmax_Y.float()).sum().item() / len(labels.float().view(-1, 1)) * 100) epoch_loss /= (iter + 1) epoch_losses.append(epoch_loss) train_acc = np.mean(np.array(train_acc)) print('Epoch {}, train loss {:.4f}, train acc {:.4f}'.format(epoch, epoch_loss, train_acc)) is_best_acc = train_acc > best_train_acc best_train_acc = max(train_acc, best_train_acc) is_best_train_loss = best_train_loss < epoch_loss best_train_loss = min(epoch_loss, best_train_loss) if is_best_acc: model.save(path + '_model_best.ckpt') model.save(path + '_checkpoint.ckpt') #scheduler.step() def save_checkpoint(state, is_best, path, filename='_checkpoint.pth.tar'): torch.save(state, path + filename) if is_best: shutil.copyfile(path + filename, path +'_model_best.pth.tar') def main(): # parsing parameters parser = argparse.ArgumentParser(description='') parser.add_argument('--resume', type=bool, default=False, help='resume training') parser.add_argument('--batch_size', type=int, default=32, help='batch size') parser.add_argument('--lr', type=float, default=1e-1, help='learning rate') parser.add_argument('--model_type', type=str, default='lstmlast', help='model type') parser.add_argument('--hidden_size', type=int, default=256, help='hidden_size') parser.add_argument('--epochs', type=int, default=100, help='training epoch') parser.add_argument('--dataset_path', type=str, default='dataset/strategy_dataset/', help='dataset path') parser.add_argument('--weight_decay', type=float, default=0.9, help='wight decay for Adam optimizer') parser.add_argument('--demo_hidden', type=int, default=512, help='demo_hidden') parser.add_argument('--dropout', type=float, default=0.5, help='dropout rate') parser.add_argument('--checkpoint', type=str, default='checkpoints/', help='checkpoints path') args = parser.parse_args() # create checkpoints path from pathlib import Path path = args.checkpoint+args.model_type+'_bs_'+str(args.batch_size)+'_lr_'+str(args.lr)+'_hidden_size_'+str(args.hidden_size) Path(path).mkdir(parents=True, exist_ok=True) print('total epochs for training: ', args.epochs) # read dataset train_loader = [] test_loader = [] loss_funcs = [] optimizers = [] models = [] parameters = [] for i in range(7): # 7 tasks # train data with open(args.dataset_path + 'train_data_' + str(i) + '.pkl', 'rb') as f: data_x = pickle.load(f) with open(args.dataset_path + 'train_label_' + str(i) + '.pkl', 'rb') as f: data_y = pickle.load(f) train_set = train_dataset(np.array(data_x), np.array(data_y)-1) train_loader.append(DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4)) print('task', str(i), 'train data size: ', len(train_set)) net = ActionDemo2Predicate(args) models.append(net) parameter = net.parameters() loss_funcs.append(nn.CrossEntropyLoss()) optimizers.append(optim.Adam(parameter, lr=args.lr, weight_decay=args.weight_decay)) for i in range(7): path_save = path + '/task' + str(i) print('checkpoint save path: ', path_save) train_model(models[i], train_loader[i], loss_funcs[i], optimizers[i], args.epochs, DEVICE, path_save, args.resume) if __name__ == '__main__': main()