Skip to content

Commit 8ad73f7

Browse files
authored
Add borrowed versions of segmenter types (#6395)
Fixes #5514 This does mean all the borrowed methods take the borrowed types by-move, even the larger ones, since the data then needs to be further reborrowed by the iterators.
1 parent 63f910b commit 8ad73f7

33 files changed

+818
-607
lines changed

components/segmenter/src/complex/dictionary.rs

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -138,21 +138,21 @@ impl<'s> DictionaryType<'_, 's> for char {
138138

139139
pub(super) struct DictionarySegmenter<'l> {
140140
dict: &'l UCharDictionaryBreakData<'l>,
141-
grapheme: &'l RuleBreakData<'l>,
141+
grapheme: GraphemeClusterSegmenterBorrowed<'l>,
142142
}
143143

144144
impl<'l> DictionarySegmenter<'l> {
145145
pub(super) fn new(
146146
dict: &'l UCharDictionaryBreakData<'l>,
147-
grapheme: &'l RuleBreakData<'l>,
147+
grapheme: GraphemeClusterSegmenterBorrowed<'l>,
148148
) -> Self {
149149
// TODO: no way to verify trie data
150150
Self { dict, grapheme }
151151
}
152152

153153
/// Create a dictionary based break iterator for an `str` (a UTF-8 string).
154154
pub(super) fn segment_str(&'l self, input: &'l str) -> impl Iterator<Item = usize> + 'l {
155-
let grapheme_iter = GraphemeClusterSegmenter::new_and_segment_str(input, self.grapheme);
155+
let grapheme_iter = self.grapheme.segment_str(input);
156156
DictionaryBreakIterator::<char, GraphemeClusterBreakIteratorUtf8> {
157157
trie: Char16Trie::new(self.dict.trie_data.clone()),
158158
iter: input.char_indices(),
@@ -163,7 +163,7 @@ impl<'l> DictionarySegmenter<'l> {
163163

164164
/// Create a dictionary based break iterator for a UTF-16 string.
165165
pub(super) fn segment_utf16(&'l self, input: &'l [u16]) -> impl Iterator<Item = usize> + 'l {
166-
let grapheme_iter = GraphemeClusterSegmenter::new_and_segment_utf16(input, self.grapheme);
166+
let grapheme_iter = self.grapheme.segment_utf16(input);
167167
DictionaryBreakIterator::<u32, GraphemeClusterBreakIteratorUtf16> {
168168
trie: Char16Trie::new(self.dict.trie_data.clone()),
169169
iter: Utf16Indices::new(input),
@@ -177,7 +177,7 @@ impl<'l> DictionarySegmenter<'l> {
177177
#[cfg(feature = "serde")]
178178
mod tests {
179179
use super::*;
180-
use crate::{LineSegmenter, WordSegmenter};
180+
use crate::{GraphemeClusterSegmenter, LineSegmenter, WordSegmenter};
181181
use icu_provider::prelude::*;
182182

183183
#[test]
@@ -204,10 +204,8 @@ mod tests {
204204
})
205205
.unwrap();
206206
let word_segmenter = WordSegmenter::new_dictionary(Default::default());
207-
let dict_segmenter = DictionarySegmenter::new(
208-
response.payload.get(),
209-
crate::provider::Baked::SINGLETON_SEGMENTER_BREAK_GRAPHEME_CLUSTER_V1,
210-
);
207+
let dict_segmenter =
208+
DictionarySegmenter::new(response.payload.get(), GraphemeClusterSegmenter::new());
211209

212210
// Match case
213211
let s = "龟山岛龟山岛";

components/segmenter/src/complex/lstm/mod.rs

Lines changed: 43 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// called LICENSE at the top level of the ICU4X source tree
33
// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).
44

5-
use crate::grapheme::GraphemeClusterSegmenter;
5+
use crate::grapheme::GraphemeClusterSegmenterBorrowed;
66
use crate::provider::*;
77
use alloc::vec::Vec;
88
use core::char::{decode_utf16, REPLACEMENT_CHARACTER};
@@ -14,13 +14,13 @@ use matrix::*;
1414

1515
// A word break iterator using LSTM model. Input string have to be same language.
1616

17-
struct LstmSegmenterIterator<'s> {
17+
pub(super) struct LstmSegmenterIterator<'s, 'data> {
1818
input: &'s str,
1919
pos_utf8: usize,
20-
bies: BiesIterator<'s>,
20+
bies: BiesIterator<'s, 'data>,
2121
}
2222

23-
impl Iterator for LstmSegmenterIterator<'_> {
23+
impl Iterator for LstmSegmenterIterator<'_, '_> {
2424
type Item = usize;
2525

2626
fn next(&mut self) -> Option<Self::Item> {
@@ -35,12 +35,12 @@ impl Iterator for LstmSegmenterIterator<'_> {
3535
}
3636
}
3737

38-
struct LstmSegmenterIteratorUtf16<'s> {
39-
bies: BiesIterator<'s>,
38+
pub(super) struct LstmSegmenterIteratorUtf16<'s, 'data> {
39+
bies: BiesIterator<'s, 'data>,
4040
pos: usize,
4141
}
4242

43-
impl Iterator for LstmSegmenterIteratorUtf16<'_> {
43+
impl Iterator for LstmSegmenterIteratorUtf16<'_, '_> {
4444
type Item = usize;
4545

4646
fn next(&mut self) -> Option<Self::Item> {
@@ -53,24 +53,27 @@ impl Iterator for LstmSegmenterIteratorUtf16<'_> {
5353
}
5454
}
5555

56-
pub(super) struct LstmSegmenter<'l> {
57-
dic: ZeroMapBorrowed<'l, PotentialUtf8, u16>,
58-
embedding: MatrixZero<'l, 2>,
59-
fw_w: MatrixZero<'l, 3>,
60-
fw_u: MatrixZero<'l, 3>,
61-
fw_b: MatrixZero<'l, 2>,
62-
bw_w: MatrixZero<'l, 3>,
63-
bw_u: MatrixZero<'l, 3>,
64-
bw_b: MatrixZero<'l, 2>,
65-
timew_fw: MatrixZero<'l, 2>,
66-
timew_bw: MatrixZero<'l, 2>,
67-
time_b: MatrixZero<'l, 1>,
68-
grapheme: Option<&'l RuleBreakData<'l>>,
56+
pub(super) struct LstmSegmenter<'data> {
57+
dic: ZeroMapBorrowed<'data, PotentialUtf8, u16>,
58+
embedding: MatrixZero<'data, 2>,
59+
fw_w: MatrixZero<'data, 3>,
60+
fw_u: MatrixZero<'data, 3>,
61+
fw_b: MatrixZero<'data, 2>,
62+
bw_w: MatrixZero<'data, 3>,
63+
bw_u: MatrixZero<'data, 3>,
64+
bw_b: MatrixZero<'data, 2>,
65+
timew_fw: MatrixZero<'data, 2>,
66+
timew_bw: MatrixZero<'data, 2>,
67+
time_b: MatrixZero<'data, 1>,
68+
grapheme: Option<GraphemeClusterSegmenterBorrowed<'data>>,
6969
}
7070

71-
impl<'l> LstmSegmenter<'l> {
71+
impl<'data> LstmSegmenter<'data> {
7272
/// Returns `Err` if grapheme data is required but not present
73-
pub(super) fn new(lstm: &'l LstmData<'l>, grapheme: &'l RuleBreakData<'l>) -> Self {
73+
pub(super) fn new(
74+
lstm: &'data LstmData<'data>,
75+
grapheme: GraphemeClusterSegmenterBorrowed<'data>,
76+
) -> Self {
7477
let LstmData::Float32(lstm) = lstm;
7578
let time_w = MatrixZero::from(&lstm.time_w);
7679
#[allow(clippy::unwrap_used)] // shape (2, 4, hunits)
@@ -94,14 +97,10 @@ impl<'l> LstmSegmenter<'l> {
9497
}
9598

9699
/// Create an LSTM based break iterator for an `str` (a UTF-8 string).
97-
pub(super) fn segment_str(&'l self, input: &'l str) -> impl Iterator<Item = usize> + 'l {
98-
self.segment_str_p(input)
99-
}
100-
101-
// For unit testing as we cannot inspect the opaque type's bies
102-
fn segment_str_p(&'l self, input: &'l str) -> LstmSegmenterIterator<'l> {
100+
pub(super) fn segment_str<'a>(&'a self, input: &'a str) -> LstmSegmenterIterator<'a, 'data> {
103101
let input_seq = if let Some(grapheme) = self.grapheme {
104-
GraphemeClusterSegmenter::new_and_segment_str(input, grapheme)
102+
grapheme
103+
.segment_str(input)
105104
.collect::<Vec<usize>>()
106105
.windows(2)
107106
.map(|chunk| {
@@ -139,9 +138,13 @@ impl<'l> LstmSegmenter<'l> {
139138
}
140139

141140
/// Create an LSTM based break iterator for a UTF-16 string.
142-
pub(super) fn segment_utf16(&'l self, input: &[u16]) -> impl Iterator<Item = usize> + 'l {
141+
pub(super) fn segment_utf16<'a>(
142+
&'a self,
143+
input: &[u16],
144+
) -> LstmSegmenterIteratorUtf16<'a, 'data> {
143145
let input_seq = if let Some(grapheme) = self.grapheme {
144-
GraphemeClusterSegmenter::new_and_segment_utf16(input, grapheme)
146+
grapheme
147+
.segment_utf16(input)
145148
.collect::<Vec<usize>>()
146149
.windows(2)
147150
.map(|chunk| {
@@ -189,18 +192,18 @@ impl<'l> LstmSegmenter<'l> {
189192
}
190193
}
191194

192-
struct BiesIterator<'l> {
193-
segmenter: &'l LstmSegmenter<'l>,
195+
struct BiesIterator<'l, 'data> {
196+
segmenter: &'l LstmSegmenter<'data>,
194197
input_seq: core::iter::Enumerate<alloc::vec::IntoIter<u16>>,
195198
h_bw: MatrixOwned<2>,
196199
curr_fw: MatrixOwned<1>,
197200
c_fw: MatrixOwned<1>,
198201
}
199202

200-
impl<'l> BiesIterator<'l> {
203+
impl<'l, 'data> BiesIterator<'l, 'data> {
201204
// input_seq is a sequence of id numbers that represents grapheme clusters or code points in the input line. These ids are used later
202205
// in the embedding layer of the model.
203-
fn new(segmenter: &'l LstmSegmenter<'l>, input_seq: Vec<u16>) -> Self {
206+
fn new(segmenter: &'l LstmSegmenter<'data>, input_seq: Vec<u16>) -> Self {
204207
let hunits = segmenter.fw_u.dim().1;
205208

206209
// Backward LSTM
@@ -231,13 +234,13 @@ impl<'l> BiesIterator<'l> {
231234
}
232235
}
233236

234-
impl ExactSizeIterator for BiesIterator<'_> {
237+
impl ExactSizeIterator for BiesIterator<'_, '_> {
235238
fn len(&self) -> usize {
236239
self.input_seq.len()
237240
}
238241
}
239242

240-
impl Iterator for BiesIterator<'_> {
243+
impl Iterator for BiesIterator<'_, '_> {
241244
type Item = bool;
242245

243246
fn next(&mut self) -> Option<Self::Item> {
@@ -321,6 +324,7 @@ fn compute_hc<'a>(
321324
#[cfg(test)]
322325
mod tests {
323326
use super::*;
327+
use crate::GraphemeClusterSegmenter;
324328
use icu_provider::prelude::*;
325329
use serde::Deserialize;
326330

@@ -357,10 +361,7 @@ mod tests {
357361
..Default::default()
358362
})
359363
.unwrap();
360-
let lstm = LstmSegmenter::new(
361-
lstm.payload.get(),
362-
crate::provider::Baked::SINGLETON_SEGMENTER_BREAK_GRAPHEME_CLUSTER_V1,
363-
);
364+
let lstm = LstmSegmenter::new(lstm.payload.get(), GraphemeClusterSegmenter::new());
364365

365366
// Importing the test data
366367
let test_text_data = serde_json::from_str(if lstm.grapheme.is_some() {
@@ -376,7 +377,7 @@ mod tests {
376377
// Testing
377378
for test_case in &test_text.data.testcases {
378379
let lstm_output = lstm
379-
.segment_str_p(&test_case.unseg)
380+
.segment_str(&test_case.unseg)
380381
.bies
381382
.map(|is_e| if is_e { 'e' } else { '?' })
382383
.collect::<String>();

0 commit comments

Comments
 (0)