@@ -899,22 +899,23 @@ def _get_next_run_config_ensemble(self,
899899 Test that get_next_run_config() creates a proper RunConfig for ensemble
900900
901901 Sets up a case where the coordinate is [1,2,4,5], which corresponds to
902- - composing model 1 max_batch_size = 2
903- - composing model 1 instance_count = 3
904- - composing model 1 concurrency = 2*3*2 = 12
905- - composing model 2 max_batch_size = 16
906- - composing model 2 instance_count = 6
907- - composing model 2 concurrency = 16*6*2 = 192
902+ - composing model A max_batch_size = 2
903+ - composing model A instance_count = 3
904+ - composing model A concurrency = 2*3*2 = 12
905+ - composing model B max_batch_size = 16
906+ - composing model B instance_count = 6
907+ - composing model B concurrency = 16*6*2 = 192
908908 - ensemble model concurrency = 12 (minimum value of [12, 192])
909909
910910 Also,
911- - sequence batching should be on for model 1
912- - dynamic batching should be on for model 2
911+ - sequence batching should be on for model A
912+ - dynamic batching should be on for model B
913+ - cpu_only should be set for model B
913914 - existing values from the base model config should persist if they aren't overwritten
914915 - existing values for perf-analyzer config should persist if they aren't overwritten
915916 """
916917
917- additional_args = []
918+ additional_args = ['--cpu-only-composing-models' , 'fake_model_B' ]
918919 if max_concurrency :
919920 additional_args .append ('--run-config-search-max-concurrency' )
920921 additional_args .append (f'{ max_concurrency } ' )
@@ -923,7 +924,7 @@ def _get_next_run_config_ensemble(self,
923924 additional_args .append (f'{ min_concurrency } ' )
924925
925926 #yapf: disable
926- expected_model_config0 = {
927+ expected_model_A_config_0 = {
927928 'cpu_only' : False ,
928929 'instanceGroup' : [{
929930 'count' : 3 ,
@@ -939,12 +940,12 @@ def _get_next_run_config_ensemble(self,
939940 }]
940941 }
941942
942- expected_model_config1 = {
943- 'cpu_only' : False ,
943+ expected_model_B_config_0 = {
944+ 'cpu_only' : True ,
944945 'dynamicBatching' : {},
945946 'instanceGroup' : [{
946947 'count' : 6 ,
947- 'kind' : 'KIND_GPU ' ,
948+ 'kind' : 'KIND_CPU ' ,
948949 }],
949950 'maxBatchSize' : 16 ,
950951 'name' : 'fake_model_B_config_0' ,
@@ -1004,15 +1005,15 @@ def _get_next_run_config_ensemble(self,
10041005
10051006 model_config = run_config .model_run_configs ()[0 ].model_config ()
10061007 perf_config = run_config .model_run_configs ()[0 ].perf_config ()
1007- composing_model_config0 = run_config .model_run_configs (
1008+ composing_model_A_config_0 = run_config .model_run_configs (
10081009 )[0 ].composing_configs ()[0 ]
1009- composing_model_config1 = run_config .model_run_configs (
1010+ composing_model_B_config_0 = run_config .model_run_configs (
10101011 )[0 ].composing_configs ()[1 ]
10111012
1012- self .assertEqual (composing_model_config0 .to_dict (),
1013- expected_model_config0 )
1014- self .assertEqual (composing_model_config1 .to_dict (),
1015- expected_model_config1 )
1013+ self .assertEqual (composing_model_A_config_0 .to_dict (),
1014+ expected_model_A_config_0 )
1015+ self .assertEqual (composing_model_B_config_0 .to_dict (),
1016+ expected_model_B_config_0 )
10161017
10171018 if max_concurrency :
10181019 self .assertEqual (perf_config ['concurrency-range' ], max_concurrency )
0 commit comments