|
13 | 13 | # permissions and limitations under the License. |
14 | 14 | """Comprehensive test for parameter resolution and flow in serving.""" |
15 | 15 |
|
16 | | -from typing import Any, Dict, List, Optional |
17 | 16 | from unittest.mock import MagicMock |
18 | 17 |
|
19 | 18 | import pytest |
20 | | -from pydantic import BaseModel |
21 | 19 |
|
22 | 20 | from zenml.deployers.server import runtime |
23 | 21 |
|
24 | 22 |
|
25 | | -class WeatherRequest(BaseModel): |
26 | | - """Mock WeatherRequest for testing.""" |
27 | | - |
28 | | - city: str |
29 | | - activities: List[str] |
30 | | - extra: Optional[Dict[str, Any]] = None |
31 | | - |
32 | | - |
33 | | -class TestParameterResolution: |
34 | | - """Test parameter resolution in serving context.""" |
35 | | - |
36 | | - @pytest.fixture(autouse=True) |
37 | | - def setup_serving_state(self): |
38 | | - """Set up deployment state for each test.""" |
39 | | - runtime.stop() # Ensure clean state |
40 | | - yield |
41 | | - runtime.stop() # Clean up after test |
42 | | - |
43 | | - def test_get_parameter_override_direct_only(self): |
44 | | - """Test that only direct parameters are returned (no nested extraction).""" |
45 | | - # Set up deployment state with WeatherRequest |
46 | | - request_obj = WeatherRequest( |
47 | | - city="munich", |
48 | | - activities=["sightseeing", "eating"], |
49 | | - extra={"budget": 500}, |
50 | | - ) |
51 | | - |
52 | | - snapshot = MagicMock() |
53 | | - snapshot.id = "test-snapshot" |
54 | | - |
55 | | - runtime.start( |
56 | | - request_id="test-request", |
57 | | - snapshot=snapshot, |
58 | | - parameters={ |
59 | | - "request": request_obj, |
60 | | - "country": "Germany", |
61 | | - }, |
62 | | - ) |
63 | | - |
64 | | - # Direct parameter only |
65 | | - assert runtime.get_parameter_override("country") == "Germany" |
66 | | - # Nested attributes are not extracted automatically |
67 | | - assert runtime.get_parameter_override("city") is None |
68 | | - assert runtime.get_parameter_override("activities") is None |
69 | | - assert runtime.get_parameter_override("extra") is None |
70 | | - |
71 | | - # Removed precedence test: nested extraction no longer supported |
72 | | - |
73 | | - def test_inactive_deployment_context(self): |
74 | | - """Test parameter resolution when serving is not active.""" |
75 | | - # Don't start serving context |
76 | | - assert runtime.get_parameter_override("city") is None |
77 | | - |
78 | | - def test_empty_pipeline_parameters(self): |
79 | | - """Test parameter resolution with empty pipeline parameters.""" |
80 | | - snapshot = MagicMock() |
81 | | - snapshot.id = "test-snapshot" |
82 | | - |
83 | | - runtime.start( |
84 | | - request_id="test-request", snapshot=snapshot, parameters={} |
85 | | - ) |
86 | | - |
87 | | - # Should return None when no parameters are available |
88 | | - assert runtime.get_parameter_override("city") is None |
89 | | - |
90 | | - # Removed complex object extraction test: not supported |
91 | | - |
92 | | - |
93 | | -class TestCompleteParameterFlow: |
94 | | - """Test complete parameter flow from request to step execution.""" |
95 | | - |
96 | | - @pytest.fixture(autouse=True) |
97 | | - def setup_serving_state(self): |
98 | | - """Set up deployment state for each test.""" |
99 | | - runtime.stop() |
100 | | - yield |
101 | | - runtime.stop() |
102 | | - |
103 | | - @pytest.fixture |
104 | | - def mock_pipeline_class(self): |
105 | | - """Mock pipeline class with WeatherRequest signature.""" |
106 | | - |
107 | | - class MockWeatherPipeline: |
108 | | - @staticmethod |
109 | | - def entrypoint( |
110 | | - request: WeatherRequest = WeatherRequest( |
111 | | - city="London", |
112 | | - activities=["walking", "reading"], |
113 | | - extra={"temperature": 20}, |
114 | | - ), |
115 | | - country: str = "UK", |
116 | | - ) -> str: |
117 | | - return f"Weather for {request.city} in {country}" |
118 | | - |
119 | | - return MockWeatherPipeline |
120 | | - |
121 | | - @pytest.fixture |
122 | | - def mock_snapshot(self, mock_pipeline_class): |
123 | | - """Mock snapshot with WeatherRequest defaults.""" |
124 | | - snapshot = MagicMock() |
125 | | - snapshot.id = "test-snapshot-id" |
126 | | - snapshot.pipeline_spec = MagicMock() |
127 | | - snapshot.pipeline_spec.source = "mock.pipeline.source" |
128 | | - snapshot.pipeline_spec.parameters = { |
129 | | - "request": { |
130 | | - "city": "London", |
131 | | - "activities": ["walking", "reading"], |
132 | | - "extra": {"temperature": 20}, |
133 | | - }, |
134 | | - "country": "UK", |
135 | | - } |
136 | | - return snapshot |
137 | | - |
138 | | - def test_weather_pipeline_scenario(self): |
139 | | - """Test the exact scenario from the weather pipeline.""" |
140 | | - # This simulates the exact case: |
141 | | - # @pipeline |
142 | | - # def weather_agent_pipeline(request: WeatherRequest = ..., country: str = "UK"): |
143 | | - # weather_data = get_weather(city=request.city, country=country) |
144 | | - |
145 | | - request_obj = WeatherRequest( |
146 | | - city="munich", activities=["whatever"], extra=None |
147 | | - ) |
148 | | - |
149 | | - snapshot = MagicMock() |
150 | | - snapshot.id = "test-snapshot" |
151 | | - |
152 | | - runtime.start( |
153 | | - request_id="test-request", |
154 | | - snapshot=snapshot, |
155 | | - parameters={ |
156 | | - "request": request_obj, |
157 | | - "country": "Germany", |
158 | | - }, |
159 | | - ) |
160 | | - |
161 | | - # Simulate the get_weather step trying to resolve its parameters |
162 | | - request_param = runtime.get_parameter_override("request") |
163 | | - country_param = runtime.get_parameter_override("country") |
164 | | - |
165 | | - # These should be the values that get passed to get_weather() |
166 | | - assert isinstance(request_param, WeatherRequest) |
167 | | - assert request_param.city == "munich" |
168 | | - assert country_param == "Germany" |
169 | | - |
170 | | - # This is exactly what should happen in the serving pipeline: |
171 | | - # get_weather(city="munich", country="Germany") |
172 | | - # instead of the compiled defaults: get_weather(city="London", country="UK") |
173 | | - |
174 | | - |
175 | 23 | class TestOutputRecording: |
176 | 24 | """Test output recording and retrieval functionality.""" |
177 | 25 |
|
|
0 commit comments