-
Notifications
You must be signed in to change notification settings - Fork 322
Open
Description
I've been experimenting with supporting weights from dreambooth in this model:
diff --git a/predict.py b/predict.py
index 5630646..2d23a87 100644
--- a/predict.py
+++ b/predict.py
@@ -10,6 +10,7 @@ from diffusers import (
StableDiffusionPipeline,
)
+USE_WEIGHTS = os.path.exists("weights")
MODEL_ID = "stabilityai/stable-diffusion-2-1"
MODEL_CACHE = "diffusers-cache"
@@ -18,11 +19,18 @@ class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
print("Loading pipeline...")
- self.pipe = StableDiffusionPipeline.from_pretrained(
- MODEL_ID,
- cache_dir=MODEL_CACHE,
- local_files_only=True,
- ).to("cuda")
+ if USE_WEIGHTS:
+ self.pipe = StableDiffusionPipeline.from_pretrained(
+ "weights",
+ safety_checker=None,
+ torch_dtype=torch.float16,
+ ).to("cuda")
+ else:
+ self.pipe = StableDiffusionPipeline.from_pretrained(
+ MODEL_ID,
+ cache_dir=MODEL_CACHE,
+ local_files_only=True,
+ ).to("cuda")
@torch.inference_mode()
def predict(
The only other major difference between this and dreambooth-template is that it has a hardcoded scheduler:
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
The default scheduler seems to work - although I don't know if those "magic numbers" in the DDIMScheduler in dreambooth-template are to maximize the quality from the dreambooth generations?
With the above patch all you have to do unzip the weights generated by this api https://replicate.com/replicate/dreambooth into cog-stable-diffusion and cog build
Metadata
Metadata
Assignees
Labels
No labels
