Skip to content

Commit 1420b81

Browse files
author
ematejska
authored
[Autodiff upstream] Upstream autodiff unittests (swiftlang#30709)
* Adding @transpose attr deserialization support * Turning on the transpose serialization test * Adding the autodiff unittests
1 parent a460fb5 commit 1420b81

File tree

1 file changed

+116
-0
lines changed

1 file changed

+116
-0
lines changed

unittests/AST/IndexSubsetTests.cpp

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,119 @@ TEST(IndexSubset, FindPrevious) {
210210
EXPECT_EQ(indices1->findPrevious(/*endIndex*/ 1), 0);
211211
EXPECT_EQ(indices1->findPrevious(/*endIndex*/ 0), -1);
212212
}
213+
214+
TEST(IndexSubset, Lowering) {
215+
TestContext testCtx;
216+
auto &C = testCtx.Ctx;
217+
// ((T, T)) -> ()
218+
EXPECT_EQ(
219+
autodiff::getLoweredParameterIndices(
220+
IndexSubset::get(C, 1, {0}),
221+
FunctionType::get({
222+
FunctionType::Param(
223+
TupleType::get({C.TheAnyType, C.TheAnyType}, C))},
224+
C.TheEmptyTupleType)),
225+
IndexSubset::get(C, 2, {0, 1}));
226+
// ((), (T, T)) -> ()
227+
EXPECT_EQ(
228+
autodiff::getLoweredParameterIndices(
229+
IndexSubset::get(C, 2, {1}),
230+
FunctionType::get({
231+
FunctionType::Param(C.TheEmptyTupleType),
232+
FunctionType::Param(
233+
TupleType::get({C.TheAnyType, C.TheAnyType}, C))},
234+
C.TheEmptyTupleType)),
235+
IndexSubset::get(C, 2, {0, 1}));
236+
// (T, (T, T)) -> ()
237+
EXPECT_EQ(
238+
autodiff::getLoweredParameterIndices(
239+
IndexSubset::get(C, 2, {1}),
240+
FunctionType::get({
241+
FunctionType::Param(C.TheAnyType),
242+
FunctionType::Param(
243+
TupleType::get({C.TheAnyType, C.TheAnyType}, C))},
244+
C.TheEmptyTupleType)),
245+
IndexSubset::get(C, 3, {1, 2}));
246+
// (T, (T, T)) -> ()
247+
EXPECT_EQ(
248+
autodiff::getLoweredParameterIndices(
249+
IndexSubset::get(C, 2, {0, 1}),
250+
FunctionType::get({
251+
FunctionType::Param(C.TheAnyType),
252+
FunctionType::Param(
253+
TupleType::get({C.TheAnyType, C.TheAnyType}, C))},
254+
C.TheEmptyTupleType)),
255+
IndexSubset::get(C, 3, {0, 1, 2}));
256+
// (T, ((T, T)), (T, T), T) -> ()
257+
EXPECT_EQ(
258+
autodiff::getLoweredParameterIndices(
259+
IndexSubset::get(C, 4, {0, 1, 3}),
260+
FunctionType::get({
261+
FunctionType::Param(C.TheAnyType),
262+
FunctionType::Param(
263+
TupleType::get({
264+
TupleType::get({C.TheAnyType, C.TheAnyType}, C)}, C)),
265+
FunctionType::Param(
266+
TupleType::get({C.TheAnyType, C.TheAnyType}, C)),
267+
FunctionType::Param(C.TheAnyType)},
268+
C.TheEmptyTupleType)),
269+
IndexSubset::get(C, 6, {0, 1, 2, 5}));
270+
// Method (T) -> ((T, T), (T, T), T) -> ()
271+
// TODO(TF-874): Fix this unit test.
272+
// The current actual result is:
273+
// `(autodiff_index_subset capacity=6 indices=(0, 1, 4))`.
274+
#if 0
275+
EXPECT_EQ(
276+
autodiff::getLoweredParameterIndices(
277+
IndexSubset::get(C, 4, {0, 1, 3}),
278+
FunctionType::get(
279+
{FunctionType::Param(C.TheAnyType)},
280+
FunctionType::get({
281+
FunctionType::Param(
282+
TupleType::get({C.TheAnyType, C.TheAnyType}, C)),
283+
FunctionType::Param(
284+
TupleType::get({C.TheAnyType, C.TheAnyType}, C)),
285+
FunctionType::Param(C.TheAnyType)},
286+
C.TheEmptyTupleType)->withExtInfo(
287+
FunctionType::ExtInfo().withSILRepresentation(
288+
SILFunctionTypeRepresentation::Method)))),
289+
IndexSubset::get(C, 6, {0, 1, 4, 5}));
290+
#endif
291+
}
292+
293+
TEST(IndexSubset, GetSubsetParameterTypes) {
294+
TestContext testCtx;
295+
auto &C = testCtx.Ctx;
296+
// (T, T) -> ()
297+
{
298+
SmallVector<AnyFunctionType::Param, 8> params;
299+
auto *functionType = FunctionType::get({FunctionType::Param(C.TheAnyType),
300+
FunctionType::Param(C.TheAnyType)},
301+
C.TheEmptyTupleType);
302+
functionType->getSubsetParameters(IndexSubset::get(C, 1, {0}), params);
303+
AnyFunctionType::Param expected[] = {AnyFunctionType::Param(C.TheAnyType)};
304+
EXPECT_TRUE(std::equal(params.begin(), params.end(), expected,
305+
[](auto param1, auto param2) {
306+
return param1.getPlainType()->isEqual(param2.getPlainType());
307+
}));
308+
}
309+
// (T) -> (T, T) -> ()
310+
{
311+
SmallVector<AnyFunctionType::Param, 8> params;
312+
auto *functionType =
313+
FunctionType::get({FunctionType::Param(C.TheIEEE16Type)},
314+
FunctionType::get({FunctionType::Param(C.TheAnyType),
315+
FunctionType::Param(C.TheAnyType)},
316+
C.TheEmptyTupleType));
317+
functionType->getSubsetParameters(IndexSubset::get(C, 3, {0, 1, 2}),
318+
params);
319+
AnyFunctionType::Param expected[] = {
320+
AnyFunctionType::Param(C.TheIEEE16Type),
321+
AnyFunctionType::Param(C.TheAnyType),
322+
AnyFunctionType::Param(C.TheAnyType)};
323+
EXPECT_TRUE(std::equal(
324+
params.begin(), params.end(), expected, [](auto param1, auto param2) {
325+
return param1.getPlainType()->isEqual(param2.getPlainType());
326+
}));
327+
}
328+
}

0 commit comments

Comments
 (0)