71 lines
2.3 KiB
Python
71 lines
2.3 KiB
Python
from .xbert import BertConfig, BertForMaskedLM, BertLMHeadModel, BertModel
|
|
|
|
|
|
def build_bert(model_config, pretrain, checkpoint, expert_type, modality_type='text'):
|
|
"""build text encoder.
|
|
|
|
Args:
|
|
model_config (dict): model config.
|
|
pretrain (bool): Whether to do pretrain or finetuning.
|
|
checkpoint (bool): whether to do gradient_checkpointing.
|
|
|
|
Returns: TODO
|
|
|
|
"""
|
|
bert_size = model_config['expert_size']
|
|
bert_config = BertConfig.from_json_file(model_config[f'bert_config_{bert_size}'])
|
|
# bert_config.encoder_width = model_config.vision_encoder.d_model
|
|
bert_config.gradient_checkpointing = checkpoint
|
|
bert_config.num_hidden_layers = model_config['num_layers_{}_expert'.format(expert_type)]
|
|
if expert_type=='modality':
|
|
if modality_type == 'vis':
|
|
bert_config.cross_attention_freq = 2
|
|
else:
|
|
bert_config.cross_attention_freq = -1
|
|
else:
|
|
bert_config.cross_attention_freq = 1
|
|
|
|
if pretrain:
|
|
text_encoder, loading_info = BertForMaskedLM.from_pretrained(
|
|
f'bert-{bert_size}-uncased',
|
|
config=bert_config,
|
|
output_loading_info=True,
|
|
)
|
|
else:
|
|
text_encoder, loading_info = BertModel.from_pretrained(
|
|
f'bert-{bert_size}-uncased',
|
|
config=bert_config,
|
|
add_pooling_layer=True,
|
|
output_loading_info=True,
|
|
)
|
|
|
|
return text_encoder
|
|
|
|
|
|
def build_bert_decoder(model_config, checkpoint):
|
|
"""build text decoder the same as the multimodal encoder.
|
|
|
|
Args:
|
|
model_config (dict): model config.
|
|
pretrain (bool): Whether to do pretrain or finetuning.
|
|
checkpoint (bool): whether to do gradient_checkpointing.
|
|
|
|
Returns: TODO
|
|
|
|
"""
|
|
bert_config = BertConfig.from_json_file(model_config.text_encoder.config)
|
|
bert_config.encoder_width = model_config.vision_encoder.d_model
|
|
bert_config.gradient_checkpointing = checkpoint
|
|
|
|
bert_config.fusion_layer = 0
|
|
bert_config.num_hidden_layers = (
|
|
bert_config.num_hidden_layers - model_config.text_encoder.fusion_layer
|
|
)
|
|
|
|
text_decoder, loading_info = BertLMHeadModel.from_pretrained(
|
|
model_config.text_encoder.pretrained,
|
|
config=bert_config,
|
|
output_loading_info=True,
|
|
)
|
|
|
|
return text_decoder
|