diff --git a/torchchat/distributed/dist_run.py b/torchchat/distributed/dist_run.py
index 389ae41c1..195c4a1de 100644
--- a/torchchat/distributed/dist_run.py
+++ b/torchchat/distributed/dist_run.py
@@ -282,16 +282,6 @@ def _cleanup():
dist.destroy_process_group()
-prompts = [
- "What is Snow?",
- # "Can you explain what is the purpose of back propagation in neural networks?",
- "Who is Santa Claus?",
- "Where does Santa live?",
- "Who is Abraham Lincoln?",
- # "How are models trained?",
-]
-
-
def main(
model_name,
builder_args,
diff --git a/torchchat/distributed/generate.py b/torchchat/distributed/generate.py
index 51c472e4a..6f3c6930e 100644
--- a/torchchat/distributed/generate.py
+++ b/torchchat/distributed/generate.py
@@ -21,6 +21,7 @@
from torchchat.cli.builder import BuilderArgs, TokenizerArgs
from torchchat.distributed.dist_run import NAME_TO_DISTRIBUTION_AND_DTYPE
from torchchat.distributed.logging_utils import SingletonLogger
+from torchchat.utils.generator import Generator, GeneratorArgs
logger = SingletonLogger.get_logger()
@@ -194,19 +195,19 @@ def step(self) -> List[Output]:
return outputs
-class DistributedGenerator(object):
+class DistributedGenerator(Generator):
def __init__(
self,
# TODO: switch this to torchchat method
model_name: str,
builder_args: BuilderArgs,
tokenizer_args: TokenizerArgs,
- # TODO: move GeneratorArgs into a different module
- generator_args,
+ generator_args: GeneratorArgs,
profile: Optional[Path],
quantize: bool,
draft_quantize: bool,
):
+ super().__init__(builder_args, tokenizer_args, generator_args)
self.model_name = model_name
self.builder_args = builder_args
self.generate_args = generator_args
diff --git a/torchchat/generate.py b/torchchat/generate.py
index dd423b58a..cfc257403 100644
--- a/torchchat/generate.py
+++ b/torchchat/generate.py
@@ -3,39 +3,24 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
-import argparse
-import base64
+
import itertools
import logging
-import os
import textwrap
import time
-from abc import ABC, abstractmethod
-from dataclasses import dataclass
-from io import BytesIO
-from os import PathLike
from pathlib import Path
-from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
+from typing import Any, Dict, Optional, Tuple
+from typing_extensions import override
import torch
import torch._dynamo.config
import torch._inductor.config
-from PIL import Image
-
-# torchtune model definition dependencies
-from torchtune.data import Message, padded_collate_tiled_images_and_mask
-
from torchtune.generation import sample as tune_sample
-from torchtune.models.llama3 import llama3_tokenizer
-
-from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform
-from torchtune.training import set_default_dtype
from torchchat.cli.builder import (
_initialize_model,
- _initialize_tokenizer,
BuilderArgs,
TokenizerArgs,
)
@@ -43,178 +28,10 @@
from torchchat.model import Model, ModelType
from torchchat.utils.build_utils import device_sync, set_precision
from torchchat.utils.device_info import get_device_info
+from torchchat.utils.generator import Generator, GeneratorArgs, E_INST, B_INST, E_SYS, B_SYS
-class _ChatFormatter(ABC):
- def __init__(self, tokenizer):
- self.tokenizer = tokenizer
-
- @abstractmethod
- def encode_dialog_prompt(self, dialog) -> List[int]:
- raise NotImplementedError()
-
-
-class Llama3ChatFormatter(_ChatFormatter):
- """Format a chat prompt using special tokens to demarcate roles and messages.
-
- Refer to the LLaMA3 documentation for more details https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3
-
- """
-
- def encode_header(self, role) -> List[int]:
- tokens = []
- tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
- tokens.extend(self.tokenizer.encode(role, bos=False, eos=False))
- tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
- tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
- return tokens
-
- def encode_message(self, message) -> List[int]:
- tokens = self.encode_header(message["role"])
- if isinstance(message["content"], str):
- tokens.extend(
- self.tokenizer.encode(message["content"], bos=False, eos=False)
- )
- elif isinstance(message["content"], list):
- for content in message["content"]:
- if content["type"] == "text":
- tokens.extend(
- self.tokenizer.encode(content["text"], bos=False, eos=False)
- )
-
- tokens.append(self.tokenizer.special_tokens["<|eot_id|>"])
- return tokens
-
- def encode_dialog_prompt(self, dialog) -> List[int]:
- tokens = []
- tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
- for message in dialog:
- tokens.extend(self.encode_message(message))
- # Add the start of an assistant message for the model to complete.
- tokens.extend(self.encode_header("assistant")) # Pass role directly as a string
- return tokens
-
-
-B_INST, E_INST = "[INST]", "[/INST]"
-B_SYS, E_SYS = "<>", "<>"
-
-
-class Llama2ChatFormatter(_ChatFormatter):
- def encode_dialog_prompt(self, dialog) -> List[int]:
- tokens = self.tokenizer.encode(f"{B_INST} ")
- first_message = True # Bool to handle placing the B_INST token. Behavior is weird - the system prompt should have the B_INST, but not the first user message. All following user messages *should* have it. Also, if there is no system prompt, then the user message should have it.
- for message in dialog:
- if isinstance(message["content"], list):
- content = message["content"][0]["text"]
- else:
- content = message["content"]
- content = content.strip()
- if message["role"] == "system":
- encoded = self.tokenizer.encode(f"{B_SYS}\n{content}\n{E_SYS}")
- first_message = False
- elif message["role"] == "user":
- encoded = [self.tokenizer.bos_id()] + self.tokenizer.encode(
- f"{B_INST if first_message else ''} {content} {E_INST} "
- )
- first_message = True
- elif message["role"] == "assistant":
- encoded = self.tokenizer.encode(f"{content}\n\n") + [
- self.tokenizer.eos_id()
- ]
- tokens += encoded
- return tokens
-
-
-@dataclass
-class GeneratorArgs:
- prompt: Optional[str] = (
- None # When passed into the Generator, this will be used as the system prompt
- )
- encoded_prompt: Optional[torch.Tensor] = None
- image_prompts: Optional[Sequence[Union[str, PathLike, bytes]]] = (
- None # string or Path to an image file or the raw base64 bytes of an image
- )
- chat_mode: bool = False
- gui_mode: bool = False
- num_samples: int = 1
- max_new_tokens: int = 200
- top_k: int = 200
- temperature: float = 0.0 # deterministic argmax if 0.0
- compile: bool = False
- compile_prefill: bool = False
- speculate_k: int = 5
- sequential_prefill: bool = False
- max_autotune: bool = False
- # (Misnomer) See Issue: https://github.com/pytorch/torchchat/issues/1273
- is_torchtune_model: bool = False
-
- def __post_init__(self):
- if self.compile_prefill and self.sequential_prefill:
- raise RuntimeError("prefill compilation requires parallel prefill")
-
- def validate_build(
- self, builder_args: BuilderArgs, model_description: str = "model"
- ):
- reason = ""
- model_type = ""
- if not self.sequential_prefill:
- reason = "parallel prefill"
- if self.compile_prefill:
- reason = "model compilation for prefill"
- if self.compile:
- reason = "model compilation"
- if builder_args.aoti_package_path:
- model_type = "PT2"
- if builder_args.dso_path:
- model_type = "DSO"
- if builder_args.pte_path:
- model_type = "PTE"
- if model_type and reason:
- raise RuntimeError(
- f"cannot perform {reason} because a {model_type} {model_description} is used"
- )
-
- @classmethod
- def from_args(cls, args):
- dso_path = getattr(args, "dso_path", None)
- pte_path = getattr(args, "pte_path", None)
- aoti_package_path = getattr(args, "aoti_package_path", None)
- sequential_prefill = (
- args.sequential_prefill or bool(aoti_package_path) or bool(pte_path) or bool(dso_path)
- )
-
- # Validate that all image prompts exist before expensive model load
- if image_prompts := getattr(args, "image_prompts", None):
- non_existent_image_prompts = [
- image_prompt
- for image_prompt in image_prompts
- if (not os.path.exists(image_prompt))
- ]
- if non_existent_image_prompts:
- raise RuntimeError(
- f"Image prompt {non_existent_image_prompts} does not exist"
- )
-
- return cls(
- prompt=getattr(args, "prompt", ""),
- encoded_prompt=None,
- image_prompts=image_prompts,
- chat_mode=args.chat,
- gui_mode=args.gui,
- num_samples=getattr(args, "num_samples", 1),
- max_new_tokens=args.max_new_tokens,
- top_k=args.top_k,
- temperature=args.temperature,
- compile=args.compile,
- compile_prefill=args.compile_prefill,
- speculate_k=args.speculate_k,
- sequential_prefill=sequential_prefill,
- max_autotune=args.max_autotune,
- is_torchtune_model=args.model and args.model.endswith("tune"),
- )
-
-
-class Generator:
+class LocalGenerator(Generator):
"""
Generates text samples based on a pre-trained Transformer model and tokenizer.
Args:
@@ -237,6 +54,7 @@ def __init__(
quantize: bool,
draft_quantize: bool,
):
+ super().__init__(builder_args, tokenizer_args, generator_args)
torch._inductor.config.coordinate_descent_tuning = (
builder_args.device != "cpu"
)
@@ -245,12 +63,10 @@ def __init__(
self.builder_args = builder_args
self.speculative_builder_args = speculative_builder_args
- self.tokenizer_args = tokenizer_args
self.profile = profile
self.quantize = quantize
self.draft_quantize = draft_quantize
self.is_torchtune_model = generator_args.is_torchtune_model
- self.dtype = builder_args.precision
self.rank: Optional[int] = None
@@ -273,21 +89,6 @@ def __init__(
))
# fmt: on
self.system_prompt = generator_args.prompt
- self.tokenizer = _initialize_tokenizer(self.tokenizer_args)
-
- # Right now the assumption is only llama3 uses tiktokenizer and it
- # must use tiktokenizer.
- # Piggy backing off of this flag then for now to identify llama3
- # without prompting user.
- self.is_llama3_model = self.tokenizer_args.is_tiktoken
- if self.is_llama3_model:
- self.chat_formatter = Llama3ChatFormatter(self.tokenizer)
- if generator_args.chat_mode:
- logging.debug(
- "Llama3 model detected in chat mode. Using updated sentence schemas"
- )
- else:
- self.chat_formatter = Llama2ChatFormatter(self.tokenizer)
self.builder_args.setup_caches = False
self.model = _initialize_model(self.builder_args, self.quantize, self.tokenizer)
@@ -606,7 +407,7 @@ def generate(
torch.manual_seed(seed)
is_speculative = draft_model is not None
- device, dtype = prompt.device, prompt.dtype
+ device = prompt.device
if len(prompt.shape) > 1:
prompt = prompt.squeeze(0)
@@ -721,13 +522,6 @@ def generate(
}
yield None, generate_stats
- def encode_tokens(self, string, bos=True, device="cpu"):
- tokens = self.tokenizer.encode(string)
- if bos:
- tokens = [self.tokenizer.bos_id()] + tokens
- logging.debug(f"Size after encode_tokens: {len(tokens)}")
- return torch.tensor(tokens, dtype=torch.int, device=device)
-
def _callback(self, x, *, buffer, done_generating):
# TODO: Refactor this callback to only include basic functionality & remove print statements
period_id = self.tokenizer.encode(".")[0]
@@ -747,159 +541,6 @@ def _callback(self, x, *, buffer, done_generating):
buffer.clear()
# print(, end='', flush=True)
- def _gen_model_input(
- self,
- prompt: Union[str | List[Any]],
- image_prompts: Optional[List[str | Image.Image]] = None,
- max_new_tokens: Optional[int] = None,
- max_seq_len: Optional[int] = 2048,
- ) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]:
- """
- Convert prompt and image prompts into consumable model input args.
-
- When prompt is a list, the anticipated format is OpenAI API Inspired:
- [ ..., {"role": message["role"], "content": message["content"]}, ...]
-
- Args:
- prompt (Union[str, List[Any]]): Prompt or list of dialog.
- image_prompts (Optional[List[str | Image.Image]]): List of image prompts. Used only with Llama 3.2 11B.
- max_new_tokens (Optional[int]): Maximum number of new tokens to generate. Used only with Llama 3.2 11B.
-
- Returns:
- Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models.
- """
-
- # Text-Only model
- if self.model.config.model_type != ModelType.Flamingo:
- # Single String prompt
- if isinstance(prompt, str):
- encoded = self.encode_tokens(
- prompt, bos=True, device=self.builder_args.device
- )
- # List of dialog
- else:
- tokens = self.chat_formatter.encode_dialog_prompt(prompt)
- encoded = torch.tensor(
- tokens, dtype=torch.int, device=self.builder_args.device
- )
-
- logging.debug(encoded)
- return encoded, None
-
- # Llama 3.2 11B
- assert (
- image_prompts is None or len(image_prompts) == 1
- ), "At most one image is supported at the moment"
-
- if image_prompts and isinstance(image_prompts[0], str):
- images = [Image.open(image_prompts[0])]
- else:
- images = None
-
- assert (
- max_new_tokens is not None
- ), "max_new_tokens must be specified for Flamingo models"
-
- # Wrap string prompts into a list
- if isinstance(prompt, str):
- prompt = [{"role": "user", "content": prompt}]
-
- image_found = False
- messages = []
- for message in prompt:
- if isinstance(message["content"], str):
- if not image_found and image_prompts:
- messages.append(
- Message(
- role=message["role"],
- content=[
- {"type": "image", "content": images[0]},
- {"type": "text", "content": message["content"]},
- ],
- )
- )
- image_found = True
- else:
- messages.append(Message(**message))
-
- elif isinstance(message["content"], list):
- images = None
- for content_dict in message["content"]:
- if content_dict["type"] == "text":
- prompt_arg = content_dict["text"]
- elif content_dict["type"] == "image_url":
- assert (
- images is None
- ), "At most one image is supported at the moment"
-
- base64_decoded = base64.b64decode(
- content_dict["image_url"].split(";base64,")[1]
- )
- images = [Image.open(BytesIO(base64_decoded))]
- image_found = True
-
- is_multimodal = images is not None
- content = [{"type": "text", "content": prompt_arg}]
-
- if is_multimodal:
- content = [{"type": "image", "content": images[0]}] + content
-
- messages.append(
- Message(
- role=message["role"],
- content=content,
- )
- )
-
- messages.append(
- Message(
- role="assistant",
- content="",
- )
- )
-
- transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path))
-
- device = torch.device(device=self.builder_args.device)
-
- with device, set_default_dtype(self.dtype):
- data = transform({"messages": messages}, inference=True)
-
- if image_found:
- batch = padded_collate_tiled_images_and_mask(
- [data], pad_direction="left", pad_max_images=1
- )
- encoded = batch.pop("tokens").to(device).view(-1)
- seq_len = encoded.size(0)
- batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
- batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(
- self.dtype
- )
-
- else:
- encoded = torch.tensor(data["tokens"], device=device).view(-1)
- seq_len = encoded.size(0)
- batch = {}
-
- total_response_length = seq_len + max_new_tokens
- batch["causal_mask"] = torch.nn.functional.pad(
- torch.tril(
- torch.ones(
- size=(total_response_length, total_response_length),
- dtype=torch.bool,
- )
- ),
- (
- 0,
- max_seq_len - total_response_length,
- 0,
- max_seq_len - total_response_length,
- ),
- value=0,
- )
-
- logging.debug(encoded)
- return encoded, batch
def chat(
self,
@@ -1087,8 +728,6 @@ def callback(x, *, done_generating=False):
)
if self.profile:
- from torch._inductor import config as inductor_config
-
torch._inductor.config.profiler_mark_wrapper_call = True
torch._inductor.config.cpp.enable_kernel_profile = True
if (i != generator_args.num_samples - 1 or not self.profile) or (
@@ -1205,15 +844,10 @@ def callback(x, *, done_generating=False):
)
if torch.cuda.is_available():
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
-
-
-def _launch_distributed_inference(
- builder_args: BuilderArgs,
-):
- from torch.distributed import launcher
- from torch.distributed.elastic.utils.distributed import get_free_port
-
- print("Launching distributed inference within generator")
+
+ @override
+ def is_text_only(self) -> bool:
+ return self.model.config.model_type != ModelType.Flamingo
def main(args):
@@ -1222,7 +856,7 @@ def main(args):
tokenizer_args = TokenizerArgs.from_args(args)
generator_args = GeneratorArgs.from_args(args)
if not builder_args.distributed:
- gen = Generator(
+ gen = LocalGenerator(
builder_args,
speculative_builder_args,
tokenizer_args,
diff --git a/torchchat/usages/openai_api.py b/torchchat/usages/openai_api.py
index 99fd82fe8..a6e480ab4 100644
--- a/torchchat/usages/openai_api.py
+++ b/torchchat/usages/openai_api.py
@@ -24,7 +24,7 @@
from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform
from torchchat.cli.download import is_model_downloaded, load_model_configs
-from torchchat.generate import Generator, GeneratorArgs
+from torchchat.generate import LocalGenerator, GeneratorArgs
from torchchat.model import FlamingoModel
from torchchat.utils.build_utils import device_sync
@@ -267,7 +267,7 @@ class CompletionResponseChunk:
usage: Optional[UsageStats] = None
-class OpenAiApiGenerator(Generator):
+class OpenAiApiGenerator(LocalGenerator):
"""A wrapper over the Generator class to interface with the OpenAI API.
Implements endpoints for completion requests, both chunked and non-chunked using the dataclasses
diff --git a/torchchat/utils/generator.py b/torchchat/utils/generator.py
new file mode 100644
index 000000000..f7896b8a5
--- /dev/null
+++ b/torchchat/utils/generator.py
@@ -0,0 +1,399 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import base64
+import logging
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from io import BytesIO
+from PIL import Image
+from os import PathLike
+from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
+
+import torch
+
+from torchchat.cli.builder import (
+ _initialize_tokenizer,
+ BuilderArgs,
+ TokenizerArgs,
+)
+
+# torchtune model definition dependencies
+from torchtune.data import Message, padded_collate_tiled_images_and_mask
+from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform
+from torchtune.training import set_default_dtype
+
+
+class _ChatFormatter(ABC):
+ def __init__(self, tokenizer):
+ self.tokenizer = tokenizer
+
+ @abstractmethod
+ def encode_dialog_prompt(self, dialog) -> List[int]:
+ raise NotImplementedError()
+
+
+class Llama3ChatFormatter(_ChatFormatter):
+ """Format a chat prompt using special tokens to demarcate roles and messages.
+
+ Refer to the LLaMA3 documentation for more details https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3
+
+ """
+
+ def encode_header(self, role) -> List[int]:
+ tokens = []
+ tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
+ tokens.extend(self.tokenizer.encode(role, bos=False, eos=False))
+ tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
+ tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
+ return tokens
+
+ def encode_message(self, message) -> List[int]:
+ tokens = self.encode_header(message["role"])
+ if isinstance(message["content"], str):
+ tokens.extend(
+ self.tokenizer.encode(message["content"], bos=False, eos=False)
+ )
+ elif isinstance(message["content"], list):
+ for content in message["content"]:
+ if content["type"] == "text":
+ tokens.extend(
+ self.tokenizer.encode(content["text"], bos=False, eos=False)
+ )
+
+ tokens.append(self.tokenizer.special_tokens["<|eot_id|>"])
+ return tokens
+
+ def encode_dialog_prompt(self, dialog) -> List[int]:
+ tokens = []
+ tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
+ for message in dialog:
+ tokens.extend(self.encode_message(message))
+ # Add the start of an assistant message for the model to complete.
+ tokens.extend(self.encode_header("assistant")) # Pass role directly as a string
+ return tokens
+
+
+B_INST, E_INST = "[INST]", "[/INST]"
+B_SYS, E_SYS = "<>", "<>"
+
+
+class Llama2ChatFormatter(_ChatFormatter):
+ def encode_dialog_prompt(self, dialog) -> List[int]:
+ tokens = self.tokenizer.encode(f"{B_INST} ")
+ first_message = True # Bool to handle placing the B_INST token. Behavior is weird - the system prompt should have the B_INST, but not the first user message. All following user messages *should* have it. Also, if there is no system prompt, then the user message should have it.
+ for message in dialog:
+ if isinstance(message["content"], list):
+ content = message["content"][0]["text"]
+ else:
+ content = message["content"]
+ content = content.strip()
+ if message["role"] == "system":
+ encoded = self.tokenizer.encode(f"{B_SYS}\n{content}\n{E_SYS}")
+ first_message = False
+ elif message["role"] == "user":
+ encoded = [self.tokenizer.bos_id()] + self.tokenizer.encode(
+ f"{B_INST if first_message else ''} {content} {E_INST} "
+ )
+ first_message = True
+ elif message["role"] == "assistant":
+ encoded = self.tokenizer.encode(f"{content}\n\n") + [
+ self.tokenizer.eos_id()
+ ]
+ tokens += encoded
+ return tokens
+
+
+@dataclass
+class GeneratorArgs:
+ prompt: Optional[str] = (
+ None # When passed into the Generator, this will be used as the system prompt
+ )
+ encoded_prompt: Optional[torch.Tensor] = None
+ image_prompts: Optional[Sequence[Union[str, PathLike, bytes]]] = (
+ None # string or Path to an image file or the raw base64 bytes of an image
+ )
+ chat_mode: bool = False
+ gui_mode: bool = False
+ num_samples: int = 1
+ max_new_tokens: int = 200
+ top_k: int = 200
+ temperature: float = 0.0 # deterministic argmax if 0.0
+ compile: bool = False
+ compile_prefill: bool = False
+ speculate_k: int = 5
+ sequential_prefill: bool = False
+ max_autotune: bool = False
+ # (Misnomer) See Issue: https://github.com/pytorch/torchchat/issues/1273
+ is_torchtune_model: bool = False
+
+ def __post_init__(self):
+ if self.compile_prefill and self.sequential_prefill:
+ raise RuntimeError("prefill compilation requires parallel prefill")
+
+ def validate_build(
+ self, builder_args: BuilderArgs, model_description: str = "model"
+ ):
+ reason = ""
+ model_type = ""
+ if not self.sequential_prefill:
+ reason = "parallel prefill"
+ if self.compile_prefill:
+ reason = "model compilation for prefill"
+ if self.compile:
+ reason = "model compilation"
+ if builder_args.aoti_package_path:
+ model_type = "PT2"
+ if builder_args.dso_path:
+ model_type = "DSO"
+ if builder_args.pte_path:
+ model_type = "PTE"
+ if model_type and reason:
+ raise RuntimeError(
+ f"cannot perform {reason} because a {model_type} {model_description} is used"
+ )
+
+ @classmethod
+ def from_args(cls, args):
+ dso_path = getattr(args, "dso_path", None)
+ pte_path = getattr(args, "pte_path", None)
+ aoti_package_path = getattr(args, "aoti_package_path", None)
+ sequential_prefill = (
+ args.sequential_prefill or bool(aoti_package_path) or bool(pte_path) or bool(dso_path)
+ )
+
+ # Validate that all image prompts exist before expensive model load
+ if image_prompts := getattr(args, "image_prompts", None):
+ non_existent_image_prompts = [
+ image_prompt
+ for image_prompt in image_prompts
+ if (not os.path.exists(image_prompt))
+ ]
+ if non_existent_image_prompts:
+ raise RuntimeError(
+ f"Image prompt {non_existent_image_prompts} does not exist"
+ )
+
+ return cls(
+ prompt=getattr(args, "prompt", ""),
+ encoded_prompt=None,
+ image_prompts=image_prompts,
+ chat_mode=args.chat,
+ gui_mode=args.gui,
+ num_samples=getattr(args, "num_samples", 1),
+ max_new_tokens=args.max_new_tokens,
+ top_k=args.top_k,
+ temperature=args.temperature,
+ compile=args.compile,
+ compile_prefill=args.compile_prefill,
+ speculate_k=args.speculate_k,
+ sequential_prefill=sequential_prefill,
+ max_autotune=args.max_autotune,
+ is_torchtune_model=args.model and args.model.endswith("tune"),
+ )
+
+
+class Generator(object):
+ """
+ Base class for generators that can be used to generate text samples based on a pre-trained Transformer model and tokenizer.
+ """
+
+ def __init__(
+ self,
+ builder_args: BuilderArgs,
+ tokenizer_args: TokenizerArgs,
+ generator_args: GeneratorArgs,
+ ):
+ self.builder_args = builder_args
+ self.tokenizer_args = tokenizer_args
+ self.generate_args = generator_args
+
+ self.dtype = builder_args.precision
+
+ self.tokenizer = _initialize_tokenizer(self.tokenizer_args)
+
+ # Right now the assumption is only llama3 uses tiktokenizer and it
+ # must use tiktokenizer.
+ # Piggy backing off of this flag then for now to identify llama3
+ # without prompting user.
+ self.is_llama3_model = self.tokenizer_args.is_tiktoken
+ if self.is_llama3_model:
+ self.chat_formatter = Llama3ChatFormatter(self.tokenizer)
+ if generator_args.chat_mode:
+ logging.debug(
+ "Llama3 model detected in chat mode. Using updated sentence schemas"
+ )
+ else:
+ self.chat_formatter = Llama2ChatFormatter(self.tokenizer)
+
+ @abstractmethod
+ def is_text_only(self) -> bool:
+ """
+ Returns True if the model is text-only, False otherwise.
+ """
+ raise NotImplementedError()
+
+ def _gen_model_input(
+ self,
+ prompt: Union[str | List[Any]],
+ image_prompts: Optional[List[str | Image.Image]] = None,
+ max_new_tokens: Optional[int] = None,
+ max_seq_len: Optional[int] = 2048,
+ ) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]:
+ """
+ Convert prompt and image prompts into consumable model input args.
+
+ When prompt is a list, the anticipated format is OpenAI API Inspired:
+ [ ..., {"role": message["role"], "content": message["content"]}, ...]
+
+ Args:
+ prompt (Union[str, List[Any]]): Prompt or list of dialog.
+ image_prompts (Optional[List[str | Image.Image]]): List of image prompts. Used only with Llama 3.2 11B.
+ max_new_tokens (Optional[int]): Maximum number of new tokens to generate. Used only with Llama 3.2 11B.
+
+ Returns:
+ Tuple[torch.Tensor, Optional[Dict[str, Any]]]: Encoded prompt and batch config for multimodal models.
+ """
+
+ # Text-Only model
+ if self.is_text_only():
+ # Single String prompt
+ if isinstance(prompt, str):
+ encoded = self.encode_tokens(
+ prompt, bos=True, device=self.builder_args.device
+ )
+ # List of dialog
+ else:
+ tokens = self.chat_formatter.encode_dialog_prompt(prompt)
+ encoded = torch.tensor(
+ tokens, dtype=torch.int, device=self.builder_args.device
+ )
+
+ logging.debug(encoded)
+ return encoded, None
+
+ # Llama 3.2 11B
+ assert (
+ image_prompts is None or len(image_prompts) == 1
+ ), "At most one image is supported at the moment"
+
+ if image_prompts and isinstance(image_prompts[0], str):
+ images = [Image.open(image_prompts[0])]
+ else:
+ images = None
+
+ assert (
+ max_new_tokens is not None
+ ), "max_new_tokens must be specified for Flamingo models"
+
+ # Wrap string prompts into a list
+ if isinstance(prompt, str):
+ prompt = [{"role": "user", "content": prompt}]
+
+ image_found = False
+ messages = []
+ for message in prompt:
+ if isinstance(message["content"], str):
+ if not image_found and image_prompts:
+ messages.append(
+ Message(
+ role=message["role"],
+ content=[
+ {"type": "image", "content": images[0]},
+ {"type": "text", "content": message["content"]},
+ ],
+ )
+ )
+ image_found = True
+ else:
+ messages.append(Message(**message))
+
+ elif isinstance(message["content"], list):
+ images = None
+ for content_dict in message["content"]:
+ if content_dict["type"] == "text":
+ prompt_arg = content_dict["text"]
+ elif content_dict["type"] == "image_url":
+ assert (
+ images is None
+ ), "At most one image is supported at the moment"
+
+ base64_decoded = base64.b64decode(
+ content_dict["image_url"].split(";base64,")[1]
+ )
+ images = [Image.open(BytesIO(base64_decoded))]
+ image_found = True
+
+ is_multimodal = images is not None
+ content = [{"type": "text", "content": prompt_arg}]
+
+ if is_multimodal:
+ content = [{"type": "image", "content": images[0]}] + content
+
+ messages.append(
+ Message(
+ role=message["role"],
+ content=content,
+ )
+ )
+
+ messages.append(
+ Message(
+ role="assistant",
+ content="",
+ )
+ )
+
+ transform = llama3_2_vision_transform(str(self.tokenizer_args.tokenizer_path))
+
+ device = torch.device(device=self.builder_args.device)
+
+ with device, set_default_dtype(self.dtype):
+ data = transform({"messages": messages}, inference=True)
+
+ if image_found:
+ batch = padded_collate_tiled_images_and_mask(
+ [data], pad_direction="left", pad_max_images=1
+ )
+ encoded = batch.pop("tokens").to(device).view(-1)
+ seq_len = encoded.size(0)
+ batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
+ batch["encoder_input"]["images"] = batch["encoder_input"]["images"].to(
+ self.dtype
+ )
+
+ else:
+ encoded = torch.tensor(data["tokens"], device=device).view(-1)
+ seq_len = encoded.size(0)
+ batch = {}
+
+ total_response_length = seq_len + max_new_tokens
+ batch["causal_mask"] = torch.nn.functional.pad(
+ torch.tril(
+ torch.ones(
+ size=(total_response_length, total_response_length),
+ dtype=torch.bool,
+ )
+ ),
+ (
+ 0,
+ max_seq_len - total_response_length,
+ 0,
+ max_seq_len - total_response_length,
+ ),
+ value=0,
+ )
+
+ logging.debug(encoded)
+ return encoded, batch
+
+ def encode_tokens(self, string, bos=True, device="cpu"):
+ tokens = self.tokenizer.encode(string)
+ if bos:
+ tokens = [self.tokenizer.bos_id()] + tokens
+ logging.debug(f"Size after encode_tokens: {len(tokens)}")
+ return torch.tensor(tokens, dtype=torch.int, device=device)