SummAct/hf_bmt/hf_2_bmtrain.py
2025-04-10 20:14:17 +02:00

108 lines
No EOL
4.5 KiB
Python

from transformers import LlamaConfig
from transformers import AutoModelForCausalLM
import torch, os
import json
from collections import OrderedDict
import shutil, pdb
import argparse
def initialize():
# get arguments
parser = argparse.ArgumentParser("")
# Output Directory for the bmt train weights.
parser.add_argument("--out_path", type=str, default=f"/Mistral-7B-v0.1/snapshots/26bca36bde8333b5d7f72e9ed20ccda6a618af24")
# Path where you downloaded mistral-7b hugging face weight
parser.add_argument('--in_path', type=str, default=f"/Mistral-{ver}-bmtrain")
args = parser.parse_args()
return args
ver = "7b"
# Change these two
# Output Directory for the bmt train weights.
# outpath = f"/Mistral-{ver}-bmtrain"
# Path where you downloaded mistral-7b hugging face weight
# inpath = f"/Mistral-7B-v0.1/snapshots/26bca36bde8333b5d7f72e9ed20ccda6a618af24"
def convert_weights(args):
hf_config = LlamaConfig.from_pretrained(args.in_path)
config = {
'dim_model': hf_config.hidden_size,
'dim_ff': hf_config.intermediate_size,
'num_layers': hf_config.num_hidden_layers,
'num_heads': hf_config.num_attention_heads,
'num_heads_kv': hf_config.num_key_value_heads,
'dim_head': hf_config.hidden_size // hf_config.num_attention_heads,
'norm_eps': hf_config.rms_norm_eps,
}
os.makedirs(args.out_path, exist_ok=True)
with open(os.path.join(args.out_path, "config.json"), 'w') as f:
json.dump(config, f)
layernum = config['num_layers']
model_hf = OrderedDict()
ckpt_num = None
if 'v0.1' in args.in_path:
prefix = "pytorch_model-"
endtext = ".bin"
else:
prefix = "model-"
endtext = ".safetensors"
for name in os.listdir(args.in_path):
if name.startswith(prefix) and name.endswith(endtext):
ckpt_num =int(name.split(endtext)[0].split('-')[-1])
for i in range(1, ckpt_num + 1):
if 'v0.1' in args.in_path:
part = torch.load(os.path.join(args.in_path, f"pytorch_model-{i:05d}-of-{ckpt_num:05d}.bin"))
else:
from safetensors import safe_open
with safe_open(os.path.join(args.in_path, f"model-{i:05d}-of-{ckpt_num:05d}.safetensors"), framework="pt", device=0) as f:
part = {}
for k in f.keys():
part[k] = f.get_tensor(k)
model_hf.update(part)
out = OrderedDict()
out["input_embedding.weight"] = model_hf['model.embed_tokens.weight'].contiguous()
out["encoder.output_layernorm.weight"] = model_hf['model.norm.weight'].contiguous()
out['output_projection.weight'] = model_hf['lm_head.weight'].contiguous()
for lnum in range(layernum):
hf_pfx = f"model.layers.{lnum}"
bmt_pfx = f"encoder.layers.{lnum}"
out[f"{bmt_pfx}.self_att.layernorm_before_attention.weight"] = model_hf[f"{hf_pfx}.input_layernorm.weight"].contiguous()
out[f"{bmt_pfx}.self_att.self_attention.project_q.weight"] = model_hf[f"{hf_pfx}.self_attn.q_proj.weight"].contiguous()
out[f"{bmt_pfx}.self_att.self_attention.project_k.weight"] = model_hf[f"{hf_pfx}.self_attn.k_proj.weight"].contiguous()
out[f"{bmt_pfx}.self_att.self_attention.project_v.weight"] = model_hf[f"{hf_pfx}.self_attn.v_proj.weight"].contiguous()
out[f"{bmt_pfx}.self_att.self_attention.attention_out.weight"] = model_hf[f"{hf_pfx}.self_attn.o_proj.weight"].contiguous()
out[f"{bmt_pfx}.ffn.layernorm_before_ffn.weight"] = model_hf[f"{hf_pfx}.post_attention_layernorm.weight"].contiguous()
out[f"{bmt_pfx}.ffn.ffn.w_in.w_0.weight"] = model_hf[f"{hf_pfx}.mlp.gate_proj.weight"].contiguous()
out[f"{bmt_pfx}.ffn.ffn.w_in.w_1.weight"] = model_hf[f"{hf_pfx}.mlp.up_proj.weight"].contiguous()
out[f"{bmt_pfx}.ffn.ffn.w_out.weight"] = model_hf[f"{hf_pfx}.mlp.down_proj.weight"].contiguous()
for key in out:
out[key] = out[key].half()
if not os.path.exists(args.out_path):
os.makedirs(args.out_path)
torch.save(out, os.path.join(args.out_path, "pytorch_model.pt"))
for file_name in ["tokenizer_config.json", "special_tokens_map.json", "tokenizer.model", "tokenizer.json"]:
if os.path.exists(os.path.join(args.in_path, file_name)):
shutil.copy(os.path.join(args.in_path, file_name), os.path.join(args.out_path, file_name))
print("BMT weights created sucessfully")
def main():
args = initialize()
convert_weights(args)
if __name__ == "__main__":
main()