Skip to content

can't get keras3::layer_concatenate to work #1525

@ChristelSwift

Description

@ChristelSwift

I'm trying to create a biLSTM model where one feature has many categories so rather than using one-hot-encoding, I'd like to use a embedding layer. The other features are in OHE format. When i try to combine both inputs with keras3::layer_concatenate i get an error. It only works with keras::layer_concatenate.

here's a reprex:

suppressPackageStartupMessages(library(tidyverse))
library(keras3)
library(tensorflow)

# -------------------------------
# 1) Create tiny random data
# -------------------------------

set.seed(123)

n_samples <- 10
n_timesteps <- 5
n_stations_classes <- 12
embedding_dim <- 4
n_day_classes <- 7
n_hour_classes <- 24

context_dim <- n_day_classes + n_hour_classes  # 31 features

# Station IDs = integers (1..n_stations_classes)
X_station_train <- array(
  sample(1:n_stations_classes, n_samples * n_timesteps, replace = TRUE),
  dim = c(n_samples, n_timesteps)
)

X_station_val <- array(
  sample(1:n_stations_classes, n_samples * n_timesteps, replace = TRUE),
  dim = c(n_samples, n_timesteps)
)

# Context = random float32 matrix (one-hot-like but not required)
X_context_train <- array(
  runif(n_samples * n_timesteps * context_dim),
  dim = c(n_samples, n_timesteps, context_dim)
)

X_context_val <- array(
  runif(n_samples * n_timesteps * context_dim),
  dim = c(n_samples, n_timesteps, context_dim)
)

cat("station dtype:", typeof(X_station_train), " dim:", dim(X_station_train), "\n")
cat("context dtype:", typeof(X_context_train), " dim:", dim(X_context_train), "\n")


# -------------------------------
# 2) Define model (tests concatenate)
# -------------------------------

input_station <- layer_input(
  shape = c(n_timesteps),
  dtype = "int32",
  name = "station_id_input"
)

embedded_station <- input_station %>%
  layer_embedding(
    input_dim = n_stations_classes,
    output_dim = embedding_dim,
    mask_zero = TRUE,
    input_length = n_timesteps
  )

input_context <- layer_input(
  shape = c(n_timesteps, context_dim),
  dtype = "float32",
  name = "context_input"
)

# This is the operation we're testing:
combined_features <- layer_concatenate(
  list(embedded_station, input_context),
  name = "concatenate_features"
)

this gives:

Error in py_call_impl(callable, call_args$unnamed, call_args$named) : 
  KeyError: 0
Run `reticulate::py_last_error()` for details.

It only works if i specify keras::layer_concatenate...

Here's my session info:

> sessionInfo()
R version 4.5.0 (2025-04-11)
Platform: aarch64-apple-darwin20
Running under: macOS Sequoia 15.6.1

Matrix products: default
BLAS:   /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib 
LAPACK: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRlapack.dylib;  LAPACK version 3.12.1

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

time zone: Europe/London
tzcode source: internal

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] caret_7.0-1       lattice_0.22-7    patchwork_1.3.2   glue_1.8.0        readxl_1.4.5      tensorflow_2.20.0 keras3_1.4.0      RJDBC_0.2-10      rJava_1.0-11     
[10] DBI_1.2.3         aws.s3_0.3.22     lubridate_1.9.4   forcats_1.0.0     stringr_1.5.2     dplyr_1.1.4       purrr_1.1.0       readr_2.1.5       tidyr_1.3.1      
[19] tibble_3.3.0      ggplot2_4.0.0     tidyverse_2.0.0  

loaded via a namespace (and not attached):
 [1] tidyselect_1.2.1     timeDate_4051.111    farver_2.1.2         S7_0.2.0             fastmap_1.2.0        pROC_1.19.0.1        digest_0.6.37        rpart_4.1.24        
 [9] timechange_0.3.0     lifecycle_1.0.4      survival_3.8-3       magrittr_2.0.4       compiler_4.5.0       rlang_1.1.6          tools_4.5.0          data.table_1.17.8   
[17] knitr_1.50           labeling_0.4.3       curl_7.0.0           here_1.0.2           reticulate_1.43.0    aws.signature_0.6.0  plyr_1.8.9           xml2_1.4.0          
[25] RColorBrewer_1.1-3   keras_2.16.0         withr_3.0.2          stats4_4.5.0         nnet_7.3-20          grid_4.5.0           e1071_1.7-16         future_1.67.0       
[33] globals_0.18.0       scales_1.4.0         iterators_1.0.14     MASS_7.3-65          zeallot_0.2.0        cli_3.6.5            generics_0.1.4       rstudioapi_0.17.1   
[41] future.apply_1.20.0  reshape2_1.4.4       httr_1.4.7           tzdb_0.5.0           tfruns_1.5.4         proxy_0.4-27         splines_4.5.0        parallel_4.5.0      
[49] cellranger_1.1.0     base64enc_0.1-3      vctrs_0.6.5          hardhat_1.4.2        Matrix_1.7-4         jsonlite_2.0.0       hms_1.1.3            listenv_0.9.1       
[57] foreach_1.5.2        gower_1.0.2          recipes_1.3.1        parallelly_1.45.1    codetools_0.2-20     stringi_1.8.7        gtable_0.3.6         pillar_1.11.1       
[65] ipred_0.9-15         lava_1.8.2           R6_2.6.1             rprojroot_2.1.1      evaluate_1.0.5       png_0.1-8            class_7.3-23         Rcpp_1.1.0          
[73] dotty_0.1.0          nlme_3.1-168         prodlim_2025.04.28   mgcv_1.9-3           whisker_0.4.1        xfun_0.53            ModelMetrics_1.2.2.2 pkgconfig_2.0.3    

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions