@@ -103,22 +103,21 @@ def __init__(self, log_dir, update_freq='epoch', **kwargs):
103
103
104
104
super (PruningSummaries , self ).__init__ (
105
105
log_dir = log_dir , update_freq = update_freq , ** kwargs )
106
+ if not compat .is_v1_apis (): # TF 2.X
107
+ log_dir = self .log_dir + '/metrics'
108
+ self ._file_writer = tf .summary .create_file_writer (log_dir )
106
109
107
110
def _log_pruning_metrics (self , logs , prefix , step ):
108
111
if compat .is_v1_apis ():
109
112
# Safely depend on TF 1.X private API given
110
113
# no more 1.X releases.
111
114
self ._write_custom_summaries (step , logs )
112
- else : # TF 2.X
113
- log_dir = self .log_dir + '/metrics'
114
-
115
- file_writer = tf .summary .create_file_writer (log_dir )
116
- file_writer .set_as_default ()
117
-
118
- for name , value in logs .items ():
119
- tf .summary .scalar (name , value , step = step )
115
+ else :
116
+ with self ._file_writer .as_default ():
117
+ for name , value in logs .items ():
118
+ tf .summary .scalar (name , value , step = step )
120
119
121
- file_writer .flush ()
120
+ self . _file_writer .flush ()
122
121
123
122
def on_epoch_begin (self , epoch , logs = None ):
124
123
if logs is not None :
0 commit comments