diff --git a/src/helm/clients/huggingface_client.py b/src/helm/clients/huggingface_client.py index 1665ea85d8..cfc3b862ce 100644 --- a/src/helm/clients/huggingface_client.py +++ b/src/helm/clients/huggingface_client.py @@ -8,7 +8,7 @@ from typing import Any, Dict, List, Optional, TypedDict from helm.common.cache import CacheConfig -from helm.common.hierarchical_logger import hexception, htrack_block, hlog, hwarn +from helm.common.hierarchical_logger import hexception, htrack_block, hlog from helm.common.optional_dependencies import handle_module_not_found_error from helm.common.request import ( wrap_request_time, @@ -101,6 +101,7 @@ def serve_request(self, raw_request: HuggingFaceRequest) -> Dict: encoded_input = tokenizer(raw_request["prompt"], return_tensors="pt", return_token_type_ids=False).to( 0 if self.device is None else self.device ) + pad_token_id = tokenizer.eos_token_id stopping_criteria: Optional[StoppingCriteriaList] = None optional_args = {} if len(raw_request["stop_sequences"]) > 0: @@ -140,6 +141,7 @@ def serve_request(self, raw_request: HuggingFaceRequest) -> Dict: output_scores=True, **optional_args, stopping_criteria=stopping_criteria, + pad_token_id=pad_token_id, ) sequences = output.sequences scores = output.scores @@ -274,7 +276,7 @@ def __init__( else: with self._wrapped_tokenizer as hf_tokenizer: self._apply_chat_template = bool(hf_tokenizer.chat_template) - hwarn( + hlog( f"Automatically set `apply_chat_template` to {self._apply_chat_template} based on " "whether the tokenizer has a chat template. " "If this is incorrect, please explicitly set `apply_chat_template`." diff --git a/src/helm/clients/huggingface_pipeline_client.py b/src/helm/clients/huggingface_pipeline_client.py index a486e47ec3..85e2a2ae71 100644 --- a/src/helm/clients/huggingface_pipeline_client.py +++ b/src/helm/clients/huggingface_pipeline_client.py @@ -5,7 +5,7 @@ from helm.clients.client import CachingClient from helm.common.cache import CacheConfig -from helm.common.hierarchical_logger import htrack_block, hwarn +from helm.common.hierarchical_logger import hlog, htrack_block from helm.common.request import GeneratedOutput, Request, RequestResult, wrap_request_time from helm.proxy.retry import NonRetriableException @@ -64,7 +64,8 @@ def __init__( # e.g. Qwen2, Qwen 2.5. # For these models, the `apply_chat_template` arg should be explicitly set to false. self._apply_chat_template = bool(self._pipeline.tokenizer.chat_template) - hwarn( + + hlog( f"Automatically set `apply_chat_template` to {self._apply_chat_template} based on " "whether the tokenizer has a chat template. " "If this is incorrect, please explicitly set `apply_chat_template`." @@ -104,6 +105,7 @@ def make_request(self, request: Request) -> RequestResult: "top_k": request.top_k_per_token if do_sample else None, "do_sample": do_sample, "return_dict_in_generate": True, + "pad_token_id": self._pipeline.tokenizer.eos_token_id, } if request.stop_sequences: stop_sequence_ids = self._pipeline.tokenizer(