Skip to content

Commit 5702b04

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Allow data casting along with cloning. (pytorch#15510)
Summary: . Differential Revision: D86070966
1 parent a11d555 commit 5702b04

File tree

3 files changed

+172
-21
lines changed

3 files changed

+172
-21
lines changed

extension/tensor/tensor_ptr.cpp

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,9 @@ TensorPtr make_tensor_ptr(
164164
[data = std::move(data)](void*) {});
165165
}
166166

167-
TensorPtr clone_tensor_ptr(const executorch::aten::Tensor& tensor) {
167+
TensorPtr clone_tensor_ptr(
168+
const executorch::aten::Tensor& tensor,
169+
executorch::aten::ScalarType type) {
168170
std::vector<executorch::aten::SizesType> sizes(
169171
tensor.sizes().begin(), tensor.sizes().end());
170172
std::vector<executorch::aten::DimOrderType> dim_order{
@@ -178,23 +180,63 @@ TensorPtr clone_tensor_ptr(const executorch::aten::Tensor& tensor) {
178180
#ifndef USE_ATEN_LIB
179181
dynamism = tensor.shape_dynamism();
180182
#endif // USE_ATEN_LIB
181-
return tensor.const_data_ptr()
182-
? make_tensor_ptr(
183-
std::move(sizes),
184-
std::vector<uint8_t>(
185-
(uint8_t*)tensor.const_data_ptr(),
186-
(uint8_t*)tensor.const_data_ptr() + tensor.nbytes()),
187-
std::move(dim_order),
188-
std::move(strides),
189-
tensor.scalar_type(),
190-
dynamism)
191-
: make_tensor_ptr(
192-
std::move(sizes),
193-
nullptr,
194-
std::move(dim_order),
195-
std::move(strides),
196-
tensor.scalar_type(),
197-
dynamism);
183+
const auto* tensor_data = tensor.const_data_ptr();
184+
if (!tensor_data) {
185+
return make_tensor_ptr(
186+
std::move(sizes),
187+
nullptr,
188+
std::move(dim_order),
189+
std::move(strides),
190+
type,
191+
dynamism);
192+
}
193+
const auto tensor_type = tensor.scalar_type();
194+
if (tensor_type == type) {
195+
return make_tensor_ptr(
196+
std::move(sizes),
197+
std::vector<uint8_t>(
198+
(uint8_t*)tensor_data, (uint8_t*)tensor_data + tensor.nbytes()),
199+
std::move(dim_order),
200+
std::move(strides),
201+
tensor_type,
202+
dynamism);
203+
}
204+
ET_CHECK_MSG(
205+
runtime::canCast(tensor_type, type),
206+
"Cannot cast tensor type to desired type.");
207+
const auto tensor_numel = static_cast<size_t>(tensor.numel());
208+
std::vector<uint8_t> data(tensor_numel * aten::elementSize(type));
209+
210+
// Create a minimal context for error handling in ET_SWITCH
211+
struct {
212+
[[noreturn]] void fail(torch::executor::Error /* error */) {
213+
ET_CHECK_MSG(false, "Unsupported dtype in clone_tensor_ptr");
214+
}
215+
} ctx;
216+
217+
ET_SWITCH_REALHBBF16_AND_UINT_TYPES(
218+
tensor_type, ctx, "clone_tensor_ptr_from", CTYPE_FROM, [&] {
219+
const CTYPE_FROM* tensor_data_ptr =
220+
static_cast<const CTYPE_FROM*>(tensor_data);
221+
ET_SWITCH_REALHBBF16_AND_UINT_TYPES(
222+
type, ctx, "clone_tensor_ptr_to", CTYPE_TO, [&] {
223+
CTYPE_TO* data_ptr = reinterpret_cast<CTYPE_TO*>(data.data());
224+
std::transform(
225+
tensor_data_ptr,
226+
tensor_data_ptr + tensor_numel,
227+
data_ptr,
228+
[](const CTYPE_FROM& val) {
229+
return static_cast<CTYPE_TO>(val);
230+
});
231+
});
232+
});
233+
return make_tensor_ptr(
234+
std::move(sizes),
235+
std::move(data),
236+
std::move(dim_order),
237+
std::move(strides),
238+
type,
239+
dynamism);
198240
}
199241

200242
runtime::Error resize_tensor_ptr(

extension/tensor/tensor_ptr.h

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ inline TensorPtr make_tensor_ptr(
114114
ET_CHECK_MSG(
115115
runtime::canCast(deduced_type, type),
116116
"Cannot cast deduced type to specified type.");
117-
std::vector<uint8_t> casted_data(data.size() * runtime::elementSize(type));
117+
std::vector<uint8_t> casted_data(data.size() * aten::elementSize(type));
118118

119119
// Create a minimal context for error handling in ET_SWITCH
120120
struct {
@@ -408,6 +408,21 @@ inline TensorPtr make_tensor_ptr(
408408
[tensor_ptr](void*) {});
409409
}
410410

411+
/**
412+
* Creates a TensorPtr that manages a new Tensor with the same properties
413+
* as the given Tensor, but with a copy of the data owned by the returned
414+
* TensorPtr, or nullptr if the original data is null.
415+
*
416+
* @param tensor The Tensor to clone.
417+
* @param type The data type for the cloned tensor. The data will be cast
418+
* from the source tensor's type.
419+
* @return A new TensorPtr that manages a Tensor with the specified type
420+
* and copied/cast data.
421+
*/
422+
TensorPtr clone_tensor_ptr(
423+
const executorch::aten::Tensor& tensor,
424+
executorch::aten::ScalarType type);
425+
411426
/**
412427
* Creates a TensorPtr that manages a new Tensor with the same properties
413428
* as the given Tensor, but with a copy of the data owned by the returned
@@ -417,7 +432,25 @@ inline TensorPtr make_tensor_ptr(
417432
* @return A new TensorPtr that manages a Tensor with the same properties as the
418433
* original but with copied data.
419434
*/
420-
TensorPtr clone_tensor_ptr(const executorch::aten::Tensor& tensor);
435+
inline TensorPtr clone_tensor_ptr(const executorch::aten::Tensor& tensor) {
436+
return clone_tensor_ptr(tensor, tensor.scalar_type());
437+
}
438+
439+
/**
440+
* Creates a new TensorPtr by cloning the given TensorPtr, copying the
441+
* underlying data.
442+
*
443+
* @param tensor The TensorPtr to clone.
444+
* @param type The data type for the cloned tensor. The data will be cast
445+
* from the source tensor's type.
446+
* @return A new TensorPtr that manages a Tensor with the specified type
447+
* and copied/cast data.
448+
*/
449+
inline TensorPtr clone_tensor_ptr(
450+
const TensorPtr& tensor,
451+
executorch::aten::ScalarType type) {
452+
return clone_tensor_ptr(*tensor, type);
453+
}
421454

422455
/**
423456
* Creates a new TensorPtr by cloning the given TensorPtr, copying the
@@ -428,7 +461,7 @@ TensorPtr clone_tensor_ptr(const executorch::aten::Tensor& tensor);
428461
* original but with copied data.
429462
*/
430463
inline TensorPtr clone_tensor_ptr(const TensorPtr& tensor) {
431-
return clone_tensor_ptr(*tensor);
464+
return clone_tensor_ptr(*tensor, tensor->scalar_type());
432465
}
433466

434467
/**

extension/tensor/test/tensor_ptr_test.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,82 @@ TEST_F(TensorPtrTest, CloneTensorPtrFromExistingTensorInt32) {
571571
EXPECT_EQ(cloned_tensor->scalar_type(), executorch::aten::ScalarType::Int);
572572
}
573573

574+
TEST_F(TensorPtrTest, CloneTensorPtrCastInt32ToFloat) {
575+
std::vector<int32_t> data = {1, 2, 3, 4};
576+
auto tensor = make_tensor_ptr({2, 2}, std::move(data));
577+
auto cloned_tensor =
578+
clone_tensor_ptr(*tensor, executorch::aten::ScalarType::Float);
579+
580+
EXPECT_EQ(cloned_tensor->dim(), 2);
581+
EXPECT_EQ(cloned_tensor->size(0), 2);
582+
EXPECT_EQ(cloned_tensor->size(1), 2);
583+
EXPECT_EQ(cloned_tensor->scalar_type(), executorch::aten::ScalarType::Float);
584+
auto ptr = cloned_tensor->const_data_ptr<float>();
585+
EXPECT_FLOAT_EQ(ptr[0], 1.0f);
586+
EXPECT_FLOAT_EQ(ptr[1], 2.0f);
587+
EXPECT_FLOAT_EQ(ptr[2], 3.0f);
588+
EXPECT_FLOAT_EQ(ptr[3], 4.0f);
589+
}
590+
591+
TEST_F(TensorPtrTest, CloneTensorPtrCastFloatToBFloat16) {
592+
std::vector<float> data = {1.0f, 2.0f, 3.5f};
593+
auto tensor = make_tensor_ptr({3}, std::move(data));
594+
auto cloned_tensor =
595+
clone_tensor_ptr(*tensor, executorch::aten::ScalarType::BFloat16);
596+
597+
EXPECT_EQ(cloned_tensor->dim(), 1);
598+
EXPECT_EQ(cloned_tensor->size(0), 3);
599+
EXPECT_EQ(
600+
cloned_tensor->scalar_type(), executorch::aten::ScalarType::BFloat16);
601+
auto ptr = cloned_tensor->const_data_ptr<executorch::aten::BFloat16>();
602+
EXPECT_NEAR(static_cast<float>(ptr[0]), 1.0f, 0.01f);
603+
EXPECT_NEAR(static_cast<float>(ptr[1]), 2.0f, 0.01f);
604+
EXPECT_NEAR(static_cast<float>(ptr[2]), 3.5f, 0.01f);
605+
}
606+
607+
TEST_F(TensorPtrTest, CloneTensorPtrCastKeepsMetadata) {
608+
std::vector<uint8_t> data(
609+
6 * executorch::aten::elementSize(executorch::aten::ScalarType::Float));
610+
auto tensor = make_tensor_ptr({2, 3}, std::move(data));
611+
auto cloned_tensor =
612+
clone_tensor_ptr(*tensor, executorch::aten::ScalarType::Float);
613+
614+
EXPECT_EQ(cloned_tensor->dim(), 2);
615+
EXPECT_EQ(cloned_tensor->size(0), 2);
616+
EXPECT_EQ(cloned_tensor->size(1), 3);
617+
EXPECT_EQ(cloned_tensor->strides()[0], 3);
618+
EXPECT_EQ(cloned_tensor->strides()[1], 1);
619+
EXPECT_EQ(cloned_tensor->scalar_type(), executorch::aten::ScalarType::Float);
620+
}
621+
622+
TEST_F(TensorPtrTest, CloneTensorPtrCastNullData) {
623+
auto tensor = make_tensor_ptr(
624+
{2, 2},
625+
nullptr,
626+
{},
627+
{},
628+
executorch::aten::ScalarType::Float,
629+
executorch::aten::TensorShapeDynamism::DYNAMIC_BOUND);
630+
auto cloned_tensor =
631+
clone_tensor_ptr(*tensor, executorch::aten::ScalarType::Int);
632+
633+
EXPECT_EQ(cloned_tensor->dim(), 2);
634+
EXPECT_EQ(cloned_tensor->size(0), 2);
635+
EXPECT_EQ(cloned_tensor->size(1), 2);
636+
EXPECT_EQ(cloned_tensor->const_data_ptr(), nullptr);
637+
EXPECT_EQ(cloned_tensor->scalar_type(), executorch::aten::ScalarType::Int);
638+
}
639+
640+
TEST_F(TensorPtrTest, CloneTensorPtrCastInvalidExpectDeath) {
641+
std::vector<float> data = {1.0f, 2.0f};
642+
auto tensor = make_tensor_ptr({2}, std::move(data));
643+
ET_EXPECT_DEATH(
644+
{
645+
auto _ = clone_tensor_ptr(*tensor, executorch::aten::ScalarType::Int);
646+
},
647+
"");
648+
}
649+
574650
TEST_F(TensorPtrTest, MakeTensorPtrFromTensorPtrInt32) {
575651
std::vector<int32_t> data = {1, 2, 3, 4};
576652
auto tensor = make_tensor_ptr({2, 2}, data);

0 commit comments

Comments
 (0)