@@ -395,37 +395,37 @@ end
395395
396396A struct to represent the diffusion operator. This is used to perform the diffusion process on N different Wiener processes.
397397=#
398- struct DiffusionOperator{T,OpType<: Tuple{Vararg{AbstractSciMLOperator}} } <: AbstractSciMLOperator{T}
398+ struct DiffusionOperator{T,OpType<: Tuple{Vararg{AbstractSciMLOperator}} }
399399 ops:: OpType
400400 function DiffusionOperator (ops:: OpType ) where {OpType}
401401 T = mapreduce (eltype, promote_type, ops)
402402 return new {T,OpType} (ops)
403403 end
404404end
405405
406- @generated function update_coefficients! (L:: DiffusionOperator , u , p, t)
406+ @generated function (L:: DiffusionOperator )(w, v , p, t)
407407 ops_types = L. parameters[2 ]. parameters
408408 N = length (ops_types)
409- return quote
409+ quote
410+ M = length (v)
411+ S = (size (w, 1 ), size (w, 2 )) # This supports also `w` as a `Vector`
412+ (S[1 ] == M && S[2 ] == $ N) || throw (DimensionMismatch (" The size of the output vector is incorrect." ))
410413 Base. @nexprs $ N i -> begin
411- update_coefficients! (L. ops[i], u, p, t)
414+ # update_coefficients!(L.ops[i], v, p, t)
415+ # mul!(@view(w[:, i]), L.ops[i], v)
416+ op = L. ops[i]
417+ op (@view (w[:, i]), v, v, p, t)
412418 end
413-
414- nothing
419+ return w
415420 end
416421end
417422
418- @generated function LinearAlgebra. mul! (v:: AbstractVecOrMat , L:: DiffusionOperator , u:: AbstractVecOrMat )
419- ops_types = L. parameters[2 ]. parameters
420- N = length (ops_types)
421- quote
422- M = length (u)
423- S = (size (v, 1 ), size (v, 2 )) # This supports also `v` as a `Vector`
424- (S[1 ] == M && S[2 ] == $ N) || throw (DimensionMismatch (" The size of the output vector is incorrect." ))
425- Base. @nexprs $ N i -> begin
426- mul! (@view (v[:, i]), L. ops[i], u)
427- end
428- return v
423+ # TODO : Remove when https://github.com/SciML/StochasticDiffEq.jl/issues/615 is fixed.
424+ function (f:: SDEFunction )(du, u, p, t)
425+ if f. f isa AbstractSciMLOperator
426+ f. f (du, u, u, p, t)
427+ else
428+ f. f (du, u, p, t)
429429 end
430430end
431431
0 commit comments