Skip to content

error in Demo_Video_Audio_MAE.ipynb opened in colab #16

@snapfinger

Description

@snapfinger

The following error occurs when I execute the cell. Actually, there is already error when I try using the given code to download necessary python packages, therefore I just pip installed without the version info specified in requirements.txt. Having said that, I don't think this installation problem is the root of the error I'm having here.

%load_ext autoreload
%autoreload 2

from demos import MAE_model
model = MAE_model()

Error

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
/usr/local/lib/python3.10/dist-packages/timm/models/registry.py:4: DeprecationWarning: Importing from timm.models.registry is deprecated, please import via timm.models
  warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
[<ipython-input-20-13cab5655236>](https://localhost:8080/#) in <cell line: 5>()
      3 
      4 from demos import MAE_model
----> 5 model = MAE_model()

2 frames
[/content/TVLT/demos.py](https://localhost:8080/#) in MAE_model(model_path)
     25 def MAE_model(model_path=''):
     26     config = MAE_config()
---> 27     model = getattr(tvlt, 'mae_vit_base_patch16_dec512d8b')(
     28         config=config).float().eval()
     29     ckpt_path = load_from_hub(repo_id="TVLT/models", filename="TVLT.ckpt")

[/content/TVLT/model/modules/tvlt.py](https://localhost:8080/#) in mae_vit_base_patch16_dec512d8b(**kwargs)
    577 @register_model
    578 def mae_vit_base_patch16_dec512d8b(**kwargs):
--> 579     model = TVLT(
    580         patch_size=16, audio_patch_size=[16, 16], embed_dim=768, depth=12, num_heads=12,
    581         decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,

[/content/TVLT/model/modules/tvlt.py](https://localhost:8080/#) in __init__(self, img_size, in_chans, patch_size, audio_patch_size, embed_dim, depth, num_heads, decoder_embed_dim, decoder_depth, decoder_num_heads, mlp_ratio, norm_layer, eps, config)
    245             self.matching_score.apply(objectives.init_weights)
    246 
--> 247             if config["loss_names"]["vatr"] > 0:
    248                 import copy
    249                 self.rank_output = copy.deepcopy(self.matching_score)

KeyError: 'vatr'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions