SummAct/train/train.py
2025-04-10 20:14:17 +02:00

374 lines
No EOL
15 KiB
Python

import argparse
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
import bmtrain as bmt
from functools import partial
import time
import os, pdb, shutil
import random
import json
from model_center.model import Llama
from model_center.tokenizer import LlamaTokenizer
from functools import partial
from dataset_wrapper import PromptIterableDataset, collator
import wandb
import csv
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
import logging
import numpy as np
import math
from sentence_transformers import SentenceTransformer, util
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def get_tokenizer(args):
tokenizer = LlamaTokenizer.from_pretrained(args.model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
return tokenizer
def get_model(args):
model = Llama.from_pretrained(args.model_name_or_path)
if args.load_ckpt is not None:
logger.info(f"loading model from {args.load_ckpt}")
bmt.load(model, os.path.join(args.load_ckpt, "pytorch_model.pt"))
return model
def get_optimizer(args, model):
optimizer = bmt.optim.AdamOffloadOptimizer(
model.parameters(),
weight_decay=args.weight_decay,
eps=1e-5,
betas=(0.9, 0.95)
)
if args.load_ckpt is not None:
file_name = os.path.join(args.load_ckpt, "optim.rank-{}.opt".format(bmt.rank()))
logger.info(file_name)
if os.path.exists(file_name):
logger.info("start to load gradient ckpt {}".format(file_name))
states = torch.load(file_name)
optimizer.load_state_dict(states)
return optimizer
def get_learning_rate_scheduler(args, optimizer):
if args.lr_decay_iters is None:
args.lr_decay_iters = args.train_iters
if args.lr_decay_style == "linear":
lr_scheduler = bmt.lr_scheduler.Linear(
optimizer,
start_lr=args.lr,
warmup_iter=int(args.warmup_iters),
end_iter=args.lr_decay_iters,
num_iter=args.start_step,
)
elif args.lr_decay_style == "cosine":
bmt.print_rank("use cosine")
class Cosine(bmt.lr_scheduler.WarmupLRScheduler):
def get_lr_warmup(self, num_iter) -> float:
return self.start_lr * num_iter / self.warmup_iter
def get_lr_decay(self, num_iter) -> float:
progress = (num_iter - self.warmup_iter) / max(1, (self.end_iter - self.warmup_iter))
return max(self.start_lr * 0.1, self.start_lr * (0.1 + 0.45 * (1.0 + math.cos(progress * math.pi))))
lr_scheduler = Cosine(
optimizer,
start_lr=args.lr,
warmup_iter=int(args.warmup_iters),
end_iter=args.lr_decay_iters,
num_iter=args.start_step,
)
elif args.lr_decay_style == "noam":
logger.info("use noam")
lr_scheduler = bmt.lr_scheduler.Noam(
optimizer,
start_lr=args.lr,
warmup_iter=int(args.warmup_iters),
end_iter=args.lr_decay_iters,
num_iter=args.start_step,
)
else:
raise NotImplementedError
return lr_scheduler
def setup_model_and_optimizer(args):
# get the tokenizer
tokenizer = get_tokenizer(args)
# get the model
model = get_model(args)
bmt.synchronize()
# get the optimizer and lr_scheduler
optimizer = get_optimizer(args, model)
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
bmt.synchronize()
return tokenizer, model, optimizer, lr_scheduler
def initialize():
parser = argparse.ArgumentParser("")
# model training arguments
parser.add_argument("--lr", type=float, default=1e-5)
parser.add_argument("--model_name_or_path")
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--max_seq_length", default=2048, type=int)
parser.add_argument("--batch_size_per_device", default=2, type=int)
parser.add_argument("--logging_step", default=100, type=int)
parser.add_argument("--save_step", default=50000, type=int)
parser.add_argument("--gradient_accumulation_steps", default=1, type=int)
parser.add_argument("--wandb", default= True ,action="store_true")
parser.add_argument("--with_eval", action="store_true")
parser.add_argument("--clip_grad", type=float, default=1.0, help="gradient clipping")
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay rate")
parser.add_argument("--loss_scale", type=float, default=6553600, help="loss scale")
parser.add_argument("--train_iters", type=int, default=2000000)
# loss parameters
parser.add_argument("--action_weight", type=float, help="weight of the tokens that match the action")
parser.add_argument("--embedding_model_path", type=str, help="The path to the sentence embedding model")
# data parameters
parser.add_argument('--data_setting', type=str ,help='MTSD or MTMD', default="MTMD")
parser.add_argument('--data_dir', type=str, help='The directory for saving the dataset')
parser.add_argument('--max_train_samples', type=int, help='The maximum number of training samples')
parser.add_argument('--cache_dir', type=str, help='The directory for cache')
parser.add_argument("--save_dir", type=str, default="")
parser.add_argument("--save_limit", type=int, default=None, help="ckpt saved limit number")
parser.add_argument("--warmup_iters", type=int, default=1000)
parser.add_argument(
"--lr_decay_style",
type=str,
default="cosine",
choices=["constant", "linear", "cosine", "exponential", "noam"],
help="learning rate decay function",
)
parser.add_argument("--lr_decay_iters", type=int, default=None, help="lr decay steps")
parser.add_argument("--start_step", type=int, default=0, help="step to start or continue training")
parser.add_argument("--load_ckpt", type=str, default=None, help="resumed ckpt")
parser.add_argument("--save_processed_data", action='store_true', help="wheather or no save the processed data")
parser.add_argument("--prompt_file", type=str, default=None, help="The file for loading the prompt")
args = parser.parse_args()
# init bmt
bmt.init_distributed(seed=args.seed)
set_seed(args.seed)
# wandb
if args.wandb and bmt.rank() == 0:
wandb.init(project='Mistral-Interact', config=args, name=args.save_dir.split('Mistral-7b/')[1][:-1], save_code=True, settings=wandb.Settings(code_dir="."))
return args
def format_one_action(action):
return f"- {action}\n"
def format_actions_list(actions):
actions_str = ""
for action in actions:
actions_str += format_one_action(action)
return actions_str
def read_json_file(filename):
with open(filename, 'r') as infile:
data = json.load(infile)
return data
def load_Mind2Web_dataset(args, save_dataset= False):
# read text from a file (file name is args.prompt_file)
with open(args.prompt_file, 'r') as file:
task_description = file.read().split('===')
raw_dataset = read_json_file(args.data_dir)
dataset=[]
for idx, d in enumerate(raw_dataset):
sequences = []
input_str = f"## Website:\n{d['website_en']}\n\n## Domain:\n{d['domain_en']}\n\n## Sub-domain:\n{d['subdomain_en']}\n\n## Actions (Each line is one action):\n{format_actions_list(d['task_subintention'])}\n## Sub-intentions summarised from these actions:\n{format_actions_list(d['steps'])}"
query_inputs = f"{task_description[0]}\n{input_str}{task_description[1]}\n"
sequences.append(query_inputs)
summary_str = d['task_description']
summary_str = "[SUMMARY] " + summary_str[0].upper() + summary_str[1:]
sequences.append(summary_str)
dataset.append({"data": sequences.copy()})
random.shuffle(dataset)
if args.max_train_samples is not None:
dataset = dataset[:args.max_train_samples]
return dataset
def load_MoTIF_dataset(args, save_dataset= False):
with open(args.prompt_file, 'r') as file:
task_description = file.read().split('===')
raw_dataset = []
for filename in os.listdir(args.data_dir):
if filename.endswith('_steps.json'):
file_path = os.path.join(args.data_dir, filename)
with open(file_path, 'r', encoding='utf-8') as json_file:
try:
content = json.load(json_file)
raw_dataset.append(content)
except json.JSONDecodeError as e:
raise ValueError(f"Error decoding JSON from file {filename}: {e}")
dataset=[]
for d in raw_dataset:
sequences = []
input_str = f"## Application:\n{d['app']}\n\n## Actions (Each line is one action):\n{format_actions_list(d['instr'])}\n## Sub-intentions summarised from these actions:\n{format_actions_list(d['steps'])}"
query_inputs = f"{task_description[0]}\n{input_str}{task_description[1]}\n"
sequences.append(query_inputs)
summary_str = d['goal']
summary_str = "[SUMMARY] " + summary_str[0].upper() + summary_str[1:]
sequences.append(summary_str)
dataset.append({"data": sequences.copy()})
random.shuffle(dataset)
if args.max_train_samples is not None:
dataset = dataset[:args.max_train_samples]
return dataset
def finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset):
embedding_model = SentenceTransformer(args.embedding_model_path, device="cuda")
for param in embedding_model.parameters():
param.requires_grad = False
logger.info(f"total training instance number: {len(dataset)}")
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, reduction="none")
optim_manager = bmt.optim.OptimManager(loss_scale=args.loss_scale)
optim_manager.add_optimizer(optimizer, lr_scheduler)
bmt.synchronize()
avg_time_recorder = bmt.utils.AverageRecorder()
avg_loss_recorder = bmt.utils.AverageRecorder()
train_start_time = time.time()
global_step = 0
logger.info("split data for each process")
data_per_gpu = len(dataset) // bmt.world_size()
dataset = dataset[bmt.rank() * data_per_gpu: (bmt.rank() + 1) * data_per_gpu]
bmt.print_rank("training on [%d, %d] of the dataset" % (bmt.rank() * data_per_gpu, (bmt.rank() + 1) * data_per_gpu))
dataset = PromptIterableDataset(
dataset,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
teacher_forcing=True,
truncate_method="tail",
)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {total_params}")
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")
for epoch in range(args.epochs):
savefolder = os.path.join(args.save_dir, f"epoch_{epoch}")
os.makedirs(savefolder, exist_ok=True)
dataloader = DataLoader(dataset, batch_size=args.batch_size_per_device)
progress_bar = tqdm(range(len(dataloader)), disable=not bmt.rank()==0, desc=f"epoch {epoch}")
logger.info(f"*******start {epoch} epoch training********")
for step, inputs in enumerate(dataloader):
if global_step < args.start_step:
global_step += 1
progress_bar.update(1)
continue
st = time.time()
with bmt.inspect.inspect_tensor() as inspector:
for k in inputs:
inputs[k] = inputs[k].cuda()
labels = inputs.pop("labels")
weight_idxs = inputs.pop('weight_idxs')
logits = model(**inputs).logits
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
shift_logits = shift_logits.view(-1, len(tokenizer))
shift_labels = shift_labels.view(-1).to(shift_logits.device)
ntp_loss = loss_func(shift_logits, shift_labels)
sample_specific_weights = torch.ones_like(shift_logits)
weight_idxs = weight_idxs[:, 1:, :].contiguous()
weight_idxs = weight_idxs.view(-1, weight_idxs.size(-1))
assert weight_idxs.shape[0] == sample_specific_weights.shape[0], "310"
sample_specific_weights[weight_idxs==1] = args.action_weight
sample_specific_weights = sample_specific_weights[torch.arange(sample_specific_weights.size(0)), shift_labels]
ntp_loss = (ntp_loss * sample_specific_weights).mean()
next_token_loss_item = bmt.sum_loss(ntp_loss).item()
global_loss = next_token_loss_item
optim_manager.backward(ntp_loss)
if (step + 1) % args.gradient_accumulation_steps == 0 or step == len(dataloader) - 1:
optim_manager.clip_grad_norm(optimizer.param_groups, max_norm=args.clip_grad)
optim_manager.step()
optim_manager.zero_grad()
global_step += 1
progress_bar.update(1)
# record time and loss
iteration_time = time.time() - st
avg_time_recorder.record(iteration_time)
if not np.isnan(global_loss):
avg_loss_recorder.record(global_loss)
# print time and loss
if global_step % args.logging_step == 0:
bmt.print_rank(
"| Iter: {:6d} | loss: {:.4f} average_loss: {:.4f} | lr: {:.4e} | time: {:.4f} seconds | total_time_passed: {:.4f} minutes".format(
global_step,
global_loss,
avg_loss_recorder.value,
lr_scheduler.current_lr,
avg_time_recorder.value,
(time.time() - train_start_time) / 60
)
)
if args.wandb and bmt.rank() == 0:
wandb.log({
"loss": global_loss,
"next_token_loss": next_token_loss_item,
"average_loss": avg_loss_recorder.value,
"lr": lr_scheduler.current_lr,
}, step=global_step)
if global_step == args.train_iters:
break
bmt.save(model, os.path.join(savefolder, "pytorch_model.pt"))
if bmt.rank() == 0:
tokenizer.save_pretrained(savefolder)
bmt.print_rank(f"model saved at {savefolder}")
def main():
args = initialize()
if "Mind2Web" in args.data_dir:
dataset = load_Mind2Web_dataset(args, save_dataset=True)
else:
assert "MoTIF" in args.data_dir
dataset = load_MoTIF_dataset(args, save_dataset=True)
args.train_iters = min(args.epochs * (len(dataset) // (bmt.world_size() * args.batch_size_per_device) + 1), args.train_iters)
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset)
if __name__ == "__main__":
main()