Skip to content

Commit 101cc41

Browse files
committed
new text_dataset_from_directory()
1 parent 3c40170 commit 101cc41

File tree

4 files changed

+196
-0
lines changed

4 files changed

+196
-0
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,7 @@ export(shape)
579579
export(skipgrams)
580580
export(tensorboard)
581581
export(test_on_batch)
582+
export(text_dataset_from_directory)
582583
export(text_hashing_trick)
583584
export(text_one_hot)
584585
export(text_to_word_sequence)

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
- `layer_stacked_rnn_cells()`
1414
To learn more, including how to make a custom cell layer, see the new vignette:
1515
"Working with RNNs".
16+
17+
- New dataset loader `text_dataset_from_directory()`.
1618

1719
- New layers:
1820
- `layer_additive_attention()`

R/preprocessing.R

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,3 +1050,99 @@ image_dataset_from_directory <- function(
10501050
class(out) <- c("tf_dataset", class(out))
10511051
out
10521052
}
1053+
1054+
#' Generate a `tf.data.Dataset` from text files in a directory
1055+
#'
1056+
#' @details
1057+
#' If your directory structure is:
1058+
#'
1059+
#' ```
1060+
#' main_directory/
1061+
#' ...class_a/
1062+
#' ......a_text_1.txt
1063+
#' ......a_text_2.txt
1064+
#' ...class_b/
1065+
#' ......b_text_1.txt
1066+
#' ......b_text_2.txt
1067+
#' ```
1068+
#'
1069+
#' Then calling `text_dataset_from_directory(main_directory, labels = 'inferred')`
1070+
#' will return a `tf.data.Dataset` that yields batches of texts from
1071+
#' the subdirectories `class_a` and `class_b`, together with labels
1072+
#' 0 and 1 (0 corresponding to `class_a` and 1 corresponding to `class_b`).
1073+
#'
1074+
#' Only `.txt` files are supported at this time.
1075+
#'
1076+
#' @param directory Directory where the data is located.
1077+
#' If `labels` is "inferred", it should contain
1078+
#' subdirectories, each containing text files for a class.
1079+
#' Otherwise, the directory structure is ignored.
1080+
#'
1081+
#' @param labels Either "inferred"
1082+
#' (labels are generated from the directory structure),
1083+
#' NULL (no labels),
1084+
#' or a list of integer labels of the same size as the number of
1085+
#' text files found in the directory. Labels should be sorted according
1086+
#' to the alphanumeric order of the text file paths
1087+
#' (obtained via `os.walk(directory)` in Python).
1088+
#'
1089+
#' @param label_mode - `'int'`: means that the labels are encoded as integers
1090+
#' (e.g. for `sparse_categorical_crossentropy` loss).
1091+
#' - `'categorical'` means that the labels are
1092+
#' encoded as a categorical vector
1093+
#' (e.g. for `categorical_crossentropy` loss).
1094+
#' - `'binary'` means that the labels (there can be only 2)
1095+
#' are encoded as `float32` scalars with values 0 or 1
1096+
#' (e.g. for `binary_crossentropy`).
1097+
#' - `NULL` (no labels).
1098+
#'
1099+
#' @param class_names Only valid if `labels` is `"inferred"`. This is the explicit
1100+
#' list of class names (must match names of subdirectories). Used
1101+
#' to control the order of the classes
1102+
#' (otherwise alphanumerical order is used).
1103+
#'
1104+
#' @param batch_size Size of the batches of data. Default: `32`.
1105+
#'
1106+
#' @param max_length Maximum size of a text string. Texts longer than this will
1107+
#' be truncated to `max_length`.
1108+
#'
1109+
#' @param shuffle Whether to shuffle the data. Default: `TRUE`.
1110+
#' If set to `FALSE`, sorts the data in alphanumeric order.
1111+
#'
1112+
#' @param seed Optional random seed for shuffling and transformations.
1113+
#'
1114+
#' @param validation_split Optional float between 0 and 1,
1115+
#' fraction of data to reserve for validation.
1116+
#'
1117+
#' @param subset One of "training" or "validation".
1118+
#' Only used if `validation_split` is set.
1119+
#'
1120+
#' @param follow_links Whether to visits subdirectories pointed to by symlinks.
1121+
#' Defaults to `FALSE`.
1122+
#'
1123+
#' @param ... For future compatibility (unused presently).
1124+
#'
1125+
#' @seealso
1126+
#' + <https://www.tensorflow.org/api_docs/python/tf/keras/utils/text_dataset_from_directory>
1127+
#'
1128+
#' @export
1129+
text_dataset_from_directory <-
1130+
function(directory,
1131+
labels = "inferred",
1132+
label_mode = "int",
1133+
class_names = NULL,
1134+
batch_size = 32L,
1135+
max_length = NULL,
1136+
shuffle = TRUE,
1137+
seed = NULL,
1138+
validation_split = NULL,
1139+
subset = NULL,
1140+
follow_links = FALSE
1141+
)
1142+
{
1143+
args <- capture_args(match.call(),
1144+
list(batch_size = as.integer,
1145+
max_length = as_nullable_integer,
1146+
seed = as_nullable_integer))
1147+
do.call(keras$preprocessing$text_dataset_from_directory, args)
1148+
}

man/text_dataset_from_directory.Rd

Lines changed: 97 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)