Skip to content

Commit d669bfb

Browse files
committed
add test for get_vocabulary()
1 parent 45ca0fc commit d669bfb

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

R/layers-preprocessing.R

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2337,11 +2337,6 @@ function (object, max_tokens = NULL, standardize = "lower_and_strip_punctuation"
23372337
}
23382338

23392339

2340-
2341-
# TODO: add tests/ confirm that `get_vocabulary()` returns an R character
2342-
# vector. In older TF versions it used to return python byte objects, which
2343-
# needed `x.decode("UTF-8") for x in vocab]`
2344-
23452340
#' @param include_special_tokens If TRUE, the returned vocabulary will include
23462341
#' the padding and OOV tokens, and a term's index in the vocabulary will equal
23472342
#' the term's index when calling the layer. If FALSE, the returned vocabulary

tests/testthat/test-layer-text_vectorization.R

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,18 @@ test_call_succeeds("can create a tf-idf layer", {
114114
expect_s3_class(x, "tensorflow.tensor")
115115

116116
})
117+
118+
119+
120+
test_call_succeeds("get_vocabulary() returns R character vector", {
121+
122+
text_vectorization <- layer_text_vectorization()
123+
with(tf$device("/cpu:0"), {
124+
text_vectorization %>% adapt(c("hello world", "hello"))
125+
})
126+
vocab <- get_vocabulary(text_vectorization)
127+
128+
expect_type(vocab, "character")
129+
expect_contains(vocab, c("hello", "world"))
130+
131+
})

0 commit comments

Comments
 (0)