From a93f4b5bee004985e2a1c78bafc429b1806b0a0c Mon Sep 17 00:00:00 2001 From: chenxwh Date: Mon, 14 Oct 2024 22:36:21 +0000 Subject: [PATCH] replicate demo --- README.md | 8 +-- cog.yaml | 28 +++++++++++ predict.py | 142 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 175 insertions(+), 3 deletions(-) create mode 100644 cog.yaml create mode 100644 predict.py diff --git a/README.md b/README.md index dfd1c32..2924ac6 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,10 @@

-Experience the CogView3-Plus-3B model online on 🤗 Huggingface Space -

+Try out the CogView3-Plus-3B model online on 🤗 Huggingface Space and + + Replicate +

📚 Check out the paper

@@ -21,7 +23,7 @@ Experience the CogView3-Plus-3B model online on == + python_packages: + - accelerate>=1.0.0 + - sentencepiece>=0.2.0 + - torch + - transformers>=4.45.2 + - openai>=1.51.2 + run: + - pip install git+https://github.com/huggingface/diffusers.git + - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget + +# predict.py defines how predictions are run on your model +predict: "predict.py:Predictor" diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..302b72f --- /dev/null +++ b/predict.py @@ -0,0 +1,142 @@ +# Prediction interface for Cog ⚙️ +# https://cog.run/python + +import os +import subprocess +import time +from cog import BasePredictor, Input, Path + +from diffusers import CogView3PlusPipeline +import torch + + +MODEL_URL = "https://weights.replicate.delivery/default/THUDM/CogView3/model_cache.tar" +MODEL_CACHE = "model_cache" + + +def download_weights(url, dest): + start = time.time() + print("downloading url: ", url) + print("downloading to: ", dest) + subprocess.check_call(["pget", "-x", url, dest], close_fds=False) + print("downloading took: ", time.time() - start) + + +class Predictor(BasePredictor): + def setup(self) -> None: + """Load the model into memory to make running multiple predictions efficient""" + + if not os.path.exists(MODEL_CACHE): + download_weights(MODEL_URL, MODEL_CACHE) + + self.pipe = CogView3PlusPipeline.from_pretrained( + MODEL_CACHE, torch_dtype=torch.bfloat16 # from THUDM/CogView3-Plus-3B + ).to("cuda") + self.pipe.enable_model_cpu_offload() + self.pipe.vae.enable_slicing() + self.pipe.vae.enable_tiling() + + def predict( + self, + prompt: str = Input( + description="Input prompt", + default="a photo of an astronaut riding a horse on mars", + ), + negative_prompt: str = Input( + description="Specify things to not see in the output", + default="", + ), + width: int = Input( + description="Width of output image. Maximum size is 1024x768 or 768x1024 because of memory limits", + choices=[ + 512, + 576, + 640, + 704, + 768, + 832, + 896, + 960, + 1024, + 1088, + 1152, + 1216, + 1280, + 1344, + 1408, + 1472, + 1536, + 1600, + 1664, + 1728, + 1792, + 1856, + 1920, + 1984, + 2048, + ], + default=1024, + ), + height: int = Input( + description="Height of output image. Maximum size is 1024x768 or 768x1024 because of memory limits", + choices=[ + 512, + 576, + 640, + 704, + 768, + 832, + 896, + 960, + 1024, + 1088, + 1152, + 1216, + 1280, + 1344, + 1408, + 1472, + 1536, + 1600, + 1664, + 1728, + 1792, + 1856, + 1920, + 1984, + 2048, + ], + default=1024, + ), + num_inference_steps: int = Input( + description="Number of denoising steps", ge=1, le=500, default=50 + ), + guidance_scale: float = Input( + description="Scale for classifier-free guidance", ge=1, le=20, default=7 + ), + seed: int = Input( + description="Random seed. Leave blank to randomize the seed", default=None + ), + ) -> Path: + """Run a single prediction on the model""" + if seed is None: + seed = int.from_bytes(os.urandom(2), "big") + print(f"Using seed: {seed}") + + image = self.pipe( + prompt=prompt, + negative_prompt=negative_prompt, + guidance_scale=guidance_scale, + num_images_per_prompt=1, + num_inference_steps=num_inference_steps, + width=width, + height=height, + generator=torch.Generator().manual_seed(seed), + ).images[0] + + torch.cuda.empty_cache() + + out_path = "/tmp/out.png" + + image.save(out_path) + return Path(out_path)