Skip to content

Commit 58fde9f

Browse files
fynnosbigabig
authored andcommitted
disable chunking on inference, doc-classifier chunking
1 parent 60f878e commit 58fde9f

File tree

2 files changed

+75
-25
lines changed

2 files changed

+75
-25
lines changed

backend/src/modules/classifier/models/doc_class_model_service.py

Lines changed: 70 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from core.doc.source_document_data_orm import SourceDocumentDataORM
2626
from core.doc.source_document_orm import SourceDocumentORM
2727
from core.tag.tag_crud import crud_tag
28-
from core.tag.tag_orm import TagORM
28+
from core.tag.tag_orm import SourceDocumentTagLinkTable, TagORM
2929
from modules.classifier.classifier_crud import crud_classifier
3030
from modules.classifier.classifier_dto import (
3131
ClassifierCreate,
@@ -234,7 +234,8 @@ def _retrieve_and_build_dataset(
234234
class_ids: list[int],
235235
classid2labelid: dict[int, int],
236236
tokenizer,
237-
) -> tuple[dict[int, list[TagORM]], Dataset]:
237+
use_chunking: bool,
238+
) -> tuple[dict[int, list[int]], Dataset]:
238239
# Find documents
239240
sdoc_ids = [
240241
sdoc.id
@@ -248,10 +249,7 @@ def _retrieve_and_build_dataset(
248249

249250
# Get annotations
250251
results = (
251-
db.query(
252-
SourceDocumentORM,
253-
TagORM,
254-
)
252+
db.query(SourceDocumentTagLinkTable)
255253
.filter(
256254
SourceDocumentORM.id.in_(sdoc_ids),
257255
SourceDocumentORM.tags.any(TagORM.id.in_(class_ids)),
@@ -269,12 +267,11 @@ def _retrieve_and_build_dataset(
269267
sdocid2data[sdoc_data.id] = sdoc_data
270268

271269
# Group classifications by source document
272-
sdoc_id2annotations: dict[int, list[TagORM]] = {
270+
sdoc_id2annotation_ids: dict[int, list[int]] = {
273271
sdoc_id: [] for sdoc_id in sdoc_ids
274272
}
275273
for row in results:
276-
doc, tag = row._tuple()
277-
sdoc_id2annotations[doc.id].append(tag)
274+
sdoc_id2annotation_ids[row.source_document_id].append(row.tag_id)
278275

279276
# Create a labeled dataset
280277
# Every source document is part of the training data
@@ -288,8 +285,8 @@ def _retrieve_and_build_dataset(
288285
continue # skip documents without data
289286

290287
# We only use the first annotation (if multiple exist)
291-
annotations = sdoc_id2annotations[sdoc_id]
292-
annotation = annotations[0].id if len(annotations) > 0 else 0
288+
annotations = sdoc_id2annotation_ids[sdoc_id]
289+
annotation = annotations[0] if len(annotations) > 0 else 0
293290
dataset.append(
294291
{
295292
"sdoc_id": sdoc_data.id,
@@ -303,13 +300,19 @@ def _retrieve_and_build_dataset(
303300

304301
# Construct a tokenized huggingface dataset
305302
def tokenize_text(examples):
306-
return tokenizer(examples["text"], truncation=True)
303+
tokenized_inputs = tokenizer(
304+
examples["text"],
305+
truncation=not use_chunking,
306+
is_split_into_words=False,
307+
add_special_tokens=not use_chunking,
308+
)
309+
return tokenized_inputs
307310

308311
hf_dataset = Dataset.from_list(dataset) # type: ignore
309312
tokenized_hf_dataset = hf_dataset.map(tokenize_text, batched=True)
310313
tokenized_hf_dataset = tokenized_hf_dataset.remove_columns(["text"])
311314

312-
return sdoc_id2annotations, tokenized_hf_dataset
315+
return sdoc_id2annotation_ids, tokenized_hf_dataset
313316

314317
def train(
315318
self, db: Session, job: Job, payload: ClassifierJobInput
@@ -328,6 +331,8 @@ def train(
328331
raise BaseModelDoesNotExistError(parameters.base_name)
329332

330333
tokenizer = AutoTokenizer.from_pretrained(parameters.base_name)
334+
if parameters.chunk_size:
335+
tokenizer.model_max_length = parameters.chunk_size
331336
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
332337

333338
job.update(
@@ -352,25 +357,66 @@ def train(
352357
id2label[0] = "O"
353358

354359
# Build dataset
355-
sdoc_id2annotations, dataset = self._retrieve_and_build_dataset(
360+
sdoc_id2annotation_ids, dataset = self._retrieve_and_build_dataset(
356361
db=db,
357362
project_id=payload.project_id,
358363
tag_ids=parameters.tag_ids,
359364
class_ids=parameters.class_ids,
360365
classid2labelid=classid2labelid,
361366
tokenizer=tokenizer,
367+
use_chunking=True,
362368
)
363369

364370
# Train test split
365371
split_dataset = dataset.train_test_split(test_size=0.2, seed=42)
372+
373+
def split_in_chunks(examples: dict):
374+
chunk_len = tokenizer.model_max_length - 2
375+
for key, values in examples.items():
376+
if "labels" == key:
377+
continue
378+
elif "attention_mask" == key:
379+
pre = [1]
380+
post = [1]
381+
elif "input_ids" == key:
382+
pre = [tokenizer.added_tokens_encoder["[CLS]"]]
383+
post = [tokenizer.added_tokens_encoder["[SEP]"]]
384+
else:
385+
raise ValueError(f"Unsupported {key} in batch examples dict")
386+
387+
result = []
388+
original_labels = examples["labels"]
389+
labels = []
390+
for val, lab in zip(values, original_labels):
391+
chunks = [
392+
pre + val[i : i + chunk_len] + post
393+
for i in range(0, len(val), chunk_len)
394+
]
395+
result.extend(chunks)
396+
labels.extend([lab] * len(chunks))
397+
examples[key] = result
398+
examples["labels"] = labels
399+
return examples
400+
401+
train_dataset = (
402+
split_dataset["train"]
403+
.remove_columns(["sdoc_id"])
404+
.map(split_in_chunks, batched=True)
405+
)
406+
val_dataset = (
407+
split_dataset["test"]
408+
.remove_columns(["sdoc_id"])
409+
.map(split_in_chunks, batched=True)
410+
)
411+
366412
train_dataloader = DataLoader(
367-
split_dataset["train"], # type: ignore
413+
train_dataset, # type: ignore
368414
shuffle=True,
369415
collate_fn=data_collator,
370416
batch_size=parameters.batch_size,
371417
)
372418
val_dataloader = DataLoader(
373-
split_dataset["test"], # type: ignore
419+
val_dataset, # type: ignore
374420
shuffle=False,
375421
collate_fn=data_collator,
376422
batch_size=parameters.batch_size,
@@ -379,13 +425,13 @@ def train(
379425
# Dataset statistics (number of annotations per code)
380426
train_dataset_stats: dict[int, int] = {tag.id: 0 for tag in tags}
381427
for sdoc_id in split_dataset["train"]["sdoc_id"]:
382-
for annotation in sdoc_id2annotations[sdoc_id]:
383-
train_dataset_stats[annotation.id] += 1
428+
for annotation in sdoc_id2annotation_ids[sdoc_id]:
429+
train_dataset_stats[annotation] += 1
384430

385431
eval_dataset_stats: dict[int, int] = {tag.id: 0 for tag in tags}
386432
for sdoc_id in split_dataset["test"]["sdoc_id"]:
387-
for annotation in sdoc_id2annotations[sdoc_id]:
388-
eval_dataset_stats[annotation.id] += 1
433+
for annotation in sdoc_id2annotation_ids[sdoc_id]:
434+
eval_dataset_stats[annotation] += 1
389435

390436
# Calculate class weights
391437
# Count the occurrences of each label in the training set
@@ -568,13 +614,14 @@ def eval(
568614
job.update(current_step=2)
569615

570616
# Build dataset
571-
sdoc_id2annotations, dataset = self._retrieve_and_build_dataset(
617+
sdoc_id2annotation_ids, dataset = self._retrieve_and_build_dataset(
572618
db=db,
573619
project_id=payload.project_id,
574620
tag_ids=parameters.tag_ids,
575621
class_ids=classifier.class_ids,
576622
classid2labelid=classid2labelid,
577623
tokenizer=tokenizer,
624+
use_chunking=False,
578625
)
579626

580627
# Build dataloader
@@ -590,8 +637,8 @@ def eval(
590637
tag_id: 0 for tag_id, label_id in classid2labelid.items() if label_id != 0
591638
}
592639
for sdoc_id in dataset["sdoc_id"]:
593-
for annotation in sdoc_id2annotations[sdoc_id]:
594-
eval_dataset_stats[annotation.id] += 1
640+
for annotation in sdoc_id2annotation_ids[sdoc_id]:
641+
eval_dataset_stats[annotation] += 1
595642

596643
# 3. Load the model
597644
job.update(current_step=3)

backend/src/modules/classifier/models/span_class_model_service.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ def _retrieve_and_build_dataset(
248248
class_ids: list[int],
249249
classid2labelid: dict[int, int],
250250
tokenizer,
251+
use_chunking: bool,
251252
) -> tuple[dict[int, dict[int, list[SpanAnnotationORM]]], Dataset]:
252253
# Find documents
253254
sdoc_ids = [
@@ -315,9 +316,9 @@ def _retrieve_and_build_dataset(
315316
def tokenize_and_align_labels(examples: dict):
316317
tokenized_inputs = tokenizer(
317318
examples["words"],
318-
truncation=False,
319+
truncation=not use_chunking,
319320
is_split_into_words=True,
320-
add_special_tokens=False,
321+
add_special_tokens=not use_chunking,
321322
)
322323

323324
labels = []
@@ -399,6 +400,7 @@ def train(
399400
class_ids=parameters.class_ids,
400401
classid2labelid=classid2labelid,
401402
tokenizer=tokenizer,
403+
use_chunking=True,
402404
)
403405

404406
# Train test split
@@ -659,6 +661,7 @@ def eval(
659661
class_ids=classifier.class_ids,
660662
classid2labelid=classid2labelid,
661663
tokenizer=tokenizer,
664+
use_chunking=False,
662665
)
663666

664667
# Build dataloader

0 commit comments

Comments
 (0)