Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/helm/clients/huggingface_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Any, Dict, List, Optional, TypedDict

from helm.common.cache import CacheConfig
from helm.common.hierarchical_logger import hexception, htrack_block, hlog, hwarn
from helm.common.hierarchical_logger import hexception, htrack_block, hlog
from helm.common.optional_dependencies import handle_module_not_found_error
from helm.common.request import (
wrap_request_time,
Expand Down Expand Up @@ -101,6 +101,7 @@ def serve_request(self, raw_request: HuggingFaceRequest) -> Dict:
encoded_input = tokenizer(raw_request["prompt"], return_tensors="pt", return_token_type_ids=False).to(
0 if self.device is None else self.device
)
pad_token_id = tokenizer.eos_token_id
stopping_criteria: Optional[StoppingCriteriaList] = None
optional_args = {}
if len(raw_request["stop_sequences"]) > 0:
Expand Down Expand Up @@ -140,6 +141,7 @@ def serve_request(self, raw_request: HuggingFaceRequest) -> Dict:
output_scores=True,
**optional_args,
stopping_criteria=stopping_criteria,
pad_token_id=pad_token_id,
)
sequences = output.sequences
scores = output.scores
Expand Down Expand Up @@ -274,7 +276,7 @@ def __init__(
else:
with self._wrapped_tokenizer as hf_tokenizer:
self._apply_chat_template = bool(hf_tokenizer.chat_template)
hwarn(
hlog(
f"Automatically set `apply_chat_template` to {self._apply_chat_template} based on "
"whether the tokenizer has a chat template. "
"If this is incorrect, please explicitly set `apply_chat_template`."
Expand Down
6 changes: 4 additions & 2 deletions src/helm/clients/huggingface_pipeline_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from helm.clients.client import CachingClient
from helm.common.cache import CacheConfig
from helm.common.hierarchical_logger import htrack_block, hwarn
from helm.common.hierarchical_logger import hlog, htrack_block
from helm.common.request import GeneratedOutput, Request, RequestResult, wrap_request_time
from helm.proxy.retry import NonRetriableException

Expand Down Expand Up @@ -64,7 +64,8 @@ def __init__(
# e.g. Qwen2, Qwen 2.5.
# For these models, the `apply_chat_template` arg should be explicitly set to false.
self._apply_chat_template = bool(self._pipeline.tokenizer.chat_template)
hwarn(

hlog(
f"Automatically set `apply_chat_template` to {self._apply_chat_template} based on "
"whether the tokenizer has a chat template. "
"If this is incorrect, please explicitly set `apply_chat_template`."
Expand Down Expand Up @@ -104,6 +105,7 @@ def make_request(self, request: Request) -> RequestResult:
"top_k": request.top_k_per_token if do_sample else None,
"do_sample": do_sample,
"return_dict_in_generate": True,
"pad_token_id": self._pipeline.tokenizer.eos_token_id,
}
if request.stop_sequences:
stop_sequence_ids = self._pipeline.tokenizer(
Expand Down