Skip to content

Commit 665dd50

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
Name the forward pass thread in the trainer loop (#895)
Summary: Pull Request resolved: #895 Internal # Context With the sched_ext effort we are trying to build custom Linux schedulers that provide a small performance boost to AI training and improve the resource isolation on the trainer hosts. The latter is necessary to avoid cases when noisy neighbor processes, like data loaders, slow down the GPU training. More details in this note: https://fb.workplace.com/notes/1118655556176038 By naming the forward pass thread we can use its name and assign it a higher priority at the linux scheduler level. The backward pass is named inside the Pytorch implementation but the forward pass needs to be named at the application level. We did the same thing in PyPer, APS, MVAI which are the largest trainer frameworks for reco models, consuming 70%+ of fleet level GPU hours for recommender systems. # This Diff Adds core lines ``` if torch.multiprocessing._get_thread_name() != "trainer_main": torch.multiprocessing._set_thread_name("trainer_main") ``` to train/eval/predict scripts. We can check the preexisting name to avoid renaming the same thread. Reviewed By: diego-urgell Differential Revision: D61924982 fbshipit-source-id: cad51567361d6cc33d2f7d662401178360ad605c
1 parent b5b0b03 commit 665dd50

File tree

4 files changed

+86
-1
lines changed

4 files changed

+86
-1
lines changed

tests/framework/test_fit.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import math
1111
import unittest
1212
from typing import Tuple
13-
from unittest.mock import MagicMock
13+
from unittest.mock import MagicMock, patch
1414

1515
import torch
1616
from torch import nn
@@ -20,9 +20,12 @@
2020
from torchtnt.framework.state import ActivePhase, State
2121
from torchtnt.framework.unit import EvalUnit, TrainUnit, TTrainUnit
2222
from torchtnt.utils.timer import Timer
23+
from torchtnt.utils.version import is_torch_version_geq
2324

2425

2526
class FitTest(unittest.TestCase):
27+
TORCH_VERSION_GEQ_2_5_0: bool = is_torch_version_geq("2.5.0")
28+
2629
def test_fit_evaluate_every_n_epochs(self) -> None:
2730
"""
2831
Test fit entry point with evaluate_every_n_epochs=1
@@ -347,6 +350,41 @@ def test_error_message(self) -> None:
347350
log.output,
348351
)
349352

353+
@unittest.skipUnless(TORCH_VERSION_GEQ_2_5_0, "test requires PyTorch 2.5.0+")
354+
@patch(
355+
"torch.multiprocessing._get_thread_name", side_effect=["foo", "trainer_main"]
356+
)
357+
@patch("torch.multiprocessing._set_thread_name")
358+
def test_fit_set_thread_name(
359+
self, mock_set_thread_name: MagicMock, mock_get_thread_name: MagicMock
360+
) -> None:
361+
"""
362+
Test fit entry point with evaluate_every_n_epochs=1
363+
"""
364+
input_dim = 2
365+
train_dataset_len = 10
366+
eval_dataset_len = 10
367+
batch_size = 1
368+
369+
my_unit = DummyFitUnit(input_dim=input_dim)
370+
371+
train_dataloader = generate_random_dataloader(
372+
train_dataset_len, input_dim, batch_size
373+
)
374+
eval_dataloader = generate_random_dataloader(
375+
eval_dataset_len, input_dim, batch_size
376+
)
377+
378+
fit(
379+
my_unit,
380+
train_dataloader=train_dataloader,
381+
eval_dataloader=eval_dataloader,
382+
max_epochs=1,
383+
evaluate_every_n_epochs=1,
384+
)
385+
self.assertEqual(mock_get_thread_name.call_count, 2)
386+
mock_set_thread_name.assert_called_once()
387+
350388

351389
class UnitWithError(TrainUnit[int], EvalUnit[int]):
352390
def train_step(self, state: State, data: int) -> None:

torchtnt/framework/evaluate.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from torchtnt.framework.unit import TEvalData, TEvalUnit
2525
from torchtnt.framework.utils import get_timing_context
2626
from torchtnt.utils.timer import get_timer_summary, TimerProtocol
27+
from torchtnt.utils.version import is_torch_version_geq
2728

2829
logger: logging.Logger = logging.getLogger(__name__)
2930

@@ -162,6 +163,21 @@ def _evaluate_impl(
162163

163164
# clear step_output to avoid retaining extra memory
164165
eval_state._step_output = None
166+
167+
if (
168+
eval_unit.eval_progress.num_steps_completed_in_epoch
169+
- prev_steps_in_epoch
170+
== 5
171+
):
172+
# Set the trainer thread name to improve debuggability. We do it after
173+
# 5 iterations to make sure that all the processes or thread pools
174+
# spawned / forked from the current process have already been created
175+
# and the trainer_main characterizes only the CPU thread that runs the
176+
# forward pass and schedules GPU work.
177+
if is_torch_version_geq("2.5.0"):
178+
if torch.multiprocessing._get_thread_name() != "trainer_main":
179+
torch.multiprocessing._set_thread_name("trainer_main")
180+
165181
except StopIteration:
166182
stop_iteration_reached = True
167183
break

torchtnt/framework/predict.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from torchtnt.framework.unit import TPredictData, TPredictUnit
2525
from torchtnt.framework.utils import get_timing_context
2626
from torchtnt.utils.timer import get_timer_summary, TimerProtocol
27+
from torchtnt.utils.version import is_torch_version_geq
2728

2829
logger: logging.Logger = logging.getLogger(__name__)
2930

@@ -170,6 +171,21 @@ def _predict_impl(
170171

171172
# clear step_output to avoid retaining extra memory
172173
predict_state._step_output = None
174+
175+
if (
176+
predict_unit.predict_progress.num_steps_completed_in_epoch
177+
- prev_steps_in_epoch
178+
== 5
179+
):
180+
# Set the trainer thread name to improve debuggability. We do it after
181+
# 5 iterations to make sure that all the processes or thread pools
182+
# spawned / forked from the current process have already been created
183+
# and the trainer_main characterizes only the CPU thread that runs the
184+
# forward pass and schedules GPU work.
185+
if is_torch_version_geq("2.5.0"):
186+
if torch.multiprocessing._get_thread_name() != "trainer_main":
187+
torch.multiprocessing._set_thread_name("trainer_main")
188+
173189
except StopIteration:
174190
stop_iteration_reached = True
175191
break

torchtnt/framework/train.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from torchtnt.framework.unit import TTrainData, TTrainUnit
2828
from torchtnt.framework.utils import get_timing_context
2929
from torchtnt.utils.timer import get_timer_summary, TimerProtocol
30+
from torchtnt.utils.version import is_torch_version_geq
3031

3132
logger: logging.Logger = logging.getLogger(__name__)
3233

@@ -221,6 +222,20 @@ def _train_epoch_impl(
221222
# clear step_output to avoid retaining extra memory
222223
train_state._step_output = None
223224

225+
if (
226+
train_unit.train_progress.num_steps_completed_in_epoch
227+
- prev_steps_in_epoch
228+
== 5
229+
):
230+
# Set the trainer thread name to improve debuggability. We do it after
231+
# 5 iterations to make sure that all the processes or thread pools
232+
# spawned / forked from the current process have already been created
233+
# and the trainer_main characterizes only the CPU thread that runs the
234+
# forward pass and schedules GPU work.
235+
if is_torch_version_geq("2.5.0"):
236+
if torch.multiprocessing._get_thread_name() != "trainer_main":
237+
torch.multiprocessing._set_thread_name("trainer_main")
238+
224239
if (
225240
evaluate_every_n_steps
226241
and train_unit.train_progress.num_steps_completed

0 commit comments

Comments
 (0)