2525from core .doc .source_document_data_orm import SourceDocumentDataORM
2626from core .doc .source_document_orm import SourceDocumentORM
2727from core .tag .tag_crud import crud_tag
28- from core .tag .tag_orm import TagORM
28+ from core .tag .tag_orm import SourceDocumentTagLinkTable , TagORM
2929from modules .classifier .classifier_crud import crud_classifier
3030from modules .classifier .classifier_dto import (
3131 ClassifierCreate ,
@@ -234,7 +234,8 @@ def _retrieve_and_build_dataset(
234234 class_ids : list [int ],
235235 classid2labelid : dict [int , int ],
236236 tokenizer ,
237- ) -> tuple [dict [int , list [TagORM ]], Dataset ]:
237+ use_chunking : bool ,
238+ ) -> tuple [dict [int , list [int ]], Dataset ]:
238239 # Find documents
239240 sdoc_ids = [
240241 sdoc .id
@@ -248,10 +249,7 @@ def _retrieve_and_build_dataset(
248249
249250 # Get annotations
250251 results = (
251- db .query (
252- SourceDocumentORM ,
253- TagORM ,
254- )
252+ db .query (SourceDocumentTagLinkTable )
255253 .filter (
256254 SourceDocumentORM .id .in_ (sdoc_ids ),
257255 SourceDocumentORM .tags .any (TagORM .id .in_ (class_ids )),
@@ -269,12 +267,11 @@ def _retrieve_and_build_dataset(
269267 sdocid2data [sdoc_data .id ] = sdoc_data
270268
271269 # Group classifications by source document
272- sdoc_id2annotations : dict [int , list [TagORM ]] = {
270+ sdoc_id2annotation_ids : dict [int , list [int ]] = {
273271 sdoc_id : [] for sdoc_id in sdoc_ids
274272 }
275273 for row in results :
276- doc , tag = row ._tuple ()
277- sdoc_id2annotations [doc .id ].append (tag )
274+ sdoc_id2annotation_ids [row .source_document_id ].append (row .tag_id )
278275
279276 # Create a labeled dataset
280277 # Every source document is part of the training data
@@ -288,8 +285,8 @@ def _retrieve_and_build_dataset(
288285 continue # skip documents without data
289286
290287 # We only use the first annotation (if multiple exist)
291- annotations = sdoc_id2annotations [sdoc_id ]
292- annotation = annotations [0 ]. id if len (annotations ) > 0 else 0
288+ annotations = sdoc_id2annotation_ids [sdoc_id ]
289+ annotation = annotations [0 ] if len (annotations ) > 0 else 0
293290 dataset .append (
294291 {
295292 "sdoc_id" : sdoc_data .id ,
@@ -303,13 +300,19 @@ def _retrieve_and_build_dataset(
303300
304301 # Construct a tokenized huggingface dataset
305302 def tokenize_text (examples ):
306- return tokenizer (examples ["text" ], truncation = True )
303+ tokenized_inputs = tokenizer (
304+ examples ["text" ],
305+ truncation = not use_chunking ,
306+ is_split_into_words = False ,
307+ add_special_tokens = not use_chunking ,
308+ )
309+ return tokenized_inputs
307310
308311 hf_dataset = Dataset .from_list (dataset ) # type: ignore
309312 tokenized_hf_dataset = hf_dataset .map (tokenize_text , batched = True )
310313 tokenized_hf_dataset = tokenized_hf_dataset .remove_columns (["text" ])
311314
312- return sdoc_id2annotations , tokenized_hf_dataset
315+ return sdoc_id2annotation_ids , tokenized_hf_dataset
313316
314317 def train (
315318 self , db : Session , job : Job , payload : ClassifierJobInput
@@ -328,6 +331,8 @@ def train(
328331 raise BaseModelDoesNotExistError (parameters .base_name )
329332
330333 tokenizer = AutoTokenizer .from_pretrained (parameters .base_name )
334+ if parameters .chunk_size :
335+ tokenizer .model_max_length = parameters .chunk_size
331336 data_collator = DataCollatorWithPadding (tokenizer = tokenizer )
332337
333338 job .update (
@@ -352,25 +357,66 @@ def train(
352357 id2label [0 ] = "O"
353358
354359 # Build dataset
355- sdoc_id2annotations , dataset = self ._retrieve_and_build_dataset (
360+ sdoc_id2annotation_ids , dataset = self ._retrieve_and_build_dataset (
356361 db = db ,
357362 project_id = payload .project_id ,
358363 tag_ids = parameters .tag_ids ,
359364 class_ids = parameters .class_ids ,
360365 classid2labelid = classid2labelid ,
361366 tokenizer = tokenizer ,
367+ use_chunking = True ,
362368 )
363369
364370 # Train test split
365371 split_dataset = dataset .train_test_split (test_size = 0.2 , seed = 42 )
372+
373+ def split_in_chunks (examples : dict ):
374+ chunk_len = tokenizer .model_max_length - 2
375+ for key , values in examples .items ():
376+ if "labels" == key :
377+ continue
378+ elif "attention_mask" == key :
379+ pre = [1 ]
380+ post = [1 ]
381+ elif "input_ids" == key :
382+ pre = [tokenizer .added_tokens_encoder ["[CLS]" ]]
383+ post = [tokenizer .added_tokens_encoder ["[SEP]" ]]
384+ else :
385+ raise ValueError (f"Unsupported { key } in batch examples dict" )
386+
387+ result = []
388+ original_labels = examples ["labels" ]
389+ labels = []
390+ for val , lab in zip (values , original_labels ):
391+ chunks = [
392+ pre + val [i : i + chunk_len ] + post
393+ for i in range (0 , len (val ), chunk_len )
394+ ]
395+ result .extend (chunks )
396+ labels .extend ([lab ] * len (chunks ))
397+ examples [key ] = result
398+ examples ["labels" ] = labels
399+ return examples
400+
401+ train_dataset = (
402+ split_dataset ["train" ]
403+ .remove_columns (["sdoc_id" ])
404+ .map (split_in_chunks , batched = True )
405+ )
406+ val_dataset = (
407+ split_dataset ["test" ]
408+ .remove_columns (["sdoc_id" ])
409+ .map (split_in_chunks , batched = True )
410+ )
411+
366412 train_dataloader = DataLoader (
367- split_dataset [ "train" ] , # type: ignore
413+ train_dataset , # type: ignore
368414 shuffle = True ,
369415 collate_fn = data_collator ,
370416 batch_size = parameters .batch_size ,
371417 )
372418 val_dataloader = DataLoader (
373- split_dataset [ "test" ] , # type: ignore
419+ val_dataset , # type: ignore
374420 shuffle = False ,
375421 collate_fn = data_collator ,
376422 batch_size = parameters .batch_size ,
@@ -379,13 +425,13 @@ def train(
379425 # Dataset statistics (number of annotations per code)
380426 train_dataset_stats : dict [int , int ] = {tag .id : 0 for tag in tags }
381427 for sdoc_id in split_dataset ["train" ]["sdoc_id" ]:
382- for annotation in sdoc_id2annotations [sdoc_id ]:
383- train_dataset_stats [annotation . id ] += 1
428+ for annotation in sdoc_id2annotation_ids [sdoc_id ]:
429+ train_dataset_stats [annotation ] += 1
384430
385431 eval_dataset_stats : dict [int , int ] = {tag .id : 0 for tag in tags }
386432 for sdoc_id in split_dataset ["test" ]["sdoc_id" ]:
387- for annotation in sdoc_id2annotations [sdoc_id ]:
388- eval_dataset_stats [annotation . id ] += 1
433+ for annotation in sdoc_id2annotation_ids [sdoc_id ]:
434+ eval_dataset_stats [annotation ] += 1
389435
390436 # Calculate class weights
391437 # Count the occurrences of each label in the training set
@@ -568,13 +614,14 @@ def eval(
568614 job .update (current_step = 2 )
569615
570616 # Build dataset
571- sdoc_id2annotations , dataset = self ._retrieve_and_build_dataset (
617+ sdoc_id2annotation_ids , dataset = self ._retrieve_and_build_dataset (
572618 db = db ,
573619 project_id = payload .project_id ,
574620 tag_ids = parameters .tag_ids ,
575621 class_ids = classifier .class_ids ,
576622 classid2labelid = classid2labelid ,
577623 tokenizer = tokenizer ,
624+ use_chunking = False ,
578625 )
579626
580627 # Build dataloader
@@ -590,8 +637,8 @@ def eval(
590637 tag_id : 0 for tag_id , label_id in classid2labelid .items () if label_id != 0
591638 }
592639 for sdoc_id in dataset ["sdoc_id" ]:
593- for annotation in sdoc_id2annotations [sdoc_id ]:
594- eval_dataset_stats [annotation . id ] += 1
640+ for annotation in sdoc_id2annotation_ids [sdoc_id ]:
641+ eval_dataset_stats [annotation ] += 1
595642
596643 # 3. Load the model
597644 job .update (current_step = 3 )
0 commit comments