77import pytest
88
99from vllm_omni .config .stage_config import (
10+ ModelPipeline ,
1011 StageConfig ,
1112 StageConfigFactory ,
12- StageTopology ,
1313 StageType ,
1414)
1515
@@ -103,8 +103,8 @@ def test_to_omegaconf_with_runtime_overrides(self):
103103 assert omega_config .runtime .max_batch_size == 64
104104
105105
106- class TestStageTopology :
107- """Tests for StageTopology class."""
106+ class TestModelPipeline :
107+ """Tests for ModelPipeline class."""
108108
109109 def test_valid_linear_dag (self ):
110110 """Test validation of a valid linear DAG."""
@@ -113,8 +113,8 @@ def test_valid_linear_dag(self):
113113 StageConfig (stage_id = 1 , model_stage = "talker" , input_sources = [0 ]),
114114 StageConfig (stage_id = 2 , model_stage = "code2wav" , input_sources = [1 ]),
115115 ]
116- topology = StageTopology (model_type = "test" , stages = stages )
117- errors = topology . validate_topology ()
116+ pipeline = ModelPipeline (model_type = "test" , stages = stages )
117+ errors = pipeline . validate_pipeline ()
118118 assert errors == [], f"Unexpected errors: { errors } "
119119
120120 def test_valid_branching_dag (self ):
@@ -124,8 +124,8 @@ def test_valid_branching_dag(self):
124124 StageConfig (stage_id = 1 , model_stage = "branch_a" , input_sources = [0 ]),
125125 StageConfig (stage_id = 2 , model_stage = "branch_b" , input_sources = [0 ]),
126126 ]
127- topology = StageTopology (model_type = "test" , stages = stages )
128- errors = topology . validate_topology ()
127+ pipeline = ModelPipeline (model_type = "test" , stages = stages )
128+ errors = pipeline . validate_pipeline ()
129129 assert errors == [], f"Unexpected errors: { errors } "
130130
131131 def test_missing_entry_point (self ):
@@ -134,8 +134,8 @@ def test_missing_entry_point(self):
134134 StageConfig (stage_id = 0 , model_stage = "stage_a" , input_sources = [1 ]),
135135 StageConfig (stage_id = 1 , model_stage = "stage_b" , input_sources = [0 ]),
136136 ]
137- topology = StageTopology (model_type = "test" , stages = stages )
138- errors = topology . validate_topology ()
137+ pipeline = ModelPipeline (model_type = "test" , stages = stages )
138+ errors = pipeline . validate_pipeline ()
139139 assert any ("entry point" in e .lower () for e in errors )
140140
141141 def test_missing_dependency (self ):
@@ -144,8 +144,8 @@ def test_missing_dependency(self):
144144 StageConfig (stage_id = 0 , model_stage = "input" , input_sources = []),
145145 StageConfig (stage_id = 1 , model_stage = "output" , input_sources = [99 ]), # Invalid
146146 ]
147- topology = StageTopology (model_type = "test" , stages = stages )
148- errors = topology . validate_topology ()
147+ pipeline = ModelPipeline (model_type = "test" , stages = stages )
148+ errors = pipeline . validate_pipeline ()
149149 assert any ("non-existent" in e .lower () for e in errors )
150150
151151 def test_duplicate_stage_ids (self ):
@@ -154,8 +154,8 @@ def test_duplicate_stage_ids(self):
154154 StageConfig (stage_id = 0 , model_stage = "stage_a" , input_sources = []),
155155 StageConfig (stage_id = 0 , model_stage = "stage_b" , input_sources = []), # Duplicate
156156 ]
157- topology = StageTopology (model_type = "test" , stages = stages )
158- errors = topology . validate_topology ()
157+ pipeline = ModelPipeline (model_type = "test" , stages = stages )
158+ errors = pipeline . validate_pipeline ()
159159 assert any ("duplicate" in e .lower () for e in errors )
160160
161161 def test_self_reference (self ):
@@ -164,8 +164,8 @@ def test_self_reference(self):
164164 StageConfig (stage_id = 0 , model_stage = "entry" , input_sources = []),
165165 StageConfig (stage_id = 1 , model_stage = "self_ref" , input_sources = [1 ]), # Self
166166 ]
167- topology = StageTopology (model_type = "test" , stages = stages )
168- errors = topology . validate_topology ()
167+ pipeline = ModelPipeline (model_type = "test" , stages = stages )
168+ errors = pipeline . validate_pipeline ()
169169 assert any ("itself" in e .lower () for e in errors )
170170
171171 def test_get_stage_by_id (self ):
@@ -174,19 +174,19 @@ def test_get_stage_by_id(self):
174174 StageConfig (stage_id = 0 , model_stage = "thinker" , input_sources = []),
175175 StageConfig (stage_id = 1 , model_stage = "talker" , input_sources = [0 ]),
176176 ]
177- topology = StageTopology (model_type = "test" , stages = stages )
177+ pipeline = ModelPipeline (model_type = "test" , stages = stages )
178178
179- stage = topology .get_stage (1 )
179+ stage = pipeline .get_stage (1 )
180180 assert stage is not None
181181 assert stage .model_stage == "talker"
182182
183- missing = topology .get_stage (99 )
183+ missing = pipeline .get_stage (99 )
184184 assert missing is None
185185
186- def test_empty_topology (self ):
187- """Test validation of empty topology ."""
188- topology = StageTopology (model_type = "test" , stages = [])
189- errors = topology . validate_topology ()
186+ def test_empty_pipeline (self ):
187+ """Test validation of empty pipeline ."""
188+ pipeline = ModelPipeline (model_type = "test" , stages = [])
189+ errors = pipeline . validate_pipeline ()
190190 assert any ("no stages" in e .lower () for e in errors )
191191
192192
@@ -281,43 +281,43 @@ def test_per_stage_override_excludes_internal_keys(self):
281281 assert "model" not in overrides
282282 assert "batch_timeout" not in overrides
283283
284- def test_all_topology_files_exist (self ):
285- """Test that every entry in TOPOLOGY_FILES has an actual YAML file."""
286- from vllm_omni .model_executor . stage_topologies import get_topology_path
284+ def test_all_pipeline_files_exist (self ):
285+ """Test that every entry in PIPELINE_DIRS has an actual pipeline.yaml file."""
286+ from vllm_omni .model_pipelines import get_pipeline_path
287287
288- for model_type , filename in StageConfigFactory .TOPOLOGY_FILES .items ():
289- path = get_topology_path ( filename )
290- assert path .exists (), f"Missing topology file for { model_type } : { path } "
288+ for model_type , dir_name in StageConfigFactory .PIPELINE_DIRS .items ():
289+ path = get_pipeline_path ( dir_name , "pipeline.yaml" )
290+ assert path .exists (), f"Missing pipeline file for { model_type } : { path } "
291291
292- @pytest .mark .parametrize ("model_type" , list (StageConfigFactory .TOPOLOGY_FILES .keys ()))
293- def test_parse_real_topology_files (self , model_type ):
294- """Test that each shipped topology YAML parses and validates correctly."""
295- from vllm_omni .model_executor . stage_topologies import get_topology_path
292+ @pytest .mark .parametrize ("model_type" , list (StageConfigFactory .PIPELINE_DIRS .keys ()))
293+ def test_parse_real_pipeline_files (self , model_type ):
294+ """Test that each shipped pipeline YAML parses and validates correctly."""
295+ from vllm_omni .model_pipelines import get_pipeline_path
296296
297- filename = StageConfigFactory .TOPOLOGY_FILES [model_type ]
298- path = get_topology_path ( filename )
299- topology = StageConfigFactory ._parse_topology_yaml (path , model_type )
297+ dir_name = StageConfigFactory .PIPELINE_DIRS [model_type ]
298+ path = get_pipeline_path ( dir_name , "pipeline.yaml" )
299+ pipeline = StageConfigFactory ._parse_pipeline_yaml (path , model_type )
300300
301301 # Basic structure
302- assert topology .model_type == model_type
303- assert len (topology .stages ) >= 1
302+ assert pipeline .model_type == model_type
303+ assert len (pipeline .stages ) >= 1
304304
305305 # Must pass validation
306- errors = topology . validate_topology ()
306+ errors = pipeline . validate_pipeline ()
307307 assert errors == [], f"{ model_type } : { errors } "
308308
309309 # Every stage must have required fields
310- for stage in topology .stages :
310+ for stage in pipeline .stages :
311311 assert isinstance (stage .stage_id , int )
312312 assert isinstance (stage .model_stage , str )
313313 assert isinstance (stage .stage_type , StageType )
314314
315315
316- class TestTopologyYamlParsing :
317- """Tests for stage topology YAML file parsing (@ZJY0516)."""
316+ class TestPipelineYamlParsing :
317+ """Tests for model pipeline YAML file parsing (@ZJY0516)."""
318318
319319 def test_parse_qwen3_omni_moe_yaml (self , tmp_path ):
320- """Test parsing the qwen3_omni_moe topology YAML."""
320+ """Test parsing the qwen3_omni_moe pipeline YAML."""
321321 yaml_content = """\
322322 model_type: qwen3_omni_moe
323323
@@ -356,13 +356,13 @@ def test_parse_qwen3_omni_moe_yaml(self, tmp_path):
356356 yaml_file = tmp_path / "qwen3_omni_moe.yaml"
357357 yaml_file .write_text (yaml_content )
358358
359- topology = StageConfigFactory ._parse_topology_yaml (yaml_file , "qwen3_omni_moe" )
359+ pipeline = StageConfigFactory ._parse_pipeline_yaml (yaml_file , "qwen3_omni_moe" )
360360
361- assert topology .model_type == "qwen3_omni_moe"
362- assert len (topology .stages ) == 3
361+ assert pipeline .model_type == "qwen3_omni_moe"
362+ assert len (pipeline .stages ) == 3
363363
364364 # Stage 0: thinker
365- s0 = topology .stages [0 ]
365+ s0 = pipeline .stages [0 ]
366366 assert s0 .stage_id == 0
367367 assert s0 .model_stage == "thinker"
368368 assert s0 .stage_type == StageType .LLM
@@ -373,7 +373,7 @@ def test_parse_qwen3_omni_moe_yaml(self, tmp_path):
373373 assert s0 .is_comprehension is True
374374
375375 # Stage 1: talker
376- s1 = topology .stages [1 ]
376+ s1 = pipeline .stages [1 ]
377377 assert s1 .stage_id == 1
378378 assert s1 .input_sources == [0 ]
379379 assert s1 .custom_process_input_func == (
@@ -382,7 +382,7 @@ def test_parse_qwen3_omni_moe_yaml(self, tmp_path):
382382 assert s1 .final_output is False
383383
384384 # Stage 2: code2wav
385- s2 = topology .stages [2 ]
385+ s2 = pipeline .stages [2 ]
386386 assert s2 .stage_id == 2
387387 assert s2 .input_sources == [1 ]
388388 assert s2 .worker_type == "generation"
@@ -405,11 +405,11 @@ def test_parse_yaml_with_legacy_engine_input_source(self, tmp_path):
405405 yaml_file = tmp_path / "legacy.yaml"
406406 yaml_file .write_text (yaml_content )
407407
408- topology = StageConfigFactory ._parse_topology_yaml (yaml_file , "legacy_model" )
409- assert topology .stages [1 ].input_sources == [0 ]
408+ pipeline = StageConfigFactory ._parse_pipeline_yaml (yaml_file , "legacy_model" )
409+ assert pipeline .stages [1 ].input_sources == [0 ]
410410
411411 def test_parse_yaml_with_connectors_and_edges (self , tmp_path ):
412- """Test parsing topology with optional connectors and edges."""
412+ """Test parsing pipeline with optional connectors and edges."""
413413 yaml_content = """\
414414 model_type: test_model
415415
@@ -429,12 +429,12 @@ def test_parse_yaml_with_connectors_and_edges(self, tmp_path):
429429 yaml_file = tmp_path / "with_connectors.yaml"
430430 yaml_file .write_text (yaml_content )
431431
432- topology = StageConfigFactory ._parse_topology_yaml (yaml_file , "test_model" )
433- assert topology .connectors == {"type" : "ray" }
434- assert topology .edges == [{"from" : 0 , "to" : 1 }]
432+ pipeline = StageConfigFactory ._parse_pipeline_yaml (yaml_file , "test_model" )
433+ assert pipeline .connectors == {"type" : "ray" }
434+ assert pipeline .edges == [{"from" : 0 , "to" : 1 }]
435435
436- def test_parsed_topology_passes_validation (self , tmp_path ):
437- """Test that a well-formed YAML produces a valid topology ."""
436+ def test_parsed_pipeline_passes_validation (self , tmp_path ):
437+ """Test that a well-formed YAML produces a valid pipeline ."""
438438 yaml_content = """\
439439 model_type: valid_model
440440
@@ -453,8 +453,8 @@ def test_parsed_topology_passes_validation(self, tmp_path):
453453 yaml_file = tmp_path / "valid.yaml"
454454 yaml_file .write_text (yaml_content )
455455
456- topology = StageConfigFactory ._parse_topology_yaml (yaml_file , "valid_model" )
457- errors = topology . validate_topology ()
456+ pipeline = StageConfigFactory ._parse_pipeline_yaml (yaml_file , "valid_model" )
457+ errors = pipeline . validate_pipeline ()
458458 assert errors == [], f"Unexpected validation errors: { errors } "
459459
460460 def test_parse_diffusion_stage_type (self , tmp_path ):
@@ -473,5 +473,5 @@ def test_parse_diffusion_stage_type(self, tmp_path):
473473 yaml_file = tmp_path / "diffusion.yaml"
474474 yaml_file .write_text (yaml_content )
475475
476- topology = StageConfigFactory ._parse_topology_yaml (yaml_file , "diff_model" )
477- assert topology .stages [0 ].stage_type == StageType .DIFFUSION
476+ pipeline = StageConfigFactory ._parse_pipeline_yaml (yaml_file , "diff_model" )
477+ assert pipeline .stages [0 ].stage_type == StageType .DIFFUSION
0 commit comments