Skip to content

Commit 087e3c1

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Add CPU registrations to custom operators (pytorch#363)
Summary: X-link: pytorch#3262 Pull Request resolved: facebookresearch/FBGEMM#363 While CPU arguments shouldnt be used for custom cuda kernels, it turns out they sometimes are in production. The outputs will be garbage but doing so seems to be part of the model construction process. This small diff fixes the issue by adding CPU registrations for custom operators. This should enable production use cases without break torch.export support. Reviewed By: jaconey, jianyuh, jiawenliu64 Differential Revision: D64703788 fbshipit-source-id: c0c8cfb7f0b67c13be10f419c8e3d83991429edb
1 parent df208f5 commit 087e3c1

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,4 +278,15 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
278278
m.impl("two_shot_car_allreduce", two_shot_car_allreduce);
279279
}
280280

281+
// Though it shouldnt be used, it is useful to define these functions for CPU to
282+
// accomodate model creation.
283+
TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
284+
m.impl("nccl_allreduce", nccl_allreduce);
285+
m.impl("nccl_allgather", nccl_allgather);
286+
m.impl("nccl_alltoall", nccl_alltoall);
287+
m.impl("nccl_reducescatter", nccl_reducescatter);
288+
m.impl("one_shot_car_allreduce", one_shot_car_allreduce);
289+
m.impl("two_shot_car_allreduce", two_shot_car_allreduce);
290+
}
291+
281292
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,4 +214,24 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
214214
#endif
215215
}
216216

217+
// Though it should never be used, it still seems helpful to define these
218+
// functions for CPU to accomodate model creation.
219+
TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
220+
m.impl("f8f8bf16_blockwise", f8f8bf16_blockwise);
221+
m.impl("f8f8bf16_tensorwise", f8f8bf16_tensorwise);
222+
m.impl("f8f8bf16_rowwise", f8f8bf16_rowwise);
223+
m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor);
224+
m.impl("quantize_fp8_per_row", quantize_fp8_per_row);
225+
m.impl("quantize_fp8_per_col", quantize_fp8_per_col);
226+
#ifndef USE_ROCM
227+
m.impl("i8i8bf16", i8i8bf16);
228+
m.impl("f8f8bf16", f8f8bf16);
229+
m.impl("f8f8bf16_cublas", f8f8bf16_cublas);
230+
m.impl("f8f8bf16_rowwise_batched", f8f8bf16_rowwise_batched);
231+
m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise);
232+
m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched);
233+
m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise);
234+
#endif
235+
}
236+
217237
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)