-
Notifications
You must be signed in to change notification settings - Fork 281
Open
Description
I'm trying to create a biLSTM model where one feature has many categories so rather than using one-hot-encoding, I'd like to use a embedding layer. The other features are in OHE format. When i try to combine both inputs with keras3::layer_concatenate i get an error. It only works with keras::layer_concatenate.
here's a reprex:
suppressPackageStartupMessages(library(tidyverse))
library(keras3)
library(tensorflow)
# -------------------------------
# 1) Create tiny random data
# -------------------------------
set.seed(123)
n_samples <- 10
n_timesteps <- 5
n_stations_classes <- 12
embedding_dim <- 4
n_day_classes <- 7
n_hour_classes <- 24
context_dim <- n_day_classes + n_hour_classes # 31 features
# Station IDs = integers (1..n_stations_classes)
X_station_train <- array(
sample(1:n_stations_classes, n_samples * n_timesteps, replace = TRUE),
dim = c(n_samples, n_timesteps)
)
X_station_val <- array(
sample(1:n_stations_classes, n_samples * n_timesteps, replace = TRUE),
dim = c(n_samples, n_timesteps)
)
# Context = random float32 matrix (one-hot-like but not required)
X_context_train <- array(
runif(n_samples * n_timesteps * context_dim),
dim = c(n_samples, n_timesteps, context_dim)
)
X_context_val <- array(
runif(n_samples * n_timesteps * context_dim),
dim = c(n_samples, n_timesteps, context_dim)
)
cat("station dtype:", typeof(X_station_train), " dim:", dim(X_station_train), "\n")
cat("context dtype:", typeof(X_context_train), " dim:", dim(X_context_train), "\n")
# -------------------------------
# 2) Define model (tests concatenate)
# -------------------------------
input_station <- layer_input(
shape = c(n_timesteps),
dtype = "int32",
name = "station_id_input"
)
embedded_station <- input_station %>%
layer_embedding(
input_dim = n_stations_classes,
output_dim = embedding_dim,
mask_zero = TRUE,
input_length = n_timesteps
)
input_context <- layer_input(
shape = c(n_timesteps, context_dim),
dtype = "float32",
name = "context_input"
)
# This is the operation we're testing:
combined_features <- layer_concatenate(
list(embedded_station, input_context),
name = "concatenate_features"
)
this gives:
Error in py_call_impl(callable, call_args$unnamed, call_args$named) :
KeyError: 0
Run `reticulate::py_last_error()` for details.
It only works if i specify keras::layer_concatenate...
Here's my session info:
> sessionInfo()
R version 4.5.0 (2025-04-11)
Platform: aarch64-apple-darwin20
Running under: macOS Sequoia 15.6.1
Matrix products: default
BLAS: /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRlapack.dylib; LAPACK version 3.12.1
locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
time zone: Europe/London
tzcode source: internal
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] caret_7.0-1 lattice_0.22-7 patchwork_1.3.2 glue_1.8.0 readxl_1.4.5 tensorflow_2.20.0 keras3_1.4.0 RJDBC_0.2-10 rJava_1.0-11
[10] DBI_1.2.3 aws.s3_0.3.22 lubridate_1.9.4 forcats_1.0.0 stringr_1.5.2 dplyr_1.1.4 purrr_1.1.0 readr_2.1.5 tidyr_1.3.1
[19] tibble_3.3.0 ggplot2_4.0.0 tidyverse_2.0.0
loaded via a namespace (and not attached):
[1] tidyselect_1.2.1 timeDate_4051.111 farver_2.1.2 S7_0.2.0 fastmap_1.2.0 pROC_1.19.0.1 digest_0.6.37 rpart_4.1.24
[9] timechange_0.3.0 lifecycle_1.0.4 survival_3.8-3 magrittr_2.0.4 compiler_4.5.0 rlang_1.1.6 tools_4.5.0 data.table_1.17.8
[17] knitr_1.50 labeling_0.4.3 curl_7.0.0 here_1.0.2 reticulate_1.43.0 aws.signature_0.6.0 plyr_1.8.9 xml2_1.4.0
[25] RColorBrewer_1.1-3 keras_2.16.0 withr_3.0.2 stats4_4.5.0 nnet_7.3-20 grid_4.5.0 e1071_1.7-16 future_1.67.0
[33] globals_0.18.0 scales_1.4.0 iterators_1.0.14 MASS_7.3-65 zeallot_0.2.0 cli_3.6.5 generics_0.1.4 rstudioapi_0.17.1
[41] future.apply_1.20.0 reshape2_1.4.4 httr_1.4.7 tzdb_0.5.0 tfruns_1.5.4 proxy_0.4-27 splines_4.5.0 parallel_4.5.0
[49] cellranger_1.1.0 base64enc_0.1-3 vctrs_0.6.5 hardhat_1.4.2 Matrix_1.7-4 jsonlite_2.0.0 hms_1.1.3 listenv_0.9.1
[57] foreach_1.5.2 gower_1.0.2 recipes_1.3.1 parallelly_1.45.1 codetools_0.2-20 stringi_1.8.7 gtable_0.3.6 pillar_1.11.1
[65] ipred_0.9-15 lava_1.8.2 R6_2.6.1 rprojroot_2.1.1 evaluate_1.0.5 png_0.1-8 class_7.3-23 Rcpp_1.1.0
[73] dotty_0.1.0 nlme_3.1-168 prodlim_2025.04.28 mgcv_1.9-3 whisker_0.4.1 xfun_0.53 ModelMetrics_1.2.2.2 pkgconfig_2.0.3
Metadata
Metadata
Assignees
Labels
No labels