diff --git a/README.md b/README.md index 92c68be..e3d046a 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ 🤖ModelScope Space 🛠️ZhipuAI MaaS(Faster)
- 👋 WeChat Community 📚 CogView3 Paper + 👋 WeChat Community 📚 CogView3 Paper 🤖 Replicate

![showcase.png](resources/showcase.png) diff --git a/README_ja.md b/README_ja.md index 659ff34..cb89e1d 100644 --- a/README_ja.md +++ b/README_ja.md @@ -12,7 +12,7 @@ 🤖ModelScope Space 🛠️ZhipuAI MaaS(Faster)
- 👋 WeChat Community 📚 CogView3 Paper + 👋 WeChat Community 📚 CogView3 Paper 🤖 Replicate

diff --git a/README_zh.md b/README_zh.md index 66a7783..dfa49bd 100644 --- a/README_zh.md +++ b/README_zh.md @@ -14,7 +14,7 @@
👋 微信社区 📚 CogView3 论文 -

+ 🤖 Replicate ![showcase.png](resources/showcase.png) diff --git a/cog.yaml b/cog.yaml new file mode 100644 index 0000000..24688df --- /dev/null +++ b/cog.yaml @@ -0,0 +1,21 @@ +# Configuration for Cog ⚙️ +# Reference: https://cog.run/yaml + +build: + gpu: true + cuda: "12.1" + python_version: "3.11" + python_packages: + - "torch==2.4" + - "git+https://github.com/huggingface/diffusers.git@24c062aaa19f5626d03d058daf8afffa2dfd49f7" + - "transformers==4.49.0" + - "accelerate==1.4.0" + - "safetensors==0.5.3" + - "pillow==10.1.0" + - "numpy<2" + + run: + - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.9.1/pget_linux_x86_64" && chmod +x /usr/local/bin/pget + +# predict.py defines how predictions are run on your model +predict: "predict.py:Predictor" diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..ad9ab4d --- /dev/null +++ b/predict.py @@ -0,0 +1,108 @@ +# Prediction interface for Cog ⚙️ +# https://cog.run/python + +import os +import time +import torch +import subprocess +from diffusers import CogView4Pipeline +from cog import BasePredictor, Input, Path + +MODEL_CACHE = "checkpoints" +MODEL_URL = "https://weights.replicate.delivery/default/THUDM/CogView4-6B/model.tar" + +def download_weights(url, dest): + start = time.time() + print("downloading url: ", url) + print("downloading to: ", dest) + subprocess.check_call(["pget", "-xf", url, dest], close_fds=False) + print("downloading took: ", time.time() - start) + +class Predictor(BasePredictor): + def setup(self) -> None: + """Load the model into memory to make running multiple predictions efficient""" + # Download weights if they don't exist + if not os.path.exists(MODEL_CACHE): + download_weights(MODEL_URL, MODEL_CACHE) + + # Load CogView4-6B model with bfloat16 precision as recommended + self.pipe = CogView4Pipeline.from_pretrained( + MODEL_CACHE, + torch_dtype=torch.bfloat16 + ) + + # Enable optimizations to reduce GPU memory usage and improve speed + self.pipe.enable_model_cpu_offload() + self.pipe.vae.enable_slicing() + self.pipe.vae.enable_tiling() + + + def predict( + self, + prompt: str = Input( + description="Text prompt to generate an image from" + ), + negative_prompt: str = Input( + description="Negative prompt to guide image generation away from certain concepts", + default=None + ), + width: int = Input( + description="Width of the generated image (must be between 512 and 2048, divisible by 32)", + default=1024, + ge=512, + le=2048 + ), + height: int = Input( + description="Height of the generated image (must be between 512 and 2048, divisible by 32)", + default=1024, + ge=512, + le=2048 + ), + num_inference_steps: int = Input( + description="Number of denoising steps", + default=50, + ge=1, + le=100 + ), + guidance_scale: float = Input( + description="Guidance scale for classifier-free guidance", + default=3.5, + ge=0.0, + le=20.0 + ), + seed: int = Input( + description="Random seed for reproducible image generation", + default=None + ) + ) -> Path: + """Run a single prediction on the model""" + # Validate dimensions + if width % 32 != 0 or height % 32 != 0: + raise ValueError("Width and height must be divisible by 32") + if width * height > 2**21: + raise ValueError(f"Resolution {width}x{height} exceeds maximum allowed pixels (2^21)") + + # Set seed for reproducibility + generator = None + if seed is None: + seed = int.from_bytes(os.urandom(3), "big") + generator = torch.Generator().manual_seed(seed) + print("Using seed: ", seed) + + # Generate image(s) + images = self.pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=width, + height=height, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + generator=generator, + num_images_per_prompt=1, + ).images + + # Save the first generated image + output_path = Path(f"/tmp/output.png") + images[0].save(output_path) + return output_path +