This commit is contained in:
Guanhua Zhang 2025-04-10 20:14:17 +02:00
commit 04c4625cfe
11 changed files with 1330 additions and 0 deletions

108
hf_bmt/hf_2_bmtrain.py Normal file
View file

@ -0,0 +1,108 @@
from transformers import LlamaConfig
from transformers import AutoModelForCausalLM
import torch, os
import json
from collections import OrderedDict
import shutil, pdb
import argparse
def initialize():
# get arguments
parser = argparse.ArgumentParser("")
# Output Directory for the bmt train weights.
parser.add_argument("--out_path", type=str, default=f"/Mistral-7B-v0.1/snapshots/26bca36bde8333b5d7f72e9ed20ccda6a618af24")
# Path where you downloaded mistral-7b hugging face weight
parser.add_argument('--in_path', type=str, default=f"/Mistral-{ver}-bmtrain")
args = parser.parse_args()
return args
ver = "7b"
# Change these two
# Output Directory for the bmt train weights.
# outpath = f"/Mistral-{ver}-bmtrain"
# Path where you downloaded mistral-7b hugging face weight
# inpath = f"/Mistral-7B-v0.1/snapshots/26bca36bde8333b5d7f72e9ed20ccda6a618af24"
def convert_weights(args):
hf_config = LlamaConfig.from_pretrained(args.in_path)
config = {
'dim_model': hf_config.hidden_size,
'dim_ff': hf_config.intermediate_size,
'num_layers': hf_config.num_hidden_layers,
'num_heads': hf_config.num_attention_heads,
'num_heads_kv': hf_config.num_key_value_heads,
'dim_head': hf_config.hidden_size // hf_config.num_attention_heads,
'norm_eps': hf_config.rms_norm_eps,
}
os.makedirs(args.out_path, exist_ok=True)
with open(os.path.join(args.out_path, "config.json"), 'w') as f:
json.dump(config, f)
layernum = config['num_layers']
model_hf = OrderedDict()
ckpt_num = None
if 'v0.1' in args.in_path:
prefix = "pytorch_model-"
endtext = ".bin"
else:
prefix = "model-"
endtext = ".safetensors"
for name in os.listdir(args.in_path):
if name.startswith(prefix) and name.endswith(endtext):
ckpt_num =int(name.split(endtext)[0].split('-')[-1])
for i in range(1, ckpt_num + 1):
if 'v0.1' in args.in_path:
part = torch.load(os.path.join(args.in_path, f"pytorch_model-{i:05d}-of-{ckpt_num:05d}.bin"))
else:
from safetensors import safe_open
with safe_open(os.path.join(args.in_path, f"model-{i:05d}-of-{ckpt_num:05d}.safetensors"), framework="pt", device=0) as f:
part = {}
for k in f.keys():
part[k] = f.get_tensor(k)
model_hf.update(part)
out = OrderedDict()
out["input_embedding.weight"] = model_hf['model.embed_tokens.weight'].contiguous()
out["encoder.output_layernorm.weight"] = model_hf['model.norm.weight'].contiguous()
out['output_projection.weight'] = model_hf['lm_head.weight'].contiguous()
for lnum in range(layernum):
hf_pfx = f"model.layers.{lnum}"
bmt_pfx = f"encoder.layers.{lnum}"
out[f"{bmt_pfx}.self_att.layernorm_before_attention.weight"] = model_hf[f"{hf_pfx}.input_layernorm.weight"].contiguous()
out[f"{bmt_pfx}.self_att.self_attention.project_q.weight"] = model_hf[f"{hf_pfx}.self_attn.q_proj.weight"].contiguous()
out[f"{bmt_pfx}.self_att.self_attention.project_k.weight"] = model_hf[f"{hf_pfx}.self_attn.k_proj.weight"].contiguous()
out[f"{bmt_pfx}.self_att.self_attention.project_v.weight"] = model_hf[f"{hf_pfx}.self_attn.v_proj.weight"].contiguous()
out[f"{bmt_pfx}.self_att.self_attention.attention_out.weight"] = model_hf[f"{hf_pfx}.self_attn.o_proj.weight"].contiguous()
out[f"{bmt_pfx}.ffn.layernorm_before_ffn.weight"] = model_hf[f"{hf_pfx}.post_attention_layernorm.weight"].contiguous()
out[f"{bmt_pfx}.ffn.ffn.w_in.w_0.weight"] = model_hf[f"{hf_pfx}.mlp.gate_proj.weight"].contiguous()
out[f"{bmt_pfx}.ffn.ffn.w_in.w_1.weight"] = model_hf[f"{hf_pfx}.mlp.up_proj.weight"].contiguous()
out[f"{bmt_pfx}.ffn.ffn.w_out.weight"] = model_hf[f"{hf_pfx}.mlp.down_proj.weight"].contiguous()
for key in out:
out[key] = out[key].half()
if not os.path.exists(args.out_path):
os.makedirs(args.out_path)
torch.save(out, os.path.join(args.out_path, "pytorch_model.pt"))
for file_name in ["tokenizer_config.json", "special_tokens_map.json", "tokenizer.model", "tokenizer.json"]:
if os.path.exists(os.path.join(args.in_path, file_name)):
shutil.copy(os.path.join(args.in_path, file_name), os.path.join(args.out_path, file_name))
print("BMT weights created sucessfully")
def main():
args = initialize()
convert_weights(args)
if __name__ == "__main__":
main()