@@ -261,11 +261,11 @@ function lazytensor_enable_cache(; maxsize::Int = -1, maxrelsize::Real = 0.0)
261261 return
262262end
263263
264- function _tp_matmul_first! (result:: Base.ReshapedArray , a:: AbstractMatrix , b:: Base.ReshapedArray , α:: Number , β:: Number )
265- d_first = size (b, 1 )
264+ function _tp_matmul_first! (result, a:: AbstractMatrix , b, α:: Number , β:: Number )
265+ d_first = size (a, 2 )
266266 d_rest = length (b)÷ d_first
267- bp = b . parent
268- rp = result . parent
267+ bp = parent (b)
268+ rp = parent (result)
269269 @uviews bp rp begin # avoid allocations on reshape
270270 br = reshape (bp, (d_first, d_rest))
271271 result_r = reshape (rp, (size (a, 1 ), d_rest))
@@ -274,11 +274,11 @@ function _tp_matmul_first!(result::Base.ReshapedArray, a::AbstractMatrix, b::Bas
274274 result
275275end
276276
277- function _tp_matmul_last! (result:: Base.ReshapedArray , a:: AbstractMatrix , b:: Base.ReshapedArray , α:: Number , β:: Number )
278- d_last = size (b, ndims (b) )
277+ function _tp_matmul_last! (result, a:: AbstractMatrix , b, α:: Number , β:: Number )
278+ d_last = size (a, 2 )
279279 d_rest = length (b)÷ d_last
280- bp = b . parent
281- rp = result . parent
280+ bp = parent (b)
281+ rp = parent (result)
282282 @uviews a bp rp begin # avoid allocations on reshape
283283 br = reshape (bp, (d_rest, d_last))
284284 result_r = reshape (rp, (d_rest, size (a, 1 )))
@@ -287,7 +287,7 @@ function _tp_matmul_last!(result::Base.ReshapedArray, a::AbstractMatrix, b::Base
287287 result
288288end
289289
290- function _tp_matmul_get_tmp (:: Type{T} , shp:: NTuple{N,Int} , sym) where {T,N}
290+ function _tp_matmul_get_tmp (:: Type{T} , shp:: NTuple{N,Int} , sym, :: Array ) where {T,N}
291291 len = prod (shp)
292292 use_cache = lazytensor_use_cache ()
293293 key = (sym, taskid (), UInt (len), T)
@@ -301,7 +301,17 @@ function _tp_matmul_get_tmp(::Type{T}, shp::NTuple{N,Int}, sym) where {T,N}
301301 Base. ReshapedArray (tmp, shp, ())
302302end
303303
304- function _tp_matmul_mid! (result:: Base.ReshapedArray , a:: AbstractMatrix , loc:: Integer , b:: Base.ReshapedArray , α:: Number , β:: Number )
304+ function _tp_matmul_get_tmp (:: Type{T} , shp:: NTuple{N,Int} , sym, arr:: AbstractArray ) where {T,N}
305+ if parent (arr) === arr
306+ # This is a fallback that does not use the cache. Does not get triggered for arr <: Array.
307+ return similar (arr, T, shp)
308+ end
309+ # Unpack wrapped arrays. If we hit an Array, we will use the cache.
310+ # If we hit a different non-wrapped array-like, we will call `similar()`.
311+ _tp_matmul_get_tmp (T, shp, sym, parent (arr))
312+ end
313+
314+ function _tp_matmul_mid! (result, a:: AbstractMatrix , loc:: Integer , b, α:: Number , β:: Number )
305315 sz_b_1 = 1
306316 for i in 1 : loc- 1
307317 sz_b_1 *= size (b,i)
@@ -320,11 +330,11 @@ function _tp_matmul_mid!(result::Base.ReshapedArray, a::AbstractMatrix, loc::Int
320330 move_left = sz_b_1 < sz_b_3
321331 perm = move_left ? (2 ,1 ,3 ) : (1 ,3 ,2 )
322332
323- br_p = _tp_matmul_get_tmp (eltype (br), ((size (br, i) for i in perm). .. ,), :_tp_matmul_mid_in )
333+ br_p = _tp_matmul_get_tmp (eltype (br), ((size (br, i) for i in perm). .. ,), :_tp_matmul_mid_in , br )
324334 @strided permutedims! (br_p, br, perm)
325335 # permutedims!(br_p, br, perm)
326336
327- result_r_p = _tp_matmul_get_tmp (eltype (result_r), ((size (result_r, i) for i in perm). .. ,), :_tp_matmul_mid_out )
337+ result_r_p = _tp_matmul_get_tmp (eltype (result_r), ((size (result_r, i) for i in perm). .. ,), :_tp_matmul_mid_out , result_r )
328338 β == 0.0 || @strided permutedims! (result_r_p, result_r, perm)
329339 # β == 0.0 || permutedims!(result_r_p, result_r, perm)
330340
366376
367377function _tp_sum_get_tmp (op:: AbstractMatrix{T} , loc:: Integer , arr:: AbstractArray{S,N} , sym) where {T,S,N}
368378 shp = ntuple (i -> i == loc ? size (op,1 ) : size (arr,i), N)
369- _tp_matmul_get_tmp (promote_type (T,S), shp, sym)
379+ _tp_matmul_get_tmp (promote_type (T,S), shp, sym, arr )
370380end
371381
372382# Apply a tensor product of operators to a vector.
@@ -434,7 +444,7 @@ Base.size(A::_SimpleIsometry, i) = A.shape[i]
434444
435445function _tp_sum_get_tmp (op:: _SimpleIsometry , loc:: Integer , arr:: AbstractArray{S,N} , sym) where {S,N}
436446 shp = ntuple (i -> i == loc ? size (op,1 ) : size (arr,i), N)
437- _tp_matmul_get_tmp (S, shp, sym)
447+ _tp_matmul_get_tmp (S, shp, sym, arr )
438448end
439449
440450function _tp_matmul! (result, a:: _SimpleIsometry , loc:: Integer , b, α:: Number , β:: Number )
0 commit comments