initial commit

This commit is contained in:
Andreas Bulling 2025-06-24 08:38:09 +02:00
commit a82bbc593e
129 changed files with 33981 additions and 0 deletions

View file

@ -0,0 +1,71 @@
from .xbert import BertConfig, BertForMaskedLM, BertLMHeadModel, BertModel
def build_bert(model_config, pretrain, checkpoint, expert_type, modality_type='text'):
"""build text encoder.
Args:
model_config (dict): model config.
pretrain (bool): Whether to do pretrain or finetuning.
checkpoint (bool): whether to do gradient_checkpointing.
Returns: TODO
"""
bert_size = model_config['expert_size']
bert_config = BertConfig.from_json_file(model_config[f'bert_config_{bert_size}'])
# bert_config.encoder_width = model_config.vision_encoder.d_model
bert_config.gradient_checkpointing = checkpoint
bert_config.num_hidden_layers = model_config['num_layers_{}_expert'.format(expert_type)]
if expert_type=='modality':
if modality_type == 'vis':
bert_config.cross_attention_freq = 2
else:
bert_config.cross_attention_freq = -1
else:
bert_config.cross_attention_freq = 1
if pretrain:
text_encoder, loading_info = BertForMaskedLM.from_pretrained(
f'bert-{bert_size}-uncased',
config=bert_config,
output_loading_info=True,
)
else:
text_encoder, loading_info = BertModel.from_pretrained(
f'bert-{bert_size}-uncased',
config=bert_config,
add_pooling_layer=True,
output_loading_info=True,
)
return text_encoder
def build_bert_decoder(model_config, checkpoint):
"""build text decoder the same as the multimodal encoder.
Args:
model_config (dict): model config.
pretrain (bool): Whether to do pretrain or finetuning.
checkpoint (bool): whether to do gradient_checkpointing.
Returns: TODO
"""
bert_config = BertConfig.from_json_file(model_config.text_encoder.config)
bert_config.encoder_width = model_config.vision_encoder.d_model
bert_config.gradient_checkpointing = checkpoint
bert_config.fusion_layer = 0
bert_config.num_hidden_layers = (
bert_config.num_hidden_layers - model_config.text_encoder.fusion_layer
)
text_decoder, loading_info = BertLMHeadModel.from_pretrained(
model_config.text_encoder.pretrained,
config=bert_config,
output_loading_info=True,
)
return text_decoder