146 lines
4.8 KiB
Python
146 lines
4.8 KiB
Python
|
import pickle
|
||
|
import numpy as np
|
||
|
from torch.utils.data import Dataset, DataLoader
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.optim as optim
|
||
|
import shutil
|
||
|
import matplotlib.pyplot as plt
|
||
|
import argparse
|
||
|
from networks import ActionDemo2Predicate
|
||
|
|
||
|
print('torch version: ',torch.__version__)
|
||
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
print(DEVICE)
|
||
|
torch.manual_seed(256)
|
||
|
|
||
|
class train_dataset(Dataset):
|
||
|
def __init__(self, x, label):
|
||
|
self.x = x
|
||
|
self.labels = label
|
||
|
|
||
|
def __getitem__(self, index):
|
||
|
x = self.x[index]
|
||
|
label = self.labels[index]
|
||
|
return x, label #, img_idx
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.labels)
|
||
|
|
||
|
class test_dataset(Dataset):
|
||
|
def __init__(self, x, label):
|
||
|
self.x = x
|
||
|
self.labels = label
|
||
|
|
||
|
def __getitem__(self, index):
|
||
|
x = self.x[index]
|
||
|
label = self.labels[index]
|
||
|
return x, label #, img_idx
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.labels)
|
||
|
|
||
|
def train_model(model, train_dataloader, criterion, optimizer, num_epochs, DEVICE, path, resume):
|
||
|
running_loss = 0
|
||
|
train_losses = 10
|
||
|
is_best_acc = False
|
||
|
is_best_train_loss = False
|
||
|
|
||
|
best_train_acc = 0
|
||
|
best_train_loss = 10
|
||
|
|
||
|
start_epoch = 0
|
||
|
accuracy = 0
|
||
|
|
||
|
model.to(DEVICE)
|
||
|
model.train()
|
||
|
for epoch in range(start_epoch, num_epochs):
|
||
|
epoch_losses = []
|
||
|
train_acc = []
|
||
|
epoch_loss = 0
|
||
|
for iter, (x, labels) in enumerate(train_dataloader):
|
||
|
x = torch.tensor(x).to(DEVICE)
|
||
|
labels = torch.tensor(labels).to(DEVICE)
|
||
|
optimizer.zero_grad()
|
||
|
logps = model(x)
|
||
|
loss = criterion(logps, labels)
|
||
|
loss.backward()
|
||
|
optimizer.step()
|
||
|
|
||
|
epoch_loss += loss.detach().item()
|
||
|
argmax_Y = torch.max(logps, 1)[1].view(-1, 1)
|
||
|
train_acc.append((labels.float().view(-1, 1) == argmax_Y.float()).sum().item() / len(labels.float().view(-1, 1)) * 100)
|
||
|
epoch_loss /= (iter + 1)
|
||
|
epoch_losses.append(epoch_loss)
|
||
|
train_acc = np.mean(np.array(train_acc))
|
||
|
print('Epoch {}, train loss {:.4f}, train acc {:.4f}'.format(epoch, epoch_loss, train_acc))
|
||
|
|
||
|
is_best_acc = train_acc > best_train_acc
|
||
|
best_train_acc = max(train_acc, best_train_acc)
|
||
|
|
||
|
is_best_train_loss = best_train_loss < epoch_loss
|
||
|
best_train_loss = min(epoch_loss, best_train_loss)
|
||
|
|
||
|
if is_best_acc:
|
||
|
model.save(path + '_model_best.ckpt')
|
||
|
model.save(path + '_checkpoint.ckpt')
|
||
|
#scheduler.step()
|
||
|
|
||
|
def save_checkpoint(state, is_best, path, filename='_checkpoint.pth.tar'):
|
||
|
torch.save(state, path + filename)
|
||
|
if is_best:
|
||
|
shutil.copyfile(path + filename, path +'_model_best.pth.tar')
|
||
|
|
||
|
def main():
|
||
|
# parsing parameters
|
||
|
parser = argparse.ArgumentParser(description='')
|
||
|
parser.add_argument('--resume', type=bool, default=False, help='resume training')
|
||
|
parser.add_argument('--batch_size', type=int, default=32, help='batch size')
|
||
|
parser.add_argument('--lr', type=float, default=1e-1, help='learning rate')
|
||
|
parser.add_argument('--model_type', type=str, default='lstmlast', help='model type')
|
||
|
parser.add_argument('--hidden_size', type=int, default=256, help='hidden_size')
|
||
|
parser.add_argument('--epochs', type=int, default=100, help='training epoch')
|
||
|
parser.add_argument('--dataset_path', type=str, default='dataset/strategy_dataset/', help='dataset path')
|
||
|
parser.add_argument('--weight_decay', type=float, default=0.9, help='wight decay for Adam optimizer')
|
||
|
parser.add_argument('--demo_hidden', type=int, default=512, help='demo_hidden')
|
||
|
parser.add_argument('--dropout', type=float, default=0.5, help='dropout rate')
|
||
|
parser.add_argument('--checkpoint', type=str, default='checkpoints/', help='checkpoints path')
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
# create checkpoints path
|
||
|
from pathlib import Path
|
||
|
path = args.checkpoint+args.model_type+'_bs_'+str(args.batch_size)+'_lr_'+str(args.lr)+'_hidden_size_'+str(args.hidden_size)
|
||
|
Path(path).mkdir(parents=True, exist_ok=True)
|
||
|
print('total epochs for training: ', args.epochs)
|
||
|
|
||
|
# read dataset
|
||
|
train_loader = []
|
||
|
test_loader = []
|
||
|
loss_funcs = []
|
||
|
optimizers = []
|
||
|
models = []
|
||
|
parameters = []
|
||
|
for i in range(7): # 7 tasks
|
||
|
# train data
|
||
|
with open(args.dataset_path + 'train_data_' + str(i) + '.pkl', 'rb') as f:
|
||
|
data_x = pickle.load(f)
|
||
|
with open(args.dataset_path + 'train_label_' + str(i) + '.pkl', 'rb') as f:
|
||
|
data_y = pickle.load(f)
|
||
|
train_set = train_dataset(np.array(data_x), np.array(data_y)-1)
|
||
|
train_loader.append(DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4))
|
||
|
print('task', str(i), 'train data size: ', len(train_set))
|
||
|
|
||
|
net = ActionDemo2Predicate(args)
|
||
|
models.append(net)
|
||
|
parameter = net.parameters()
|
||
|
loss_funcs.append(nn.CrossEntropyLoss())
|
||
|
optimizers.append(optim.Adam(parameter, lr=args.lr, weight_decay=args.weight_decay))
|
||
|
|
||
|
for i in range(7):
|
||
|
path_save = path + '/task' + str(i)
|
||
|
print('checkpoint save path: ', path_save)
|
||
|
train_model(models[i], train_loader[i], loss_funcs[i], optimizers[i], args.epochs, DEVICE, path_save, args.resume)
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|