Skip to content

Commit 0b380d1

Browse files
fix jac with vmap
1 parent fb664ed commit 0b380d1

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
- Fix customized jax eigh operator by noting the return is a namedtuple
1818

19+
- Fix some issues in `jacfwd` and `jacrev` when integrated with vmap
20+
1921
## 1.0.2
2022

2123
### Added

tensorcircuit/backends/abstract_backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1548,7 +1548,7 @@ def wrapper(*args: Any, **kws: Any) -> Any:
15481548
if i == argnum
15491549
else self.reshape(
15501550
self.zeros(
1551-
[self.sizen(arg), self.sizen(arg)],
1551+
[self.sizen(args[argnum]), self.sizen(arg)],
15521552
dtype=arg.dtype,
15531553
),
15541554
[-1] + list(self.shape_tuple(arg)),
@@ -1636,6 +1636,7 @@ def _first(x: Sequence[Any]) -> Any:
16361636
),
16371637
jj,
16381638
)
1639+
jj = [jji for ind, jji in enumerate(jj) if ind in argnums]
16391640
if len(jj) == 1:
16401641
jj = jj[0]
16411642
jjs.append(jj)

0 commit comments

Comments
 (0)