Skip to content

Commit 60f878e

Browse files
fynnosbigabig
authored andcommitted
review feedback
1 parent 13d31bc commit 60f878e

File tree

9 files changed

+28
-14
lines changed

9 files changed

+28
-14
lines changed

backend/configs/development.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ vllm:
112112
port: ${oc.env:VLLM_EMB_PORT, 8000}
113113
model: ${oc.env:VLLM_EMB_MODEL, snowflake-arctic-embed2:568m}
114114

115+
job:
116+
gpu_memory_limit: ${oc.env:RQ_WORKER_GPU_MEM_GB, 20}
117+
115118
llm_assistant:
116119
sentence_annotation:
117120
few_shot_threshold: 4

backend/configs/production.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ vllm:
112112
port: ${oc.env:VLLM_EMB_PORT, 8000}
113113
model: ${oc.env:VLLM_EMB_MODEL, snowflake-arctic-embed2:568m}
114114

115+
job:
116+
gpu_memory_limit: ${oc.env:RQ_WORKER_GPU_MEM_GB, 20}
117+
115118
llm_assistant:
116119
sentence_annotation:
117120
few_shot_threshold: 4

backend/src/modules/classifier/classifier_service.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,6 @@ def handle_classifier_job(
6565
status_message="Started ClassifierJob!",
6666
)
6767

68-
gpu_mem_limit_gb = 20
69-
gpu_mem_limit_bytes = gpu_mem_limit_gb * 1024 * 1024 * 1024
70-
71-
# set GPU memory limit for job
72-
total = torch.cuda.get_device_properties().total_memory
73-
allowed_fraction = gpu_mem_limit_bytes / total
74-
torch.cuda.set_per_process_memory_fraction(allowed_fraction)
75-
7668
# get the correct classifier service
7769
tcs: TextClassificationModelService
7870
match payload.model_type:

backend/src/modules/classifier/models/doc_class_model_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __init__(
7777
dropout: float,
7878
learning_rate: float,
7979
weight_decay: float,
80-
class_weights: torch.Tensor,
80+
class_weights: list[float],
8181
id2label: dict[int, str] | None = None,
8282
label2id: dict[str, int] | None = None,
8383
):
@@ -455,7 +455,7 @@ def train(
455455
dropout=parameters.dropout,
456456
learning_rate=parameters.learning_rate,
457457
weight_decay=parameters.weight_decay,
458-
class_weights=torch.tensor(class_weights, dtype=torch.float32),
458+
class_weights=class_weights,
459459
id2label=id2label,
460460
label2id={v: k for k, v in id2label.items()},
461461
)

backend/src/modules/classifier/models/sent_class_model_service.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __init__(
7575
dropout: float,
7676
learning_rate: float,
7777
weight_decay: float,
78-
class_weights: torch.Tensor,
78+
class_weights: list[float],
7979
# special params
8080
embedding_model_name: str,
8181
embedding_dim: int,
@@ -200,20 +200,23 @@ def _val_test_step(self, prefix: str, batch, batch_idx: int) -> torch.Tensor:
200200
)
201201
return loss
202202

203+
@torch.no_grad()
203204
def validation_step(self, batch, batch_idx):
204205
return self._val_test_step(
205206
prefix="eval",
206207
batch=batch,
207208
batch_idx=batch_idx,
208209
)
209210

211+
@torch.no_grad()
210212
def test_step(self, batch, batch_idx):
211213
return self._val_test_step(
212214
prefix="test",
213215
batch=batch,
214216
batch_idx=batch_idx,
215217
)
216218

219+
@torch.no_grad()
217220
def predict_step(self, batch: dict[str, Any], batch_idx: int) -> Any:
218221
# Get predictions and ground truth tags
219222
predictions = self(sentences=batch["sentences"], mask=batch["mask"])
@@ -544,7 +547,7 @@ def train(
544547
dropout=parameters.dropout,
545548
learning_rate=parameters.learning_rate,
546549
weight_decay=parameters.weight_decay,
547-
class_weights=torch.tensor(class_weights, dtype=torch.float32),
550+
class_weights=class_weights,
548551
id2label=id2label,
549552
label2id={v: k for k, v in id2label.items()},
550553
)

backend/src/modules/classifier/models/span_class_model_service.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,6 @@ def split_in_chunks(examples: dict):
523523
callbacks.append(JobProgressCallback(job=job))
524524

525525
trainer = pl.Trainer(
526-
accelerator="gpu",
527526
logger=csv_logger,
528527
max_epochs=parameters.epochs,
529528
callbacks=callbacks,

backend/src/systems/job_system/job_handler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from datetime import datetime
22

3-
from common.gpu_utils import find_unused_cuda_device
43
from common.job_type import JobType
4+
from config import conf
55
from modules.doc_processing.doc_processing_pipeline import (
66
handle_job_error,
77
handle_job_finished,
88
)
99
from systems.job_system.job_dto import Job, JobInputBase
10+
from utils.gpu_utils import find_unused_cuda_device, set_cuda_memory_limit
1011

1112

1213
def rq_job_handler(jobtype: JobType, handler, payload: JobInputBase):
@@ -18,6 +19,7 @@ def rq_job_handler(jobtype: JobType, handler, payload: JobInputBase):
1819

1920
cuda_device = find_unused_cuda_device()
2021
with torch.cuda.device(cuda_device):
22+
set_cuda_memory_limit(conf.job.gpu_memory_limit)
2123
output = handler(payload=payload, job=job)
2224
else:
2325
output = handler(payload=payload, job=job)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,17 @@ def find_unused_cuda_device() -> str:
3232
return "cuda:0"
3333

3434

35+
def set_cuda_memory_limit(limit_gb: int):
36+
import torch
37+
38+
limit_bytes = limit_gb * 1024 * 1024 * 1024
39+
40+
# set GPU memory limit for job
41+
total = torch.cuda.get_device_properties().total_memory
42+
allowed_fraction = limit_bytes / total
43+
torch.cuda.set_per_process_memory_fraction(allowed_fraction)
44+
45+
3546
def parse_device_string(device_str: str) -> tuple[str, list[int]]:
3647
"""
3748
Parses a device string and returns the appropriate device configuration.

docker/.env.example

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ RQ_WORKERS_CPU=8
106106
RQ_WORKERS_API=16
107107
RQ_WORKERS_GPU=1
108108
RQ_DEVICE_IDS=1
109+
RQ_WORKER_GPU_MEM_GB=20
109110
# NUM DB CONNECTIONS
110111
RQ_POSTGRES_POOL_SIZE=1
111112
RQ_POSTGRES_MAX_OVERFLOW=1

0 commit comments

Comments
 (0)