Skip to content

Commit 1b8004a

Browse files
nipung90facebook-github-bot
authored andcommitted
Remove usage of _running_with_deploy in torchrec (#3216)
Summary: Pull Request resolved: #3216 As per https://fb.workplace.com/groups/pytorch.dev/permalink/1828123831099422 we can now safely remove “torch.is_deploy_running”. Since no models should be using torch.deploy, this should always return false. This diff removes all the references to _running_with_deploy from torchrec codebase. Reviewed By: iamzainhuda, jd7-tr Differential Revision: D78667510 fbshipit-source-id: 520a2c099cdb026c158a10984a6b418fe6a31c06
1 parent 332b8b4 commit 1b8004a

File tree

3 files changed

+181
-181
lines changed

3 files changed

+181
-181
lines changed

torchrec/distributed/comm_ops.py

Lines changed: 178 additions & 172 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,6 @@ def get_gradient_division() -> bool:
5454

5555

5656
def set_use_sync_collectives(val: bool) -> None:
57-
if val and torch._running_with_deploy():
58-
raise RuntimeError(
59-
"TorchRec sync_collectives are not supported in torch.deploy."
60-
)
61-
6257
global USE_SYNC_COLLECTIVES
6358
USE_SYNC_COLLECTIVES = val
6459

@@ -2356,202 +2351,213 @@ def backward(ctx, grad_output: Tensor) -> Tuple[None, None, Tensor]:
23562351
return (None, None, myreq.dummy_tensor)
23572352

23582353

2359-
if not torch._running_with_deploy(): # noqa C901
2360-
# Torch Library op def can not be used in Deploy
2361-
class AllToAllSingle(torch.autograd.Function):
2362-
@staticmethod
2363-
# pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
2364-
def forward(
2365-
# pyre-fixme[2]: Parameter must be annotated.
2366-
ctx,
2367-
input: Tensor,
2368-
output_split_sizes: List[int],
2369-
input_split_sizes: List[int],
2370-
group_name: str,
2371-
group_size: int,
2372-
gradient_division: bool,
2373-
) -> Tensor:
2374-
ctx.output_split_sizes = input_split_sizes
2375-
ctx.input_split_sizes = output_split_sizes
2376-
ctx.group_name = group_name
2377-
ctx.group_size = group_size
2378-
ctx.gradient_division = gradient_division
2379-
return torch.distributed._functional_collectives.all_to_all_single(
2380-
input, output_split_sizes, input_split_sizes, group_name
2381-
)
2382-
2383-
@staticmethod
2384-
# pyre-ignore
2385-
def backward(ctx, grad):
2386-
grad = torch.distributed._functional_collectives.all_to_all_single(
2387-
grad,
2388-
ctx.output_split_sizes,
2389-
ctx.input_split_sizes,
2390-
ctx.group_name,
2391-
)
2392-
if ctx.gradient_division:
2393-
grad.div_(ctx.group_size)
2394-
2395-
return grad, None, None, None, None, None
2396-
2397-
# torchrec::reduce_scatter_tensor
2398-
@torch.library.custom_op("torchrec::reduce_scatter_tensor", mutates_args=())
2399-
def reduce_scatter_tensor(
2354+
class AllToAllSingle(torch.autograd.Function):
2355+
@staticmethod
2356+
# pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
2357+
def forward(
2358+
# pyre-fixme[2]: Parameter must be annotated.
2359+
ctx,
24002360
input: Tensor,
2401-
reduceOp: str,
2402-
group_size: int,
2361+
output_split_sizes: List[int],
2362+
input_split_sizes: List[int],
24032363
group_name: str,
2404-
gradient_division: bool,
2405-
) -> Tensor:
2406-
out = torch.ops._c10d_functional.reduce_scatter_tensor(
2407-
input,
2408-
reduceOp,
2409-
group_size,
2410-
group_name,
2411-
)
2412-
return torch.ops._c10d_functional.wait_tensor(out)
2413-
2414-
@torch.library.register_fake("torchrec::reduce_scatter_tensor")
2415-
def reduce_scatter_tensor_fake(
2416-
input: Tensor,
2417-
reduceOp: str,
24182364
group_size: int,
2419-
group_name: str,
24202365
gradient_division: bool,
24212366
) -> Tensor:
2422-
return torch.ops._c10d_functional.reduce_scatter_tensor(
2423-
input,
2424-
reduceOp,
2425-
group_size,
2426-
group_name,
2427-
)
2428-
2429-
# pyre-ignore
2430-
def reduce_scatter_tensor_setup_context(ctx, inputs, output) -> None:
2431-
_, _, group_size, group_name, gradient_division = inputs
2432-
ctx.group_size = group_size
2367+
ctx.output_split_sizes = input_split_sizes
2368+
ctx.input_split_sizes = output_split_sizes
24332369
ctx.group_name = group_name
2370+
ctx.group_size = group_size
24342371
ctx.gradient_division = gradient_division
2372+
return torch.distributed._functional_collectives.all_to_all_single(
2373+
input, output_split_sizes, input_split_sizes, group_name
2374+
)
24352375

2376+
@staticmethod
24362377
# pyre-ignore
2437-
def reduce_scatter_tensor_backward(ctx, grad):
2438-
# TODO(ivankobzarev): Support codecs(quantization) on backward
2439-
out = torch.ops._c10d_functional.all_gather_into_tensor(
2378+
def backward(ctx, grad):
2379+
grad = torch.distributed._functional_collectives.all_to_all_single(
24402380
grad,
2441-
ctx.group_size,
2381+
ctx.output_split_sizes,
2382+
ctx.input_split_sizes,
24422383
ctx.group_name,
24432384
)
2444-
grad = torch.ops._c10d_functional.wait_tensor(out)
24452385
if ctx.gradient_division:
24462386
grad.div_(ctx.group_size)
24472387

24482388
return grad, None, None, None, None, None
24492389

2450-
torch.library.register_autograd(
2451-
"torchrec::reduce_scatter_tensor",
2452-
reduce_scatter_tensor_backward,
2453-
setup_context=reduce_scatter_tensor_setup_context,
2390+
2391+
# torchrec::reduce_scatter_tensor
2392+
@torch.library.custom_op("torchrec::reduce_scatter_tensor", mutates_args=())
2393+
def reduce_scatter_tensor(
2394+
input: Tensor,
2395+
reduceOp: str,
2396+
group_size: int,
2397+
group_name: str,
2398+
gradient_division: bool,
2399+
) -> Tensor:
2400+
out = torch.ops._c10d_functional.reduce_scatter_tensor(
2401+
input,
2402+
reduceOp,
2403+
group_size,
2404+
group_name,
24542405
)
2406+
return torch.ops._c10d_functional.wait_tensor(out)
24552407

2456-
# torchrec::all_gather_into_tensor
2457-
@torch.library.custom_op("torchrec::all_gather_into_tensor", mutates_args=())
2458-
def all_gather_into_tensor(
2459-
shard: Tensor,
2460-
gather_dim: int,
2461-
group_size: int,
2462-
group_name: str,
2463-
gradient_division: bool,
2464-
) -> Tensor:
2465-
out = torch.ops._c10d_functional.all_gather_into_tensor(
2466-
shard, group_size, group_name
2467-
)
2468-
return torch.ops._c10d_functional.wait_tensor(out)
24692408

2470-
@torch.library.register_fake("torchrec::all_gather_into_tensor")
2471-
def all_gather_into_tensor_fake(
2472-
shard: Tensor,
2473-
gather_dim: int,
2474-
group_size: int,
2475-
group_name: str,
2476-
gradient_division: bool,
2477-
) -> Tensor:
2478-
return torch.ops._c10d_functional.all_gather_into_tensor(
2479-
shard, group_size, group_name
2480-
)
2409+
@torch.library.register_fake("torchrec::reduce_scatter_tensor")
2410+
def reduce_scatter_tensor_fake(
2411+
input: Tensor,
2412+
reduceOp: str,
2413+
group_size: int,
2414+
group_name: str,
2415+
gradient_division: bool,
2416+
) -> Tensor:
2417+
return torch.ops._c10d_functional.reduce_scatter_tensor(
2418+
input,
2419+
reduceOp,
2420+
group_size,
2421+
group_name,
2422+
)
24812423

2482-
# pyre-ignore
2483-
def all_gather_into_tensor_setup_context(ctx, inputs, output) -> None:
2484-
_, gather_dim, group_size, group_name, gradient_division = inputs
2485-
ctx.group_size = group_size
2486-
ctx.group_name = group_name
2487-
ctx.gradient_division = gradient_division
24882424

2489-
# pyre-ignore
2490-
def all_gather_into_tensor_backward(ctx, grad):
2491-
# TODO(ivankobzarev): Support codecs(quantization) on backward
2492-
out = torch.ops._c10d_functional.reduce_scatter_tensor(
2493-
grad,
2494-
"sum",
2495-
ctx.group_size,
2496-
ctx.group_name,
2497-
)
2498-
grad = torch.ops._c10d_functional.wait_tensor(out)
2499-
if ctx.gradient_division:
2500-
grad.div_(ctx.group_size)
2425+
# pyre-ignore
2426+
def reduce_scatter_tensor_setup_context(ctx, inputs, output) -> None:
2427+
_, _, group_size, group_name, gradient_division = inputs
2428+
ctx.group_size = group_size
2429+
ctx.group_name = group_name
2430+
ctx.gradient_division = gradient_division
25012431

2502-
return grad, None, None, None, None
25032432

2504-
torch.library.register_autograd(
2505-
"torchrec::all_gather_into_tensor",
2506-
all_gather_into_tensor_backward,
2507-
setup_context=all_gather_into_tensor_setup_context,
2433+
# pyre-ignore
2434+
def reduce_scatter_tensor_backward(ctx, grad):
2435+
# TODO(ivankobzarev): Support codecs(quantization) on backward
2436+
out = torch.ops._c10d_functional.all_gather_into_tensor(
2437+
grad,
2438+
ctx.group_size,
2439+
ctx.group_name,
25082440
)
2441+
grad = torch.ops._c10d_functional.wait_tensor(out)
2442+
if ctx.gradient_division:
2443+
grad.div_(ctx.group_size)
25092444

2510-
@torch.library.custom_op("torchrec::_split_1d_cat_2d", mutates_args=())
2511-
def _split_1d_cat_2d_impl(
2512-
t: torch.Tensor, dim0: int, dim1s: List[int]
2513-
) -> torch.Tensor:
2514-
torch._check_is_size(dim0)
2515-
[torch._check_is_size(dim1) for dim1 in dim1s]
2516-
splits: List[torch.Tensor] = t.split([dim0 * dim1 for dim1 in dim1s])
2517-
return torch.cat(
2518-
[s.reshape(dim0, dim1) for s, dim1 in zip(splits, dim1s)],
2519-
dim=1,
2520-
)
2445+
return grad, None, None, None, None, None
2446+
2447+
2448+
torch.library.register_autograd(
2449+
"torchrec::reduce_scatter_tensor",
2450+
reduce_scatter_tensor_backward,
2451+
setup_context=reduce_scatter_tensor_setup_context,
2452+
)
25212453

2522-
@torch.library.register_fake("torchrec::_split_1d_cat_2d")
2523-
def _split_1d_cat_2d_impl_abstract(
2524-
t: torch.Tensor, dim0: int, dim1s: List[int]
2525-
) -> torch.Tensor:
2526-
return t.new_empty([dim0, sum(dim1s)])
25272454

2528-
@torch.library.custom_op(
2529-
"torchrec::_split_1d_cat_2d_backward_impl", mutates_args=()
2455+
# torchrec::all_gather_into_tensor
2456+
@torch.library.custom_op("torchrec::all_gather_into_tensor", mutates_args=())
2457+
def all_gather_into_tensor(
2458+
shard: Tensor,
2459+
gather_dim: int,
2460+
group_size: int,
2461+
group_name: str,
2462+
gradient_division: bool,
2463+
) -> Tensor:
2464+
out = torch.ops._c10d_functional.all_gather_into_tensor(
2465+
shard, group_size, group_name
25302466
)
2531-
def _split_1d_cat_2d_backward_impl(
2532-
grad: torch.Tensor, dim1s: List[int]
2533-
) -> torch.Tensor:
2534-
splits = grad.split(dim1s, dim=1)
2535-
return torch.cat([s.reshape(-1) for s in splits], dim=0)
2536-
2537-
@torch.library.register_fake("torchrec::_split_1d_cat_2d_backward_impl")
2538-
def _split_1d_cat_2d_backward_impl_fake(
2539-
grad: torch.Tensor, dim1s: List[int]
2540-
) -> torch.Tensor:
2541-
return grad.new_empty([grad.numel()])
2467+
return torch.ops._c10d_functional.wait_tensor(out)
25422468

2543-
# pyre-ignore
2544-
def _split_1d_cat_2d_backward(ctx, grad):
2545-
ret = torch.ops.torchrec._split_1d_cat_2d_backward_impl(grad, ctx.dim1s)
2546-
return ret, None, None
25472469

2548-
# pyre-ignore
2549-
def _split_1d_cat_2d_setup_context(ctx, inputs, output):
2550-
(x, dim0, dim1s) = inputs
2551-
ctx.dim1s = dim1s
2552-
2553-
torch.library.register_autograd(
2554-
"torchrec::_split_1d_cat_2d",
2555-
_split_1d_cat_2d_backward,
2556-
setup_context=_split_1d_cat_2d_setup_context,
2470+
@torch.library.register_fake("torchrec::all_gather_into_tensor")
2471+
def all_gather_into_tensor_fake(
2472+
shard: Tensor,
2473+
gather_dim: int,
2474+
group_size: int,
2475+
group_name: str,
2476+
gradient_division: bool,
2477+
) -> Tensor:
2478+
return torch.ops._c10d_functional.all_gather_into_tensor(
2479+
shard, group_size, group_name
25572480
)
2481+
2482+
2483+
# pyre-ignore
2484+
def all_gather_into_tensor_setup_context(ctx, inputs, output) -> None:
2485+
_, gather_dim, group_size, group_name, gradient_division = inputs
2486+
ctx.group_size = group_size
2487+
ctx.group_name = group_name
2488+
ctx.gradient_division = gradient_division
2489+
2490+
2491+
# pyre-ignore
2492+
def all_gather_into_tensor_backward(ctx, grad):
2493+
# TODO(ivankobzarev): Support codecs(quantization) on backward
2494+
out = torch.ops._c10d_functional.reduce_scatter_tensor(
2495+
grad,
2496+
"sum",
2497+
ctx.group_size,
2498+
ctx.group_name,
2499+
)
2500+
grad = torch.ops._c10d_functional.wait_tensor(out)
2501+
if ctx.gradient_division:
2502+
grad.div_(ctx.group_size)
2503+
2504+
return grad, None, None, None, None
2505+
2506+
2507+
torch.library.register_autograd(
2508+
"torchrec::all_gather_into_tensor",
2509+
all_gather_into_tensor_backward,
2510+
setup_context=all_gather_into_tensor_setup_context,
2511+
)
2512+
2513+
2514+
@torch.library.custom_op("torchrec::_split_1d_cat_2d", mutates_args=())
2515+
def _split_1d_cat_2d_impl(t: torch.Tensor, dim0: int, dim1s: List[int]) -> torch.Tensor:
2516+
torch._check_is_size(dim0)
2517+
[torch._check_is_size(dim1) for dim1 in dim1s]
2518+
splits: List[torch.Tensor] = t.split([dim0 * dim1 for dim1 in dim1s])
2519+
return torch.cat(
2520+
[s.reshape(dim0, dim1) for s, dim1 in zip(splits, dim1s)],
2521+
dim=1,
2522+
)
2523+
2524+
2525+
@torch.library.register_fake("torchrec::_split_1d_cat_2d")
2526+
def _split_1d_cat_2d_impl_abstract(
2527+
t: torch.Tensor, dim0: int, dim1s: List[int]
2528+
) -> torch.Tensor:
2529+
return t.new_empty([dim0, sum(dim1s)])
2530+
2531+
2532+
@torch.library.custom_op("torchrec::_split_1d_cat_2d_backward_impl", mutates_args=())
2533+
def _split_1d_cat_2d_backward_impl(
2534+
grad: torch.Tensor, dim1s: List[int]
2535+
) -> torch.Tensor:
2536+
splits = grad.split(dim1s, dim=1)
2537+
return torch.cat([s.reshape(-1) for s in splits], dim=0)
2538+
2539+
2540+
@torch.library.register_fake("torchrec::_split_1d_cat_2d_backward_impl")
2541+
def _split_1d_cat_2d_backward_impl_fake(
2542+
grad: torch.Tensor, dim1s: List[int]
2543+
) -> torch.Tensor:
2544+
return grad.new_empty([grad.numel()])
2545+
2546+
2547+
# pyre-ignore
2548+
def _split_1d_cat_2d_backward(ctx, grad):
2549+
ret = torch.ops.torchrec._split_1d_cat_2d_backward_impl(grad, ctx.dim1s)
2550+
return ret, None, None
2551+
2552+
2553+
# pyre-ignore
2554+
def _split_1d_cat_2d_setup_context(ctx, inputs, output):
2555+
(x, dim0, dim1s) = inputs
2556+
ctx.dim1s = dim1s
2557+
2558+
2559+
torch.library.register_autograd(
2560+
"torchrec::_split_1d_cat_2d",
2561+
_split_1d_cat_2d_backward,
2562+
setup_context=_split_1d_cat_2d_setup_context,
2563+
)

0 commit comments

Comments
 (0)