1- #include < unordered_map>
2- #include < c10/core/impl/alloc_cpu.h>
31#include < c10/core/Allocator.h>
42#include < c10/core/ScalarType.h>
3+ #include < c10/core/impl/DeviceGuardImplInterface.h>
4+ #include < c10/core/impl/alloc_cpu.h>
5+ #include < c10/macros/Macros.h>
56#include < c10/util/ArrayRef.h>
67
78#include < torch/csrc/Device.h>
89#include < torch/csrc/jit/serialization/pickler.h>
9- #include < c10/core/impl/DeviceGuardImplInterface.h>
10- #include < c10/macros/Macros.h>
1110#include < torch/extension.h>
1211
13- #include < ATen/native/cpu/Loops.h>
14- #include < ATen/native/quantized/AffineQuantizer.h>
12+ #include < ATen/EmptyTensor.h>
13+ #include < ATen/detail/PrivateUse1HooksInterface.h>
14+ #include < ATen/native/CPUFallback.h>
1515#include < ATen/native/DispatchStub.h>
1616#include < ATen/native/Resize.h>
1717#include < ATen/native/UnaryOps.h>
18- #include < ATen/native/CPUFallback.h>
18+ #include < ATen/native/cpu/Loops.h>
19+ #include < ATen/native/quantized/AffineQuantizer.h>
20+ #include < ATen/native/transformers/attention.h>
21+ #include < ATen/native/transformers/sdp_utils_cpp.h>
1922#include < ATen/ops/abs_native.h>
20- #include < ATen/EmptyTensor.h>
21- #include < ATen/core/GeneratorForPrivateuseone.h>
22- #include < ATen/detail/PrivateUse1HooksInterface.h>
2323#include < ATen/ops/view.h>
24- # include < ATen/native/transformers/sdp_utils_cpp.h >
25- #include < ATen/native/transformers/attention.h >
24+
25+ #include < unordered_map >
2626
2727static uint64_t add_counter = 0 ;
2828static uint64_t last_saved_value = 0 ;
@@ -551,8 +551,15 @@ bool custom_add_called() {
551551 return called;
552552}
553553
554+ void set_custom_device_index (c10::DeviceIndex device_index) {
555+ custom_device_index = device_index;
556+ }
557+
558+ // a global flag used for dummy pin_memory of custom device
559+ bool custom_pinned_flag = false ;
560+
554561class PrivateGeneratorImpl : public at ::CPUGeneratorImpl {
555- public:
562+ public:
556563 // Constructors
557564 PrivateGeneratorImpl (c10::DeviceIndex device_index) {
558565 device_ = c10::Device (c10::DeviceType::PrivateUse1, device_index);
@@ -561,45 +568,33 @@ class PrivateGeneratorImpl : public at::CPUGeneratorImpl {
561568 ~PrivateGeneratorImpl () override = default ;
562569};
563570
564- // this is used to register generator
565- at::Generator make_generator_privateuse1 (c10::DeviceIndex device_index) {
566- return at::make_generator<PrivateGeneratorImpl>(device_index);
567- }
568-
569- void register_generator_first () {
570- REGISTER_GENERATOR_PRIVATEUSE1 (make_generator_privateuse1)
571- }
571+ struct FooHooksArgs : public at ::PrivateUse1HooksArgs {};
572572
573- void register_generator_second () {
574- REGISTER_GENERATOR_PRIVATEUSE1 (make_generator_privateuse1)
575- }
573+ struct FooHooksInterface : public at ::PrivateUse1HooksInterface {
574+ FooHooksInterface (FooHooksArgs) {}
575+ ~FooHooksInterface () override = default ;
576576
577- void set_custom_device_index (c10::DeviceIndex device_index) {
578- custom_device_index = device_index;
579- }
577+ const at::Generator& getDefaultGenerator (
578+ c10::DeviceIndex device_index) const override {
579+ static auto device_gen = at::make_generator<PrivateGeneratorImpl>(device_index);
580+ return device_gen;
581+ }
580582
581- // a global flag used for dummy pin_memory of custom device
582- bool custom_pinned_flag = false ;
583+ at::Generator getNewGenerator (c10::DeviceIndex device_index) const {
584+ return at::make_generator<PrivateGeneratorImpl>(device_index);
585+ }
583586
584- struct FooHooksArgs : public at ::PrivateUse1HooksArgs {};
587+ // this is a simple implementation, custom_pinned_flag will be set as true
588+ // once tensor.pin_memory() is called. And then tensor.is_pinned()
589+ // always return true no matter what tensor it's called on.
590+ bool isPinnedPtr (const void * data) const override {
591+ return custom_pinned_flag;
592+ }
585593
586- struct FooHooksInterface : public at ::PrivateUse1HooksInterface {
587- FooHooksInterface (FooHooksArgs) {}
588- ~FooHooksInterface () override = default ;
589- const at::Generator& getDefaultGenerator (c10::DeviceIndex device_index) const override {
590- static auto device_gen = make_generator_privateuse1 (device_index);
591- return device_gen;
592- }
593- // this is a simple implementation, custom_pinned_flag will be set as true
594- // once tensor.pin_memory() is called. And then tensor.is_pinned()
595- // always return true no matter what tensor it's called on.
596- bool isPinnedPtr (const void * data) const override {
597- return custom_pinned_flag;
598- }
599- c10::Allocator* getPinnedMemoryAllocator () const override {
600- custom_pinned_flag = true ;
601- return c10::GetCPUAllocator ();
602- }
594+ c10::Allocator* getPinnedMemoryAllocator () const override {
595+ custom_pinned_flag = true ;
596+ return c10::GetCPUAllocator ();
597+ }
603598};
604599
605600TORCH_DECLARE_REGISTRY (PrivateUse1HooksRegistry, FooHooksInterface, FooHooksArgs);
@@ -682,8 +677,6 @@ at::Tensor custom_autograd_fn_aliasing(at::Tensor x) {
682677PYBIND11_MODULE (TORCH_EXTENSION_NAME, m) {
683678 m.def (" custom_device" , &get_custom_device, " get custom device object" );
684679 m.def (" custom_add_called" , &custom_add_called, " check if our custom add function was called" );
685- m.def (" register_generator_first" , ®ister_generator_first, " register generator for custom device firstly" );
686- m.def (" register_generator_second" , ®ister_generator_second, " register generator for custom device secondly" );
687680 m.def (" set_custom_device_index" , &set_custom_device_index, " set custom device index" );
688681 m.def (" custom_storage_registry" , &custom_storage_registry, " set custom storageImpl creat method" );
689682 m.def (" custom_storageImpl_called" , &custom_storageImpl_called, " check if our custom abs function was called" );
0 commit comments