92 lines
No EOL
4.2 KiB
Python
92 lines
No EOL
4.2 KiB
Python
import os, pdb
|
|
import json
|
|
import torch
|
|
import sys
|
|
import shutil
|
|
import argparse
|
|
from collections import OrderedDict
|
|
from transformers import AutoConfig, AutoModelForCausalLM
|
|
|
|
|
|
def transform_to_hf(bmt_model, model_size):
|
|
model_hf = OrderedDict()
|
|
|
|
if 'input_embedding.weight' in bmt_model.keys():
|
|
model_hf['model.embed_tokens.weight'] = bmt_model["input_embedding.weight"].contiguous().float()
|
|
model_hf['model.norm.weight'] = bmt_model["encoder.output_layernorm.weight"].contiguous().float()
|
|
try:
|
|
model_hf['lm_head.weight'] = bmt_model['output_projection.weight'].contiguous().float()
|
|
except:
|
|
model_hf['lm_head.weight'] = bmt_model["input_embedding.weight"].contiguous().float()
|
|
else:
|
|
model_hf['model.embed_tokens.weight'] = bmt_model["LLM.input_embedding.weight"].contiguous().float()
|
|
model_hf['model.norm.weight'] = bmt_model["LLM.encoder.output_layernorm.weight"].contiguous().float()
|
|
try:
|
|
model_hf['lm_head.weight'] = bmt_model['LLM.output_projection.weight'].contiguous().float()
|
|
except:
|
|
model_hf['lm_head.weight'] = bmt_model["LLM.input_embedding.weight"].contiguous().float()
|
|
|
|
if model_size == "7b":
|
|
layernum = 32
|
|
elif model_size == "13b" or model_size == "13b-2":
|
|
layernum = 40
|
|
elif model_size == "65b":
|
|
layernum = 80
|
|
|
|
for lnum in range(layernum):
|
|
hf_pfx = f"model.layers.{lnum}"
|
|
if 'input_embedding.weight' in bmt_model.keys():
|
|
bmt_pfx = f"encoder.layers.{lnum}"
|
|
else:
|
|
bmt_pfx = f"LLM.encoder.layers.{lnum}"
|
|
|
|
model_hf[f"{hf_pfx}.input_layernorm.weight"] = bmt_model[f"{bmt_pfx}.self_att.layernorm_before_attention.weight"].contiguous().float()
|
|
|
|
model_hf[f"{hf_pfx}.self_attn.q_proj.weight"] = bmt_model[f"{bmt_pfx}.self_att.self_attention.project_q.weight"].contiguous().float()
|
|
model_hf[f"{hf_pfx}.self_attn.k_proj.weight"] = bmt_model[f"{bmt_pfx}.self_att.self_attention.project_k.weight"].contiguous().float()
|
|
model_hf[f"{hf_pfx}.self_attn.v_proj.weight"] = bmt_model[f"{bmt_pfx}.self_att.self_attention.project_v.weight"].contiguous().float()
|
|
model_hf[f"{hf_pfx}.self_attn.o_proj.weight"] = bmt_model[f"{bmt_pfx}.self_att.self_attention.attention_out.weight"].contiguous().float()
|
|
|
|
model_hf[f"{hf_pfx}.post_attention_layernorm.weight"] = bmt_model[f"{bmt_pfx}.ffn.layernorm_before_ffn.weight"].contiguous().float()
|
|
|
|
model_hf[f"{hf_pfx}.mlp.gate_proj.weight"] = bmt_model[f"{bmt_pfx}.ffn.ffn.w_in.w_0.weight"].contiguous().float()
|
|
model_hf[f"{hf_pfx}.mlp.up_proj.weight"] = bmt_model[f"{bmt_pfx}.ffn.ffn.w_in.w_1.weight"].contiguous().float()
|
|
|
|
model_hf[f"{hf_pfx}.mlp.down_proj.weight"] = bmt_model[f"{bmt_pfx}.ffn.ffn.w_out.weight"].contiguous().float()
|
|
|
|
for key in model_hf:
|
|
model_hf[key] = model_hf[key].bfloat16()
|
|
return model_hf
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--in_path", type=str)
|
|
parser.add_argument("--output_path", type=str)
|
|
parser.add_argument("--original_mistral_path", type=str)
|
|
|
|
args = parser.parse_args()
|
|
os.makedirs(args.output_path, exist_ok=True)
|
|
print("transforming " + args.in_path)
|
|
|
|
model_size = "7b"
|
|
|
|
ckpt = [name for name in os.listdir(args.in_path) if name.endswith(".pt")]
|
|
bmt_model = torch.load(os.path.join(args.in_path, ckpt[0]))
|
|
|
|
hf_state_dict = transform_to_hf(bmt_model, model_size)
|
|
print(f"start saving to {args.output_path}")
|
|
|
|
model_config = AutoConfig.from_pretrained(args.original_mistral_path)
|
|
model = AutoModelForCausalLM.from_config(model_config)
|
|
model.load_state_dict(hf_state_dict)
|
|
|
|
for param in model.parameters():
|
|
param.data = param.data.to(torch.bfloat16)
|
|
|
|
model.save_pretrained(args.output_path, safe_serialization=False)
|
|
for file_name in ["tokenizer_config.json", "special_tokens_map.json", "tokenizer.model", "tokenizer.json"]:
|
|
if os.path.exists(os.path.join(args.in_path, file_name)):
|
|
shutil.copy(os.path.join(args.in_path, file_name), os.path.join(args.output_path, file_name))
|
|
print("saved huggingface checkpoint") |