Skip to content

Commit 64969ec

Browse files
author
Marc Ritter
committed
Merge branch 'marc_ritter/cleanup' into 'main'
some cleanup See merge request tensors4fields/TensorCrossInterpolation.jl!79
2 parents 901430b + 11bd797 commit 64969ec

File tree

4 files changed

+12
-7
lines changed

4 files changed

+12
-7
lines changed

src/integration.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ function integrate(
2222
f,
2323
a::Vector{ValueType},
2424
b::Vector{ValueType};
25-
tolerance=1e-8,
26-
GKorder::Int=15
25+
GKorder::Int=15,
26+
normalizeerror=false,
27+
kwargs...
2728
) where {ValueType}
2829
if iseven(GKorder)
2930
error("Gauss--Kronrod order must be odd, e.g. 15 or 61.")
@@ -50,7 +51,9 @@ function integrate(
5051
ValueType,
5152
F,
5253
localdims;
53-
tolerance
54+
nsearchglobalpivot=10,
55+
normalizeerror,
56+
kwargs...
5457
)
5558

5659
return sum(tci2) / normalization

src/tensorci2.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -684,14 +684,15 @@ function optimize!(
684684
end
685685

686686
sweep2site!(
687-
tci, f, 1;
687+
tci, f, 2;
688688
iter1 = 1,
689689
abstol=abstol,
690690
maxbonddim=maxbonddim,
691691
pivotsearch=pivotsearch,
692692
strictlynested=strictlynested,
693693
verbosity=verbosity,
694-
sweepstrategy=sweepstrategy
694+
sweepstrategy=sweepstrategy,
695+
fillsitetensors=true
695696
)
696697
if verbosity > 0 && length(globalpivots) > 0
697698
nrejections = length([p for p in globalpivots if abs(evaluate(tci, p) - f(p)) > abstol])

test/test_integration.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ import Random
2727
end
2828

2929
@testset "Integrate 10d function" begin
30+
Random.seed!(1234)
31+
3032
function f(x)
3133
return 1000 * cos(10 * sum(x .^ 2)) * exp(-sum(x)^4 / 1000)
3234
end

test/test_tensortrain.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@ using Random
33
using Zygote
44
using Optim
55

6-
@testset "tensor train" begin
7-
g(v) = 1 / (sum(v .^ 2) + 1im)
6+
@testset "tensor train" for g in [v -> exp(exp(1im * sum(v))), v -> 1 / (sum(v .^ 2) + 1im)]
87
localdims = (6, 6, 6, 6)
98
tolerance = 1e-8
109
allindices = CartesianIndices(localdims)

0 commit comments

Comments
 (0)