Skip to content

Commit 9e64d01

Browse files
committed
add option to reweight likelihood for datapoints
1 parent 4ffe360 commit 9e64d01

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

src/MatrixProductStates/derivatives.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ Return also the loglikelihood, which is a byproduct of the computation.
6666
function grad_loglikelihood_two_site(p::MPS, k::Integer, X;
6767
prodA_left = [precompute_left_environments(p.ψ, x) for x in X],
6868
prodA_right = [precompute_right_environments(p.ψ, x) for x in X],
69-
Aᵏᵏ⁺¹ =_merge_tensors(p[k], p[k+1]))
69+
Aᵏᵏ⁺¹ =_merge_tensors(p[k], p[k+1]),
70+
weights = ones(length(X)))
7071

7172
Zprime, Z = grad_normalization_two_site_canonical(p, k; Aᵏᵏ⁺¹)
7273
ll = -log(Z)
@@ -78,8 +79,8 @@ function grad_loglikelihood_two_site(p::MPS, k::Integer, X;
7879
gr, val = grad_evaluate_two_site(p.ψ, k, x;
7980
Ax_left = prodA_left[n][k-1], Ax_right = prodA_right[n][k+2], Aᵏᵏ⁺¹
8081
)
81-
gA[:,:,x[k]...,x[k+1]...] .+= 2/T * gr / val
82-
ll += 1/T * log(abs2(val))
82+
gA[:,:,x[k]...,x[k+1]...] .+= 2/T * gr / val * weights[n]
83+
ll += 1/T * log(abs2(val)) * weights[n]
8384
end
8485
return gA, ll
8586
end

test/mps.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,9 @@ end
277277
X = [sample(p)[1] for _ in 1:10^2]
278278
q = rand_mps(ComplexF64, 2, length(p), 2,2)
279279
ll = loglikelihood(q, X)
280+
weights = ones(length(X))
280281
two_site_dmrg!(q, X, 1;
281-
η=1e-4, ndesc=10, svd_trunc=TruncBond(5))
282+
η=1e-4, ndesc=10, svd_trunc=TruncBond(5), weights)
282283
@test loglikelihood(q, X) > ll
283284
end
284285
end

0 commit comments

Comments
 (0)