107 lines
3.5 KiB
Python
107 lines
3.5 KiB
Python
import logging
|
|
import torch
|
|
from models.utils import (interpolate_pos_relative_bias_beit,
|
|
load_temp_embed_with_mismatch)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def interpolate_pos_embed_beit(state_dict, new_model):
|
|
"""interpolate the positional embeddings.
|
|
The spatial pe is relative and temporal pe is absolute.
|
|
additional temporal pe is padded with 0.
|
|
|
|
Args:
|
|
state_dict (dict): The state_dict.
|
|
new_model (nn.Module): The created model.
|
|
|
|
Returns: dict. The state_dict with updated positional embeddings.
|
|
|
|
"""
|
|
state_dict = interpolate_pos_relative_bias_beit(
|
|
state_dict_old=state_dict,
|
|
state_dict_new=new_model.state_dict(),
|
|
patch_shape_new=new_model.beit.embeddings.patch_embeddings.patch_shape,
|
|
)
|
|
# absolute temporal pos bias
|
|
temporal_pe_key = "beit.embeddings.temporal_position_embeddings"
|
|
if temporal_pe_key in state_dict:
|
|
logger.info(f"interpolate temporal positional embeddings: {temporal_pe_key}")
|
|
state_dict[temporal_pe_key] = load_temp_embed_with_mismatch(
|
|
temp_embed_old=state_dict[temporal_pe_key],
|
|
temp_embed_new=new_model.state_dict()[temporal_pe_key],
|
|
)
|
|
return state_dict
|
|
|
|
def extract_beit_from_vindlu(vindlu_state_dict):
|
|
beit_state_dict = {}
|
|
beit_param_names = [k for k in vindlu_state_dict if k.startswith('vision_encoder.') and 'temp_model' not in k]
|
|
for param_name in beit_param_names:
|
|
new_name = param_name.replace('vision_encoder.', '')
|
|
beit_state_dict[new_name] = vindlu_state_dict[param_name]
|
|
|
|
return beit_state_dict
|
|
|
|
def build_beit(model_config, image_res, checkpoint=False):
|
|
"""build beit with configuration.
|
|
|
|
Args:
|
|
config (dict): The configs for beit.
|
|
image_res (int): The image resolution.
|
|
checkpoint (bool): Whether to enable gradient checkpointing.
|
|
|
|
Returns: nn.Module
|
|
|
|
"""
|
|
from .st_beit import BeitConfig as config_cls
|
|
from .st_beit import BeitModel as model_cls
|
|
|
|
|
|
vindlu_state_dict = torch.load(model_config['vindlu_path'])['model']
|
|
state_dict = extract_beit_from_vindlu(vindlu_state_dict)
|
|
model_config = model_config['beit_config_json']
|
|
|
|
logger.info(
|
|
f"Loading vit pre-trained weights from huggingface {model_config['pretrained']}."
|
|
)
|
|
# BEiT uses average pooled tokens instead of [CLS] used by other models
|
|
aux_kwargs = {"add_pooling_layer": True}
|
|
# tmp_model = model_cls.from_pretrained(model_config['beit_pretrained'], **aux_kwargs)
|
|
|
|
|
|
# tmp_model = model_cls.from_pretrained(model_config['pretrained'], **aux_kwargs)
|
|
# state_dict = tmp_model.state_dict()
|
|
|
|
# del tmp_model
|
|
|
|
logger.info(f"Init new model with new image size {image_res}, and load weights.")
|
|
|
|
# other_cfg = model_config.temporal_modeling
|
|
other_cfg = {}
|
|
|
|
vit_config = config_cls.from_pretrained(
|
|
model_config['pretrained'], image_size=image_res, **other_cfg
|
|
)
|
|
|
|
# vit_config.update(model_config)
|
|
|
|
model = model_cls(config=vit_config, **aux_kwargs)
|
|
|
|
if checkpoint:
|
|
model.gradient_checkpointing_enable()
|
|
|
|
# interpolate relative pos bias
|
|
state_dict = interpolate_pos_relative_bias_beit(
|
|
state_dict_old=state_dict,
|
|
state_dict_new=model.state_dict(),
|
|
patch_shape_new=model.embeddings.patch_embeddings.patch_shape,
|
|
)
|
|
|
|
# del prompt_bias_table
|
|
for k in list(state_dict.keys()):
|
|
if "prompt_bias_table" in k:
|
|
del state_dict[k]
|
|
|
|
msg = model.load_state_dict(state_dict, strict=False)
|
|
logger.info(msg)
|
|
return model
|