@@ -78,7 +78,8 @@ and we query results from `self$metrics` at the end to retrieve their current va
7878CustomModel <- new_model_class(
7979 " CustomModel" ,
8080 train_step = function (data ) {
81- c(x , y , sample_weight ) %<- % unpack_x_y_sample_weight(data )
81+ # unpack data into x, y, and sample_weight
82+ c(x , y = NULL , sample_weight = NULL ) %<- % data
8283
8384 with(tf $ GradientTape() %as % tape , {
8485 y_pred <- self(x , training = TRUE )
@@ -163,7 +164,8 @@ CustomModel <- new_model_class(
163164 self $ loss_fn <- loss_mean_squared_error()
164165 },
165166 train_step = function (data ) {
166- c(x , y , sample_weight ) %<- % unpack_x_y_sample_weight(data )
167+ # unpack data into x, y, and sample_weight
168+ c(x , y = NULL , sample_weight = NULL ) %<- % data
167169
168170 with(tf $ GradientTape() %as % tape , {
169171 y_pred <- self(x , training = TRUE )
@@ -235,7 +237,8 @@ it manually if you don't rely on `compile()` for losses & metrics)
235237CustomModel <- new_model_class(
236238 " CustomModel" ,
237239 train_step = function (data ) {
238- c(x , y , sample_weight ) %<- % unpack_x_y_sample_weight(data )
240+ # unpack data into x, y, and sample_weight
241+ c(x , y = NULL , sample_weight = NULL ) %<- % data
239242
240243 with(tf $ GradientTape() %as % tape , {
241244 y_pred <- self(x , training = TRUE )
@@ -300,7 +303,7 @@ CustomModel <- new_model_class(
300303 " CustomModel" ,
301304 test_step = function (data ) {
302305 # Unpack the data
303- c(x , y , sw ) %<- % unpack_x_y_sample_weight( data )
306+ c(x , y = NULL , sw = NULL ) %<- % data
304307 # Compute predictions
305308 y_pred = self(x , training = FALSE )
306309 # Updates the metrics tracking the loss
0 commit comments