Skip to content

Commit 3592976

Browse files
authored
Merge pull request #1430 from rstudio/fix-application_preprocess_inputs
convert input to a writeable numpy array in preprocess_inputs
2 parents 778e39f + 6944a25 commit 3592976

File tree

5 files changed

+140
-1
lines changed

5 files changed

+140
-1
lines changed

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,8 @@ export(image_from_array)
201201
export(image_load)
202202
export(image_smart_resize)
203203
export(image_to_array)
204+
export(imagenet_decode_predictions)
205+
export(imagenet_preprocess_input)
204206
export(initializer_constant)
205207
export(initializer_glorot_normal)
206208
export(initializer_glorot_uniform)

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ New functions:
6060

6161
- `layer_lstm()` and `layer_gru()` gain arg `use_cudnn`, default `'auto'`.
6262

63+
- Fixed an issue where `application_preprocess_inputs()` would error if supplied
64+
an R array as input.
65+
6366
- Doc improvements.
6467

6568
# keras3 0.1.0

R/applications.R

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3953,7 +3953,88 @@ list_model_names <- function() {
39533953
}
39543954

39553955
set_preprocessing_attributes <- function(object, module) {
3956-
attr(object, "preprocess_input") <- module$preprocess_input
3956+
.preprocess_input <- r_to_py(module)$preprocess_input
3957+
3958+
attr(object, "preprocess_input") <-
3959+
as.function.default(c(formals(.preprocess_input), bquote({
3960+
args <- capture_args(list(
3961+
x = function(x) {
3962+
if (!is_py_object(x))
3963+
x <- np_array(x)
3964+
if (inherits(x, "numpy.ndarray") &&
3965+
!py_bool(x$flags$writeable))
3966+
x <- x$copy()
3967+
x
3968+
}
3969+
))
3970+
do.call(.(.preprocess_input), args)
3971+
})), envir = parent.env(environment()))
3972+
39573973
attr(object, "decode_predictions") <- module$decode_predictions
39583974
object
39593975
}
3976+
3977+
3978+
#' Decodes the prediction of an ImageNet model.
3979+
#'
3980+
#' @param preds Tensor encoding a batch of predictions.
3981+
#' @param top integer, how many top-guesses to return.
3982+
#'
3983+
#' @return List of data frames with variables `class_name`, `class_description`,
3984+
#' and `score` (one data frame per sample in batch input).
3985+
#'
3986+
#' @export
3987+
#' @keywords internal
3988+
imagenet_decode_predictions <- function(preds, top = 5) {
3989+
3990+
# decode predictions
3991+
decoded <- keras$applications$imagenet_utils$decode_predictions(
3992+
preds = preds,
3993+
top = as.integer(top)
3994+
)
3995+
3996+
# convert to a list of data frames
3997+
lapply(decoded, function(x) {
3998+
m <- t(sapply(1:length(x), function(n) x[[n]]))
3999+
data.frame(class_name = as.character(m[,1]),
4000+
class_description = as.character(m[,2]),
4001+
score = as.numeric(m[,3]),
4002+
stringsAsFactors = FALSE)
4003+
})
4004+
}
4005+
4006+
4007+
#' Preprocesses a tensor or array encoding a batch of images.
4008+
#'
4009+
#' @param x Input Numpy or symbolic tensor, 3D or 4D.
4010+
#' @param data_format Data format of the image tensor/array.
4011+
#' @param mode One of "caffe", "tf", or "torch"
4012+
#' - caffe: will convert the images from RGB to BGR,
4013+
#' then will zero-center each color channel with
4014+
#' respect to the ImageNet dataset,
4015+
#' without scaling.
4016+
#' - tf: will scale pixels between -1 and 1, sample-wise.
4017+
#' - torch: will scale pixels between 0 and 1 and then
4018+
#' will normalize each channel with respect to the
4019+
#' ImageNet dataset.
4020+
#'
4021+
#' @return Preprocessed tensor or array.
4022+
#'
4023+
#' @export
4024+
#' @keywords internal
4025+
imagenet_preprocess_input <- function(x, data_format = NULL, mode = "caffe") {
4026+
args <- capture_args(list(
4027+
x = function(x) {
4028+
if (!is_py_object(x))
4029+
x <- np_array(x)
4030+
if (inherits(x, "numpy.ndarray") &&
4031+
!py_bool(x$flags$writeable))
4032+
x <- x$copy()
4033+
x
4034+
}
4035+
))
4036+
4037+
preprocess_input <- r_to_py(keras$applications$imagenet_utils)$preprocess_input
4038+
do.call(preprocess_input, args)
4039+
}
4040+

man/imagenet_decode_predictions.Rd

Lines changed: 21 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/imagenet_preprocess_input.Rd

Lines changed: 32 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)