|
1 | | - |
2 | | - |
3 | | -KerasMetricsCallback <- R6::R6Class("KerasMetricsCallback", |
4 | | - |
| 1 | +KerasMetricsCallback <- R6::R6Class( |
| 2 | + "KerasMetricsCallback", |
| 3 | + |
5 | 4 | inherit = KerasCallback, |
6 | 5 |
|
7 | 6 | public = list( |
@@ -87,7 +86,7 @@ KerasMetricsCallback <- R6::R6Class("KerasMetricsCallback", |
87 | 86 |
|
88 | 87 | # convert keras history to metrics data frame suitable for plotting |
89 | 88 | as_metrics_df = function(history) { |
90 | | - |
| 89 | + |
91 | 90 | # create metrics data frame |
92 | 91 | df <- as.data.frame(history$metrics) |
93 | 92 |
|
@@ -128,8 +127,146 @@ KerasMetricsCallback <- R6::R6Class("KerasMetricsCallback", |
128 | 127 | }, error = function(e) { |
129 | 128 | warning("Unable to log model info: ", e$message, call. = FALSE) |
130 | 129 | }) |
131 | | - |
| 130 | + |
| 131 | + } |
| 132 | + ) |
| 133 | +) |
| 134 | + |
| 135 | +KerasMetricsCallbackV2 <- R6::R6Class( |
| 136 | + "KerasMetricsCallbackV2", |
| 137 | + |
| 138 | + inherit = KerasCallback, |
| 139 | + |
| 140 | + public = list( |
| 141 | + |
| 142 | + # instance data |
| 143 | + metrics = list(), |
| 144 | + metrics_viewer = NULL, |
| 145 | + view_metrics = FALSE, |
| 146 | + |
| 147 | + initialize = function(view_metrics = FALSE) { |
| 148 | + self$view_metrics <- view_metrics |
| 149 | + }, |
| 150 | + |
| 151 | + on_train_begin = function(logs = NULL) { |
| 152 | + if (tfruns::is_run_active()) { |
| 153 | + self$write_params(self$params) |
| 154 | + self$write_model_info(self$model) |
| 155 | + } |
| 156 | + }, |
| 157 | + |
| 158 | + on_epoch_end = function(epoch, logs = NULL) { |
| 159 | + |
| 160 | + if (epoch == 0) { |
| 161 | + |
| 162 | + metric_names <- names(logs) |
| 163 | + for (metric in metric_names) |
| 164 | + self$metrics[[metric]] <- numeric() |
| 165 | + |
| 166 | + sleep <- 0.5 |
| 167 | + } else { |
| 168 | + |
| 169 | + sleep <- 0.1 |
| 170 | + |
| 171 | + } |
| 172 | + |
| 173 | + # handle metrics |
| 174 | + self$on_metrics(logs, sleep) |
| 175 | + |
| 176 | + }, |
| 177 | + |
| 178 | + on_metrics = function(logs, sleep) { |
| 179 | + |
| 180 | + # record metrics |
| 181 | + for (metric in names(self$metrics)) { |
| 182 | + # guard against metrics not yet available by using NA |
| 183 | + # when a named metrics isn't passed in 'logs' |
| 184 | + value <- logs[[metric]] |
| 185 | + if (is.null(value)) |
| 186 | + value <- NA |
| 187 | + else |
| 188 | + value <- mean(value) |
| 189 | + |
| 190 | + self$metrics[[metric]] <- c(self$metrics[[metric]], value) |
| 191 | + } |
| 192 | + |
| 193 | + # create history object and convert to metrics data frame |
| 194 | + |
| 195 | + history <- keras_training_history(self$params, self$metrics) |
| 196 | + metrics <- self$as_metrics_df(history) |
| 197 | + |
| 198 | + # view metrics if requested |
| 199 | + if (self$view_metrics) { |
| 200 | + |
| 201 | + # create the metrics_viewer or update if we already have one |
| 202 | + if (is.null(self$metrics_viewer)) { |
| 203 | + self$metrics_viewer <- tfruns::view_run_metrics(metrics) |
| 204 | + } else { |
| 205 | + tfruns::update_run_metrics(self$metrics_viewer, metrics) |
| 206 | + } |
| 207 | + |
| 208 | + # pump events |
| 209 | + Sys.sleep(sleep) |
| 210 | + } |
| 211 | + |
| 212 | + # record metrics |
| 213 | + tfruns::write_run_metadata("metrics", metrics) |
| 214 | + |
| 215 | + }, |
| 216 | + |
| 217 | + # convert keras history to metrics data frame suitable for plotting |
| 218 | + as_metrics_df = function(history) { |
| 219 | + # create metrics data frame |
| 220 | + df <- as.data.frame(history$metrics) |
| 221 | + |
| 222 | + # pad to epochs if necessary |
| 223 | + pad <- history$params$epochs - nrow(df) |
| 224 | + pad_data <- list() |
| 225 | + |
| 226 | + if (tensorflow::tf_version() < "2.2") |
| 227 | + metric_names <- history$params$metrics |
| 228 | + else |
| 229 | + metric_names <- names(history$metrics) |
| 230 | + |
| 231 | + for (metric in metric_names) |
| 232 | + pad_data[[metric]] <- rep_len(NA, pad) |
| 233 | + |
| 234 | + df <- rbind(df, pad_data) |
| 235 | + |
| 236 | + # return df |
| 237 | + df |
| 238 | + }, |
| 239 | + |
| 240 | + write_params = function(params) { |
| 241 | + properties <- list() |
| 242 | + properties$samples <- params$samples |
| 243 | + properties$validation_samples <- params$validation_samples |
| 244 | + properties$epochs <- params$epochs |
| 245 | + properties$batch_size <- params$batch_size |
| 246 | + tfruns::write_run_metadata("properties", properties) |
| 247 | + }, |
| 248 | + |
| 249 | + write_model_info = function(model) { |
| 250 | + tryCatch({ |
| 251 | + model_info <- list() |
| 252 | + model_info$model <- py_str(model, line_length = 80L) |
| 253 | + if (is.character(model$loss)) |
| 254 | + model_info$loss_function <- model$loss |
| 255 | + else if (inherits(model$loss, "python.builtin.function")) |
| 256 | + model_info$loss_function <- model$loss$`__name__` |
| 257 | + optimizer <- model$optimizer |
| 258 | + if (!is.null(optimizer)) { |
| 259 | + model_info$optimizer <- py_str(optimizer) |
| 260 | + model_info$learning_rate <- k_eval(optimizer$lr) |
| 261 | + } |
| 262 | + tfruns::write_run_metadata("properties", model_info) |
| 263 | + }, error = function(e) { |
| 264 | + warning("Unable to log model info: ", e$message, call. = FALSE) |
| 265 | + }) |
| 266 | + |
132 | 267 | } |
133 | 268 | ) |
134 | 269 | ) |
135 | 270 |
|
| 271 | + |
| 272 | + |
0 commit comments