Skip to content

Commit d4f03be

Browse files
authored
Use similar in creation of DiffResults buffer (#95)
* use similar in creation of diffresults buffer * use the input to make the DiffResults buffer * remove the logdensity argument from _diffresults_buffer * removed type-annotation * patch version bump
1 parent a6a5707 commit d4f03be

File tree

4 files changed

+6
-6
lines changed

4 files changed

+6
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LogDensityProblems"
22
uuid = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
33
authors = ["Tamas K. Papp <[email protected]>"]
4-
version = "1.0.2"
4+
version = "1.0.3"
55

66
[deps]
77
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"

src/AD_ForwardDiff.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ end
4545

4646
function logdensity_and_gradient(fℓ::ForwardDiffLogDensity, x::AbstractVector)
4747
@unpack ℓ, gradientconfig = fℓ
48-
buffer = _diffresults_buffer(ℓ, x)
48+
buffer = _diffresults_buffer(x)
4949
result = ForwardDiff.gradient!(buffer, Base.Fix1(logdensity, ℓ), x, gradientconfig)
5050
_diffresults_extract(result)
5151
end

src/AD_ReverseDiff.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ end
5050

5151
function logdensity_and_gradient(∇ℓ::ReverseDiffLogDensity, x::AbstractVector)
5252
@unpack ℓ, compiledtape = ∇ℓ
53-
buffer = _diffresults_buffer(ℓ, x)
53+
buffer = _diffresults_buffer(x)
5454
if compiledtape === nothing
5555
result = ReverseDiff.gradient!(buffer, Base.Fix1(logdensity, ℓ), x)
5656
else

src/DiffResults_helpers.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ $(SIGNATURES)
1111
Allocate a DiffResults buffer for a gradient, taking the element type of `x` into account
1212
(heuristically).
1313
"""
14-
function _diffresults_buffer(ℓ, x)
14+
function _diffresults_buffer(x)
1515
T = eltype(x)
1616
S = T <: Real ? float(Real) : Float64 # heuristic
17-
DiffResults.MutableDiffResult(zero(S), (Vector{S}(undef, dimension(ℓ)), ))
17+
DiffResults.MutableDiffResult(zero(S), (similar(x, S), ))
1818
end
1919

2020
"""
@@ -25,5 +25,5 @@ constructed with [`diffresults_buffer`](@ref). Gradient is not copied as caller
2525
vector.
2626
"""
2727
function _diffresults_extract(diffresult::DiffResults.DiffResult)
28-
DiffResults.value(diffresult)::Real, DiffResults.gradient(diffresult)
28+
DiffResults.value(diffresult), DiffResults.gradient(diffresult)
2929
end

0 commit comments

Comments
 (0)