Skip to content

Commit ec001ad

Browse files
authored
Merge pull request #164 from pytti-tools/test - MMC support
2 parents fb6054a + 2d4e970 commit ec001ad

File tree

2 files changed

+118
-7
lines changed

2 files changed

+118
-7
lines changed

src/pytti/Notebook.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
from pytti import Perceptor
2020

21+
from pytti.Perceptor import CLIP_PERCEPTORS
22+
2123

2224
# https://stackoverflow.com/questions/15411967/how-can-i-check-if-code-is-executed-in-the-ipython-notebook
2325
def is_notebook():
@@ -248,11 +250,33 @@ def load_clip(params):
248250
if params.get(config_name):
249251
CLIP_MODEL_NAMES.append(clip_name)
250252

251-
if last_names != CLIP_MODEL_NAMES or Perceptor.CLIP_PERCEPTORS is None:
252-
if CLIP_MODEL_NAMES == []:
253+
if not params.get("use_mmc"):
254+
if last_names != CLIP_MODEL_NAMES or Perceptor.CLIP_PERCEPTORS is None:
255+
if CLIP_MODEL_NAMES == []:
256+
Perceptor.free_clip()
257+
raise RuntimeError("Please select at least one CLIP model")
253258
Perceptor.free_clip()
254-
raise RuntimeError("Please select at least one CLIP model")
255-
Perceptor.free_clip()
256-
logger.debug("Loading CLIP...")
257-
Perceptor.init_clip(CLIP_MODEL_NAMES)
258-
logger.debug("CLIP loaded.")
259+
logger.debug("Loading CLIP...")
260+
Perceptor.init_clip(CLIP_MODEL_NAMES)
261+
logger.debug("CLIP loaded.")
262+
else:
263+
logger.debug("attempting to use mmc to load perceptors")
264+
import mmc
265+
from mmc.registry import REGISTRY
266+
import mmc.loaders # force trigger model registrations
267+
from mmc.mock.openai import MockOpenaiClip
268+
269+
CLIP_PERCEPTORS = (
270+
[]
271+
) # this will be fine because we'll use it to overwrite the module object
272+
for item in params.mmc_models:
273+
logger.debug(item)
274+
model_loaders = REGISTRY.find(**item)
275+
logger.debug(model_loaders)
276+
for model_loader in model_loaders:
277+
logger.debug(model_loader)
278+
model = model_loader.load()
279+
model = MockOpenaiClip(model)
280+
CLIP_PERCEPTORS.append(model)
281+
logger.debug(CLIP_PERCEPTORS)
282+
Perceptor.CLIP_PERCEPTORS = CLIP_PERCEPTORS # weird that htis works, but fine.

tests/test_mmc_loaders.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import pytest
2+
3+
from hydra import initialize, compose
4+
from loguru import logger
5+
from pytti.workhorse import _main as render_frames
6+
from omegaconf import OmegaConf, open_dict
7+
8+
9+
CONFIG_BASE_PATH = "config"
10+
CONFIG_DEFAULTS = "default.yaml"
11+
12+
13+
cfg_yaml_openai = """# @package _global_
14+
scenes: a photograph of an apple
15+
use_mmc: true
16+
mmc_models:
17+
- architecture: clip
18+
publisher: openai
19+
id: RN50
20+
- architecture: clip
21+
publisher: openai
22+
id: ViT-B/32
23+
"""
24+
25+
26+
def test_mmc_openai_models():
27+
28+
with initialize(config_path=CONFIG_BASE_PATH):
29+
cfg_base = compose(
30+
config_name=CONFIG_DEFAULTS,
31+
overrides=[f"conf=_empty"],
32+
)
33+
cfg_mmc = OmegaConf.create(cfg_yaml_openai)
34+
35+
with open_dict(cfg_base) as cfg:
36+
cfg = OmegaConf.merge(cfg_base, cfg_mmc)
37+
render_frames(cfg)
38+
39+
40+
cfg_yaml_mlf = """# @package _global_
41+
scenes: a photograph of an apple
42+
use_mmc: true
43+
mmc_models:
44+
- architecture: clip
45+
publisher: mlfoundations
46+
id: RN50--yfcc15m
47+
- architecture: clip
48+
publisher: mlfoundations
49+
id: ViT-B-32--laion400m_avg
50+
"""
51+
52+
53+
def test_mmc_mlf_models():
54+
55+
with initialize(config_path=CONFIG_BASE_PATH):
56+
cfg_base = compose(
57+
config_name=CONFIG_DEFAULTS,
58+
overrides=[f"conf=_empty"],
59+
)
60+
cfg_mmc = OmegaConf.create(cfg_yaml_mlf)
61+
62+
with open_dict(cfg_base) as cfg:
63+
cfg = OmegaConf.merge(cfg_base, cfg_mmc)
64+
render_frames(cfg)
65+
66+
67+
cfg_yaml_all_cloob = """# @package _global_
68+
scenes: a photograph of an apple
69+
use_mmc: true
70+
mmc_models:
71+
- architecture: cloob
72+
id: cloob_laion_400m_vit_b_16_32_epochs
73+
"""
74+
75+
76+
def test_mmc_all_cloob_models():
77+
78+
with initialize(config_path=CONFIG_BASE_PATH):
79+
cfg_base = compose(
80+
config_name=CONFIG_DEFAULTS,
81+
overrides=[f"conf=_empty"],
82+
)
83+
cfg_mmc = OmegaConf.create(cfg_yaml_all_cloob)
84+
85+
with open_dict(cfg_base) as cfg:
86+
cfg = OmegaConf.merge(cfg_base, cfg_mmc)
87+
render_frames(cfg)

0 commit comments

Comments
 (0)