Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ def encode_header(self, role) -> List[int]:

def encode_message(self, message) -> List[int]:
tokens = self.encode_header(message["role"])
if type(message["content"]) is str:
if isinstance(message["content"], str):
tokens.extend(
self.tokenizer.encode(message["content"], bos=False, eos=False)
)
elif type(message["content"]) is list:
elif isinstance(message["content"], list):
for content in message["content"]:
if content["type"] == "text":
tokens.extend(
Expand Down Expand Up @@ -190,7 +190,7 @@ def from_args(cls, args):
for image_prompt in image_prompts
if (not os.path.exists(image_prompt))
]
if len(non_existent_image_prompts):
if non_existent_image_prompts:
raise RuntimeError(
f"Image prompt {non_existent_image_prompts} does not exist"
)
Expand Down Expand Up @@ -238,7 +238,7 @@ def __init__(
draft_quantize: bool,
):
torch._inductor.config.coordinate_descent_tuning = (
False if builder_args.device == "cpu" else True
builder_args.device != "cpu"
)
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
Expand Down Expand Up @@ -1002,11 +1002,8 @@ def chat(
max_seq_length,
)

max_seq_length = (
max_seq_length + self.speculative_builder_args.speculate_k + 1
if self.draft_model is not None
else max_seq_length
)
if self.draft_model is not None:
max_seq_length += self.speculative_builder_args.speculate_k + 1

aggregate_metrics = {
"tokens_per_sec": [],
Expand Down
2 changes: 1 addition & 1 deletion torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
self.encoder = encoder
self.decoder = decoder

# esclate the embedding layer outside decoder llava model need to fuse
# escalate the embedding layer outside decoder llava model need to fuse
# the text and image embedding together before passing to decoder.
self.tok_embeddings = getattr(self.decoder, token_embedding_name)

Expand Down
Loading