We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 4ffe360 commit 9e64d01Copy full SHA for 9e64d01
src/MatrixProductStates/derivatives.jl
@@ -66,7 +66,8 @@ Return also the loglikelihood, which is a byproduct of the computation.
66
function grad_loglikelihood_two_site(p::MPS, k::Integer, X;
67
prodA_left = [precompute_left_environments(p.ψ, x) for x in X],
68
prodA_right = [precompute_right_environments(p.ψ, x) for x in X],
69
- Aᵏᵏ⁺¹ =_merge_tensors(p[k], p[k+1]))
+ Aᵏᵏ⁺¹ =_merge_tensors(p[k], p[k+1]),
70
+ weights = ones(length(X)))
71
72
Zprime, Z = grad_normalization_two_site_canonical(p, k; Aᵏᵏ⁺¹)
73
ll = -log(Z)
@@ -78,8 +79,8 @@ function grad_loglikelihood_two_site(p::MPS, k::Integer, X;
78
79
gr, val = grad_evaluate_two_site(p.ψ, k, x;
80
Ax_left = prodA_left[n][k-1], Ax_right = prodA_right[n][k+2], Aᵏᵏ⁺¹
81
)
- gA[:,:,x[k]...,x[k+1]...] .+= 2/T * gr / val
82
- ll += 1/T * log(abs2(val))
+ gA[:,:,x[k]...,x[k+1]...] .+= 2/T * gr / val * weights[n]
83
+ ll += 1/T * log(abs2(val)) * weights[n]
84
end
85
return gA, ll
86
test/mps.jl
@@ -277,8 +277,9 @@ end
277
X = [sample(p)[1] for _ in 1:10^2]
278
q = rand_mps(ComplexF64, 2, length(p), 2,2)
279
ll = loglikelihood(q, X)
280
+ weights = ones(length(X))
281
two_site_dmrg!(q, X, 1;
- η=1e-4, ndesc=10, svd_trunc=TruncBond(5))
282
+ η=1e-4, ndesc=10, svd_trunc=TruncBond(5), weights)
283
@test loglikelihood(q, X) > ll
284
285
0 commit comments