Skip to content

Commit 629b07e

Browse files
hanbinyooncopybara-github
authored andcommitted
Handle setting tensors of complex128 types. Add FFT test coverage.
PiperOrigin-RevId: 444399447
1 parent 5af72f8 commit 629b07e

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

lib/tensor/dense_host_tensor_kernels.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,25 @@ Error SyncSetDenseTensorWithConstantValues(
145145
return Error::success();
146146
}
147147

148+
template <>
149+
Error SyncSetDenseTensorWithConstantValues(
150+
MutableDHTArrayView<std::complex<double>> in,
151+
ArrayAttribute<std::complex<double>> values) {
152+
// In actuality, 'values' is an ArrayAttribute<double>. Treating it as an
153+
// ArrayAttribute<std::complex<double>> makes the copy easier, and fits the
154+
// template specialization nicely.
155+
const int total_value_count = in.NumElements() * 2; // real and imaginary
156+
if (total_value_count != values.size()) {
157+
return MakeStringError(
158+
"Incorrect number of real and imaginary values for the complex "
159+
"tensor: ",
160+
values.size(), ", but expected ", total_value_count);
161+
}
162+
std::copy(values.data().begin(), values.data().begin() + in.NumElements(),
163+
in.Elements().begin());
164+
return Error::success();
165+
}
166+
148167
template <typename T>
149168
static void SetDenseTensorWithConstantValues(
150169
ArgumentView<MutableDHTArrayView<T>> in, Argument<Chain> chain_in,

0 commit comments

Comments
 (0)