Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/api/tools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Zero-shot Learning
tl.slide_caption


Virtual staining
Generative Modeling
~~~~~~~~~~~~~~~~

.. currentmodule:: lazyslide
Expand All @@ -78,3 +78,4 @@ Virtual staining
:nosignatures:

tl.virtual_stain
tl.image_generation
3 changes: 3 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
"pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None),
"matplotlib": ("https://matplotlib.org/stable/", None),
"anndata": ("https://anndata.readthedocs.io/en/latest/", None),
"PIL": ("https://pillow.readthedocs.io/en/stable/", None),
}


Expand Down Expand Up @@ -244,13 +245,15 @@ def generate_models_rst(app, config):
set(),
),
"style_transfer": ("Style transfer models", "style_transfer", set()),
"image_generation": ("Image generation models", "image_generation", set()),
"base": (
"Base model class",
"base",
[
mb.ModelBase,
mb.ImageModel,
mb.ImageTextModel,
mb.ImageGenerationModel,
mb.SegmentationModel,
mb.SlideEncoderModel,
mb.TilePredictionModel,
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ model = [
"musk",
"fairscale>=0.4.13",
"sentencepiece>=0.2.1",
"diffusers>=0.36.0",
"accelerate>=1.12.0",
]


Expand Down
14 changes: 8 additions & 6 deletions src/lazyslide/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from . import (
image_generation,
multimodal,
segmentation,
style_transfer,
Expand All @@ -25,6 +26,7 @@
"style_transfer",
"tile_prediction",
"vision",
"image_generation",
"MODEL_REGISTRY",
"register",
"ImageModel",
Expand Down Expand Up @@ -57,7 +59,7 @@ def list_models(task: ModelTask | str = None):
"""
if task is None:
return list(MODEL_REGISTRY.keys())
if task is not None:
else:
try:
task = ModelTask(task)
except ValueError:
Expand All @@ -66,10 +68,10 @@ def list_models(task: ModelTask | str = None):
f"Available tasks are: {', '.join([t.value for t in ModelTask])}."
)
models = []
for name, model in MODEL_REGISTRY.items():
model_task = model.task
if isinstance(model_task, ModelTask):
model_task = [model_task]
if task in model_task:
for name, model_cls in MODEL_REGISTRY.items():
model_tasks = getattr(model_cls, "task", [])
if isinstance(model_tasks, ModelTask):
model_tasks = [model_tasks]
if task in model_tasks:
models.append(name)
return models
14 changes: 1 addition & 13 deletions src/lazyslide/models/_model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import textwrap
import warnings
from collections.abc import MutableMapping
from enum import Enum
from typing import TYPE_CHECKING, Dict, Iterator

import pandas as pd
Expand All @@ -12,7 +11,7 @@
from ._repr import model_doc, model_registry_repr_html

if TYPE_CHECKING:
from .base import ModelBase
from .base import ModelBase, ModelTask


class ModelRegistry(MutableMapping):
Expand Down Expand Up @@ -111,17 +110,6 @@ def _repr_html_(self) -> str:
MODEL_REGISTRY = ModelRegistry()


class ModelTask(Enum):
vision = "vision"
segmentation = "segmentation"
multimodal = "multimodal"
slide_encoder = "slide_encoder"
tile_prediction = "tile_prediction"
feature_prediction = "feature_prediction"
style_transfer = "style_transfer"
cv_feature = "cv_feature"


def register(
key: str | list[str],
task: ModelTask | list[ModelTask] = None,
Expand Down
9 changes: 9 additions & 0 deletions src/lazyslide/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class ModelTask(Enum):
feature_prediction = "feature_prediction"
style_transfer = "style_transfer"
cv_feature = "cv_feature"
image_generation = "image_generation"


class ModelBase(ABC):
Expand Down Expand Up @@ -198,3 +199,11 @@ def predict(self, image):
@abstractmethod
def get_channel_names(self):
raise NotImplementedError


class ImageGenerationModel(ModelBase):
def generate(self, *args, **kwargs):
raise NotImplementedError

def generate_conditionally(self, *args, **kwargs):
raise NotImplementedError
3 changes: 3 additions & 0 deletions src/lazyslide/models/image_generation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .cytosyn import CytoSyn

__all__ = ["CytoSyn"]
59 changes: 59 additions & 0 deletions src/lazyslide/models/image_generation/cytosyn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from importlib import import_module
from importlib.util import find_spec

import torch

from .._model_registry import register
from .._utils import hf_access
from ..base import ImageGenerationModel, ModelTask


@register(
key="cytosyn",
task=ModelTask.image_generation,
is_gated=True,
license="CC BY-NC-ND 4.0",
description="A REPA-E Histopathology Image Generation Model",
commercial=False,
github_url="https://github.com/prov-gigatime/GigaTIME",
paper_url="https://www.owkin.com/blogs-case-studies/"
"cytosyn-a-state-of-the-art-diffusion-model-for-histopathology-image-generation",
param_size="766M",
)
class CytoSyn(ImageGenerationModel):
def __init__(self, model_path=None, token=None):
diffusers = find_spec("diffusers")
if diffusers is None:
raise ModuleNotFoundError(
"Please install diffusers to use CytoSyn: `pip install diffusers`"
)

DiffusionPipeline = import_module("diffusers.pipelines").DiffusionPipeline
with hf_access("Owkin-Bioptimus/CytoSyn"):
self.model = DiffusionPipeline.from_pretrained(
"Owkin-Bioptimus/CytoSyn",
custom_pipeline="Owkin-Bioptimus/CytoSyn",
trust_remote_code=True,
torch_dtype=torch.float32,
)

def generate(self, *args, **kwargs):
opts = dict(
num_images_per_prompt=1,
num_inference_steps=250,
guidance_scale=1.0, # No guidance for unconditional
)
opts.update(kwargs)
return self.model(**opts)["images"]

def generate_conditionally(self, h0_mini_embeds, **kwargs):
opts = dict(
h0_mini_embeds=h0_mini_embeds,
num_images_per_prompt=1,
num_inference_steps=250,
guidance_scale=2.5,
guidance_low=0.0,
guidance_high=0.75,
)
opts.update(kwargs)
return self.model(**opts)["images"]
2 changes: 2 additions & 0 deletions src/lazyslide/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ._domain import spatial_domain, tile_shaper
from ._features import feature_aggregation, feature_extraction
from ._image_generation import image_generation
from ._signatures import RNALinker
from ._spatial_features import spatial_features
from ._text_annotate import text_embedding, text_image_similarity
Expand All @@ -13,6 +14,7 @@
"tile_shaper",
"feature_extraction",
"feature_aggregation",
"image_generation",
"RNALinker",
"spatial_features",
"text_embedding",
Expand Down
107 changes: 107 additions & 0 deletions src/lazyslide/tools/_image_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from contextlib import nullcontext
from typing import List

import torch
from PIL import Image
from wsidata import WSIData

from lazyslide import _api
from lazyslide.models import MODEL_REGISTRY
from lazyslide.models.base import ImageGenerationModel


def image_generation(
wsi: WSIData = None,
model: str | ImageGenerationModel = "cytosyn",
prompt_tiles: slice = None,
tile_key: str = "tiles",
device: str = None,
amp: bool = None,
autocast_dtype: torch.dtype = None,
num_images_per_tiles: int = 2,
seed: int = 0,
**kwargs,
) -> List[Image.Image]:
"""
Generation of tile images unconditionally or conditionally.

Currently only supports cytosyn model, conditionally generation relied on H0-mini features.

Parameters
----------
wsi : :class:`WSIData <wsidata.WSIData>`
The WSIData object to work on.
model : str, default: "cytosyn"
The image generation model.
prompt_tiles : slice, default: None
The tiles to generate images for, please use index to select tiles.
If None, unconditional generation is performed.
tile_key : str, default: "tiles"
Which tile table to use.
device : str, optional
The device to use for inference. If not provided, the device will be automatically selected.
amp : bool, default: False
Whether to use automatic mixed precision.
autocast_dtype : torch.dtype, default: torch.float16
The dtype for automatic mixed precision.
num_images_per_tiles : int, default: 2
The number of images to generate for each tile if conditional generation is used.
Otherwise, it's the total number of images to generate if unconditional generation is used.
seed : int, default: 0
The random seed to ensure reproducible image generation (May not work for all models).
kwargs : dict, optional
Please refer to the documentation of the specific model for additional parameters.

Returns
-------
:class:`PIL.Image.Image`
The function returns a list of generated images in PIL format.

Examples
--------

>>> import lazyslide as zs
>>> # Unconditional generation
>>> imgs = zs.tl.image_generation()
>>> # Conditional generation
>>> wsi = zs.datasets.sample()
>>> zs.tl.feature_extraction(wsi, "h0-mini")
>>> imgs = zs.tl.image_generation(wsi, prompt_tiles=slice(0, 2)) # Generate images for the first two tiles

"""
device = _api.default_value("device", device)
amp = _api.default_value("amp", amp)
autocast_dtype = _api.default_value("autocast_dtype", autocast_dtype)

if isinstance(model, ImageGenerationModel):
raise NotImplementedError("Currently only supports cytosyn model.")

generation_model: ImageGenerationModel = MODEL_REGISTRY[model]()
try:
generation_model.to(device)
except: # noqa: E722
pass
if isinstance(device, torch.device):
device = device.type
amp_ctx = torch.autocast(device, autocast_dtype) if amp else nullcontext()
with amp_ctx, torch.inference_mode():
opts = dict(
num_images_per_prompt=num_images_per_tiles,
seed=seed,
)
opts.update(kwargs)
# Unconditional generation
if prompt_tiles is None:
return generation_model.generate(**opts)
# Conditional generation
else:
# Check if H0-mini features exist
try:
feature_key = wsi._check_feature_key("h0-mini", tile_key)
except KeyError:
raise KeyError(
"H0-mini features are needed for image generation with cytosyn model."
)
cls_tokens = wsi[feature_key][prompt_tiles].X[:, :768]
cls_tokens = torch.tensor(cls_tokens, dtype=torch.float32)
return generation_model.generate_conditionally(cls_tokens, **opts)
30 changes: 30 additions & 0 deletions tests/models/test_models_image_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest
import torch
from huggingface_hub.errors import GatedRepoError

from lazyslide.models import MODEL_REGISTRY, list_models
from lazyslide.models.base import ImageGenerationModel

# List of image generation models to test
IMAGE_GENERATION_MODELS = list_models(task="image_generation")


@pytest.mark.large_runner
@pytest.mark.parametrize("model_name", IMAGE_GENERATION_MODELS)
def test_image_generation_model(model_name):
# Initialize the model
try:
model = MODEL_REGISTRY[model_name]()
except GatedRepoError:
pytest.skip(f"{model_name} is not available.")
return

# Test 1: Model initialization
assert isinstance(model, ImageGenerationModel)

# Test 6: Model prediction
with torch.inference_mode():
_ = model.generate()

# Explicitly delete the model to free memory
del model
Loading
Loading