|
1 | 1 | """Test Perplexity Chat API wrapper.""" |
2 | 2 |
|
3 | 3 | import os |
| 4 | +from typing import Any, Dict, List, Optional |
| 5 | +from unittest.mock import MagicMock |
4 | 6 |
|
5 | 7 | import pytest |
| 8 | +from langchain_core.messages import AIMessageChunk, BaseMessageChunk |
| 9 | +from pytest_mock import MockerFixture |
6 | 10 |
|
7 | 11 | from langchain_community.chat_models import ChatPerplexity |
8 | 12 |
|
@@ -40,3 +44,58 @@ def test_perplexity_initialization() -> None: |
40 | 44 | ]: |
41 | 45 | assert model.request_timeout == 1 |
42 | 46 | assert model.pplx_api_key == "test" |
| 47 | + |
| 48 | + |
| 49 | +@pytest.mark.requires("openai") |
| 50 | +def test_perplexity_stream_includes_citations(mocker: MockerFixture) -> None: |
| 51 | + """Test that the stream method includes citations in the additional_kwargs.""" |
| 52 | + llm = ChatPerplexity( |
| 53 | + model="test", |
| 54 | + timeout=30, |
| 55 | + verbose=True, |
| 56 | + ) |
| 57 | + mock_chunk_0 = { |
| 58 | + "choices": [ |
| 59 | + { |
| 60 | + "delta": { |
| 61 | + "content": "Hello ", |
| 62 | + }, |
| 63 | + "finish_reason": None, |
| 64 | + } |
| 65 | + ], |
| 66 | + "citations": ["example.com", "example2.com"], |
| 67 | + } |
| 68 | + mock_chunk_1 = { |
| 69 | + "choices": [ |
| 70 | + { |
| 71 | + "delta": { |
| 72 | + "content": "Perplexity", |
| 73 | + }, |
| 74 | + "finish_reason": None, |
| 75 | + } |
| 76 | + ], |
| 77 | + "citations": ["example.com", "example2.com"], |
| 78 | + } |
| 79 | + mock_chunks: List[Dict[str, Any]] = [mock_chunk_0, mock_chunk_1] |
| 80 | + mock_stream = MagicMock() |
| 81 | + mock_stream.__iter__.return_value = mock_chunks |
| 82 | + patcher = mocker.patch.object( |
| 83 | + llm.client.chat.completions, "create", return_value=mock_stream |
| 84 | + ) |
| 85 | + stream = llm.stream("Hello langchain") |
| 86 | + full: Optional[BaseMessageChunk] = None |
| 87 | + for i, chunk in enumerate(stream): |
| 88 | + full = chunk if full is None else full + chunk |
| 89 | + assert chunk.content == mock_chunks[i]["choices"][0]["delta"]["content"] |
| 90 | + if i == 0: |
| 91 | + assert chunk.additional_kwargs["citations"] == [ |
| 92 | + "example.com", |
| 93 | + "example2.com", |
| 94 | + ] |
| 95 | + else: |
| 96 | + assert "citations" not in chunk.additional_kwargs |
| 97 | + assert isinstance(full, AIMessageChunk) |
| 98 | + assert full.content == "Hello Perplexity" |
| 99 | + assert full.additional_kwargs == {"citations": ["example.com", "example2.com"]} |
| 100 | + |
| 101 | + patcher.assert_called_once() |
0 commit comments