Uploaded
This commit is contained in:
commit
04c4625cfe
11 changed files with 1330 additions and 0 deletions
374
train/train.py
Normal file
374
train/train.py
Normal file
|
@ -0,0 +1,374 @@
|
|||
import argparse
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import DataLoader
|
||||
import bmtrain as bmt
|
||||
from functools import partial
|
||||
import time
|
||||
import os, pdb, shutil
|
||||
import random
|
||||
import json
|
||||
from model_center.model import Llama
|
||||
from model_center.tokenizer import LlamaTokenizer
|
||||
from functools import partial
|
||||
from dataset_wrapper import PromptIterableDataset, collator
|
||||
import wandb
|
||||
import csv
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
|
||||
import logging
|
||||
import numpy as np
|
||||
import math
|
||||
from sentence_transformers import SentenceTransformer, util
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def set_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
def get_tokenizer(args):
|
||||
tokenizer = LlamaTokenizer.from_pretrained(args.model_name_or_path)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.padding_side = 'left'
|
||||
return tokenizer
|
||||
|
||||
def get_model(args):
|
||||
model = Llama.from_pretrained(args.model_name_or_path)
|
||||
if args.load_ckpt is not None:
|
||||
logger.info(f"loading model from {args.load_ckpt}")
|
||||
bmt.load(model, os.path.join(args.load_ckpt, "pytorch_model.pt"))
|
||||
return model
|
||||
|
||||
def get_optimizer(args, model):
|
||||
optimizer = bmt.optim.AdamOffloadOptimizer(
|
||||
model.parameters(),
|
||||
weight_decay=args.weight_decay,
|
||||
eps=1e-5,
|
||||
betas=(0.9, 0.95)
|
||||
)
|
||||
if args.load_ckpt is not None:
|
||||
file_name = os.path.join(args.load_ckpt, "optim.rank-{}.opt".format(bmt.rank()))
|
||||
logger.info(file_name)
|
||||
if os.path.exists(file_name):
|
||||
logger.info("start to load gradient ckpt {}".format(file_name))
|
||||
states = torch.load(file_name)
|
||||
optimizer.load_state_dict(states)
|
||||
return optimizer
|
||||
|
||||
def get_learning_rate_scheduler(args, optimizer):
|
||||
if args.lr_decay_iters is None:
|
||||
args.lr_decay_iters = args.train_iters
|
||||
if args.lr_decay_style == "linear":
|
||||
lr_scheduler = bmt.lr_scheduler.Linear(
|
||||
optimizer,
|
||||
start_lr=args.lr,
|
||||
warmup_iter=int(args.warmup_iters),
|
||||
end_iter=args.lr_decay_iters,
|
||||
num_iter=args.start_step,
|
||||
)
|
||||
elif args.lr_decay_style == "cosine":
|
||||
bmt.print_rank("use cosine")
|
||||
class Cosine(bmt.lr_scheduler.WarmupLRScheduler):
|
||||
def get_lr_warmup(self, num_iter) -> float:
|
||||
return self.start_lr * num_iter / self.warmup_iter
|
||||
|
||||
def get_lr_decay(self, num_iter) -> float:
|
||||
progress = (num_iter - self.warmup_iter) / max(1, (self.end_iter - self.warmup_iter))
|
||||
return max(self.start_lr * 0.1, self.start_lr * (0.1 + 0.45 * (1.0 + math.cos(progress * math.pi))))
|
||||
|
||||
lr_scheduler = Cosine(
|
||||
optimizer,
|
||||
start_lr=args.lr,
|
||||
warmup_iter=int(args.warmup_iters),
|
||||
end_iter=args.lr_decay_iters,
|
||||
num_iter=args.start_step,
|
||||
)
|
||||
|
||||
elif args.lr_decay_style == "noam":
|
||||
logger.info("use noam")
|
||||
lr_scheduler = bmt.lr_scheduler.Noam(
|
||||
optimizer,
|
||||
start_lr=args.lr,
|
||||
warmup_iter=int(args.warmup_iters),
|
||||
end_iter=args.lr_decay_iters,
|
||||
num_iter=args.start_step,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return lr_scheduler
|
||||
|
||||
def setup_model_and_optimizer(args):
|
||||
# get the tokenizer
|
||||
tokenizer = get_tokenizer(args)
|
||||
# get the model
|
||||
model = get_model(args)
|
||||
bmt.synchronize()
|
||||
# get the optimizer and lr_scheduler
|
||||
optimizer = get_optimizer(args, model)
|
||||
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
|
||||
bmt.synchronize()
|
||||
return tokenizer, model, optimizer, lr_scheduler
|
||||
|
||||
def initialize():
|
||||
parser = argparse.ArgumentParser("")
|
||||
# model training arguments
|
||||
parser.add_argument("--lr", type=float, default=1e-5)
|
||||
parser.add_argument("--model_name_or_path")
|
||||
parser.add_argument("--epochs", type=int, default=1)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--max_seq_length", default=2048, type=int)
|
||||
parser.add_argument("--batch_size_per_device", default=2, type=int)
|
||||
parser.add_argument("--logging_step", default=100, type=int)
|
||||
parser.add_argument("--save_step", default=50000, type=int)
|
||||
parser.add_argument("--gradient_accumulation_steps", default=1, type=int)
|
||||
parser.add_argument("--wandb", default= True ,action="store_true")
|
||||
parser.add_argument("--with_eval", action="store_true")
|
||||
parser.add_argument("--clip_grad", type=float, default=1.0, help="gradient clipping")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay rate")
|
||||
parser.add_argument("--loss_scale", type=float, default=6553600, help="loss scale")
|
||||
parser.add_argument("--train_iters", type=int, default=2000000)
|
||||
|
||||
# loss parameters
|
||||
parser.add_argument("--action_weight", type=float, help="weight of the tokens that match the action")
|
||||
parser.add_argument("--embedding_model_path", type=str, help="The path to the sentence embedding model")
|
||||
|
||||
# data parameters
|
||||
parser.add_argument('--data_setting', type=str ,help='MTSD or MTMD', default="MTMD")
|
||||
parser.add_argument('--data_dir', type=str, help='The directory for saving the dataset')
|
||||
parser.add_argument('--max_train_samples', type=int, help='The maximum number of training samples')
|
||||
|
||||
parser.add_argument('--cache_dir', type=str, help='The directory for cache')
|
||||
parser.add_argument("--save_dir", type=str, default="")
|
||||
|
||||
parser.add_argument("--save_limit", type=int, default=None, help="ckpt saved limit number")
|
||||
|
||||
parser.add_argument("--warmup_iters", type=int, default=1000)
|
||||
parser.add_argument(
|
||||
"--lr_decay_style",
|
||||
type=str,
|
||||
default="cosine",
|
||||
choices=["constant", "linear", "cosine", "exponential", "noam"],
|
||||
help="learning rate decay function",
|
||||
)
|
||||
parser.add_argument("--lr_decay_iters", type=int, default=None, help="lr decay steps")
|
||||
parser.add_argument("--start_step", type=int, default=0, help="step to start or continue training")
|
||||
parser.add_argument("--load_ckpt", type=str, default=None, help="resumed ckpt")
|
||||
parser.add_argument("--save_processed_data", action='store_true', help="wheather or no save the processed data")
|
||||
parser.add_argument("--prompt_file", type=str, default=None, help="The file for loading the prompt")
|
||||
args = parser.parse_args()
|
||||
# init bmt
|
||||
bmt.init_distributed(seed=args.seed)
|
||||
set_seed(args.seed)
|
||||
# wandb
|
||||
if args.wandb and bmt.rank() == 0:
|
||||
wandb.init(project='Mistral-Interact', config=args, name=args.save_dir.split('Mistral-7b/')[1][:-1], save_code=True, settings=wandb.Settings(code_dir="."))
|
||||
return args
|
||||
|
||||
def format_one_action(action):
|
||||
return f"- {action}\n"
|
||||
|
||||
def format_actions_list(actions):
|
||||
actions_str = ""
|
||||
for action in actions:
|
||||
actions_str += format_one_action(action)
|
||||
return actions_str
|
||||
|
||||
def read_json_file(filename):
|
||||
with open(filename, 'r') as infile:
|
||||
data = json.load(infile)
|
||||
return data
|
||||
|
||||
def load_Mind2Web_dataset(args, save_dataset= False):
|
||||
# read text from a file (file name is args.prompt_file)
|
||||
with open(args.prompt_file, 'r') as file:
|
||||
task_description = file.read().split('===')
|
||||
raw_dataset = read_json_file(args.data_dir)
|
||||
|
||||
dataset=[]
|
||||
for idx, d in enumerate(raw_dataset):
|
||||
sequences = []
|
||||
input_str = f"## Website:\n{d['website_en']}\n\n## Domain:\n{d['domain_en']}\n\n## Sub-domain:\n{d['subdomain_en']}\n\n## Actions (Each line is one action):\n{format_actions_list(d['task_subintention'])}\n## Sub-intentions summarised from these actions:\n{format_actions_list(d['steps'])}"
|
||||
|
||||
query_inputs = f"{task_description[0]}\n{input_str}{task_description[1]}\n"
|
||||
sequences.append(query_inputs)
|
||||
summary_str = d['task_description']
|
||||
summary_str = "[SUMMARY] " + summary_str[0].upper() + summary_str[1:]
|
||||
sequences.append(summary_str)
|
||||
dataset.append({"data": sequences.copy()})
|
||||
|
||||
random.shuffle(dataset)
|
||||
|
||||
if args.max_train_samples is not None:
|
||||
dataset = dataset[:args.max_train_samples]
|
||||
return dataset
|
||||
|
||||
def load_MoTIF_dataset(args, save_dataset= False):
|
||||
with open(args.prompt_file, 'r') as file:
|
||||
task_description = file.read().split('===')
|
||||
|
||||
raw_dataset = []
|
||||
for filename in os.listdir(args.data_dir):
|
||||
if filename.endswith('_steps.json'):
|
||||
file_path = os.path.join(args.data_dir, filename)
|
||||
with open(file_path, 'r', encoding='utf-8') as json_file:
|
||||
try:
|
||||
content = json.load(json_file)
|
||||
raw_dataset.append(content)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Error decoding JSON from file {filename}: {e}")
|
||||
|
||||
dataset=[]
|
||||
for d in raw_dataset:
|
||||
sequences = []
|
||||
input_str = f"## Application:\n{d['app']}\n\n## Actions (Each line is one action):\n{format_actions_list(d['instr'])}\n## Sub-intentions summarised from these actions:\n{format_actions_list(d['steps'])}"
|
||||
query_inputs = f"{task_description[0]}\n{input_str}{task_description[1]}\n"
|
||||
sequences.append(query_inputs)
|
||||
summary_str = d['goal']
|
||||
summary_str = "[SUMMARY] " + summary_str[0].upper() + summary_str[1:]
|
||||
sequences.append(summary_str)
|
||||
dataset.append({"data": sequences.copy()})
|
||||
|
||||
random.shuffle(dataset)
|
||||
|
||||
if args.max_train_samples is not None:
|
||||
dataset = dataset[:args.max_train_samples]
|
||||
return dataset
|
||||
|
||||
def finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset):
|
||||
embedding_model = SentenceTransformer(args.embedding_model_path, device="cuda")
|
||||
for param in embedding_model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
logger.info(f"total training instance number: {len(dataset)}")
|
||||
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, reduction="none")
|
||||
optim_manager = bmt.optim.OptimManager(loss_scale=args.loss_scale)
|
||||
optim_manager.add_optimizer(optimizer, lr_scheduler)
|
||||
bmt.synchronize()
|
||||
|
||||
avg_time_recorder = bmt.utils.AverageRecorder()
|
||||
avg_loss_recorder = bmt.utils.AverageRecorder()
|
||||
train_start_time = time.time()
|
||||
global_step = 0
|
||||
|
||||
logger.info("split data for each process")
|
||||
data_per_gpu = len(dataset) // bmt.world_size()
|
||||
dataset = dataset[bmt.rank() * data_per_gpu: (bmt.rank() + 1) * data_per_gpu]
|
||||
bmt.print_rank("training on [%d, %d] of the dataset" % (bmt.rank() * data_per_gpu, (bmt.rank() + 1) * data_per_gpu))
|
||||
dataset = PromptIterableDataset(
|
||||
dataset,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=args.max_seq_length,
|
||||
teacher_forcing=True,
|
||||
truncate_method="tail",
|
||||
)
|
||||
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print(f"Total trainable parameters: {total_params}")
|
||||
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
print(f"Total parameters: {total_params}")
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
savefolder = os.path.join(args.save_dir, f"epoch_{epoch}")
|
||||
os.makedirs(savefolder, exist_ok=True)
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size=args.batch_size_per_device)
|
||||
|
||||
progress_bar = tqdm(range(len(dataloader)), disable=not bmt.rank()==0, desc=f"epoch {epoch}")
|
||||
logger.info(f"*******start {epoch} epoch training********")
|
||||
for step, inputs in enumerate(dataloader):
|
||||
if global_step < args.start_step:
|
||||
global_step += 1
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
st = time.time()
|
||||
|
||||
with bmt.inspect.inspect_tensor() as inspector:
|
||||
for k in inputs:
|
||||
inputs[k] = inputs[k].cuda()
|
||||
|
||||
labels = inputs.pop("labels")
|
||||
weight_idxs = inputs.pop('weight_idxs')
|
||||
logits = model(**inputs).logits
|
||||
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
|
||||
# Flatten the tokens
|
||||
shift_logits = shift_logits.view(-1, len(tokenizer))
|
||||
shift_labels = shift_labels.view(-1).to(shift_logits.device)
|
||||
ntp_loss = loss_func(shift_logits, shift_labels)
|
||||
|
||||
sample_specific_weights = torch.ones_like(shift_logits)
|
||||
weight_idxs = weight_idxs[:, 1:, :].contiguous()
|
||||
weight_idxs = weight_idxs.view(-1, weight_idxs.size(-1))
|
||||
assert weight_idxs.shape[0] == sample_specific_weights.shape[0], "310"
|
||||
sample_specific_weights[weight_idxs==1] = args.action_weight
|
||||
sample_specific_weights = sample_specific_weights[torch.arange(sample_specific_weights.size(0)), shift_labels]
|
||||
|
||||
ntp_loss = (ntp_loss * sample_specific_weights).mean()
|
||||
next_token_loss_item = bmt.sum_loss(ntp_loss).item()
|
||||
|
||||
global_loss = next_token_loss_item
|
||||
optim_manager.backward(ntp_loss)
|
||||
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0 or step == len(dataloader) - 1:
|
||||
optim_manager.clip_grad_norm(optimizer.param_groups, max_norm=args.clip_grad)
|
||||
optim_manager.step()
|
||||
optim_manager.zero_grad()
|
||||
|
||||
global_step += 1
|
||||
progress_bar.update(1)
|
||||
|
||||
# record time and loss
|
||||
iteration_time = time.time() - st
|
||||
|
||||
avg_time_recorder.record(iteration_time)
|
||||
if not np.isnan(global_loss):
|
||||
avg_loss_recorder.record(global_loss)
|
||||
|
||||
# print time and loss
|
||||
if global_step % args.logging_step == 0:
|
||||
bmt.print_rank(
|
||||
"| Iter: {:6d} | loss: {:.4f} average_loss: {:.4f} | lr: {:.4e} | time: {:.4f} seconds | total_time_passed: {:.4f} minutes".format(
|
||||
global_step,
|
||||
global_loss,
|
||||
avg_loss_recorder.value,
|
||||
lr_scheduler.current_lr,
|
||||
avg_time_recorder.value,
|
||||
(time.time() - train_start_time) / 60
|
||||
)
|
||||
)
|
||||
if args.wandb and bmt.rank() == 0:
|
||||
wandb.log({
|
||||
"loss": global_loss,
|
||||
"next_token_loss": next_token_loss_item,
|
||||
"average_loss": avg_loss_recorder.value,
|
||||
"lr": lr_scheduler.current_lr,
|
||||
}, step=global_step)
|
||||
|
||||
if global_step == args.train_iters:
|
||||
break
|
||||
|
||||
bmt.save(model, os.path.join(savefolder, "pytorch_model.pt"))
|
||||
if bmt.rank() == 0:
|
||||
tokenizer.save_pretrained(savefolder)
|
||||
bmt.print_rank(f"model saved at {savefolder}")
|
||||
|
||||
def main():
|
||||
args = initialize()
|
||||
if "Mind2Web" in args.data_dir:
|
||||
dataset = load_Mind2Web_dataset(args, save_dataset=True)
|
||||
else:
|
||||
assert "MoTIF" in args.data_dir
|
||||
dataset = load_MoTIF_dataset(args, save_dataset=True)
|
||||
args.train_iters = min(args.epochs * (len(dataset) // (bmt.world_size() * args.batch_size_per_device) + 1), args.train_iters)
|
||||
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
|
||||
finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
44
train/train.sh
Executable file
44
train/train.sh
Executable file
|
@ -0,0 +1,44 @@
|
|||
#! /bin/bash
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=12345
|
||||
NNODES=1
|
||||
NODE_RANK=0
|
||||
GPUS_PER_NODE=2
|
||||
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \
|
||||
--nnodes $NNODES \
|
||||
--node_rank $NODE_RANK \
|
||||
--master_addr $MASTER_ADDR \
|
||||
--master_port $MASTER_PORT"
|
||||
|
||||
PROJECT_PATH="your-project-path"
|
||||
|
||||
OPTS=""
|
||||
# model config
|
||||
MAXSEQLEN=1024
|
||||
OPTS+=" --max_seq_length ${MAXSEQLEN}"
|
||||
OPTS+=" --model_name_or_path ${PROJECT_PATH}/Mistral-7b-bmtrain"
|
||||
# training config
|
||||
OPTS+=" --logging_step 4"
|
||||
BATCHSIZE=16
|
||||
OPTS+=" --batch_size_per_device ${BATCHSIZE}"
|
||||
OPTS+=" --save_step 500"
|
||||
OPTS+=" --epochs 15"
|
||||
LR=1e-6
|
||||
OPTS+=" --lr ${LR}"
|
||||
OPTS+=" --warmup_iters 0"
|
||||
OPTS+=" --start_step 0"
|
||||
OPTS+=" --loss_scale 6400"
|
||||
ACTIONWEIGHT=2
|
||||
OPTS+=" --action_weight ${ACTIONWEIGHT}"
|
||||
EMBEDDING_MODEL_PATH="${PROJECT_PATH}/sentence-transformer/models--sentence-transformers--all-MiniLM-L6-v2/snapshots/e4ce9877abf3edfe10b0d82785e83bdcb973e22e"
|
||||
OPTS+=" --embedding_model_path ${EMBEDDING_MODEL_PATH}"
|
||||
|
||||
OPTS+=" --prompt_file ${PROJECT_PATH}/prompts/summarisation/summarisation_prompt.txt"
|
||||
OPTS+=" --save_dir ${PROJECT_PATH}/ckpts/experiment"
|
||||
|
||||
CMD="torchrun ${DISTRIBUTED_ARGS} train.py ${OPTS}"
|
||||
|
||||
echo "-------final CMD is------"
|
||||
echo "${CMD}"
|
||||
echo "-------final CMD end------"
|
||||
${CMD}
|
Loading…
Add table
Add a link
Reference in a new issue