Skip to content

Commit bc5dd4f

Browse files
authored
[Bugfix] Fix broken GritLM model and tests (missing pooling_metadata) (#16631)
Signed-off-by: Pooya Davoodi <[email protected]>
1 parent dbb036c commit bc5dd4f

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

tests/models/embedding/language/test_gritlm.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,24 +57,25 @@ def test_find_array(monkeypatch: pytest.MonkeyPatch):
5757
def server_embedding():
5858
# GritLM embedding implementation is only supported by XFormers backend.
5959
args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)]
60-
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
61-
yield remote_server
60+
with pytest.MonkeyPatch.context() as m:
61+
m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
62+
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
63+
yield remote_server
6264

6365

6466
@pytest.fixture(scope="module")
6567
def server_generate():
6668
args = ["--task", "generate", "--max_model_len", str(MAX_MODEL_LEN)]
67-
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
68-
yield remote_server
69+
with pytest.MonkeyPatch.context() as m:
70+
m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
71+
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
72+
yield remote_server
6973

7074

7175
@pytest_asyncio.fixture
72-
async def client_embedding(monkeypatch: pytest.MonkeyPatch,
73-
server_embedding: RemoteOpenAIServer):
74-
with monkeypatch.context() as m:
75-
m.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
76-
async with server_embedding.get_async_client() as async_client:
77-
yield async_client
76+
async def client_embedding(server_embedding: RemoteOpenAIServer):
77+
async with server_embedding.get_async_client() as async_client:
78+
yield async_client
7879

7980

8081
@pytest_asyncio.fixture

vllm/model_executor/models/gritlm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,8 @@ def forward(
170170
mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze(
171171
1)
172172

173-
pooled_data = self.head(mean_embeddings)
173+
pooled_data = self.head(mean_embeddings,
174+
pooling_metadata=pooling_metadata)
174175

175176
pooled_outputs = [
176177
PoolingSequenceGroupOutput(data) for data in pooled_data

0 commit comments

Comments
 (0)