Skip to content

Commit ad288bb

Browse files
fynnosbigabig
authored andcommitted
reduce (span) classifer training memory by chunking long documents
1 parent 2c6f4ac commit ad288bb

File tree

9 files changed

+274
-99
lines changed

9 files changed

+274
-99
lines changed

backend/src/modules/classifier/classifier_dto.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from enum import Enum
33
from typing import Any, Literal
44

5+
from lightning_fabric.plugins.precision.precision import _PRECISION_INPUT
56
from pydantic import BaseModel, ConfigDict, Field
67

78
from repos.db.dto_base import UpdateDTOBase
@@ -121,6 +122,12 @@ class ClassifierTrainingParams(BaseModel):
121122
learning_rate: float = Field(description="Learning rate to use for training")
122123
weight_decay: float = Field(description="Weight decay to use for training")
123124
dropout: float = Field(description="Dropout rate to use in the model")
125+
chunk_size: int | None = Field(
126+
description="Slice long documents into chunks of size x"
127+
)
128+
precision: _PRECISION_INPUT | None = Field(
129+
description="Precision, e.g. 32-true, 16-mixed, 16-true, bf16-true, bf16-mixed"
130+
)
124131
# specific training settings
125132
is_bio: bool = Field(description="Whether to use BIO or IO tagging")
126133

backend/src/modules/classifier/classifier_service.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,13 @@ def handle_classifier_job(
6565
status_message="Started ClassifierJob!",
6666
)
6767

68-
# free GPU memory before job starts
69-
gc.collect()
70-
torch.cuda.empty_cache()
68+
gpu_mem_limit_gb = 20
69+
gpu_mem_limit_bytes = gpu_mem_limit_gb * 1024 * 1024 * 1024
70+
71+
# set GPU memory limit for job
72+
total = torch.cuda.get_device_properties().total_memory
73+
allowed_fraction = gpu_mem_limit_bytes / total
74+
torch.cuda.set_per_process_memory_fraction(allowed_fraction)
7175

7276
# get the correct classifier service
7377
tcs: TextClassificationModelService

backend/src/modules/classifier/models/doc_class_model_service.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
149149
# labels = batch["labels"]
150150
# loss = self.loss_fn(logits.view(-1, self.num_labels), labels.view(-1))
151151

152-
self.log("train_loss", loss, on_step=False, on_epoch=True)
152+
self.log("train_loss", loss.detach(), on_step=False, on_epoch=True)
153153
return loss
154154

155155
def _val_test_step(
@@ -186,20 +186,23 @@ def _val_test_step(
186186
)
187187
return outputs.loss
188188

189+
@torch.no_grad()
189190
def validation_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
190191
return self._val_test_step(
191192
prefix="eval",
192193
batch=batch,
193194
batch_idx=batch_idx,
194195
)
195196

197+
@torch.no_grad()
196198
def test_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
197199
return self._val_test_step(
198200
prefix="test",
199201
batch=batch,
200202
batch_idx=batch_idx,
201203
)
202204

205+
@torch.no_grad()
203206
def predict_step(self, batch: dict[str, Any], batch_idx: int) -> Any:
204207
outputs = self.model(
205208
input_ids=batch["input_ids"],
@@ -214,7 +217,10 @@ def predict_step(self, batch: dict[str, Any], batch_idx: int) -> Any:
214217

215218
def configure_optimizers(self) -> torch.optim.Optimizer:
216219
optimizer = torch.optim.AdamW(
217-
self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
220+
self.parameters(),
221+
lr=self.learning_rate,
222+
weight_decay=self.weight_decay,
223+
fused=True,
218224
)
219225
return optimizer
220226

@@ -399,17 +405,6 @@ def train(
399405

400406
# 2. Initialize PyTorch Lightning components
401407
job.update(current_step=2)
402-
# Initialize the Lightning Model
403-
lightning_model = DocClassificationLightningModel(
404-
base_name=parameters.base_name,
405-
num_labels=len(classid2labelid),
406-
dropout=parameters.dropout,
407-
learning_rate=parameters.learning_rate,
408-
weight_decay=parameters.weight_decay,
409-
class_weights=torch.tensor(class_weights, dtype=torch.float32),
410-
id2label=id2label,
411-
label2id={v: k for k, v in id2label.items()},
412-
)
413408

414409
# Create the Trainer
415410
model_name: str = str(uuid4())
@@ -447,11 +442,24 @@ def train(
447442
max_epochs=parameters.epochs,
448443
callbacks=callbacks,
449444
enable_progress_bar=True,
445+
precision=parameters.precision,
450446
# Special params
451-
# precision=32, # full precision training
452447
# gradient_clip_val=1.0, # Gradient clipping
453448
)
454449

450+
with trainer.init_module():
451+
# Initialize the Lightning Model
452+
lightning_model = DocClassificationLightningModel(
453+
base_name=parameters.base_name,
454+
num_labels=len(classid2labelid),
455+
dropout=parameters.dropout,
456+
learning_rate=parameters.learning_rate,
457+
weight_decay=parameters.weight_decay,
458+
class_weights=torch.tensor(class_weights, dtype=torch.float32),
459+
id2label=id2label,
460+
label2id={v: k for k, v in id2label.items()},
461+
)
462+
455463
# 3. Train the model
456464
job.update(current_step=3)
457465
trainer.fit(

backend/src/modules/classifier/models/sent_class_model_service.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,10 @@ def predict_step(self, batch: dict[str, Any], batch_idx: int) -> Any:
225225

226226
def configure_optimizers(self) -> torch.optim.Optimizer:
227227
optimizer = torch.optim.AdamW(
228-
self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
228+
self.parameters(),
229+
lr=self.learning_rate,
230+
weight_decay=self.weight_decay,
231+
fused=True,
229232
)
230233
return optimizer
231234

@@ -485,23 +488,6 @@ def train(
485488

486489
# 2. Initialize PyTorch Lightning components
487490
job.update(current_step=2)
488-
# Initialize the Lightning Model
489-
lightning_model = SentClassificationLightningModel(
490-
# embedding model params
491-
embedding_model_name=parameters.base_name,
492-
embedding_dim=embedding_dim,
493-
# sent classifier specific params
494-
hidden_dim=int(embedding_dim / 2),
495-
use_lstm=True,
496-
# training params
497-
num_labels=len(classid2labelid),
498-
dropout=parameters.dropout,
499-
learning_rate=parameters.learning_rate,
500-
weight_decay=parameters.weight_decay,
501-
class_weights=torch.tensor(class_weights, dtype=torch.float32),
502-
id2label=id2label,
503-
label2id={v: k for k, v in id2label.items()},
504-
)
505491

506492
# Create the Trainer
507493
model_name: str = str(uuid4())
@@ -539,11 +525,30 @@ def train(
539525
max_epochs=parameters.epochs,
540526
callbacks=callbacks,
541527
enable_progress_bar=True,
528+
precision=parameters.precision,
542529
# Special params
543-
# precision=32, # full precision training
544530
# gradient_clip_val=1.0, # Gradient clipping
545531
)
546532

533+
with trainer.init_module():
534+
# Initialize the Lightning Model
535+
lightning_model = SentClassificationLightningModel(
536+
# embedding model params
537+
embedding_model_name=parameters.base_name,
538+
embedding_dim=embedding_dim,
539+
# sent classifier specific params
540+
hidden_dim=int(embedding_dim / 2),
541+
use_lstm=True,
542+
# training params
543+
num_labels=len(classid2labelid),
544+
dropout=parameters.dropout,
545+
learning_rate=parameters.learning_rate,
546+
weight_decay=parameters.weight_decay,
547+
class_weights=torch.tensor(class_weights, dtype=torch.float32),
548+
id2label=id2label,
549+
label2id={v: k for k, v in id2label.items()},
550+
)
551+
547552
# 3. Train the model
548553
job.update(current_step=3)
549554
trainer.fit(

0 commit comments

Comments
 (0)