@@ -59,6 +59,14 @@ def ensure_async_inference_works(user, create_endpoint_request, inference_payloa
5959 ensure_all_async_tasks_success (task_ids , user , return_pickled )
6060
6161
62+ @retry (stop = stop_after_attempt (3 ), wait = wait_fixed (20 ))
63+ def ensure_endpoint_updated (create_endpoint_request , update_endpoint_request , user ):
64+ endpoint = get_model_endpoint (create_endpoint_request ["name" ], user )
65+ assert endpoint ["resource_state" ]["cpus" ] == update_endpoint_request ["cpus" ]
66+ assert endpoint ["resource_state" ]["memory" ] == update_endpoint_request ["memory" ]
67+ assert endpoint ["deployment_state" ]["max_workers" ] == update_endpoint_request ["max_workers" ]
68+
69+
6270@pytest .mark .parametrize (
6371 "create_endpoint_request,update_endpoint_request,inference_requests" ,
6472 [
@@ -99,13 +107,7 @@ def test_async_model_endpoint(
99107 ensure_n_ready_endpoints_short (1 , user )
100108
101109 print ("Checking endpoint state..." )
102- endpoint = get_model_endpoint (create_endpoint_request ["name" ], user )
103- assert endpoint ["resource_state" ]["cpus" ] == update_endpoint_request ["cpus" ]
104- assert endpoint ["resource_state" ]["memory" ] == update_endpoint_request ["memory" ]
105- assert (
106- endpoint ["deployment_state" ]["max_workers" ]
107- == update_endpoint_request ["max_workers" ]
108- )
110+ ensure_endpoint_updated (create_endpoint_request , update_endpoint_request , user )
109111
110112 time .sleep (20 )
111113
0 commit comments