@@ -153,7 +153,9 @@ def _mock_list_hyperparameters(
153
153
):
154
154
return self ._hyperparameters
155
155
156
- def _experiment_from_metadata (self , * , include_metrics = True ):
156
+ def _experiment_from_metadata (
157
+ self , * , include_metrics = True , hparams_limit = None
158
+ ):
157
159
"""Calls the expected operations for generating an Experiment proto."""
158
160
ctxt = backend_context .Context (self ._mock_tb_context )
159
161
request_ctx = context .RequestContext ()
@@ -162,7 +164,10 @@ def _experiment_from_metadata(self, *, include_metrics=True):
162
164
"123" ,
163
165
include_metrics ,
164
166
ctxt .hparams_metadata (request_ctx , "123" ),
165
- ctxt .hparams_from_data_provider (request_ctx , "123" , limit = None ),
167
+ ctxt .hparams_from_data_provider (
168
+ request_ctx , "123" , limit = hparams_limit
169
+ ),
170
+ hparams_limit ,
166
171
)
167
172
168
173
def test_experiment_with_experiment_tag (self ):
@@ -897,6 +902,178 @@ def test_experiment_from_data_provider_old_response_type(self):
897
902
"""
898
903
self .assertProtoEquals (expected_exp , actual_exp )
899
904
905
+ def test_experiment_from_tags_with_hparams_limit_no_differed_hparams (self ):
906
+ experiment = """
907
+ name: 'Test experiment'
908
+ hparam_infos: {
909
+ name: 'batch_size'
910
+ type: DATA_TYPE_FLOAT64
911
+ differs: false
912
+ }
913
+ hparam_infos: {
914
+ name: 'lr'
915
+ type: DATA_TYPE_FLOAT64
916
+ differs: false
917
+ }
918
+ hparam_infos: {
919
+ name: 'use_batch_norm'
920
+ type: DATA_TYPE_BOOL
921
+ differs: false
922
+ }
923
+ hparam_infos: {
924
+ name: 'model_type'
925
+ type: DATA_TYPE_STRING
926
+ differs: false
927
+ }
928
+ """
929
+ t = provider .TensorTimeSeries (
930
+ max_step = 0 ,
931
+ max_wall_time = 0 ,
932
+ plugin_content = self ._serialized_plugin_data (
933
+ DATA_TYPE_EXPERIMENT , experiment
934
+ ),
935
+ description = "" ,
936
+ display_name = "" ,
937
+ )
938
+ self ._mock_tb_context .data_provider .list_tensors .side_effect = None
939
+ self ._mock_tb_context .data_provider .list_tensors .return_value = {
940
+ "train" : {metadata .EXPERIMENT_TAG : t }
941
+ }
942
+ expected_exp = """
943
+ name: 'Test experiment'
944
+ hparam_infos: {
945
+ name: 'batch_size'
946
+ type: DATA_TYPE_FLOAT64
947
+ differs: false
948
+ }
949
+ hparam_infos: {
950
+ name: 'lr'
951
+ type: DATA_TYPE_FLOAT64
952
+ differs: false
953
+ }
954
+ """
955
+ actual_exp = self ._experiment_from_metadata (
956
+ include_metrics = False , hparams_limit = 2
957
+ )
958
+ self .assertProtoEquals (expected_exp , actual_exp )
959
+
960
+ def test_experiment_from_tags_with_hparams_limit_returns_differed_hparams_first (
961
+ self ,
962
+ ):
963
+ experiment = """
964
+ name: 'Test experiment'
965
+ hparam_infos: {
966
+ name: 'batch_size'
967
+ type: DATA_TYPE_FLOAT64
968
+ differs: false
969
+ }
970
+ hparam_infos: {
971
+ name: 'lr'
972
+ type: DATA_TYPE_FLOAT64
973
+ differs: true
974
+ }
975
+ hparam_infos: {
976
+ name: 'use_batch_norm'
977
+ type: DATA_TYPE_BOOL
978
+ differs: false
979
+ }
980
+ hparam_infos: {
981
+ name: 'model_type'
982
+ type: DATA_TYPE_STRING
983
+ differs: true
984
+ }
985
+ """
986
+ t = provider .TensorTimeSeries (
987
+ max_step = 0 ,
988
+ max_wall_time = 0 ,
989
+ plugin_content = self ._serialized_plugin_data (
990
+ DATA_TYPE_EXPERIMENT , experiment
991
+ ),
992
+ description = "" ,
993
+ display_name = "" ,
994
+ )
995
+ self ._mock_tb_context .data_provider .list_tensors .side_effect = None
996
+ self ._mock_tb_context .data_provider .list_tensors .return_value = {
997
+ "train" : {metadata .EXPERIMENT_TAG : t }
998
+ }
999
+ expected_exp = """
1000
+ name: 'Test experiment'
1001
+ hparam_infos: {
1002
+ name: 'lr'
1003
+ type: DATA_TYPE_FLOAT64
1004
+ differs: true
1005
+ },
1006
+ hparam_infos: {
1007
+ name: 'model_type'
1008
+ type: DATA_TYPE_STRING
1009
+ differs: true
1010
+ }
1011
+ """
1012
+ actual_exp = self ._experiment_from_metadata (
1013
+ include_metrics = False , hparams_limit = 2
1014
+ )
1015
+ self .assertProtoEquals (expected_exp , actual_exp )
1016
+
1017
+ def test_experiment_from_tags_sorts_differed_hparams_first (self ):
1018
+ experiment = """
1019
+ name: 'Test experiment'
1020
+ hparam_infos: {
1021
+ name: 'batch_size'
1022
+ type: DATA_TYPE_FLOAT64
1023
+ differs: false
1024
+ }
1025
+ hparam_infos: {
1026
+ name: 'lr'
1027
+ type: DATA_TYPE_FLOAT64
1028
+ differs: true
1029
+ }
1030
+ hparam_infos: {
1031
+ name: 'use_batch_norm'
1032
+ type: DATA_TYPE_BOOL
1033
+ differs: false
1034
+ }
1035
+ hparam_infos: {
1036
+ name: 'model_type'
1037
+ type: DATA_TYPE_STRING
1038
+ differs: true
1039
+ }
1040
+ """
1041
+ t = provider .TensorTimeSeries (
1042
+ max_step = 0 ,
1043
+ max_wall_time = 0 ,
1044
+ plugin_content = self ._serialized_plugin_data (
1045
+ DATA_TYPE_EXPERIMENT , experiment
1046
+ ),
1047
+ description = "" ,
1048
+ display_name = "" ,
1049
+ )
1050
+ self ._mock_tb_context .data_provider .list_tensors .side_effect = None
1051
+ self ._mock_tb_context .data_provider .list_tensors .return_value = {
1052
+ "train" : {metadata .EXPERIMENT_TAG : t }
1053
+ }
1054
+ expected_exp = """
1055
+ name: 'Test experiment'
1056
+ hparam_infos: {
1057
+ name: 'lr'
1058
+ type: DATA_TYPE_FLOAT64
1059
+ differs: true
1060
+ }
1061
+ hparam_infos: {
1062
+ name: 'model_type'
1063
+ type: DATA_TYPE_STRING
1064
+ differs: true
1065
+ }
1066
+ hparam_infos: {
1067
+ name: 'batch_size'
1068
+ type: DATA_TYPE_FLOAT64
1069
+ differs: false
1070
+ }
1071
+ """
1072
+ actual_exp = self ._experiment_from_metadata (
1073
+ include_metrics = False , hparams_limit = None
1074
+ )
1075
+ self .assertProtoEquals (expected_exp , actual_exp )
1076
+
900
1077
def _serialized_plugin_data (self , data_oneof_field , text_protobuffer ):
901
1078
oneof_type_dict = {
902
1079
DATA_TYPE_EXPERIMENT : api_pb2 .Experiment ,
0 commit comments