Skip to content

Commit 29ddffd

Browse files
authored
Allow HF and sentence-transformer models (#63)
* Fix Bug to allow both HF and sentence-transformer models * Add tests
1 parent 0ecc5d3 commit 29ddffd

File tree

5 files changed

+81
-4
lines changed

5 files changed

+81
-4
lines changed

crossfit/backend/torch/hf/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from crossfit.backend.torch.model import Model
2727
from crossfit.dataset.home import CF_HOME
28+
from crossfit.utils.model_adapter import adapt_model_input
2829

2930

3031
class HFModel(Model):
@@ -96,7 +97,7 @@ def fit_memory_estimate_curve(self, model=None):
9697
}
9798

9899
try:
99-
_ = model(**batch)
100+
_ = adapt_model_input(model, batch)
100101
memory_used = torch.cuda.max_memory_allocated() / (1024**2) # Convert to MB
101102
X.append([batch_size, seq_len, seq_len**2])
102103
y.append(memory_used)

crossfit/backend/torch/loader.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from crossfit.data.array.conversion import convert_array
2323
from crossfit.data.array.dispatch import crossarray
2424
from crossfit.data.dataframe.dispatch import CrossFrame
25+
from crossfit.utils.model_adapter import adapt_model_input
2526

2627
DEFAULT_BATCH_SIZE = 512
2728

@@ -70,7 +71,7 @@ def __next__(self):
7071
self.current_idx += self.batch_size
7172

7273
for fn in self._to_map:
73-
batch = fn(batch)
74+
batch = adapt_model_input(fn, batch)
7475

7576
if self.progress_bar is not None:
7677
self.progress_bar.update(batch_size)
@@ -141,7 +142,7 @@ def __next__(self):
141142
batch = {key: val[:, :clip_len] for key, val in batch.items()}
142143

143144
for fn in self._to_map:
144-
batch = fn(batch)
145+
batch = adapt_model_input(fn, batch)
145146

146147
break
147148
except torch.cuda.OutOfMemoryError:

crossfit/backend/torch/op/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def call(self, data, partition_info=None):
7474
for output in loader.map(self.model.get_model(self.get_worker())):
7575
if isinstance(output, dict):
7676
if self.model_output_col not in output:
77-
raise ValueError(f"Column '{self.model_outupt_col}' not found in model output.")
77+
raise ValueError(f"Column '{self.model_output_col}' not found in model output.")
7878
output = output[self.model_output_col]
7979

8080
if self.post is not None:

crossfit/utils/model_adapter.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from typing import Any, Callable
2+
3+
4+
def adapt_model_input(model: Callable, encoded_input: dict) -> Any:
5+
"""
6+
Adapt the encoded input to the model, handling both single and multiple argument cases.
7+
8+
This function allows flexible calling of different model types:
9+
- Models expecting keyword arguments (e.g., Hugging Face models)
10+
- Models expecting a single dictionary input (e.g., Sentence Transformers)
11+
12+
:param model: The model function to apply
13+
:param encoded_input: The encoded input to pass to the model
14+
:return: The output of the model
15+
"""
16+
try:
17+
# First, try to call the model with keyword arguments
18+
# For standard Hugging Face models
19+
return model(**encoded_input)
20+
except TypeError:
21+
# If that fails, try calling it with a single argument
22+
# This is useful for models like Sentence Transformers
23+
return model(encoded_input)

tests/test_utils.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2024 NVIDIA CORPORATION
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
from crossfit.utils.model_adapter import adapt_model_input
18+
19+
torch = pytest.importorskip("torch")
20+
sentence_transformers = pytest.importorskip("sentence_transformers")
21+
transformers = pytest.importorskip("transformers")
22+
23+
24+
def test_adapt_model_input_hf():
25+
from transformers import AutoTokenizer, DistilBertModel
26+
27+
with torch.no_grad():
28+
model_hf = DistilBertModel.from_pretrained("distilbert-base-uncased")
29+
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
30+
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
31+
32+
# Hugging Face model output
33+
outputs_hf = model_hf(**inputs)
34+
adapted_inputs_hf = adapt_model_input(model_hf, inputs)
35+
assert torch.equal(adapted_inputs_hf.last_hidden_state, outputs_hf.last_hidden_state)
36+
37+
38+
def test_adapt_model_input_sentence_transformers():
39+
from transformers import AutoTokenizer
40+
41+
with torch.no_grad():
42+
model_st = sentence_transformers.SentenceTransformer("all-MiniLM-L6-v2").to("cpu")
43+
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
44+
45+
inputs = tokenizer(
46+
["Hello", "my dog is cute"], return_tensors="pt", padding=True, truncation=True
47+
)
48+
# Sentence Transformers model output
49+
expected_output = model_st(inputs)
50+
adapted_output_st = adapt_model_input(model_st, inputs)
51+
52+
assert torch.equal(adapted_output_st.sentence_embedding, expected_output.sentence_embedding)

0 commit comments

Comments
 (0)