Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -1 +1,20 @@
# The .dockerignore file excludes files from the container build process.
#
# https://docs.docker.com/engine/reference/builder/#dockerignore-file

# Exclude Git files
.git
.github
.gitignore

# Exclude Python cache files
__pycache__
.mypy_cache
.pytest_cache
.ruff_cache

# Exclude weights
diffusers-cache

# Exclude Python virtual environment
/venv
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
.cog/
__pycache__/
diffusers-cache/
diffusers-cache/
.cog/
16 changes: 9 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

This is an implementation of the [Diffusers Stable Diffusion v2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1) as a Cog model. [Cog packages machine learning models as standard containers.](https://github.com/replicate/cog)

First, download the pre-trained weights:

cog run script/download-weights

Then, you can run predictions:

cog predict -i prompt="monkey scuba diving"
Make single prediction:
```bash
cog predict -i prompt="monkey scuba diving"
```

Run HTTP API for making predictions:
```bash
cog run -p 5000
```
3 changes: 3 additions & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,7 @@ build:
- "accelerate==0.15.0"
- "huggingface-hub==0.13.2"

run:
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.3.0/pget" && chmod +x /usr/local/bin/pget

predict: "predict.py:Predictor"
27 changes: 21 additions & 6 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,48 @@
import torch
from cog import BasePredictor, Input, Path
from diffusers import (
StableDiffusionPipeline,
PNDMScheduler,
LMSDiscreteScheduler,
DDIMScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionPipeline,
)
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)

from weights_downloader import WeightsDownloader

# MODEL_ID refers to a diffusers-compatible model on HuggingFace
# e.g. prompthero/openjourney-v2, wavymulder/Analog-Diffusion, etc
MODEL_ID = "stabilityai/stable-diffusion-2-1"
MODEL_CACHE = "diffusers-cache"

SD_MODEL_CACHE = os.path.join(MODEL_CACHE, "models--stabilityai--stable-diffusion-2-1")
MODEL_ID = "stabilityai/stable-diffusion-2-1"
SD_URL = "https://weights.replicate.delivery/default/stable-diffusion/stable-diffusion-2-1.tar"

SAFETY_CACHE = os.path.join(
MODEL_CACHE, "models--CompVis--stable-diffusion-safety-checker"
)
SAFETY_MODEL_ID = "CompVis/stable-diffusion-safety-checker"
SAFETY_URL = "https://weights.replicate.delivery/default/stable-diffusion/stable-diffusion-safety-checker.tar"


class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""

print("Loading pipeline...")
WeightsDownloader.download_if_not_exists(SAFETY_URL, SAFETY_CACHE)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably include a revision, same as https://github.com/replicate/cog-deepfloyd-if/pull/2

SAFETY_MODEL_ID,
cache_dir=MODEL_CACHE,
local_files_only=True,
)

WeightsDownloader.download_if_not_exists(SD_URL, SD_MODEL_CACHE)
self.pipe = StableDiffusionPipeline.from_pretrained(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should probably include a revision

MODEL_ID,
safety_checker=safety_checker,
Expand Down
28 changes: 0 additions & 28 deletions script/download-weights

This file was deleted.

18 changes: 18 additions & 0 deletions weights_downloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os
import subprocess
import time


class WeightsDownloader:
@staticmethod
def download_if_not_exists(url, dest):
if not os.path.exists(dest):
WeightsDownloader.download(url, dest)

@staticmethod
def download(url, dest):
start = time.time()
print("downloading url: ", url)
print("downloading to: ", dest)
subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
print("downloading took: ", time.time() - start)