@@ -87,7 +87,7 @@ def setUp(self):
87
87
self ._test_trial = trial_module .Trial (
88
88
hyperparameters = self ._test_hyperparameters ,
89
89
trial_id = "1" ,
90
- status = trial_module .TrialStatus ,
90
+ status = trial_module .TrialStatus . RUNNING ,
91
91
)
92
92
# TODO(b/170687807) Switch from using "{}".format() to f-string
93
93
self ._job_id = "{}_{}" .format (self ._study_id , self ._test_trial .trial_id )
@@ -253,7 +253,8 @@ def test_create_trial_after_early_stopping(self):
253
253
self .assertEqual (trial .hyperparameters .values , {})
254
254
self .assertEqual (trial .status , trial_module .TrialStatus .STOPPED )
255
255
256
- def test_update_trial (self ):
256
+ @mock .patch .object (oracle_module .Oracle , "update_trial" , auto_spec = True )
257
+ def test_update_trial (self , mock_super_update_trial ):
257
258
self ._tuner_with_hparams ()
258
259
259
260
self .mock_client .should_trial_stop .return_value = True
@@ -277,6 +278,9 @@ def test_update_trial(self):
277
278
)
278
279
self .mock_client .should_trial_stop .assert_called_once_with ("1" )
279
280
self .assertEqual (status , trial_module .TrialStatus .STOPPED )
281
+ mock_super_update_trial .assert_called_once_with (
282
+ "1" , {"val_acc" : 0.8 }, 3
283
+ )
280
284
281
285
def test_end_trial_success (self ):
282
286
self ._tuner_with_hparams ()
@@ -285,17 +289,29 @@ def test_end_trial_success(self):
285
289
"state" : "COMPLETED" ,
286
290
"parameters" : [{"parameter" : "learning_rate" , "floatValue" : 0.01 }],
287
291
"finalMeasurement" : {
288
- "stepCount" : 3 ,
292
+ "stepCount" : "3" ,
289
293
"metrics" : [{"metric" : "val_acc" , "value" : 0.7 }],
290
294
},
291
295
"trial_infeasible" : False ,
292
296
"infeasible_reason" : None ,
293
297
}
294
-
298
+ mock_save_trial = mock .Mock ()
299
+ self .tuner .oracle ._save_trial = mock_save_trial
295
300
self .tuner .oracle .ongoing_trials = {"tuner_0" : self ._test_trial }
301
+ expected_trial = trial_module .Trial (
302
+ hyperparameters = self ._test_hyperparameters ,
303
+ trial_id = "1" ,
304
+ status = trial_module .TrialStatus .COMPLETED ,
305
+ )
306
+ expected_trial .best_step = 3
307
+ expected_trial .score = 0.7
308
+
296
309
self .tuner .oracle .end_trial (trial_id = "1" )
310
+
297
311
self .mock_client .complete_trial .assert_called_once_with (
298
312
"1" , False , None )
313
+ self .assertEqual (repr (mock_save_trial .call_args [0 ][0 ].get_state ()),
314
+ repr (expected_trial .get_state ()))
299
315
300
316
def test_end_trial_infeasible_trial (self ):
301
317
self ._tuner_with_hparams ()
@@ -319,35 +335,6 @@ def test_end_trial_invalid_status(self):
319
335
with self .assertRaises (ValueError ):
320
336
self .tuner .oracle .end_trial (trial_id = "1" , status = "FOO" )
321
337
322
- def test_get_trial_success (self ):
323
- self ._tuner_with_hparams ()
324
- self .mock_client .get_trial .return_value = {
325
- "name" : "1" ,
326
- "state" : "COMPLETED" ,
327
- "parameters" : [{"parameter" : "learning_rate" , "floatValue" : 0.01 }],
328
- "finalMeasurement" : {
329
- "stepCount" : 3 ,
330
- "metrics" : [{"metric" : "val_acc" , "value" : 0.7 }],
331
- },
332
- "trial_infeasible" : False ,
333
- "infeasible_reason" : None ,
334
- }
335
- trial = self .tuner .oracle .get_trial (trial_id = "1" )
336
- self .mock_client .get_trial .assert_called_once_with ("1" )
337
- self .assertEqual (trial .trial_id , "1" )
338
- self .assertEqual (trial .score , 0.7 )
339
- self .assertEqual (trial .status , trial_module .TrialStatus .COMPLETED )
340
- self .assertEqual (trial .hyperparameters .values , {"learning_rate" : 0.01 })
341
-
342
- def test_get_trial_failed (self ):
343
- self ._tuner_with_hparams ()
344
- self .mock_client .get_trial .return_value = {
345
- "name" : "1" ,
346
- "state" : "FOO"
347
- }
348
- with self .assertRaises (ValueError ):
349
- self .tuner .oracle .get_trial (trial_id = "1" )
350
-
351
338
def test_get_best_trials (self ):
352
339
self ._tuner_with_hparams ()
353
340
@@ -358,7 +345,7 @@ def test_get_best_trials(self):
358
345
"parameters" :
359
346
[{"parameter" : "learning_rate" , "floatValue" : 0.01 }],
360
347
"finalMeasurement" : {
361
- "stepCount" : 3 ,
348
+ "stepCount" : "3" ,
362
349
"metrics" : [{"metric" : "val_acc" , "value" : 0.7 }],
363
350
},
364
351
"trial_infeasible" : False ,
@@ -370,7 +357,7 @@ def test_get_best_trials(self):
370
357
"parameters" :
371
358
[{"parameter" : "learning_rate" , "floatValue" : 0.001 }],
372
359
"finalMeasurement" : {
373
- "stepCount" : 3 ,
360
+ "stepCount" : "3" ,
374
361
"metrics" : [{"metric" : "val_acc" , "value" : 0.9 }],
375
362
},
376
363
"trial_infeasible" : False ,
@@ -425,7 +412,7 @@ def test_get_best_trials_multi_tuners(self):
425
412
"parameters" :
426
413
[{"parameter" : "learning_rate" , "floatValue" : 0.01 }],
427
414
"finalMeasurement" : {
428
- "stepCount" : 3 ,
415
+ "stepCount" : "3" ,
429
416
"metrics" : [{"metric" : "val_acc" , "value" : 0.7 }],
430
417
},
431
418
"trial_infeasible" : False ,
@@ -437,7 +424,7 @@ def test_get_best_trials_multi_tuners(self):
437
424
"parameters" :
438
425
[{"parameter" : "learning_rate" , "floatValue" : 0.001 }],
439
426
"finalMeasurement" : {
440
- "stepCount" : 3 ,
427
+ "stepCount" : "3" ,
441
428
"metrics" : [{"metric" : "val_acc" , "value" : 0.9 }],
442
429
},
443
430
"trial_infeasible" : False ,
@@ -458,6 +445,11 @@ def test_get_best_trials_multi_tuners(self):
458
445
self .assertEqual (best_trials_1 [0 ].score , 0.9 )
459
446
self .assertEqual (best_trials_1 [0 ].best_step , 3 )
460
447
448
+ def test_get_single_objective (self ):
449
+ self ._tuner_with_hparams ()
450
+ self .assertEqual ([self .tuner .oracle .objective ],
451
+ self .tuner .oracle ._get_objective ())
452
+
461
453
@mock .patch .object (super_tuner .Tuner , "__init__" , auto_spec = True )
462
454
@mock .patch .object (tf .summary , "create_file_writer" , auto_spec = True )
463
455
@mock .patch .object (hparams_api , "hparams" , auto_spec = True )
0 commit comments