@@ -1832,6 +1832,31 @@ def get_finished_kv_transfer(
1832
1832
scheduler_output .finished_req_ids )
1833
1833
return None , None
1834
1834
1835
+ def _build_attention_metadata (self , with_prefill , num_reqs , skip_attn ):
1836
+ if skip_attn :
1837
+ attn_metadata = None
1838
+ else :
1839
+ # TODO(zzzzwwjj): when aclgraph and full graph mode, we need build attn_metadata
1840
+ attn_metadata = None
1841
+ return attn_metadata
1842
+
1843
+ def _generate_dummy_run_hidden_states (self , with_prefill ,
1844
+ is_torchair_compile , input_ids ,
1845
+ positions , attn_metadata , num_tokens ,
1846
+ intermediate_tensors , inputs_embeds ):
1847
+ maybe_converting_weight_acl_format (self .model , ACL_FORMAT_FRACTAL_ND )
1848
+ hidden_states = self .model (input_ids = input_ids ,
1849
+ positions = positions ,
1850
+ intermediate_tensors = intermediate_tensors ,
1851
+ inputs_embeds = inputs_embeds )
1852
+ if self .use_aux_hidden_state_outputs :
1853
+ hidden_states , _ = hidden_states
1854
+ else :
1855
+ hidden_states = hidden_states
1856
+ if self .use_spec_decode and isinstance (self .drafter , EagleProposer ):
1857
+ self .drafter .dummy_run (num_tokens )
1858
+ return hidden_states
1859
+
1835
1860
@torch .inference_mode ()
1836
1861
def _dummy_run (
1837
1862
self ,
@@ -1868,20 +1893,11 @@ def _dummy_run(
1868
1893
if self .is_kv_producer :
1869
1894
with_prefill = True
1870
1895
1871
- # NOTE: If torchair graph mode and not with_prefill,
1872
- # we can't skip_attn, it will cause graph recompile.
1873
- if self .torchair_graph_enabled and not with_prefill :
1874
- attn_metadata = self .attn_metadata_builder .build_torchair_graph_dummy (
1875
- num_reqs = num_reqs , num_actual_tokens = 1 )
1876
- elif skip_attn :
1877
- attn_metadata = None
1878
- else :
1879
- # TODO(zzzzwwjj): when aclgraph and full graph mode, we need build attn_metadata
1880
- attn_metadata = None
1896
+ attn_metadata = self ._build_attention_metadata (with_prefill , num_reqs ,
1897
+ skip_attn )
1881
1898
1882
1899
with self .maybe_dummy_run_with_lora (self .lora_config ,
1883
1900
num_scheduled_tokens ):
1884
- model = self .model
1885
1901
if self .is_multimodal_model :
1886
1902
input_ids = None
1887
1903
inputs_embeds = self .inputs_embeds [:num_tokens ]
@@ -1917,61 +1933,10 @@ def _dummy_run(
1917
1933
in_profile_run = self .in_profile_run ,
1918
1934
num_actual_tokens = 0 ,
1919
1935
):
1920
- model_kwargs = {}
1921
- if self .torchair_graph_enabled and not with_prefill :
1922
- # Only mark static while compiling
1923
- if is_torchair_compile :
1924
- torch ._dynamo .mark_static (input_ids )
1925
- torch ._dynamo .mark_static (positions )
1926
- torch ._dynamo .mark_static (
1927
- attn_metadata .decode .block_table )
1928
- torch ._dynamo .mark_static (
1929
- attn_metadata .decode .input_positions )
1930
- torch ._dynamo .mark_static (
1931
- get_forward_context ().mc2_mask )
1932
- if hasattr (attn_metadata .decode , "sin" ):
1933
- torch ._dynamo .mark_static (attn_metadata .decode .sin )
1934
- torch ._dynamo .mark_static (attn_metadata .decode .cos )
1935
- torch ._dynamo .mark_static (attn_metadata .slot_mapping )
1936
- if self .speculative_config :
1937
- torch ._dynamo .mark_static (
1938
- attn_metadata .decode .attn_mask )
1939
- for kv in self .kv_caches :
1940
- assert isinstance (
1941
- kv , tuple ), "kv_cache must be a tuple"
1942
- torch ._dynamo .mark_static (kv [0 ])
1943
- torch ._dynamo .mark_static (kv [1 ])
1944
-
1945
- maybe_converting_weight_acl_format (self .model ,
1946
- ACL_FORMAT_FRACTAL_NZ )
1947
-
1948
- compiled_model = self ._get_torchair_lazy_compiled_model (
1949
- num_tokens )
1950
- model_kwargs ["kv_caches" ] = self .kv_caches
1951
- model_kwargs ["attn_metadata" ] = attn_metadata
1952
- hidden_states = compiled_model (
1953
- input_ids = input_ids ,
1954
- positions = positions ,
1955
- intermediate_tensors = intermediate_tensors ,
1956
- inputs_embeds = None ,
1957
- ** model_kwargs ,
1958
- )
1959
- else :
1960
- maybe_converting_weight_acl_format (self .model ,
1961
- ACL_FORMAT_FRACTAL_ND )
1962
-
1963
- hidden_states = model (
1964
- input_ids = input_ids ,
1965
- positions = positions ,
1966
- intermediate_tensors = intermediate_tensors ,
1967
- inputs_embeds = inputs_embeds )
1968
- if self .use_aux_hidden_state_outputs :
1969
- hidden_states , _ = hidden_states
1970
- else :
1971
- hidden_states = hidden_states
1972
- if self .use_spec_decode and isinstance (
1973
- self .drafter , EagleProposer ):
1974
- self .drafter .dummy_run (num_tokens )
1936
+ hidden_states = self ._generate_dummy_run_hidden_states (
1937
+ with_prefill , is_torchair_compile , input_ids , positions ,
1938
+ attn_metadata , num_tokens , intermediate_tensors ,
1939
+ inputs_embeds )
1975
1940
if self .speculative_config and self .speculative_config .method == "deepseek_mtp" :
1976
1941
assert isinstance (self .drafter , MtpProposer )
1977
1942
self .drafter .dummy_run (
0 commit comments