Skip to content

Commit 8e675d8

Browse files
authored
Support models on HuggingFace Inference Providers (#4086)
1 parent 7967e0d commit 8e675d8

File tree

2 files changed

+142
-0
lines changed

2 files changed

+142
-0
lines changed

src/helm/benchmark/model_deployment_registry.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,27 @@ def auto_generate_model_deployment(name: str) -> ModelDeployment:
160160
tokenizer_name=name,
161161
max_sequence_length=max_sequence_length,
162162
)
163+
elif model_deployment_base == "huggingface-inference-providers":
164+
from helm.tokenizers.huggingface_tokenizer import HuggingFaceTokenizer
165+
166+
pretrained_model_name_or_path = "/".join(name_parts[1:])
167+
pretrained_model_name_or_path = pretrained_model_name_or_path.split(":")[0]
168+
with HuggingFaceTokenizer.create_tokenizer(pretrained_model_name_or_path) as tokenizer:
169+
max_sequence_length = tokenizer.model_max_length
170+
if max_sequence_length > 1_000_000_000:
171+
hwarn(
172+
f"Hugging Face model {pretrained_model_name_or_path} does not have a configured model_max_length; "
173+
"input truncation may not work correctly; errors may result from exceeding the model's max length"
174+
)
175+
return ModelDeployment(
176+
name=name,
177+
model_name=model_name,
178+
client_spec=ClientSpec(
179+
"helm.clients.huggingface_inference_providers_client.HuggingFaceInferenceProvidersClient"
180+
),
181+
tokenizer_name=f"huggingface/{pretrained_model_name_or_path}",
182+
max_sequence_length=max_sequence_length,
183+
)
163184
else:
164185
raise NotImplementedError(f"Unknown model deployment base {model_deployment_base}")
165186

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import dataclasses
2+
from typing import Any, Dict, List, Optional, TypedDict
3+
4+
from huggingface_hub import ChatCompletionOutput, InferenceClient
5+
6+
from helm.common.cache import CacheConfig
7+
from helm.common.request import (
8+
Thinking,
9+
wrap_request_time,
10+
Request,
11+
RequestResult,
12+
GeneratedOutput,
13+
)
14+
from helm.clients.client import CachingClient
15+
16+
17+
class HuggingFaceInferenceProvidersChatCompletionRequest(TypedDict):
18+
model: str
19+
messages: List[Dict]
20+
frequency_penalty: Optional[float]
21+
logprobs: Optional[bool]
22+
max_tokens: Optional[int]
23+
n: Optional[int]
24+
presence_penalty: Optional[float]
25+
# TODO: Support JSON Schema response format
26+
# response_format: Optional[ChatCompletionInputGrammarType] = None
27+
stop: Optional[List[str]]
28+
temperature: Optional[float]
29+
top_p: Optional[float]
30+
31+
32+
class HuggingFaceInferenceProvidersClient(CachingClient):
33+
34+
def __init__(
35+
self,
36+
cache_config: CacheConfig,
37+
api_key: Optional[str] = None,
38+
base_url: Optional[str] = None,
39+
huggingface_model_name: Optional[str] = None,
40+
):
41+
super().__init__(cache_config=cache_config)
42+
self._client = InferenceClient(api_key=api_key, base_url=base_url)
43+
self._huggingface_model_name = huggingface_model_name
44+
45+
def _make_raw_request(self, request: Request) -> HuggingFaceInferenceProvidersChatCompletionRequest:
46+
input_messages: List[Dict[str, Any]]
47+
48+
if request.multimodal_prompt:
49+
raise ValueError("`multimodal_prompt` is not supported by `HuggingFaceInferenceProvidersClient`")
50+
51+
if request.prompt and request.messages:
52+
raise ValueError("More than one of `prompt` and `messages` was set in request")
53+
54+
if request.messages is not None:
55+
# Checks that all messages have a role and some content
56+
for message in request.messages:
57+
if not message.get("role") or not message.get("content"):
58+
raise ValueError("All messages must have a role and content")
59+
# Checks that the last role is "user"
60+
if request.messages[-1]["role"] != "user":
61+
raise ValueError("Last message must have role 'user'")
62+
input_messages = request.messages
63+
else:
64+
input_messages = [{"role": "user", "content": request.prompt}]
65+
66+
return {
67+
"model": self._huggingface_model_name or request.model,
68+
"messages": input_messages,
69+
"frequency_penalty": request.frequency_penalty,
70+
"logprobs": False,
71+
"max_tokens": request.max_tokens,
72+
"n": request.num_completions,
73+
"presence_penalty": request.presence_penalty,
74+
"stop": request.stop_sequences,
75+
"temperature": request.temperature,
76+
"top_p": request.top_p,
77+
}
78+
79+
def make_request(self, request: Request) -> RequestResult:
80+
if request.echo_prompt:
81+
raise NotImplementedError("`echo_prompt` is not supported")
82+
if request.embedding:
83+
raise NotImplementedError("`embedding` is not supported")
84+
85+
raw_request = self._make_raw_request(request)
86+
87+
def do_it() -> Dict[str, Any]:
88+
hf_raw_response = self._client.chat_completion(**raw_request)
89+
assert isinstance(hf_raw_response, ChatCompletionOutput)
90+
return dataclasses.asdict(hf_raw_response)
91+
92+
cache_key = CachingClient.make_cache_key(raw_request, request)
93+
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
94+
request_time = response["request_time"]
95+
del response["request_time"]
96+
request_datetime = response["request_datetime"]
97+
del response["request_datetime"]
98+
99+
completions: list[GeneratedOutput] = []
100+
for choice in response["choices"]:
101+
thinking = Thinking(text=choice["message"]["reasoning"]) if choice["message"]["reasoning"] else None
102+
output_text = choice["message"]["content"]
103+
if output_text is None:
104+
raise ValueError("Response content was `None`, possibly due to content blocking")
105+
completion = GeneratedOutput(
106+
text=output_text,
107+
logprob=0.0,
108+
tokens=[],
109+
finish_reason={"reason": choice["finish_reason"]},
110+
thinking=thinking,
111+
)
112+
completions.append(completion)
113+
114+
return RequestResult(
115+
success=True,
116+
cached=cached,
117+
request_time=request_time,
118+
request_datetime=request_datetime,
119+
completions=completions,
120+
embedding=[],
121+
)

0 commit comments

Comments
 (0)