Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 020176b

Browse files
committed
feat(log): Add better logging in model and generate
In generate, there were a number of commented-out log lines. These are safe to leave in as long as lazy string interpolation is used. Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart <[email protected]>
1 parent ac94d1c commit 020176b

File tree

2 files changed

+20
-9
lines changed

2 files changed

+20
-9
lines changed

torchchat/generate.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@
4545
from torchchat.utils.device_info import get_device_info
4646

4747

48+
logger = logging.getLogger(__name__)
49+
50+
4851
class _ChatFormatter(ABC):
4952
def __init__(self, tokenizer):
5053
self.tokenizer = tokenizer
@@ -292,7 +295,7 @@ def __init__(
292295
if self.is_llama3_model:
293296
self.chat_formatter = Llama3ChatFormatter(self.tokenizer)
294297
if generator_args.chat_mode:
295-
logging.debug(
298+
logger.debug(
296299
"Llama3 model detected in chat mode. Using updated sentence schemas"
297300
)
298301
elif self.tokenizer_args.is_hf_tokenizer:
@@ -354,10 +357,12 @@ def sample(
354357
temperature: float = 0,
355358
top_k: Optional[int] = None,
356359
):
360+
logits = logits[0, -1]
361+
logger.debug("Logits: %s", logits)
357362
if temperature == 0 and not need_probs:
358-
_, idx_next = torch.topk(logits[0, -1], k=1, dim=-1)
363+
_, idx_next = torch.topk(logits, k=1, dim=-1)
359364
return (idx_next, None)
360-
probs = self.logits_to_probs(logits[0, -1], temperature, top_k)
365+
probs = self.logits_to_probs(logits, temperature, top_k)
361366
idx_next = self.multinomial_sample_one_no_sync(probs)
362367
return idx_next, probs
363368

@@ -371,7 +376,7 @@ def prefill(
371376
sequential_prefill=True,
372377
**sampling_kwargs,
373378
) -> torch.Tensor:
374-
# logging.debug(f"x: {x}, input_pos: {input_pos}")
379+
logger.debug("x: %s, input_pos: %s", x, input_pos)
375380
width = x.size(1)
376381
assert input_pos.size(0) == width
377382

@@ -407,7 +412,7 @@ def prefill(
407412
elif sequential_prefill:
408413
for i in range(width):
409414
x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1)
410-
# logging.debug(f"<sliced> x: {x_sliced}, input_pos: {ip_sliced}")
415+
logger.debug("<sliced> x: %s, input_pos: %s", x_sliced, ip_sliced)
411416
logits = model(x_sliced, ip_sliced) # (x[:, i], input_pos[i])da
412417
else:
413418
# input_pos: [B, S]
@@ -740,7 +745,7 @@ def encode_tokens(self, string, bos=True, device="cpu"):
740745
tokens = self.tokenizer.encode(string)
741746
if bos:
742747
tokens = [self.tokenizer.bos_id()] + tokens
743-
logging.debug(f"Size after encode_tokens: {len(tokens)}")
748+
logger.debug("Size after encode_tokens: %d", len(tokens))
744749
return torch.tensor(tokens, dtype=torch.int, device=device)
745750

746751
def _callback(self, x, *, buffer, done_generating):
@@ -798,7 +803,7 @@ def _gen_model_input(
798803
tokens, dtype=torch.int, device=self.builder_args.device
799804
)
800805

801-
logging.debug(encoded)
806+
logger.debug(encoded)
802807
return encoded, None
803808

804809
# Llama 3.2 11B
@@ -913,7 +918,7 @@ def _gen_model_input(
913918
value=0,
914919
)
915920

916-
logging.debug(encoded)
921+
logger.debug(encoded)
917922
return encoded, batch
918923

919924
def chat(
@@ -1244,6 +1249,7 @@ def main(args):
12441249
speculative_builder_args = BuilderArgs.from_speculative_args(args)
12451250
tokenizer_args = TokenizerArgs.from_args(args)
12461251
generator_args = GeneratorArgs.from_args(args)
1252+
logger.debug("GeneratorArgs: %s", generator_args)
12471253
if not builder_args.distributed:
12481254
gen = Generator(
12491255
builder_args,

torchchat/model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66
import json
7+
import logging
78
import os
89
import warnings
910
from abc import ABC, abstractmethod
@@ -48,6 +49,8 @@
4849

4950
config_path = Path(f"{str(Path(__file__).parent)}/model_params")
5051

52+
logger = logging.getLogger(__name__)
53+
5154

5255
class QuickGELUActivation(nn.Module):
5356
"""
@@ -477,7 +480,9 @@ def build_model(self) -> nn.Module:
477480
for name, module_class in recipe.modules.items():
478481
config_args = self.config.transformer_args[name]
479482
if module_class == Transformer:
480-
modules[name] = module_class(TransformerArgs.from_params(config_args))
483+
transformer_args = TransformerArgs.from_params(config_args)
484+
logger.debug("Transformer Args: %s", transformer_args)
485+
modules[name] = module_class(transformer_args)
481486
else:
482487
modules[name] = module_class(**config_args)
483488

0 commit comments

Comments
 (0)