Skip to content

Commit e78ae37

Browse files
authored
Merge pull request #895 from dfalbel/feature/text_vectorization
Initial implementation of text_vectorization layers
2 parents 5d46e9e + 93652bd commit e78ae37

File tree

8 files changed

+389
-1
lines changed

8 files changed

+389
-1
lines changed

NAMESPACE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ export(activation_softmax)
3131
export(activation_softplus)
3232
export(activation_softsign)
3333
export(activation_tanh)
34+
export(adapt)
3435
export(application_densenet)
3536
export(application_densenet121)
3637
export(application_densenet169)
@@ -106,6 +107,7 @@ export(get_layer)
106107
export(get_output_at)
107108
export(get_output_mask_at)
108109
export(get_output_shape_at)
110+
export(get_vocabulary)
109111
export(get_weights)
110112
export(hdf5_matrix)
111113
export(image_array_resize)
@@ -355,6 +357,7 @@ export(layer_spatial_dropout_1d)
355357
export(layer_spatial_dropout_2d)
356358
export(layer_spatial_dropout_3d)
357359
export(layer_subtract)
360+
export(layer_text_vectorization)
358361
export(layer_upsampling_1d)
359362
export(layer_upsampling_2d)
360363
export(layer_upsampling_3d)
@@ -438,6 +441,7 @@ export(save_model_weights_tf)
438441
export(save_text_tokenizer)
439442
export(sequences_to_matrix)
440443
export(serialize_model)
444+
export(set_vocabulary)
441445
export(set_weights)
442446
export(shape)
443447
export(skipgrams)

R/layer-methods.R

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,24 @@ as_node_index <- function(node_index) {
156156
as.integer(node_index-1)
157157
}
158158

159-
159+
#' Fits the state of the preprocessing layer to the data being passed.
160+
#'
161+
#' @param object Preprocessing layer object
162+
#' @param data The data to train on. It can be passed either as a tf.data Dataset,
163+
#' or as an R array.
164+
#' @param reset_state Optional argument specifying whether to clear the state of
165+
#' the layer at the start of the call to `adapt`, or whether to start from
166+
#' the existing state. Subclasses may choose to throw if `reset_state` is set
167+
#' to `FALSE`. `NULL` mean layer's default.
168+
#'
169+
#' @export
170+
adapt <- function(object, data, reset_state = NULL) {
171+
# layers can implement adapt with different default reset_state
172+
if (is.null(reset_state))
173+
object$adapt(data)
174+
else
175+
object$adapt(data, reset_state)
176+
}
160177

161178

162179

R/layer-text_vectorization.R

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
#' Text vectorization layer
2+
#'
3+
#' This layer has basic options for managing text in a Keras model. It
4+
#' transforms a batch of strings (one sample = one string) into either a list of
5+
#' token indices (one sample = 1D tensor of integer token indices) or a dense
6+
#' representation (one sample = 1D tensor of float values representing data about
7+
#' the sample's tokens).
8+
#'
9+
#' The processing of each sample contains the following steps:
10+
#'
11+
#' 1) standardize each sample (usually lowercasing + punctuation stripping)
12+
#' 2) split each sample into substrings (usually words)
13+
#' 3) recombine substrings into tokens (usually ngrams)
14+
#' 4) index tokens (associate a unique int value with each token)
15+
#' 5) transform each sample using this index, either into a vector of ints or
16+
#' a dense float vector.
17+
#'
18+
#' @inheritParams layer_dense
19+
#' @param max_tokens The maximum size of the vocabulary for this layer. If `NULL`,
20+
#' there is no cap on the size of the vocabulary.
21+
#' @param standardize Optional specification for standardization to apply to the
22+
#' input text. Values can be `NULL` (no standardization),
23+
#' `"lower_and_strip_punctuation"` (lowercase and remove punctuation) or a
24+
#' Callable. Default is `"lower_and_strip_punctuation"`.
25+
#' @param split Optional specification for splitting the input text. Values can be
26+
#' `NULL` (no splitting), `"split_on_whitespace"` (split on ASCII whitespace), or
27+
#' a Callable. Default is `"split_on_whitespace"`.
28+
#' @param ngrams Optional specification for ngrams to create from the possibly-split
29+
#' input text. Values can be `NULL`, an integer or a list of integers; passing
30+
#' an integer will create ngrams up to that integer, and passing a list of
31+
#' integers will create ngrams for the specified values in the list. Passing
32+
#' `NULL` means that no ngrams will be created.
33+
#' @param output_mode Optional specification for the output of the layer. Values can
34+
#' be `"int"`, `"binary"`, `"count"` or `"tfidf"`, which control the outputs as follows:
35+
#' * "int": Outputs integer indices, one integer index per split string token.
36+
#' * "binary": Outputs a single int array per batch, of either vocab_size or
37+
#' `max_tokens` size, containing 1s in all elements where the token mapped
38+
#' to that index exists at least once in the batch item.
39+
#' * "count": As "binary", but the int array contains a count of the number of
40+
#' times the token at that index appeared in the batch item.
41+
#' * "tfidf": As "binary", but the TF-IDF algorithm is applied to find the value
42+
#' in each token slot.
43+
#' @param output_sequence_length Only valid in "int" mode. If set, the output will have
44+
#' its time dimension padded or truncated to exactly `output_sequence_length`
45+
#' values, resulting in a tensor of shape (batch_size, output_sequence_length) regardless
46+
#' of how many tokens resulted from the splitting step. Defaults to `NULL`.
47+
#' @param pad_to_max_tokens Only valid in "binary", "count", and "tfidf" modes. If `TRUE`,
48+
#' the output will have its feature axis padded to `max_tokens` even if the
49+
#' number of unique tokens in the vocabulary is less than max_tokens,
50+
#' resulting in a tensor of shape (batch_size, max_tokens) regardless of
51+
#' vocabulary size. Defaults to `TRUE`.
52+
#' @param ... Not used.
53+
#'
54+
#' @export
55+
layer_text_vectorization <- function(object, max_tokens = NULL, standardize = "lower_and_strip_punctuation",
56+
split = "whitespace", ngrams = NULL,
57+
output_mode = c("int", "binary", "count", "tfidf"),
58+
output_sequence_length = NULL, pad_to_max_tokens = TRUE,
59+
...) {
60+
61+
if (tensorflow::tf_version() < "2.1")
62+
stop("Text Vectorization requires TensorFlow version >= 2.1", call. = FALSE)
63+
64+
if (length(ngrams) > 1)
65+
ngrams <- as_integer_tuple(ngrams)
66+
else
67+
ngrams <- as_nullable_integer(ngrams)
68+
69+
output_mode <- match.arg(output_mode)
70+
71+
args <- list(
72+
max_tokens = as_nullable_integer(max_tokens),
73+
ngrams = ngrams,
74+
output_mode = output_mode,
75+
output_sequence_length = as_nullable_integer(output_sequence_length),
76+
pad_to_max_tokens = pad_to_max_tokens
77+
)
78+
79+
# see https://github.com/tensorflow/tensorflow/pull/34420
80+
if (!identical(standardize, "lower_and_strip_punctuation"))
81+
args$standardize <- standardize
82+
83+
if (!identical(split, "whitespace"))
84+
args$split <- split
85+
86+
create_layer(resolve_text_vectorization_module(), object, args)
87+
}
88+
89+
#' Get the vocabulary for text vectorization layers
90+
#'
91+
#' @param object a text vectorization layer
92+
#'
93+
#' @seealso [set_vocabulary()]
94+
#' @export
95+
get_vocabulary <- function(object) {
96+
object$get_vocabulary()
97+
}
98+
99+
#' Sets vocabulary (and optionally document frequency) data for the layer
100+
#'
101+
#' This method sets the vocabulary and DF data for this layer directly, instead
102+
#' of analyzing a dataset through [adapt()]. It should be used whenever the `vocab`
103+
#' (and optionally document frequency) information is already known. If
104+
#' vocabulary data is already present in the layer, this method will either
105+
#' replace it, if `append` is set to `FALSE`, or append to it (if 'append' is set
106+
#' to `TRUE`)
107+
#'
108+
#' @inheritParams get_vocabulary
109+
#' @param vocab An array of string tokens.
110+
#' @param df_data An array of document frequency data. Only necessary if the layer
111+
#' output_mode is "tfidf".
112+
#' @param oov_df_value The document frequency of the OOV token. Only necessary if
113+
#' output_mode is "tfidf". OOV data is optional when appending additional
114+
#' data in "tfidf" mode; if an OOV value is supplied it will overwrite the
115+
#' existing OOV value.
116+
#' @param append Whether to overwrite or append any existing vocabulary data.
117+
#'
118+
#' @seealso [get_vocabulary()]
119+
#'
120+
#' @export
121+
set_vocabulary <- function(object, vocab, df_data = NULL, oov_df_value = FALSE,
122+
append = FALSE) {
123+
object$set_vocabulary(vocab, df_data, oov_df_value, append)
124+
}
125+
126+
127+
resolve_text_vectorization_module <- function() {
128+
if (keras_version() >= "2.2.4")
129+
keras$layers$experimental$preprocessing$TextVectorization
130+
else
131+
stop("Keras >= 2.2.4 is required", call. = FALSE)
132+
}

man/adapt.Rd

Lines changed: 22 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/get_vocabulary.Rd

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/layer_text_vectorization.Rd

Lines changed: 76 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/set_vocabulary.Rd

Lines changed: 35 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)