Skip to content

Commit e5fee97

Browse files
committed
[functorch] Fix index.Tensor, index_put batching rules (pytorch/functorch#862)
Fixes pytorch/functorch#859 Start reading at `NOTE: [advanced indexing (index.Tensor) batch rule]` in the code for details. This PR rewrites the index.Tensor and index_put batching rules. The TL;DR is: - advanced indexing has different behavior depending on if the "advanced indices are adjacent": https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing - we have to take this into account in our batching rules, because index.Tensor and index_put handle these internally. Test Plan - I added new test cases for getitem and aten.ops.index_put via OpInfo testing. Future - primtorch should have a sane decomposition that we can use - We haven't fixed the index_put_ batching rule yet. TODO later... - Upstream our test cases (see next section) into pytorch/pytorch
1 parent b14de0c commit e5fee97

File tree

4 files changed

+383
-71
lines changed

4 files changed

+383
-71
lines changed

0 commit comments

Comments
 (0)