Skip to content

Commit f47aac6

Browse files
fffrogpytorchmergebot
authored andcommitted
Make Context to be Device-agnostic Step by Step (3/N) (pytorch#137578)
Detailed Descriptions: - Using unified Device-agnostic API to create new generator for accelerator. - Add deprecated info for GeneratorForPrivateuseone Pull Request resolved: pytorch#137578 Approved by: https://github.com/cyyever, https://github.com/ezyang
1 parent 80a4239 commit f47aac6

File tree

13 files changed

+125
-140
lines changed

13 files changed

+125
-140
lines changed

aten/src/ATen/core/GeneratorForPrivateuseone.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
#include <mutex>
21
#include <ATen/core/GeneratorForPrivateuseone.h>
32

3+
#include <mutex>
4+
45
namespace at {
56

67
static std::mutex _generator_mutex_lock;
@@ -12,6 +13,11 @@ std::optional<GeneratorFuncType>& GetGeneratorPrivate() {
1213

1314
_GeneratorRegister::_GeneratorRegister(const GeneratorFuncType& func) {
1415
std::lock_guard<std::mutex> lock(_generator_mutex_lock);
16+
17+
TORCH_WARN_DEPRECATION(
18+
"REGISTER_GENERATOR_PRIVATEUSE1 is deprecated. \
19+
Please derive PrivateUse1HooksInterface to implememt getNewGenerator instead.")
20+
1521
TORCH_CHECK(
1622
!GetGeneratorPrivate().has_value(),
1723
"Only can register a generator to the PrivateUse1 dispatch key once!");
@@ -21,6 +27,10 @@ _GeneratorRegister::_GeneratorRegister(const GeneratorFuncType& func) {
2127
}
2228

2329
at::Generator GetGeneratorForPrivateuse1(c10::DeviceIndex device_index) {
30+
TORCH_WARN_DEPRECATION(
31+
"GetGeneratorForPrivateuse1() is deprecated. Please use \
32+
globalContext().getAcceleratorHooksInterface(device_type).getNewGenerator() instead.")
33+
2434
TORCH_CHECK(
2535
GetGeneratorPrivate().has_value(),
2636
"Please register a generator to the PrivateUse1 dispatch key, \

aten/src/ATen/core/GeneratorForPrivateuseone.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ namespace at {
77

88
using GeneratorFuncType = std::function<at::Generator(c10::DeviceIndex)>;
99

10-
std::optional<GeneratorFuncType>& GetGeneratorPrivate();
10+
TORCH_API std::optional<GeneratorFuncType>& GetGeneratorPrivate();
1111

1212
class TORCH_API _GeneratorRegister {
1313
public:

aten/src/ATen/cuda/detail/CUDAHooks.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ const Generator& CUDAHooks::getDefaultGenerator(DeviceIndex device_index) const
106106
return at::cuda::detail::getDefaultCUDAGenerator(device_index);
107107
}
108108

109+
Generator CUDAHooks::getNewGenerator(DeviceIndex device_index) const {
110+
return make_generator<at::CUDAGeneratorImpl>(device_index);
111+
}
112+
109113
Device CUDAHooks::getDeviceFromPtr(void* data) const {
110114
return at::cuda::getDeviceFromPtr(data);
111115
}

aten/src/ATen/cuda/detail/CUDAHooks.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ struct CUDAHooks : public at::CUDAHooksInterface {
2323
bool isPinnedPtr(const void* data) const override;
2424
const Generator& getDefaultGenerator(
2525
DeviceIndex device_index = -1) const override;
26+
Generator getNewGenerator(
27+
DeviceIndex device_index = -1) const override;
2628
bool hasCUDA() const override;
2729
bool hasMAGMA() const override;
2830
bool hasCuDNN() const override;

aten/src/ATen/detail/CUDAHooksInterface.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,14 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
7474
CUDA_HELP);
7575
}
7676

77+
Generator getNewGenerator(
78+
[[maybe_unused]] DeviceIndex device_index = -1) const override {
79+
TORCH_CHECK(
80+
false,
81+
"Cannot get CUDA generator without ATen_cuda library. ",
82+
CUDA_HELP);
83+
}
84+
7785
Device getDeviceFromPtr(void* /*data*/) const override {
7886
TORCH_CHECK(false, "Cannot get device of pointer on CUDA without ATen_cuda library. ", CUDA_HELP);
7987
}

aten/src/ATen/detail/MPSHooksInterface.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface {
3535
[[maybe_unused]] DeviceIndex device_index = -1) const override {
3636
FAIL_MPSHOOKS_FUNC(__func__);
3737
}
38+
Generator getNewGenerator(
39+
[[maybe_unused]] DeviceIndex device_index) const override {
40+
FAIL_MPSHOOKS_FUNC(__func__);
41+
}
3842
virtual Allocator* getMPSDeviceAllocator() const {
3943
FAIL_MPSHOOKS_FUNC(__func__);
4044
}
Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#pragma once
22

3+
#include <ATen/core/GeneratorForPrivateuseone.h>
34
#include <ATen/detail/AcceleratorHooksInterface.h>
5+
46
#include <c10/core/Allocator.h>
57
#include <c10/core/Device.h>
68
#include <c10/core/Storage.h>
@@ -11,45 +13,54 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
1113
namespace at {
1214

1315
struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface {
16+
#define FAIL_PRIVATEUSE1HOOKS_FUNC(func) \
17+
TORCH_CHECK_NOT_IMPLEMENTED( \
18+
false, \
19+
"You should register `PrivateUse1HooksInterface`", \
20+
"by `RegisterPrivateUse1HooksInterface` and implement `", \
21+
func, \
22+
"` at the same time for PrivateUse1.");
23+
1424
~PrivateUse1HooksInterface() override = default;
1525

1626
const at::Generator& getDefaultGenerator(
1727
c10::DeviceIndex device_index) const override {
18-
TORCH_CHECK_NOT_IMPLEMENTED(
19-
false,
20-
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDefaultGenerator`.");
28+
FAIL_PRIVATEUSE1HOOKS_FUNC(__func__);
29+
}
30+
31+
Generator getNewGenerator(
32+
[[maybe_unused]] DeviceIndex device_index = -1) const override {
33+
// TODO(FFFrog): Perserved for BC and will be removed in the future.
34+
if (at::GetGeneratorPrivate().has_value())
35+
return at::GetGeneratorForPrivateuse1(device_index);
36+
37+
FAIL_PRIVATEUSE1HOOKS_FUNC(__func__);
2138
}
2239

2340
at::Device getDeviceFromPtr(void* data) const override {
24-
TORCH_CHECK_NOT_IMPLEMENTED(
25-
false,
26-
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDeviceFromPtr`.");
41+
FAIL_PRIVATEUSE1HOOKS_FUNC(__func__);
2742
}
2843

2944
bool isPinnedPtr(const void* data) const override {
3045
return false;
3146
}
3247

3348
Allocator* getPinnedMemoryAllocator() const override {
34-
TORCH_CHECK(
35-
false,
36-
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getPinnedMemoryAllocator`.");
49+
FAIL_PRIVATEUSE1HOOKS_FUNC(__func__);
3750
}
3851

3952
bool hasPrimaryContext(DeviceIndex device_index) const override {
40-
TORCH_CHECK_NOT_IMPLEMENTED(
41-
false,
42-
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `hasPrimaryContext`.");
53+
FAIL_PRIVATEUSE1HOOKS_FUNC(__func__);
4354
}
4455

4556
void init() const override {}
4657
virtual void resizePrivateUse1Bytes(
4758
const c10::Storage& storage,
4859
size_t newsize) const {
49-
TORCH_CHECK_NOT_IMPLEMENTED(
50-
false,
51-
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `resizePrivateUse1Bytes`.");
60+
FAIL_PRIVATEUSE1HOOKS_FUNC(__func__);
5261
}
62+
63+
#undef FAIL_PRIVATEUSE1HOOKS_FUNC
5364
};
5465

5566
struct TORCH_API PrivateUse1HooksArgs {};
@@ -66,4 +77,5 @@ TORCH_API const at::PrivateUse1HooksInterface& getPrivateUse1Hooks();
6677
} // namespace detail
6778

6879
} // namespace at
80+
6981
C10_DIAGNOSTIC_POP()

aten/src/ATen/mps/MPSHooks.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ struct MPSHooks : public at::MPSHooksInterface {
2121
// MPSGeneratorImpl interface
2222
const Generator& getDefaultGenerator(
2323
DeviceIndex device_index = -1) const override;
24+
Generator getNewGenerator(DeviceIndex device_index = -1) const override;
2425

2526
// MPSStream interface
2627
void deviceSynchronize() const override;

aten/src/ATen/mps/MPSHooks.mm

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@
6969
return at::mps::detail::getDefaultMPSGenerator();
7070
}
7171

72+
Generator MPSHooks::getNewGenerator([[maybe_unused]] DeviceIndex device_index) const {
73+
return make_generator<at::MPSGeneratorImpl>();
74+
}
75+
7276
void MPSHooks::deviceSynchronize() const {
7377
at::mps::getDefaultMPSStream()->synchronize(SyncType::COMMIT_AND_WAIT);
7478
}

test/cpp_extensions/open_registration_extension.cpp

Lines changed: 42 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,28 @@
1-
#include <unordered_map>
2-
#include <c10/core/impl/alloc_cpu.h>
31
#include <c10/core/Allocator.h>
42
#include <c10/core/ScalarType.h>
3+
#include <c10/core/impl/DeviceGuardImplInterface.h>
4+
#include <c10/core/impl/alloc_cpu.h>
5+
#include <c10/macros/Macros.h>
56
#include <c10/util/ArrayRef.h>
67

78
#include <torch/csrc/Device.h>
89
#include <torch/csrc/jit/serialization/pickler.h>
9-
#include <c10/core/impl/DeviceGuardImplInterface.h>
10-
#include <c10/macros/Macros.h>
1110
#include <torch/extension.h>
1211

13-
#include <ATen/native/cpu/Loops.h>
14-
#include <ATen/native/quantized/AffineQuantizer.h>
12+
#include <ATen/EmptyTensor.h>
13+
#include <ATen/detail/PrivateUse1HooksInterface.h>
14+
#include <ATen/native/CPUFallback.h>
1515
#include <ATen/native/DispatchStub.h>
1616
#include <ATen/native/Resize.h>
1717
#include <ATen/native/UnaryOps.h>
18-
#include <ATen/native/CPUFallback.h>
18+
#include <ATen/native/cpu/Loops.h>
19+
#include <ATen/native/quantized/AffineQuantizer.h>
20+
#include <ATen/native/transformers/attention.h>
21+
#include <ATen/native/transformers/sdp_utils_cpp.h>
1922
#include <ATen/ops/abs_native.h>
20-
#include <ATen/EmptyTensor.h>
21-
#include <ATen/core/GeneratorForPrivateuseone.h>
22-
#include <ATen/detail/PrivateUse1HooksInterface.h>
2323
#include <ATen/ops/view.h>
24-
#include <ATen/native/transformers/sdp_utils_cpp.h>
25-
#include <ATen/native/transformers/attention.h>
24+
25+
#include <unordered_map>
2626

2727
static uint64_t add_counter = 0;
2828
static uint64_t last_saved_value = 0;
@@ -551,8 +551,15 @@ bool custom_add_called() {
551551
return called;
552552
}
553553

554+
void set_custom_device_index(c10::DeviceIndex device_index) {
555+
custom_device_index = device_index;
556+
}
557+
558+
// a global flag used for dummy pin_memory of custom device
559+
bool custom_pinned_flag = false;
560+
554561
class PrivateGeneratorImpl : public at::CPUGeneratorImpl {
555-
public:
562+
public:
556563
// Constructors
557564
PrivateGeneratorImpl(c10::DeviceIndex device_index) {
558565
device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index);
@@ -561,45 +568,33 @@ class PrivateGeneratorImpl : public at::CPUGeneratorImpl {
561568
~PrivateGeneratorImpl() override = default;
562569
};
563570

564-
// this is used to register generator
565-
at::Generator make_generator_privateuse1(c10::DeviceIndex device_index) {
566-
return at::make_generator<PrivateGeneratorImpl>(device_index);
567-
}
568-
569-
void register_generator_first() {
570-
REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1)
571-
}
571+
struct FooHooksArgs : public at::PrivateUse1HooksArgs {};
572572

573-
void register_generator_second() {
574-
REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1)
575-
}
573+
struct FooHooksInterface : public at::PrivateUse1HooksInterface {
574+
FooHooksInterface(FooHooksArgs) {}
575+
~FooHooksInterface() override = default;
576576

577-
void set_custom_device_index(c10::DeviceIndex device_index) {
578-
custom_device_index = device_index;
579-
}
577+
const at::Generator& getDefaultGenerator(
578+
c10::DeviceIndex device_index) const override {
579+
static auto device_gen = at::make_generator<PrivateGeneratorImpl>(device_index);
580+
return device_gen;
581+
}
580582

581-
// a global flag used for dummy pin_memory of custom device
582-
bool custom_pinned_flag = false;
583+
at::Generator getNewGenerator(c10::DeviceIndex device_index) const {
584+
return at::make_generator<PrivateGeneratorImpl>(device_index);
585+
}
583586

584-
struct FooHooksArgs : public at::PrivateUse1HooksArgs {};
587+
// this is a simple implementation, custom_pinned_flag will be set as true
588+
// once tensor.pin_memory() is called. And then tensor.is_pinned()
589+
// always return true no matter what tensor it's called on.
590+
bool isPinnedPtr(const void* data) const override {
591+
return custom_pinned_flag;
592+
}
585593

586-
struct FooHooksInterface : public at::PrivateUse1HooksInterface {
587-
FooHooksInterface(FooHooksArgs) {}
588-
~FooHooksInterface() override = default;
589-
const at::Generator& getDefaultGenerator(c10::DeviceIndex device_index) const override {
590-
static auto device_gen = make_generator_privateuse1(device_index);
591-
return device_gen;
592-
}
593-
// this is a simple implementation, custom_pinned_flag will be set as true
594-
// once tensor.pin_memory() is called. And then tensor.is_pinned()
595-
// always return true no matter what tensor it's called on.
596-
bool isPinnedPtr(const void* data) const override {
597-
return custom_pinned_flag;
598-
}
599-
c10::Allocator* getPinnedMemoryAllocator() const override {
600-
custom_pinned_flag = true;
601-
return c10::GetCPUAllocator();
602-
}
594+
c10::Allocator* getPinnedMemoryAllocator() const override {
595+
custom_pinned_flag = true;
596+
return c10::GetCPUAllocator();
597+
}
603598
};
604599

605600
TORCH_DECLARE_REGISTRY(PrivateUse1HooksRegistry, FooHooksInterface, FooHooksArgs);
@@ -682,8 +677,6 @@ at::Tensor custom_autograd_fn_aliasing(at::Tensor x) {
682677
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
683678
m.def("custom_device", &get_custom_device, "get custom device object");
684679
m.def("custom_add_called", &custom_add_called, "check if our custom add function was called");
685-
m.def("register_generator_first", &register_generator_first, "register generator for custom device firstly");
686-
m.def("register_generator_second", &register_generator_second, "register generator for custom device secondly");
687680
m.def("set_custom_device_index", &set_custom_device_index, "set custom device index");
688681
m.def("custom_storage_registry", &custom_storage_registry, "set custom storageImpl creat method");
689682
m.def("custom_storageImpl_called", &custom_storageImpl_called, "check if our custom abs function was called");

0 commit comments

Comments
 (0)