diff --git a/torchchat/generate.py b/torchchat/generate.py index be6a2e819..dd423b58a 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -71,11 +71,11 @@ def encode_header(self, role) -> List[int]: def encode_message(self, message) -> List[int]: tokens = self.encode_header(message["role"]) - if type(message["content"]) is str: + if isinstance(message["content"], str): tokens.extend( self.tokenizer.encode(message["content"], bos=False, eos=False) ) - elif type(message["content"]) is list: + elif isinstance(message["content"], list): for content in message["content"]: if content["type"] == "text": tokens.extend( @@ -190,7 +190,7 @@ def from_args(cls, args): for image_prompt in image_prompts if (not os.path.exists(image_prompt)) ] - if len(non_existent_image_prompts): + if non_existent_image_prompts: raise RuntimeError( f"Image prompt {non_existent_image_prompts} does not exist" ) @@ -238,7 +238,7 @@ def __init__( draft_quantize: bool, ): torch._inductor.config.coordinate_descent_tuning = ( - False if builder_args.device == "cpu" else True + builder_args.device != "cpu" ) torch._inductor.config.triton.unique_kernel_names = True torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future @@ -1002,11 +1002,8 @@ def chat( max_seq_length, ) - max_seq_length = ( - max_seq_length + self.speculative_builder_args.speculate_k + 1 - if self.draft_model is not None - else max_seq_length - ) + if self.draft_model is not None: + max_seq_length += self.speculative_builder_args.speculate_k + 1 aggregate_metrics = { "tokens_per_sec": [], diff --git a/torchchat/model.py b/torchchat/model.py index 11f3dc167..2a3b9f12f 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -94,7 +94,7 @@ def __init__( self.encoder = encoder self.decoder = decoder - # esclate the embedding layer outside decoder llava model need to fuse + # escalate the embedding layer outside decoder llava model need to fuse # the text and image embedding together before passing to decoder. self.tok_embeddings = getattr(self.decoder, token_embedding_name)