InferringIntention/keyboard_and_mouse/train.py

146 lines
4.8 KiB
Python
Raw Normal View History

2024-03-24 23:42:27 +01:00
import pickle
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
import shutil
import matplotlib.pyplot as plt
import argparse
from networks import ActionDemo2Predicate
print('torch version: ',torch.__version__)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)
torch.manual_seed(256)
class train_dataset(Dataset):
def __init__(self, x, label):
self.x = x
self.labels = label
def __getitem__(self, index):
x = self.x[index]
label = self.labels[index]
return x, label #, img_idx
def __len__(self):
return len(self.labels)
class test_dataset(Dataset):
def __init__(self, x, label):
self.x = x
self.labels = label
def __getitem__(self, index):
x = self.x[index]
label = self.labels[index]
return x, label #, img_idx
def __len__(self):
return len(self.labels)
def train_model(model, train_dataloader, criterion, optimizer, num_epochs, DEVICE, path, resume):
running_loss = 0
train_losses = 10
is_best_acc = False
is_best_train_loss = False
best_train_acc = 0
best_train_loss = 10
start_epoch = 0
accuracy = 0
model.to(DEVICE)
model.train()
for epoch in range(start_epoch, num_epochs):
epoch_losses = []
train_acc = []
epoch_loss = 0
for iter, (x, labels) in enumerate(train_dataloader):
x = torch.tensor(x).to(DEVICE)
labels = torch.tensor(labels).to(DEVICE)
optimizer.zero_grad()
logps = model(x)
loss = criterion(logps, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.detach().item()
argmax_Y = torch.max(logps, 1)[1].view(-1, 1)
train_acc.append((labels.float().view(-1, 1) == argmax_Y.float()).sum().item() / len(labels.float().view(-1, 1)) * 100)
epoch_loss /= (iter + 1)
epoch_losses.append(epoch_loss)
train_acc = np.mean(np.array(train_acc))
print('Epoch {}, train loss {:.4f}, train acc {:.4f}'.format(epoch, epoch_loss, train_acc))
is_best_acc = train_acc > best_train_acc
best_train_acc = max(train_acc, best_train_acc)
is_best_train_loss = best_train_loss < epoch_loss
best_train_loss = min(epoch_loss, best_train_loss)
if is_best_acc:
model.save(path + '_model_best.ckpt')
model.save(path + '_checkpoint.ckpt')
#scheduler.step()
def save_checkpoint(state, is_best, path, filename='_checkpoint.pth.tar'):
torch.save(state, path + filename)
if is_best:
shutil.copyfile(path + filename, path +'_model_best.pth.tar')
def main():
# parsing parameters
parser = argparse.ArgumentParser(description='')
parser.add_argument('--resume', type=bool, default=False, help='resume training')
parser.add_argument('--batch_size', type=int, default=32, help='batch size')
parser.add_argument('--lr', type=float, default=1e-1, help='learning rate')
parser.add_argument('--model_type', type=str, default='lstmlast', help='model type')
parser.add_argument('--hidden_size', type=int, default=256, help='hidden_size')
parser.add_argument('--epochs', type=int, default=100, help='training epoch')
parser.add_argument('--dataset_path', type=str, default='dataset/strategy_dataset/', help='dataset path')
parser.add_argument('--weight_decay', type=float, default=0.9, help='wight decay for Adam optimizer')
parser.add_argument('--demo_hidden', type=int, default=512, help='demo_hidden')
parser.add_argument('--dropout', type=float, default=0.5, help='dropout rate')
parser.add_argument('--checkpoint', type=str, default='checkpoints/', help='checkpoints path')
args = parser.parse_args()
# create checkpoints path
from pathlib import Path
path = args.checkpoint+args.model_type+'_bs_'+str(args.batch_size)+'_lr_'+str(args.lr)+'_hidden_size_'+str(args.hidden_size)
Path(path).mkdir(parents=True, exist_ok=True)
print('total epochs for training: ', args.epochs)
# read dataset
train_loader = []
test_loader = []
loss_funcs = []
optimizers = []
models = []
parameters = []
for i in range(7): # 7 tasks
# train data
with open(args.dataset_path + 'train_data_' + str(i) + '.pkl', 'rb') as f:
data_x = pickle.load(f)
with open(args.dataset_path + 'train_label_' + str(i) + '.pkl', 'rb') as f:
data_y = pickle.load(f)
train_set = train_dataset(np.array(data_x), np.array(data_y)-1)
train_loader.append(DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4))
print('task', str(i), 'train data size: ', len(train_set))
net = ActionDemo2Predicate(args)
models.append(net)
parameter = net.parameters()
loss_funcs.append(nn.CrossEntropyLoss())
optimizers.append(optim.Adam(parameter, lr=args.lr, weight_decay=args.weight_decay))
for i in range(7):
path_save = path + '/task' + str(i)
print('checkpoint save path: ', path_save)
train_model(models[i], train_loader[i], loss_funcs[i], optimizers[i], args.epochs, DEVICE, path_save, args.resume)
if __name__ == '__main__':
main()