Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 4e9771c

Browse files
committed
add typing to added functions
1 parent 54d895b commit 4e9771c

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

dist_run.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,14 @@ def _load_model_weights(stage_module, hf_model_name, device, model_config):
122122
if num_missing_weights > 0:
123123
raise ValueError(f"Missing {num_missing_weights} weights")
124124

125-
def _encode_string(string, tokenizer, bos=True, device="cuda", dtype=torch.int64)-> torch.Tensor:
125+
def _encode_string(string: str, tokenizer, bos: bool =True, device: str ="cuda", dtype=torch.int64)-> torch.Tensor:
126126
"""Encode a prompt string into a tensor of token ids."""
127127
tokens = tokenizer.encode(string)
128128
if bos:
129129
tokens = [tokenizer.bos_id()] + tokens
130130
return torch.tensor(tokens, dtype=dtype, device=device)
131131

132-
def _create_padded_prompt(input_ids, tokenizer, seqlen, start_pos, device) -> Tuple[torch.Tensor, int]:
132+
def _create_padded_prompt(input_ids: torch.Tensor, tokenizer, seqlen: int, start_pos: int, device: str) -> Tuple[torch.Tensor, int]:
133133
"""Create a padded tensor for the encoded input prompt. Returns the padded tensor and the prompt length."""
134134
prompt_len = input_ids.size(0)
135135
max_new_tokens = min(seqlen, seqlen - start_pos - prompt_len)
@@ -251,7 +251,7 @@ def main():
251251
if len(cpu_tensors) > 0:
252252
raise ValueError("Found cpu tensors in stage")
253253

254-
prompt = "What is snow?"
254+
prompt = "What is the capital of France?"
255255
start_pos = 0
256256

257257
# encode the prompt

0 commit comments

Comments
 (0)