Skip to content

Commit fe46dd6

Browse files
committed
Add option checkbatchevaluatable to TCI2
1 parent 42670ed commit fe46dd6

File tree

3 files changed

+24
-11
lines changed

3 files changed

+24
-11
lines changed

src/tensorci2.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,8 @@ function convergencecriterion(
587587
) || all(lastranks .>= maxbonddim)
588588
end
589589

590+
591+
590592
"""
591593
function optimize!(
592594
tci::TensorCI2{ValueType},
@@ -628,6 +630,7 @@ Arguments:
628630
- `nsearchglobalpivot::Int` can be set to `>= 0`. Default: `0`.
629631
- `tolmarginglobalsearch` can be set to `>= 1.0`. Seach global pivots where the interpolation error is larger than the tolerance by `tolmarginglobalsearch`. Default: `10.0`.
630632
- `strictlynested::Bool` determines whether to preserve partial nesting in the TCI algorithm. Default: `false`.
633+
- `checkbatchevaluatable::Bool` Check if the function `f` is batch evaluatable. Default: `false`.
631634
632635
Notes:
633636
- Set `tolerance` to be > 0 or `maxbonddim` to some reasonable value. Otherwise, convergence is not reachable.
@@ -652,12 +655,17 @@ function optimize!(
652655
maxnglobalpivot::Int=5,
653656
nsearchglobalpivot::Int=0,
654657
tolmarginglobalsearch::Float64=10.0,
655-
strictlynested::Bool=false
658+
strictlynested::Bool=false,
659+
checkbatchevaluatable::Bool=false
656660
) where {ValueType}
657661
errors = Float64[]
658662
ranks = Int[]
659663
nglobalpivots = Int[]
660664

665+
if checkbatchevaluatable && !(f isa BatchEvaluator)
666+
error("Function `f` is not batch evaluatable")
667+
end
668+
661669
#if maxnglobalpivot > 0 && nsearchglobalpivot > 0
662670
#!strictlynested || error("nglobalpivots > 0 requires strictlynested=false!")
663671
#end
@@ -868,6 +876,7 @@ Arguments:
868876
- `nsearchglobalpivot::Int` can be set to `>= 0`. Default: `0`.
869877
- `tolmarginglobalsearch` can be set to `>= 1.0`. Seach global pivots where the interpolation error is larger than the tolerance by `tolmarginglobalsearch`. Default: `10.0`.
870878
- `strictlynested::Bool=true` determines whether to preserve partial nesting in the TCI algorithm. Default: `true`.
879+
- `checkbatchevaluatable::Bool` Check if the function `f` is batch evaluatable. Default: `false`.
871880
872881
Notes:
873882
- Set `tolerance` to be > 0 or `maxbonddim` to some reasonable value. Otherwise, convergence is not reachable.

test/runtests.jl

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,4 @@ include("test_tensorci1.jl")
1717
include("test_tensorci2.jl")
1818
include("test_tensortrain.jl")
1919
include("test_contraction.jl")
20-
include("test_integration.jl")
21-
22-
#==
23-
if VERSION.major >= 2 || (VERSION.major == 1 && VERSION.minor >= 9)
24-
@testset "ITensor conversion interface" begin
25-
include("test_ttmpsconversion.jl")
26-
include("test_mpsutil.jl")
27-
end
28-
end
29-
==#
20+
include("test_integration.jl")

test/test_tensorci2.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,19 @@ import QuanticsGrids as QD
3737
@test tci.pivoterrors == diags
3838
end
3939

40+
@testset "checkbatchevaluatable" begin
41+
f(x) = 1.0 # Constant function without batch evaluation
42+
L = 10
43+
localdims = fill(2, L)
44+
firstpivots = [fill(1, L)]
45+
@test_throws ErrorException crossinterpolate2(
46+
Float64,
47+
f,
48+
localdims,
49+
firstpivots;
50+
checkbatchevaluatable=true
51+
)
52+
end
4053

4154
@testset "trivial MPS(exp): pivotsearch=$pivotsearch" for pivotsearch in [:full, :rook], strictlynested in [false, true], nsearchglobalpivot in [0, 10]
4255
if nsearchglobalpivot > 0 && strictlynested

0 commit comments

Comments
 (0)