@@ -36,6 +36,12 @@ def generate():
3636 gen_config = config .generator
3737 model_path = config .trainer .hf_assets_path
3838
39+ # Patch model_spec to use the RL-specific parallelize function.
40+ # TODO: Switch to canonical Qwen3 parallel plan
41+ from torchtitan .experiments .rl .unified .models .parallelize import parallelize_qwen3
42+
43+ config .model_spec .parallelize_fn = parallelize_qwen3
44+
3945 # Register TorchTitan model with vLLM before engine creation
4046 from torchtitan .experiments .rl .unified .plugin import (
4147 register_model_to_vllm_model_registry ,
@@ -52,7 +58,7 @@ def generate():
5258 )
5359
5460 # Create EngineArgs from config
55- engine_args = EngineArgs (
61+ engine_kwargs = dict (
5662 # Model configuration
5763 model = model_path ,
5864 trust_remote_code = True ,
@@ -65,11 +71,12 @@ def generate():
6571 # Memory and performance
6672 gpu_memory_utilization = gen_config .gpu_memory_limit ,
6773 enforce_eager = gen_config .enforce_eager ,
68- # Seed
69- seed = gen_config .seed ,
7074 # HuggingFace overrides
7175 hf_overrides = {"architectures" : [VLLM_MODEL_NAME ]},
7276 )
77+ if gen_config .seed is not None :
78+ engine_kwargs ["seed" ] = gen_config .seed
79+ engine_args = EngineArgs (** engine_kwargs )
7380
7481 logger .debug ("Initializing LLMEngine from EngineArgs..." )
7582 engine = LLMEngine .from_engine_args (engine_args )
0 commit comments