diff --git a/install/install_requirements.sh b/install/install_requirements.sh index 43110da4d..b698315ff 100755 --- a/install/install_requirements.sh +++ b/install/install_requirements.sh @@ -53,7 +53,7 @@ PYTORCH_NIGHTLY_VERSION=dev20240814 VISION_NIGHTLY_VERSION=dev20240814 # Nightly version for torchtune -TUNE_NIGHTLY_VERSION=dev20240910 +TUNE_NIGHTLY_VERSION=dev20240916 # Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same diff --git a/torchchat/generate.py b/torchchat/generate.py index 30490d396..37493761f 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -727,11 +727,13 @@ def chat( if generator_args.image_prompts is not None: print("Image prompts", generator_args.image_prompts) + # Support for just the first image prompt for now + images = [Image.open(generator_args.image_prompts[0])] messages = [ Message( role="user", content=[ - {"type": "image"}, + {"type": "image", "content": images[0]}, {"type": "text", "content": generator_args.prompt}, ], eot=True, @@ -739,10 +741,8 @@ def chat( Message(role="assistant", content=""), ] - images = [Image.open(generator_args.image_prompts[0])] transform = flamingo_transform(str(self.tokenizer_args.tokenizer_path)) - - data = transform({"images": images, "messages": messages}, inference=True) + data = transform({"messages": messages}, inference=True) batch = padded_collate([data], self.builder_args.device) batch.pop("mask") encoded = batch["tokens"]