|
24 | 24 |
|
25 | 25 | from PIL import Image |
26 | 26 |
|
| 27 | +# torchtune model definition dependencies |
| 28 | +from torchtune.data import Message, padded_collate_tiled_images_and_mask |
| 29 | + |
| 30 | +from torchtune.generation import sample as tune_sample |
| 31 | +from torchtune.models.llama3 import llama3_tokenizer |
| 32 | + |
| 33 | +from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform |
| 34 | +from torchtune.training import set_default_dtype |
| 35 | + |
27 | 36 | from torchchat.cli.builder import ( |
28 | 37 | _initialize_model, |
29 | 38 | _initialize_tokenizer, |
|
34 | 43 | from torchchat.utils.build_utils import device_sync, set_precision |
35 | 44 | from torchchat.utils.device_info import get_device_info |
36 | 45 |
|
37 | | -# torchtune model definition dependencies |
38 | | -from torchtune.data import Message, padded_collate_tiled_images_and_mask |
39 | | - |
40 | | -from torchtune.generation import sample as tune_sample |
41 | | -from torchtune.models.llama3 import llama3_tokenizer |
42 | | - |
43 | | -from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform |
44 | | -from torchtune.training import set_default_dtype |
45 | | - |
46 | 46 |
|
47 | 47 | class _ChatFormatter(ABC): |
48 | 48 | def __init__(self, tokenizer): |
@@ -1164,13 +1164,9 @@ def callback(x, *, done_generating=False): |
1164 | 1164 | print( |
1165 | 1165 | f"just-in-time compilation time (incl run time): {compilation_time:.2} seconds" |
1166 | 1166 | ) |
1167 | | - aggregate_metrics["tokens_per_sec_jit_compile"] = tokens_sec |
1168 | | - # Don't continue here.... because we need to report and reset |
1169 | | - # continue |
1170 | | - else: |
1171 | | - aggregate_metrics["tokens_per_sec"].append(tokens_sec) |
1172 | | - aggregate_metrics["first_token_per_sec"].append(first_token_sec) |
1173 | | - aggregate_metrics["next_tokens_per_sec"].append(next_tokens_sec) |
| 1167 | + aggregate_metrics["tokens_per_sec"].append(tokens_sec) |
| 1168 | + aggregate_metrics["first_token_per_sec"].append(first_token_sec) |
| 1169 | + aggregate_metrics["next_tokens_per_sec"].append(next_tokens_sec) |
1174 | 1170 |
|
1175 | 1171 | logging.info( |
1176 | 1172 | f"\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\ |
|
0 commit comments