4545from torchchat .utils .device_info import get_device_info
4646
4747
48+ logger = logging .getLogger (__name__ )
49+
50+
4851class _ChatFormatter (ABC ):
4952 def __init__ (self , tokenizer ):
5053 self .tokenizer = tokenizer
@@ -292,7 +295,7 @@ def __init__(
292295 if self .is_llama3_model :
293296 self .chat_formatter = Llama3ChatFormatter (self .tokenizer )
294297 if generator_args .chat_mode :
295- logging .debug (
298+ logger .debug (
296299 "Llama3 model detected in chat mode. Using updated sentence schemas"
297300 )
298301 elif self .tokenizer_args .is_hf_tokenizer :
@@ -354,10 +357,12 @@ def sample(
354357 temperature : float = 0 ,
355358 top_k : Optional [int ] = None ,
356359 ):
360+ logits = logits [0 , - 1 ]
361+ logger .debug ("Logits: %s" , logits )
357362 if temperature == 0 and not need_probs :
358- _ , idx_next = torch .topk (logits [ 0 , - 1 ] , k = 1 , dim = - 1 )
363+ _ , idx_next = torch .topk (logits , k = 1 , dim = - 1 )
359364 return (idx_next , None )
360- probs = self .logits_to_probs (logits [ 0 , - 1 ] , temperature , top_k )
365+ probs = self .logits_to_probs (logits , temperature , top_k )
361366 idx_next = self .multinomial_sample_one_no_sync (probs )
362367 return idx_next , probs
363368
@@ -371,7 +376,7 @@ def prefill(
371376 sequential_prefill = True ,
372377 ** sampling_kwargs ,
373378 ) -> torch .Tensor :
374- # logging .debug(f "x: {x} , input_pos: {input_pos}" )
379+ logger .debug ("x: %s , input_pos: %s" , x , input_pos )
375380 width = x .size (1 )
376381 assert input_pos .size (0 ) == width
377382
@@ -407,7 +412,7 @@ def prefill(
407412 elif sequential_prefill :
408413 for i in range (width ):
409414 x_sliced , ip_sliced = x [:, i ].view (- 1 , 1 ), input_pos [i ].view (- 1 )
410- # logging .debug(f "<sliced> x: {x_sliced} , input_pos: {ip_sliced}" )
415+ logger .debug ("<sliced> x: %s , input_pos: %s" , x_sliced , ip_sliced )
411416 logits = model (x_sliced , ip_sliced ) # (x[:, i], input_pos[i])da
412417 else :
413418 # input_pos: [B, S]
@@ -740,7 +745,7 @@ def encode_tokens(self, string, bos=True, device="cpu"):
740745 tokens = self .tokenizer .encode (string )
741746 if bos :
742747 tokens = [self .tokenizer .bos_id ()] + tokens
743- logging .debug (f "Size after encode_tokens: { len (tokens )} " )
748+ logger .debug ("Size after encode_tokens: %d" , len (tokens ))
744749 return torch .tensor (tokens , dtype = torch .int , device = device )
745750
746751 def _callback (self , x , * , buffer , done_generating ):
@@ -798,7 +803,7 @@ def _gen_model_input(
798803 tokens , dtype = torch .int , device = self .builder_args .device
799804 )
800805
801- logging .debug (encoded )
806+ logger .debug (encoded )
802807 return encoded , None
803808
804809 # Llama 3.2 11B
@@ -913,7 +918,7 @@ def _gen_model_input(
913918 value = 0 ,
914919 )
915920
916- logging .debug (encoded )
921+ logger .debug (encoded )
917922 return encoded , batch
918923
919924 def chat (
@@ -1244,6 +1249,7 @@ def main(args):
12441249 speculative_builder_args = BuilderArgs .from_speculative_args (args )
12451250 tokenizer_args = TokenizerArgs .from_args (args )
12461251 generator_args = GeneratorArgs .from_args (args )
1252+ logger .debug ("GeneratorArgs: %s" , generator_args )
12471253 if not builder_args .distributed :
12481254 gen = Generator (
12491255 builder_args ,
0 commit comments