Skip to content

Commit 5e54eed

Browse files
committed
Merge branch 'main' into 37-recompression-interface
2 parents 31893ac + eec5288 commit 5e54eed

14 files changed

+167
-27
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@ version = "0.9.0"
66
[deps]
77
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
9+
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
910

1011
[compat]
1112
EllipsisNotation = "1"
1213
julia = "1.6"
1314

1415
[extras]
15-
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
1616
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
1717
ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5"
18+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
1819
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
1920
QuanticsGrids = "634c7f73-3e90-4749-a1bd-001b8efc642d"
2021
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

docs/src/documentation.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22

33
Documentation of all types and methods in module [TensorCrossInterpolation](https://gitlab.com/tensors4fields/TensorCrossInterpolation.jl).
44

5-
```@index
6-
```
7-
85
## Matrix approximation
96

107

@@ -47,6 +44,12 @@ Modules = [TensorCrossInterpolation]
4744
Pages = ["tensorci2.jl"]
4845
```
4946

47+
### Integration
48+
```@autodocs
49+
Modules = [TensorCrossInterpolation]
50+
Pages = ["integration.jl"]
51+
```
52+
5053
## Helpers and utility methods
5154
```@autodocs
5255
Modules = [TensorCrossInterpolation]

docs/src/index.md

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,6 @@ This is the documentation for [TensorCrossInterpolation](https://gitlab.com/tens
88

99
With the user manual and usage examples below, users should be able to use this library as a "black box" in most cases. Detailed documentation of (almost) all methods can be found in the [Documentation](@ref) section, and [Implementation](@ref) contains a detailed explanation of this implementation of TCI.
1010

11-
```@contents
12-
Pages = ["index.md", "documentation.md", "implementation.md"]
13-
```
14-
1511
## Interpolating functions
1612

1713
The most convenient way to create a TCI is [`crossinterpolate2`](@ref). For example, consider the lorentzian in 5 dimensions, i.e. $f(\mathbf v) = 1/(1 + \mathbf v^2)$ on a mesh $\mathbf{v} \in \{1, 2, ..., 10\}^5$.
@@ -34,7 +30,7 @@ println("TCI approximation: $(tci([1, 2, 3, 4, 5]))")
3430
For easy integration into tensor network algorithms, the tensor train can be converted to ITensors MPS format. If you're using julia version 1.9 or later, an extension is automatically loaded if both `TensorCrossInterpolation.jl` and `ITensors.jl` are present.
3531
For older versions of julia, use the package using [TCIITensorConversion.jl](https://gitlab.com/tensors4fields/tciitensorconversion.jl).
3632

37-
## Sums
33+
## Sums and Integrals
3834

3935
Tensor trains are a way to efficiently obtain sums over all lattice sites, since this sum can be factorized:
4036
```@example simple
@@ -47,6 +43,22 @@ println("Sum of tensor train: $sumtt")
4743
```
4844
For further information, see [`sum`](@ref).
4945

46+
This factorized sum can be used for efficient evaluation of high-dimensional integrals. This is implemented with Gauss-Kronrod quadrature rules in [`integrate`](@ref). For example, the integral
47+
```math
48+
I = 10^3 \int\limits_{[-1, +1]^{10}} d^{10} \vec{x} \,
49+
\cos\!\left(10 \textstyle\sum_{n=1}^{10} x_n^2 \right)
50+
\exp\!\left[-10^{-3}\left(\textstyle\sum_{n=1}^{10} x_n\right)^4\right]
51+
```
52+
is evaluated by the following code:
53+
```@example simple
54+
function f(x)
55+
return 1e3 * cos(10 * sum(x .^ 2)) * exp(-sum(x)^4 / 1e3)
56+
end
57+
I = TCI.integrate(Float64, f, fill(-1.0, 10), fill(+1.0, 10); GKorder=15, tolerance=1e-8)
58+
println("GK15 integral value: $I")
59+
```
60+
The argument `GKorder` controls the Gauss-Kronrod quadrature rule used for the integration, and `tolerance` controls the tolerance in the TCI approximation, which is distinct from the tolerance in the integral. For complicated functions, it is recommended to integrate using two different GK rules and to compare the results to get a good estimate of the discretization error.
61+
5062
## Properties of the TCI object
5163

5264
After running the code above, `tci` is a [`TensorCI2`](@ref) object that can be interrogated for various properties. The most important ones are the rank (i.e. maximum bond dimension) and the link dimensions:

src/TensorCrossInterpolation.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module TensorCrossInterpolation
22

33
using LinearAlgebra
44
using EllipsisNotation
5+
import QuadGK
56

67
# To add a method for rank(tci)
78
import LinearAlgebra: rank, diag
@@ -12,7 +13,7 @@ import Base: ==, +
1213
import Base: isempty, iterate, getindex, lastindex, broadcastable
1314
import Base: length, size, sum
1415

15-
export crossinterpolate, crossinterpolate2, optfirstpivot
16+
export crossinterpolate1, crossinterpolate2, optfirstpivot
1617
export tensortrain
1718

1819
include("util.jl")
@@ -31,5 +32,6 @@ include("tensorci1.jl")
3132
include("tensorci2.jl")
3233
include("tensortrain.jl")
3334
include("conversion.jl")
35+
include("integration.jl")
3436

3537
end

src/integration.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""
2+
function integrate(
3+
::Type{ValueType},
4+
f,
5+
a::Vector{ValueType},
6+
b::Vector{ValueType};
7+
tolerance=1e-8,
8+
GKorder::Int=15
9+
) where {ValueType}
10+
11+
Integrate the function `f` using TCI and Gauss--Kronrod quadrature rules.
12+
13+
Arguments:
14+
- ValueType: return type of `f`.
15+
- a: Vector of lower bounds in each dimension. Effectively, the lower corner of the hypercube that is being integrated over.
16+
- b: Vector of upper bounds in each dimension.
17+
- tolerance: tolerance of the TCI approximation for the values of f.
18+
- GKorder: Order of the Gauss--Kronrod rule, e.g. 15.
19+
"""
20+
function integrate(
21+
::Type{ValueType},
22+
f,
23+
a::Vector{ValueType},
24+
b::Vector{ValueType};
25+
tolerance=1e-8,
26+
GKorder::Int=15
27+
) where {ValueType}
28+
if iseven(GKorder)
29+
error("Gauss--Kronrod order must be odd, e.g. 15 or 61.")
30+
end
31+
32+
if length(a) != length(b)
33+
error("Integral bounds must have the same dimensionality, but got $(length(a)) lower bounds and $(length(b)) upper bounds.")
34+
end
35+
36+
nodes1d, weights1d, _ = QuadGK.kronrod(GKorder ÷ 2, -1, +1)
37+
nodes = @. (b - a) * (nodes1d' + 1) / 2 + a
38+
weights = @. (b - a) * weights1d' / 2
39+
normalization = GKorder^length(a)
40+
41+
localdims = fill(length(nodes1d), length(a))
42+
43+
function F(indices)
44+
x = [nodes[n, i] for (n, i) in enumerate(indices)]
45+
w = prod(weights[n, i] for (n, i) in enumerate(indices))
46+
return w * f(x) * normalization
47+
end
48+
49+
tci2, ranks, errors = crossinterpolate2(
50+
ValueType,
51+
F,
52+
localdims;
53+
tolerance
54+
)
55+
56+
return sum(tci2) / normalization
57+
end

src/tensorci1.jl

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
mutable struct TensorCI1{ValueType} <: AbstractTensorTrain{ValueType}
33
4-
Type that represents tensor cross interpolations created using the TCI1 algorithm. Users may want to create these using [`crossinterpolate`](@ref) rather than calling a constructor directly.
4+
Type that represents tensor cross interpolations created using the TCI1 algorithm. Users may want to create these using [`crossinterpolate1`](@ref) rather than calling a constructor directly.
55
"""
66
mutable struct TensorCI1{ValueType} <: AbstractTensorTrain{ValueType}
77
Iset::Vector{IndexSet{MultiIndex}}
@@ -469,7 +469,7 @@ end
469469

470470

471471
@doc raw"""
472-
function crossinterpolate(
472+
function crossinterpolate1(
473473
::Type{ValueType},
474474
f,
475475
localdims::Union{Vector{Int},NTuple{N,Int}},
@@ -504,7 +504,7 @@ Notes:
504504
505505
See also: [`optfirstpivot`](@ref), [`CachedFunction`](@ref), [`crossinterpolate2`](@ref)
506506
"""
507-
function crossinterpolate(
507+
function crossinterpolate1(
508508
::Type{ValueType},
509509
f,
510510
localdims::Union{Vector{Int},NTuple{N,Int}},
@@ -551,3 +551,30 @@ function crossinterpolate(
551551
errornormalization = normalizeerror ? tci.maxsamplevalue : 1.0
552552
return tci, ranks, errors ./ errornormalization
553553
end
554+
555+
@doc raw"""
556+
function crossinterpolate(
557+
::Type{ValueType},
558+
f,
559+
localdims::Union{Vector{Int},NTuple{N,Int}},
560+
firstpivot::MultiIndex=ones(Int, length(localdims));
561+
tolerance::Float64=1e-8,
562+
maxiter::Int=200,
563+
sweepstrategy::Symbol=:backandforth,
564+
pivottolerance::Float64=1e-12,
565+
verbosity::Int=0,
566+
additionalpivots::Vector{MultiIndex}=MultiIndex[],
567+
normalizeerror::Bool=true
568+
) where {ValueType, N}
569+
570+
Deprecated, and only included for backward compatibility. Please use [`crossinterpolate1`](@ref) instead.
571+
"""
572+
function crossinterpolate(
573+
::Type{ValueType},
574+
f,
575+
localdims::Union{Vector{Int},NTuple{N,Int}},
576+
firstpivot::MultiIndex=ones(Int, length(localdims));
577+
kwargs...
578+
) where {ValueType, N}
579+
return crossinterpolate1(ValueType, f, localdims, firstpivot; kwargs...)
580+
end

src/tensorci2.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ Notes:
632632
- By default, no caching takes place. Use the [`CachedFunction`](@ref) wrapper if your function is expensive to evaluate.
633633
634634
635-
See also: [`crossinterpolate2`](@ref), [`optfirstpivot`](@ref), [`CachedFunction`](@ref), [`crossinterpolate`](@ref)
635+
See also: [`crossinterpolate2`](@ref), [`optfirstpivot`](@ref), [`CachedFunction`](@ref), [`crossinterpolate1`](@ref)
636636
"""
637637
function optimize!(
638638
tci::TensorCI2{ValueType},
@@ -778,8 +778,8 @@ function sweep2site!(
778778
extraIset = tci.Iset
779779
extraJset = tci.Jset
780780
if length(tci.Iset_history) > 0
781-
extraIset = union(extraIset, tci.Iset_history[end])
782-
extraJset = union(extraJset, tci.Jset_history[end])
781+
extraIset = union.(extraIset, tci.Iset_history[end])
782+
extraJset = union.(extraJset, tci.Jset_history[end])
783783
end
784784
end
785785

@@ -871,7 +871,7 @@ Notes:
871871
- By default, no caching takes place. Use the [`CachedFunction`](@ref) wrapper if your function is expensive to evaluate.
872872
873873
874-
See also: [`optimize!`](@ref), [`optfirstpivot`](@ref), [`CachedFunction`](@ref), [`crossinterpolate`](@ref)
874+
See also: [`optimize!`](@ref), [`optfirstpivot`](@ref), [`CachedFunction`](@ref), [`crossinterpolate1`](@ref)
875875
"""
876876
function crossinterpolate2(
877877
::Type{ValueType},

src/tensortrain.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,9 @@ function to_tensors(obj::TensorTrainFit{ValueType}, x::Vector{ValueType}) where
146146
]
147147
end
148148

149-
_evaluate(tt, indexset) = only(prod(T[:, i, :] for (T, i) in zip(tt, indexset)))
149+
function _evaluate(tt::Vector{Array{V, 3}}, indexset) where {V}
150+
only(prod(T[:, i, :] for (T, i) in zip(tt, indexset)))
151+
end
150152

151153
function (obj::TensorTrainFit{ValueType})(x::Vector{ValueType}) where {ValueType}
152154
tensors = to_tensors(obj, x)

src/util.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,11 @@ Optimize the first pivot for a tensor cross interpolation.
6969
7070
Arguments:
7171
- `f` is function to be interpolated.
72-
- `localdims::Union{Vector{Int},NTuple{N,Int}}` determines the local dimensions of the function parameters (see [`crossinterpolate`](@ref)).
72+
- `localdims::Union{Vector{Int},NTuple{N,Int}}` determines the local dimensions of the function parameters (see [`crossinterpolate1`](@ref)).
7373
- `fistpivot::MultiIndex=ones(Int, length(localdims))` is the starting point for the optimization. It is advantageous to choose it close to a global maximum of the function.
7474
- `maxsweep` is the maximum number of optimization sweeps. Default: `1000`.
7575
76-
See also: [`crossinterpolate`](@ref)
76+
See also: [`crossinterpolate1`](@ref)
7777
"""
7878
function optfirstpivot(
7979
f,

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import TensorCrossInterpolation as TCI
22
using Test
33
using LinearAlgebra
4-
import ITensors
54

65
include("test_with_aqua.jl")
76
include("test_with_jet.jl")
@@ -17,6 +16,7 @@ include("test_batcheval.jl")
1716
include("test_tensorci1.jl")
1817
include("test_tensorci2.jl")
1918
include("test_tensortrain.jl")
19+
include("test_integration.jl")
2020

2121
#==
2222
if VERSION.major >= 2 || (VERSION.major == 1 && VERSION.minor >= 9)

0 commit comments

Comments
 (0)