99
1010import socket
1111import signal
12- from kepler_model .estimate .model_server_connector import make_request
12+ from kepler_model .estimate .model_server_connector import make_request , is_model_server_enabled
1313from kepler_model .estimate .archived_model import get_achived_model
1414from kepler_model .estimate .model .model import load_downloaded_model
1515from kepler_model .util .loader import get_download_output_path
@@ -64,15 +64,19 @@ def handle_request(data):
6464 if output_type .name not in loaded_model :
6565 loaded_model [output_type .name ] = dict ()
6666 output_path = ""
67- request_trainer = False
68- if power_request .trainer_name is not None :
69- if output_type .name in loaded_model and power_request .energy_source in loaded_model [output_type .name ]:
70- current_trainer = loaded_model [output_type .name ][power_request .energy_source ].trainer_name
71- request_trainer = current_trainer != power_request .trainer_name
72- if request_trainer :
73- logger .info (f"try obtaining the requesting trainer { power_request .trainer_name } (current: { current_trainer } )" )
74- if power_request .energy_source not in loaded_model [output_type .name ] or request_trainer :
67+ mismatch_trainer = False
68+ if is_model_server_enabled ():
69+ if power_request .trainer_name is not None and power_request .trainer_name :
70+ if output_type .name in loaded_model and power_request .energy_source in loaded_model [output_type .name ]:
71+ current_trainer = loaded_model [output_type .name ][power_request .energy_source ].trainer_name
72+ mismatch_trainer = current_trainer != power_request .trainer_name
73+ if mismatch_trainer :
74+ logger .info (f"try obtaining the requesting trainer { power_request .trainer_name } (current: { current_trainer } )" )
75+ if power_request .energy_source not in loaded_model [output_type .name ] or mismatch_trainer :
7576 output_path = get_download_output_path (download_path , power_request .energy_source , output_type )
77+ if mismatch_trainer and os .path .exists (output_path ):
78+ # remove existing model if mismatch
79+ shutil .rmtree (output_path )
7680 if not os .path .exists (output_path ):
7781 # try connecting to model server
7882 output_path = make_request (power_request )
@@ -87,12 +91,11 @@ def handle_request(data):
8791 logger .info (f"load model from config: { output_path } " )
8892 else :
8993 logger .info (f"load model from model server: { output_path } " )
94+
9095 loaded_item = load_downloaded_model (power_request .energy_source , output_type )
9196 if loaded_item is not None and loaded_item .estimator is not None :
9297 loaded_model [output_type .name ][power_request .energy_source ] = loaded_item
9398 logger .info (f"set model { loaded_item .model_name } for { output_type .name } ({ power_request .energy_source } )" )
94- # remove loaded model
95- shutil .rmtree (output_path )
9699
97100 model = loaded_model [output_type .name ][power_request .energy_source ]
98101 powers , msg = model .get_power (power_request .datapoint )
0 commit comments