Skip to content

Commit 2d8dce9

Browse files
committed
apply suggestions from review
1 parent 61831bd commit 2d8dce9

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Callable, Dict, List, Optional, Tuple, Union
1919

2020
import torch
21+
import PIL
2122
from transformers import T5EncoderModel, T5Tokenizer
2223

2324
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
@@ -431,6 +432,7 @@ def prepare_extra_step_kwargs(self, generator, eta):
431432

432433
def check_inputs(
433434
self,
435+
image,
434436
prompt,
435437
height,
436438
width,
@@ -441,6 +443,16 @@ def check_inputs(
441443
prompt_embeds=None,
442444
negative_prompt_embeds=None,
443445
):
446+
if (
447+
not isinstance(image, torch.Tensor)
448+
and not isinstance(image, PIL.Image.Image)
449+
and not isinstance(image, list)
450+
):
451+
raise ValueError(
452+
"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
453+
f" {type(image)}"
454+
)
455+
444456
if height % 8 != 0 or width % 8 != 0:
445457
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
446458

@@ -659,6 +671,7 @@ def __call__(
659671

660672
# 1. Check inputs. Raise error if not correct
661673
self.check_inputs(
674+
image,
662675
prompt,
663676
height,
664677
width,

0 commit comments

Comments
 (0)