Skip to content

Commit f399414

Browse files
committed
add op_exp2() and op_inner()
1 parent 89bfe53 commit f399414

File tree

3 files changed

+83
-14
lines changed

3 files changed

+83
-14
lines changed

.tether/man/keras.ops.numpy.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ einsum(subscripts, *operands)
131131
empty(shape, dtype=None)
132132
equal(x1, x2)
133133
exp(x)
134+
exp2(x)
134135
expand_dims(x, axis)
135136
expm1(x)
136137
eye(
@@ -163,6 +164,7 @@ histogram(
163164
hstack(xs)
164165
identity(n, dtype=None)
165166
imag(x)
167+
inner(x1, x2)
166168
isclose(
167169
x1,
168170
x2,

R/ops-numpy.R

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,3 +436,52 @@ keras$ops$saturate_cast(x, dtype)
436436
op_trunc <-
437437
function (x)
438438
keras$ops$trunc(x)
439+
440+
441+
#' Calculate the base-2 exponential of all elements in the input tensor.
442+
#'
443+
#' @returns
444+
#' Output tensor, element-wise base-2 exponential of `x`.
445+
#'
446+
#' @param x
447+
#' Input tensor.
448+
#'
449+
#' @export
450+
#' @family numpy ops
451+
#' @family ops
452+
#' @tether keras.ops.exp2
453+
op_exp2 <-
454+
function (x)
455+
keras$ops$exp2(x)
456+
457+
458+
#' Return the inner product of two tensors.
459+
#'
460+
#' @description
461+
#' Ordinary inner product of vectors for 1-D tensors
462+
#' (without complex conjugation), in higher dimensions
463+
#' a sum product over the last axes.
464+
#'
465+
#' Multidimensional arrays are treated as vectors by flattening
466+
#' all but their last axes. The resulting dot product is performed
467+
#' over their last axes.
468+
#'
469+
#' @returns
470+
#' Output tensor. The shape of the output is determined by
471+
#' broadcasting the shapes of `x1` and `x2` after removing
472+
#' their last axes.
473+
#'
474+
#' @param x1
475+
#' First input tensor.
476+
#'
477+
#' @param x2
478+
#' Second input tensor. The last dimension of `x1` and `x2`
479+
#' must match.
480+
#'
481+
#' @export
482+
#' @family numpy ops
483+
#' @family ops
484+
#' @tether keras.ops.inner
485+
op_inner <-
486+
function (x1, x2)
487+
keras$ops$inner(x1, x2)

tools/retether.R

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -80,27 +80,36 @@ if(FALSE) {
8080
mk_export("keras.activations.squareplus")$dump |> cat_cb()
8181
mk_export("keras.activations.tanh_shrink")$dump |> cat_cb()
8282
mk_export("keras.config.is_flash_attention_enabled")$dump |> cat_cb()
83+
mk_export("keras.losses.circle")$dump |> cat_cb()
84+
mk_export("keras.initializers.STFT")$dump |> cat_cb()
85+
mk_export("keras.metrics.ConcordanceCorrelation")$dump |> cat_cb()
86+
mk_export("keras.metrics.PearsonCorrelation")$dump |> cat_cb()
8387

8488

8589
mk_export("keras.Model.set_state_tree")$dump |> cat_cb()
8690
mk_export("keras.layers.Solarization")$dump |> cat_cb()
8791

8892
catched <- character()
8993
catch <- function(...) catched <<- c(catched, "\n\n", ...)
90-
mk_export("keras.ops.bitwise_and")$dump |> catch()
91-
mk_export("keras.ops.bitwise_invert")$dump |> catch()
92-
mk_export("keras.ops.bitwise_left_shift")$dump |> catch()
93-
mk_export("keras.ops.bitwise_not")$dump |> catch()
94-
mk_export("keras.ops.bitwise_or")$dump |> catch()
95-
mk_export("keras.ops.bitwise_right_shift")$dump |> catch()
96-
mk_export("keras.ops.bitwise_xor")$dump |> catch()
97-
mk_export("keras.ops.dot_product_attention")$dump |> catch()
98-
mk_export("keras.ops.histogram")$dump |> catch()
99-
mk_export("keras.ops.left_shift")$dump |> catch()
100-
mk_export("keras.ops.right_shift")$dump |> catch()
101-
mk_export("keras.ops.logdet")$dump |> catch()
102-
mk_export("keras.ops.saturate_cast")$dump |> catch()
103-
mk_export("keras.ops.trunc")$dump |> catch()
94+
mk_export("keras.ops.exp2")$dump |> catch()
95+
mk_export("keras.ops.inner")$dump |> catch()
96+
97+
# mk_export("keras.ops.glu")$dump |> catch()
98+
# mk_export("keras.ops.hard_shrink")$dump |> catch()
99+
# mk_export("keras.ops.hard_tanh")$dump |> catch()
100+
# mk_export("keras.ops.soft_shrink")$dump |> catch()
101+
# mk_export("keras.ops.squareplus")$dump |> catch()
102+
# mk_export("keras.ops.tanh_shrink")$dump |> catch()
103+
# mk_export("keras.ops.celu")$dump |> catch()
104+
105+
#
106+
# mk_export("keras.ops.dot_product_attention")$dump |> catch()
107+
# mk_export("keras.ops.histogram")$dump |> catch()
108+
# mk_export("keras.ops.left_shift")$dump |> catch()
109+
# mk_export("keras.ops.right_shift")$dump |> catch()
110+
# mk_export("keras.ops.logdet")$dump |> catch()
111+
# mk_export("keras.ops.saturate_cast")$dump |> catch()
112+
# mk_export("keras.ops.trunc")$dump |> catch()
104113

105114

106115
catched |> str_flatten_and_compact_lines(roxygen = TRUE) |> cat_cb()
@@ -120,3 +129,12 @@ view_vignette_adaptation_diff <- function(rmd_file) {
120129
}
121130

122131
# view_vignette_adaptation_diff("vignettes-src/writing_a_training_loop_from_scratch.Rmd")
132+
133+
add 7 new nn ops
134+
op_glu()
135+
op_hard_shrink()
136+
op_hard_tanh()
137+
op_soft_shrink()
138+
op_squareplus()
139+
op_tanh_shrink()
140+
op_celu()

0 commit comments

Comments
 (0)