initial commit
This commit is contained in:
commit
a82bbc593e
129 changed files with 33981 additions and 0 deletions
71
models/backbones/bert/builder.py
Normal file
71
models/backbones/bert/builder.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
from .xbert import BertConfig, BertForMaskedLM, BertLMHeadModel, BertModel
|
||||
|
||||
|
||||
def build_bert(model_config, pretrain, checkpoint, expert_type, modality_type='text'):
|
||||
"""build text encoder.
|
||||
|
||||
Args:
|
||||
model_config (dict): model config.
|
||||
pretrain (bool): Whether to do pretrain or finetuning.
|
||||
checkpoint (bool): whether to do gradient_checkpointing.
|
||||
|
||||
Returns: TODO
|
||||
|
||||
"""
|
||||
bert_size = model_config['expert_size']
|
||||
bert_config = BertConfig.from_json_file(model_config[f'bert_config_{bert_size}'])
|
||||
# bert_config.encoder_width = model_config.vision_encoder.d_model
|
||||
bert_config.gradient_checkpointing = checkpoint
|
||||
bert_config.num_hidden_layers = model_config['num_layers_{}_expert'.format(expert_type)]
|
||||
if expert_type=='modality':
|
||||
if modality_type == 'vis':
|
||||
bert_config.cross_attention_freq = 2
|
||||
else:
|
||||
bert_config.cross_attention_freq = -1
|
||||
else:
|
||||
bert_config.cross_attention_freq = 1
|
||||
|
||||
if pretrain:
|
||||
text_encoder, loading_info = BertForMaskedLM.from_pretrained(
|
||||
f'bert-{bert_size}-uncased',
|
||||
config=bert_config,
|
||||
output_loading_info=True,
|
||||
)
|
||||
else:
|
||||
text_encoder, loading_info = BertModel.from_pretrained(
|
||||
f'bert-{bert_size}-uncased',
|
||||
config=bert_config,
|
||||
add_pooling_layer=True,
|
||||
output_loading_info=True,
|
||||
)
|
||||
|
||||
return text_encoder
|
||||
|
||||
|
||||
def build_bert_decoder(model_config, checkpoint):
|
||||
"""build text decoder the same as the multimodal encoder.
|
||||
|
||||
Args:
|
||||
model_config (dict): model config.
|
||||
pretrain (bool): Whether to do pretrain or finetuning.
|
||||
checkpoint (bool): whether to do gradient_checkpointing.
|
||||
|
||||
Returns: TODO
|
||||
|
||||
"""
|
||||
bert_config = BertConfig.from_json_file(model_config.text_encoder.config)
|
||||
bert_config.encoder_width = model_config.vision_encoder.d_model
|
||||
bert_config.gradient_checkpointing = checkpoint
|
||||
|
||||
bert_config.fusion_layer = 0
|
||||
bert_config.num_hidden_layers = (
|
||||
bert_config.num_hidden_layers - model_config.text_encoder.fusion_layer
|
||||
)
|
||||
|
||||
text_decoder, loading_info = BertLMHeadModel.from_pretrained(
|
||||
model_config.text_encoder.pretrained,
|
||||
config=bert_config,
|
||||
output_loading_info=True,
|
||||
)
|
||||
|
||||
return text_decoder
|
Loading…
Add table
Add a link
Reference in a new issue