2
2
// called LICENSE at the top level of the ICU4X source tree
3
3
// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).
4
4
5
- use crate :: grapheme:: GraphemeClusterSegmenter ;
5
+ use crate :: grapheme:: GraphemeClusterSegmenterBorrowed ;
6
6
use crate :: provider:: * ;
7
7
use alloc:: vec:: Vec ;
8
8
use core:: char:: { decode_utf16, REPLACEMENT_CHARACTER } ;
@@ -14,13 +14,13 @@ use matrix::*;
14
14
15
15
// A word break iterator using LSTM model. Input string have to be same language.
16
16
17
- struct LstmSegmenterIterator < ' s > {
17
+ pub ( super ) struct LstmSegmenterIterator < ' s , ' data > {
18
18
input : & ' s str ,
19
19
pos_utf8 : usize ,
20
- bies : BiesIterator < ' s > ,
20
+ bies : BiesIterator < ' s , ' data > ,
21
21
}
22
22
23
- impl Iterator for LstmSegmenterIterator < ' _ > {
23
+ impl Iterator for LstmSegmenterIterator < ' _ , ' _ > {
24
24
type Item = usize ;
25
25
26
26
fn next ( & mut self ) -> Option < Self :: Item > {
@@ -35,12 +35,12 @@ impl Iterator for LstmSegmenterIterator<'_> {
35
35
}
36
36
}
37
37
38
- struct LstmSegmenterIteratorUtf16 < ' s > {
39
- bies : BiesIterator < ' s > ,
38
+ pub ( super ) struct LstmSegmenterIteratorUtf16 < ' s , ' data > {
39
+ bies : BiesIterator < ' s , ' data > ,
40
40
pos : usize ,
41
41
}
42
42
43
- impl Iterator for LstmSegmenterIteratorUtf16 < ' _ > {
43
+ impl Iterator for LstmSegmenterIteratorUtf16 < ' _ , ' _ > {
44
44
type Item = usize ;
45
45
46
46
fn next ( & mut self ) -> Option < Self :: Item > {
@@ -53,24 +53,27 @@ impl Iterator for LstmSegmenterIteratorUtf16<'_> {
53
53
}
54
54
}
55
55
56
- pub ( super ) struct LstmSegmenter < ' l > {
57
- dic : ZeroMapBorrowed < ' l , PotentialUtf8 , u16 > ,
58
- embedding : MatrixZero < ' l , 2 > ,
59
- fw_w : MatrixZero < ' l , 3 > ,
60
- fw_u : MatrixZero < ' l , 3 > ,
61
- fw_b : MatrixZero < ' l , 2 > ,
62
- bw_w : MatrixZero < ' l , 3 > ,
63
- bw_u : MatrixZero < ' l , 3 > ,
64
- bw_b : MatrixZero < ' l , 2 > ,
65
- timew_fw : MatrixZero < ' l , 2 > ,
66
- timew_bw : MatrixZero < ' l , 2 > ,
67
- time_b : MatrixZero < ' l , 1 > ,
68
- grapheme : Option < & ' l RuleBreakData < ' l > > ,
56
+ pub ( super ) struct LstmSegmenter < ' data > {
57
+ dic : ZeroMapBorrowed < ' data , PotentialUtf8 , u16 > ,
58
+ embedding : MatrixZero < ' data , 2 > ,
59
+ fw_w : MatrixZero < ' data , 3 > ,
60
+ fw_u : MatrixZero < ' data , 3 > ,
61
+ fw_b : MatrixZero < ' data , 2 > ,
62
+ bw_w : MatrixZero < ' data , 3 > ,
63
+ bw_u : MatrixZero < ' data , 3 > ,
64
+ bw_b : MatrixZero < ' data , 2 > ,
65
+ timew_fw : MatrixZero < ' data , 2 > ,
66
+ timew_bw : MatrixZero < ' data , 2 > ,
67
+ time_b : MatrixZero < ' data , 1 > ,
68
+ grapheme : Option < GraphemeClusterSegmenterBorrowed < ' data > > ,
69
69
}
70
70
71
- impl < ' l > LstmSegmenter < ' l > {
71
+ impl < ' data > LstmSegmenter < ' data > {
72
72
/// Returns `Err` if grapheme data is required but not present
73
- pub ( super ) fn new ( lstm : & ' l LstmData < ' l > , grapheme : & ' l RuleBreakData < ' l > ) -> Self {
73
+ pub ( super ) fn new (
74
+ lstm : & ' data LstmData < ' data > ,
75
+ grapheme : GraphemeClusterSegmenterBorrowed < ' data > ,
76
+ ) -> Self {
74
77
let LstmData :: Float32 ( lstm) = lstm;
75
78
let time_w = MatrixZero :: from ( & lstm. time_w ) ;
76
79
#[ allow( clippy:: unwrap_used) ] // shape (2, 4, hunits)
@@ -94,14 +97,10 @@ impl<'l> LstmSegmenter<'l> {
94
97
}
95
98
96
99
/// Create an LSTM based break iterator for an `str` (a UTF-8 string).
97
- pub ( super ) fn segment_str ( & ' l self , input : & ' l str ) -> impl Iterator < Item = usize > + ' l {
98
- self . segment_str_p ( input)
99
- }
100
-
101
- // For unit testing as we cannot inspect the opaque type's bies
102
- fn segment_str_p ( & ' l self , input : & ' l str ) -> LstmSegmenterIterator < ' l > {
100
+ pub ( super ) fn segment_str < ' a > ( & ' a self , input : & ' a str ) -> LstmSegmenterIterator < ' a , ' data > {
103
101
let input_seq = if let Some ( grapheme) = self . grapheme {
104
- GraphemeClusterSegmenter :: new_and_segment_str ( input, grapheme)
102
+ grapheme
103
+ . segment_str ( input)
105
104
. collect :: < Vec < usize > > ( )
106
105
. windows ( 2 )
107
106
. map ( |chunk| {
@@ -139,9 +138,13 @@ impl<'l> LstmSegmenter<'l> {
139
138
}
140
139
141
140
/// Create an LSTM based break iterator for a UTF-16 string.
142
- pub ( super ) fn segment_utf16 ( & ' l self , input : & [ u16 ] ) -> impl Iterator < Item = usize > + ' l {
141
+ pub ( super ) fn segment_utf16 < ' a > (
142
+ & ' a self ,
143
+ input : & [ u16 ] ,
144
+ ) -> LstmSegmenterIteratorUtf16 < ' a , ' data > {
143
145
let input_seq = if let Some ( grapheme) = self . grapheme {
144
- GraphemeClusterSegmenter :: new_and_segment_utf16 ( input, grapheme)
146
+ grapheme
147
+ . segment_utf16 ( input)
145
148
. collect :: < Vec < usize > > ( )
146
149
. windows ( 2 )
147
150
. map ( |chunk| {
@@ -189,18 +192,18 @@ impl<'l> LstmSegmenter<'l> {
189
192
}
190
193
}
191
194
192
- struct BiesIterator < ' l > {
193
- segmenter : & ' l LstmSegmenter < ' l > ,
195
+ struct BiesIterator < ' l , ' data > {
196
+ segmenter : & ' l LstmSegmenter < ' data > ,
194
197
input_seq : core:: iter:: Enumerate < alloc:: vec:: IntoIter < u16 > > ,
195
198
h_bw : MatrixOwned < 2 > ,
196
199
curr_fw : MatrixOwned < 1 > ,
197
200
c_fw : MatrixOwned < 1 > ,
198
201
}
199
202
200
- impl < ' l > BiesIterator < ' l > {
203
+ impl < ' l , ' data > BiesIterator < ' l , ' data > {
201
204
// input_seq is a sequence of id numbers that represents grapheme clusters or code points in the input line. These ids are used later
202
205
// in the embedding layer of the model.
203
- fn new ( segmenter : & ' l LstmSegmenter < ' l > , input_seq : Vec < u16 > ) -> Self {
206
+ fn new ( segmenter : & ' l LstmSegmenter < ' data > , input_seq : Vec < u16 > ) -> Self {
204
207
let hunits = segmenter. fw_u . dim ( ) . 1 ;
205
208
206
209
// Backward LSTM
@@ -231,13 +234,13 @@ impl<'l> BiesIterator<'l> {
231
234
}
232
235
}
233
236
234
- impl ExactSizeIterator for BiesIterator < ' _ > {
237
+ impl ExactSizeIterator for BiesIterator < ' _ , ' _ > {
235
238
fn len ( & self ) -> usize {
236
239
self . input_seq . len ( )
237
240
}
238
241
}
239
242
240
- impl Iterator for BiesIterator < ' _ > {
243
+ impl Iterator for BiesIterator < ' _ , ' _ > {
241
244
type Item = bool ;
242
245
243
246
fn next ( & mut self ) -> Option < Self :: Item > {
@@ -321,6 +324,7 @@ fn compute_hc<'a>(
321
324
#[ cfg( test) ]
322
325
mod tests {
323
326
use super :: * ;
327
+ use crate :: GraphemeClusterSegmenter ;
324
328
use icu_provider:: prelude:: * ;
325
329
use serde:: Deserialize ;
326
330
@@ -357,10 +361,7 @@ mod tests {
357
361
..Default :: default ( )
358
362
} )
359
363
. unwrap ( ) ;
360
- let lstm = LstmSegmenter :: new (
361
- lstm. payload . get ( ) ,
362
- crate :: provider:: Baked :: SINGLETON_SEGMENTER_BREAK_GRAPHEME_CLUSTER_V1 ,
363
- ) ;
364
+ let lstm = LstmSegmenter :: new ( lstm. payload . get ( ) , GraphemeClusterSegmenter :: new ( ) ) ;
364
365
365
366
// Importing the test data
366
367
let test_text_data = serde_json:: from_str ( if lstm. grapheme . is_some ( ) {
@@ -376,7 +377,7 @@ mod tests {
376
377
// Testing
377
378
for test_case in & test_text. data . testcases {
378
379
let lstm_output = lstm
379
- . segment_str_p ( & test_case. unseg )
380
+ . segment_str ( & test_case. unseg )
380
381
. bies
381
382
. map ( |is_e| if is_e { 'e' } else { '?' } )
382
383
. collect :: < String > ( ) ;
0 commit comments