import argparse import torch from tqdm import tqdm from torch.utils.data import DataLoader import bmtrain as bmt from functools import partial import time import os, pdb, shutil import random import json from model_center.model import Llama from model_center.tokenizer import LlamaTokenizer from functools import partial from dataset_wrapper import PromptIterableDataset, collator import wandb import csv os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32" import logging import numpy as np import math from sentence_transformers import SentenceTransformer, util logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def set_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True def get_tokenizer(args): tokenizer = LlamaTokenizer.from_pretrained(args.model_name_or_path) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = 'left' return tokenizer def get_model(args): model = Llama.from_pretrained(args.model_name_or_path) if args.load_ckpt is not None: logger.info(f"loading model from {args.load_ckpt}") bmt.load(model, os.path.join(args.load_ckpt, "pytorch_model.pt")) return model def get_optimizer(args, model): optimizer = bmt.optim.AdamOffloadOptimizer( model.parameters(), weight_decay=args.weight_decay, eps=1e-5, betas=(0.9, 0.95) ) if args.load_ckpt is not None: file_name = os.path.join(args.load_ckpt, "optim.rank-{}.opt".format(bmt.rank())) logger.info(file_name) if os.path.exists(file_name): logger.info("start to load gradient ckpt {}".format(file_name)) states = torch.load(file_name) optimizer.load_state_dict(states) return optimizer def get_learning_rate_scheduler(args, optimizer): if args.lr_decay_iters is None: args.lr_decay_iters = args.train_iters if args.lr_decay_style == "linear": lr_scheduler = bmt.lr_scheduler.Linear( optimizer, start_lr=args.lr, warmup_iter=int(args.warmup_iters), end_iter=args.lr_decay_iters, num_iter=args.start_step, ) elif args.lr_decay_style == "cosine": bmt.print_rank("use cosine") class Cosine(bmt.lr_scheduler.WarmupLRScheduler): def get_lr_warmup(self, num_iter) -> float: return self.start_lr * num_iter / self.warmup_iter def get_lr_decay(self, num_iter) -> float: progress = (num_iter - self.warmup_iter) / max(1, (self.end_iter - self.warmup_iter)) return max(self.start_lr * 0.1, self.start_lr * (0.1 + 0.45 * (1.0 + math.cos(progress * math.pi)))) lr_scheduler = Cosine( optimizer, start_lr=args.lr, warmup_iter=int(args.warmup_iters), end_iter=args.lr_decay_iters, num_iter=args.start_step, ) elif args.lr_decay_style == "noam": logger.info("use noam") lr_scheduler = bmt.lr_scheduler.Noam( optimizer, start_lr=args.lr, warmup_iter=int(args.warmup_iters), end_iter=args.lr_decay_iters, num_iter=args.start_step, ) else: raise NotImplementedError return lr_scheduler def setup_model_and_optimizer(args): # get the tokenizer tokenizer = get_tokenizer(args) # get the model model = get_model(args) bmt.synchronize() # get the optimizer and lr_scheduler optimizer = get_optimizer(args, model) lr_scheduler = get_learning_rate_scheduler(args, optimizer) bmt.synchronize() return tokenizer, model, optimizer, lr_scheduler def initialize(): parser = argparse.ArgumentParser("") # model training arguments parser.add_argument("--lr", type=float, default=1e-5) parser.add_argument("--model_name_or_path") parser.add_argument("--epochs", type=int, default=1) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--max_seq_length", default=2048, type=int) parser.add_argument("--batch_size_per_device", default=2, type=int) parser.add_argument("--logging_step", default=100, type=int) parser.add_argument("--save_step", default=50000, type=int) parser.add_argument("--gradient_accumulation_steps", default=1, type=int) parser.add_argument("--wandb", default= True ,action="store_true") parser.add_argument("--with_eval", action="store_true") parser.add_argument("--clip_grad", type=float, default=1.0, help="gradient clipping") parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay rate") parser.add_argument("--loss_scale", type=float, default=6553600, help="loss scale") parser.add_argument("--train_iters", type=int, default=2000000) # loss parameters parser.add_argument("--action_weight", type=float, help="weight of the tokens that match the action") parser.add_argument("--embedding_model_path", type=str, help="The path to the sentence embedding model") # data parameters parser.add_argument('--data_setting', type=str ,help='MTSD or MTMD', default="MTMD") parser.add_argument('--data_dir', type=str, help='The directory for saving the dataset') parser.add_argument('--max_train_samples', type=int, help='The maximum number of training samples') parser.add_argument('--cache_dir', type=str, help='The directory for cache') parser.add_argument("--save_dir", type=str, default="") parser.add_argument("--save_limit", type=int, default=None, help="ckpt saved limit number") parser.add_argument("--warmup_iters", type=int, default=1000) parser.add_argument( "--lr_decay_style", type=str, default="cosine", choices=["constant", "linear", "cosine", "exponential", "noam"], help="learning rate decay function", ) parser.add_argument("--lr_decay_iters", type=int, default=None, help="lr decay steps") parser.add_argument("--start_step", type=int, default=0, help="step to start or continue training") parser.add_argument("--load_ckpt", type=str, default=None, help="resumed ckpt") parser.add_argument("--save_processed_data", action='store_true', help="wheather or no save the processed data") parser.add_argument("--prompt_file", type=str, default=None, help="The file for loading the prompt") args = parser.parse_args() # init bmt bmt.init_distributed(seed=args.seed) set_seed(args.seed) # wandb if args.wandb and bmt.rank() == 0: wandb.init(project='Mistral-Interact', config=args, name=args.save_dir.split('Mistral-7b/')[1][:-1], save_code=True, settings=wandb.Settings(code_dir=".")) return args def format_one_action(action): return f"- {action}\n" def format_actions_list(actions): actions_str = "" for action in actions: actions_str += format_one_action(action) return actions_str def read_json_file(filename): with open(filename, 'r') as infile: data = json.load(infile) return data def load_Mind2Web_dataset(args, save_dataset= False): # read text from a file (file name is args.prompt_file) with open(args.prompt_file, 'r') as file: task_description = file.read().split('===') raw_dataset = read_json_file(args.data_dir) dataset=[] for idx, d in enumerate(raw_dataset): sequences = [] input_str = f"## Website:\n{d['website_en']}\n\n## Domain:\n{d['domain_en']}\n\n## Sub-domain:\n{d['subdomain_en']}\n\n## Actions (Each line is one action):\n{format_actions_list(d['task_subintention'])}\n## Sub-intentions summarised from these actions:\n{format_actions_list(d['steps'])}" query_inputs = f"{task_description[0]}\n{input_str}{task_description[1]}\n" sequences.append(query_inputs) summary_str = d['task_description'] summary_str = "[SUMMARY] " + summary_str[0].upper() + summary_str[1:] sequences.append(summary_str) dataset.append({"data": sequences.copy()}) random.shuffle(dataset) if args.max_train_samples is not None: dataset = dataset[:args.max_train_samples] return dataset def load_MoTIF_dataset(args, save_dataset= False): with open(args.prompt_file, 'r') as file: task_description = file.read().split('===') raw_dataset = [] for filename in os.listdir(args.data_dir): if filename.endswith('_steps.json'): file_path = os.path.join(args.data_dir, filename) with open(file_path, 'r', encoding='utf-8') as json_file: try: content = json.load(json_file) raw_dataset.append(content) except json.JSONDecodeError as e: raise ValueError(f"Error decoding JSON from file {filename}: {e}") dataset=[] for d in raw_dataset: sequences = [] input_str = f"## Application:\n{d['app']}\n\n## Actions (Each line is one action):\n{format_actions_list(d['instr'])}\n## Sub-intentions summarised from these actions:\n{format_actions_list(d['steps'])}" query_inputs = f"{task_description[0]}\n{input_str}{task_description[1]}\n" sequences.append(query_inputs) summary_str = d['goal'] summary_str = "[SUMMARY] " + summary_str[0].upper() + summary_str[1:] sequences.append(summary_str) dataset.append({"data": sequences.copy()}) random.shuffle(dataset) if args.max_train_samples is not None: dataset = dataset[:args.max_train_samples] return dataset def finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset): embedding_model = SentenceTransformer(args.embedding_model_path, device="cuda") for param in embedding_model.parameters(): param.requires_grad = False logger.info(f"total training instance number: {len(dataset)}") loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, reduction="none") optim_manager = bmt.optim.OptimManager(loss_scale=args.loss_scale) optim_manager.add_optimizer(optimizer, lr_scheduler) bmt.synchronize() avg_time_recorder = bmt.utils.AverageRecorder() avg_loss_recorder = bmt.utils.AverageRecorder() train_start_time = time.time() global_step = 0 logger.info("split data for each process") data_per_gpu = len(dataset) // bmt.world_size() dataset = dataset[bmt.rank() * data_per_gpu: (bmt.rank() + 1) * data_per_gpu] bmt.print_rank("training on [%d, %d] of the dataset" % (bmt.rank() * data_per_gpu, (bmt.rank() + 1) * data_per_gpu)) dataset = PromptIterableDataset( dataset, tokenizer=tokenizer, max_seq_length=args.max_seq_length, teacher_forcing=True, truncate_method="tail", ) total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Total trainable parameters: {total_params}") total_params = sum(p.numel() for p in model.parameters()) print(f"Total parameters: {total_params}") for epoch in range(args.epochs): savefolder = os.path.join(args.save_dir, f"epoch_{epoch}") os.makedirs(savefolder, exist_ok=True) dataloader = DataLoader(dataset, batch_size=args.batch_size_per_device) progress_bar = tqdm(range(len(dataloader)), disable=not bmt.rank()==0, desc=f"epoch {epoch}") logger.info(f"*******start {epoch} epoch training********") for step, inputs in enumerate(dataloader): if global_step < args.start_step: global_step += 1 progress_bar.update(1) continue st = time.time() with bmt.inspect.inspect_tensor() as inspector: for k in inputs: inputs[k] = inputs[k].cuda() labels = inputs.pop("labels") weight_idxs = inputs.pop('weight_idxs') logits = model(**inputs).logits shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens shift_logits = shift_logits.view(-1, len(tokenizer)) shift_labels = shift_labels.view(-1).to(shift_logits.device) ntp_loss = loss_func(shift_logits, shift_labels) sample_specific_weights = torch.ones_like(shift_logits) weight_idxs = weight_idxs[:, 1:, :].contiguous() weight_idxs = weight_idxs.view(-1, weight_idxs.size(-1)) assert weight_idxs.shape[0] == sample_specific_weights.shape[0], "310" sample_specific_weights[weight_idxs==1] = args.action_weight sample_specific_weights = sample_specific_weights[torch.arange(sample_specific_weights.size(0)), shift_labels] ntp_loss = (ntp_loss * sample_specific_weights).mean() next_token_loss_item = bmt.sum_loss(ntp_loss).item() global_loss = next_token_loss_item optim_manager.backward(ntp_loss) if (step + 1) % args.gradient_accumulation_steps == 0 or step == len(dataloader) - 1: optim_manager.clip_grad_norm(optimizer.param_groups, max_norm=args.clip_grad) optim_manager.step() optim_manager.zero_grad() global_step += 1 progress_bar.update(1) # record time and loss iteration_time = time.time() - st avg_time_recorder.record(iteration_time) if not np.isnan(global_loss): avg_loss_recorder.record(global_loss) # print time and loss if global_step % args.logging_step == 0: bmt.print_rank( "| Iter: {:6d} | loss: {:.4f} average_loss: {:.4f} | lr: {:.4e} | time: {:.4f} seconds | total_time_passed: {:.4f} minutes".format( global_step, global_loss, avg_loss_recorder.value, lr_scheduler.current_lr, avg_time_recorder.value, (time.time() - train_start_time) / 60 ) ) if args.wandb and bmt.rank() == 0: wandb.log({ "loss": global_loss, "next_token_loss": next_token_loss_item, "average_loss": avg_loss_recorder.value, "lr": lr_scheduler.current_lr, }, step=global_step) if global_step == args.train_iters: break bmt.save(model, os.path.join(savefolder, "pytorch_model.pt")) if bmt.rank() == 0: tokenizer.save_pretrained(savefolder) bmt.print_rank(f"model saved at {savefolder}") def main(): args = initialize() if "Mind2Web" in args.data_dir: dataset = load_Mind2Web_dataset(args, save_dataset=True) else: assert "MoTIF" in args.data_dir dataset = load_MoTIF_dataset(args, save_dataset=True) args.train_iters = min(args.epochs * (len(dataset) // (bmt.world_size() * args.batch_size_per_device) + 1), args.train_iters) tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args) finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset) if __name__ == "__main__": main()