@@ -210,3 +210,119 @@ TEST(IndexSubset, FindPrevious) {
210
210
EXPECT_EQ (indices1->findPrevious (/* endIndex*/ 1 ), 0 );
211
211
EXPECT_EQ (indices1->findPrevious (/* endIndex*/ 0 ), -1 );
212
212
}
213
+
214
+ TEST (IndexSubset, Lowering) {
215
+ TestContext testCtx;
216
+ auto &C = testCtx.Ctx ;
217
+ // ((T, T)) -> ()
218
+ EXPECT_EQ (
219
+ autodiff::getLoweredParameterIndices (
220
+ IndexSubset::get (C, 1 , {0 }),
221
+ FunctionType::get ({
222
+ FunctionType::Param (
223
+ TupleType::get ({C.TheAnyType , C.TheAnyType }, C))},
224
+ C.TheEmptyTupleType )),
225
+ IndexSubset::get (C, 2 , {0 , 1 }));
226
+ // ((), (T, T)) -> ()
227
+ EXPECT_EQ (
228
+ autodiff::getLoweredParameterIndices (
229
+ IndexSubset::get (C, 2 , {1 }),
230
+ FunctionType::get ({
231
+ FunctionType::Param (C.TheEmptyTupleType ),
232
+ FunctionType::Param (
233
+ TupleType::get ({C.TheAnyType , C.TheAnyType }, C))},
234
+ C.TheEmptyTupleType )),
235
+ IndexSubset::get (C, 2 , {0 , 1 }));
236
+ // (T, (T, T)) -> ()
237
+ EXPECT_EQ (
238
+ autodiff::getLoweredParameterIndices (
239
+ IndexSubset::get (C, 2 , {1 }),
240
+ FunctionType::get ({
241
+ FunctionType::Param (C.TheAnyType ),
242
+ FunctionType::Param (
243
+ TupleType::get ({C.TheAnyType , C.TheAnyType }, C))},
244
+ C.TheEmptyTupleType )),
245
+ IndexSubset::get (C, 3 , {1 , 2 }));
246
+ // (T, (T, T)) -> ()
247
+ EXPECT_EQ (
248
+ autodiff::getLoweredParameterIndices (
249
+ IndexSubset::get (C, 2 , {0 , 1 }),
250
+ FunctionType::get ({
251
+ FunctionType::Param (C.TheAnyType ),
252
+ FunctionType::Param (
253
+ TupleType::get ({C.TheAnyType , C.TheAnyType }, C))},
254
+ C.TheEmptyTupleType )),
255
+ IndexSubset::get (C, 3 , {0 , 1 , 2 }));
256
+ // (T, ((T, T)), (T, T), T) -> ()
257
+ EXPECT_EQ (
258
+ autodiff::getLoweredParameterIndices (
259
+ IndexSubset::get (C, 4 , {0 , 1 , 3 }),
260
+ FunctionType::get ({
261
+ FunctionType::Param (C.TheAnyType ),
262
+ FunctionType::Param (
263
+ TupleType::get ({
264
+ TupleType::get ({C.TheAnyType , C.TheAnyType }, C)}, C)),
265
+ FunctionType::Param (
266
+ TupleType::get ({C.TheAnyType , C.TheAnyType }, C)),
267
+ FunctionType::Param (C.TheAnyType )},
268
+ C.TheEmptyTupleType )),
269
+ IndexSubset::get (C, 6 , {0 , 1 , 2 , 5 }));
270
+ // Method (T) -> ((T, T), (T, T), T) -> ()
271
+ // TODO(TF-874): Fix this unit test.
272
+ // The current actual result is:
273
+ // `(autodiff_index_subset capacity=6 indices=(0, 1, 4))`.
274
+ #if 0
275
+ EXPECT_EQ(
276
+ autodiff::getLoweredParameterIndices(
277
+ IndexSubset::get(C, 4, {0, 1, 3}),
278
+ FunctionType::get(
279
+ {FunctionType::Param(C.TheAnyType)},
280
+ FunctionType::get({
281
+ FunctionType::Param(
282
+ TupleType::get({C.TheAnyType, C.TheAnyType}, C)),
283
+ FunctionType::Param(
284
+ TupleType::get({C.TheAnyType, C.TheAnyType}, C)),
285
+ FunctionType::Param(C.TheAnyType)},
286
+ C.TheEmptyTupleType)->withExtInfo(
287
+ FunctionType::ExtInfo().withSILRepresentation(
288
+ SILFunctionTypeRepresentation::Method)))),
289
+ IndexSubset::get(C, 6, {0, 1, 4, 5}));
290
+ #endif
291
+ }
292
+
293
+ TEST (IndexSubset, GetSubsetParameterTypes) {
294
+ TestContext testCtx;
295
+ auto &C = testCtx.Ctx ;
296
+ // (T, T) -> ()
297
+ {
298
+ SmallVector<AnyFunctionType::Param, 8 > params;
299
+ auto *functionType = FunctionType::get ({FunctionType::Param (C.TheAnyType ),
300
+ FunctionType::Param (C.TheAnyType )},
301
+ C.TheEmptyTupleType );
302
+ functionType->getSubsetParameters (IndexSubset::get (C, 1 , {0 }), params);
303
+ AnyFunctionType::Param expected[] = {AnyFunctionType::Param (C.TheAnyType )};
304
+ EXPECT_TRUE (std::equal (params.begin (), params.end (), expected,
305
+ [](auto param1, auto param2) {
306
+ return param1.getPlainType ()->isEqual (param2.getPlainType ());
307
+ }));
308
+ }
309
+ // (T) -> (T, T) -> ()
310
+ {
311
+ SmallVector<AnyFunctionType::Param, 8 > params;
312
+ auto *functionType =
313
+ FunctionType::get ({FunctionType::Param (C.TheIEEE16Type )},
314
+ FunctionType::get ({FunctionType::Param (C.TheAnyType ),
315
+ FunctionType::Param (C.TheAnyType )},
316
+ C.TheEmptyTupleType ));
317
+ functionType->getSubsetParameters (IndexSubset::get (C, 3 , {0 , 1 , 2 }),
318
+ params);
319
+ AnyFunctionType::Param expected[] = {
320
+ AnyFunctionType::Param (C.TheIEEE16Type ),
321
+ AnyFunctionType::Param (C.TheAnyType ),
322
+ AnyFunctionType::Param (C.TheAnyType )};
323
+ EXPECT_TRUE (std::equal (
324
+ params.begin (), params.end (), expected, [](auto param1, auto param2) {
325
+ return param1.getPlainType ()->isEqual (param2.getPlainType ());
326
+ }));
327
+ }
328
+ }
0 commit comments