@@ -122,14 +122,14 @@ def _load_model_weights(stage_module, hf_model_name, device, model_config):
122122 if num_missing_weights > 0 :
123123 raise ValueError (f"Missing { num_missing_weights } weights" )
124124
125- def _encode_string (string , tokenizer , bos = True , device = "cuda" , dtype = torch .int64 )-> torch .Tensor :
125+ def _encode_string (string : str , tokenizer , bos : bool = True , device : str = "cuda" , dtype = torch .int64 )-> torch .Tensor :
126126 """Encode a prompt string into a tensor of token ids."""
127127 tokens = tokenizer .encode (string )
128128 if bos :
129129 tokens = [tokenizer .bos_id ()] + tokens
130130 return torch .tensor (tokens , dtype = dtype , device = device )
131131
132- def _create_padded_prompt (input_ids , tokenizer , seqlen , start_pos , device ) -> Tuple [torch .Tensor , int ]:
132+ def _create_padded_prompt (input_ids : torch . Tensor , tokenizer , seqlen : int , start_pos : int , device : str ) -> Tuple [torch .Tensor , int ]:
133133 """Create a padded tensor for the encoded input prompt. Returns the padded tensor and the prompt length."""
134134 prompt_len = input_ids .size (0 )
135135 max_new_tokens = min (seqlen , seqlen - start_pos - prompt_len )
@@ -251,7 +251,7 @@ def main():
251251 if len (cpu_tensors ) > 0 :
252252 raise ValueError ("Found cpu tensors in stage" )
253253
254- prompt = "What is snow ?"
254+ prompt = "What is the capital of France ?"
255255 start_pos = 0
256256
257257 # encode the prompt
0 commit comments