diff --git a/.gitignore b/.gitignore index b02ba6e..2251642 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1 @@ -Manifest.toml -samples/ \ No newline at end of file +Manifest.toml \ No newline at end of file diff --git a/samples/sample_structural_opt.jl b/samples/sample_structural_opt.jl new file mode 100644 index 0000000..250c572 --- /dev/null +++ b/samples/sample_structural_opt.jl @@ -0,0 +1,48 @@ +using Test +using TreeTCI: crossinterpolate_adaptivetree, crossinterpolate +using NamedGraphs: NamedGraph, add_edge!, vertices +using Random + +function evaluate_error_sampled(f, ttn1, ttn2, localdims::Vector{Int}, nsamples::Int; rng=Random.default_rng()) + total_error1 = 0.0 + total_error2 = 0.0 + + for _ in 1:nsamples + xvec = [rand(rng, 1:d) for d in localdims] + fx = f(xvec) + total_error1 += abs(fx - ttn1(xvec)) + total_error2 += abs(fx - ttn2(xvec)) + end + + mean_error1 = total_error1 / nsamples + mean_error2 = total_error2 / nsamples + + println("Sampled mean |f(x) - original(x)| = ", mean_error1) + println("Sampled mean |f(x) - optimized(x)| = ", mean_error2) + + return mean_error1, mean_error2 +end + + +function main() + # make graph + g = NamedGraph(8) + add_edge!(g, 1, 2) + add_edge!(g, 2, 3) + add_edge!(g, 3, 4) + add_edge!(g, 4, 5) + add_edge!(g, 5, 6) + add_edge!(g, 6, 7) + add_edge!(g, 7, 8) + + localdims = fill(10, length(vertices(g))) + f(v) = 1 / (1 + v' * v) + kwargs = (maxbonddim = 20, maxiter = 10) + tt, ranks, errors = crossinterpolate(Float64, f, localdims, g; kwargs...) + optimized_tt, ranks, errors = crossinterpolate_adaptivetree(Float64, f, localdims, g, 20; kwargs...) + + evaluate_error_sampled(f, tt, optimized_tt, localdims, 1000) + return 0 +end + +main() diff --git a/samples/sample_treetci.jl b/samples/sample_treetci.jl new file mode 100644 index 0000000..b0c5d3a --- /dev/null +++ b/samples/sample_treetci.jl @@ -0,0 +1,27 @@ +using Test +using TreeTCI +import NamedGraphs: NamedGraph, NamedEdge, add_edge!, vertices, edges, has_edge + +function main() + # make graph + g = NamedGraph(10) + add_edge!(g, 1, 3) + add_edge!(g, 2, 3) + add_edge!(g, 3, 5) + + add_edge!(g, 4, 5) + add_edge!(g, 5, 7) + add_edge!(g, 6, 7) + add_edge!(g, 7, 8) + add_edge!(g, 8, 9) + add_edge!(g, 8, 10) + + localdims = fill(2, length(vertices(g))) + f(v) = 1 / (1 + v' * v) + kwargs = (maxbonddim = 20, maxiter = 10) + ttn, ranks, errors = TreeTCI.crossinterpolate(Float64, f, localdims, g; kwargs...) + @show ttn([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), f([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) + @show ttn([1, 2, 1, 2, 1, 2, 1, 2, 1, 2]), f([1, 2, 1, 2, 1, 2, 1, 2, 1, 2]) +end + +main() diff --git a/samples/sample_treetci2.jl b/samples/sample_treetci2.jl new file mode 100644 index 0000000..4628ce5 --- /dev/null +++ b/samples/sample_treetci2.jl @@ -0,0 +1,32 @@ +using Test +using TreeTCI +import NamedGraphs: NamedGraph, NamedEdge, add_edge!, vertices, edges, has_edge + +function f_pairwise(v::Vector{Int}) + s = 0.0 + for i in 1:2:length(v)-1 + s += v[i] * v[i+1] + end + return 1 / (1+s) +end + + +function main() + # make graph + g = NamedGraph(8) + add_edge!(g, 1, 2) + add_edge!(g, 2, 3) + add_edge!(g, 3, 4) + add_edge!(g, 4, 5) + add_edge!(g, 5, 6) + add_edge!(g, 6, 7) + add_edge!(g, 7, 8) + + localdims = fill(8, length(vertices(g))) + f(v) = f_pairwise(v) + kwargs = (maxbonddim = 64, maxiter = 10) + ttn, ranks, errors = TreeTCI.crossinterpolate(Float64, f, localdims, g; kwargs...) + +end + +main() diff --git a/src/TreeTCI.jl b/src/TreeTCI.jl index 759d43c..6b9df46 100644 --- a/src/TreeTCI.jl +++ b/src/TreeTCI.jl @@ -1,23 +1,9 @@ module TreeTCI -import Graphs -import NamedGraphs: - NamedGraph, - NamedEdge, - is_directed, - outneighbors, - has_edge, - edges, - vertices, - src, - dst, - namedgraph_dijkstra_shortest_paths -import TensorCrossInterpolation as TCI -import SimpleTensorNetworks: - TensorNetwork, IndexedArray, Index, complete_contraction, getindex, contract -import Random: shuffle +include("imports.jl") include("treegraph_utils.jl") include("simpletci.jl") +include("newstructureproposer.jl") include("pivotcandidateproposer.jl") include("sweep2sitepathproposer.jl") include("simpletci_optimize.jl") diff --git a/src/abstracttreetensornetwork.jl b/src/abstracttreetensornetwork.jl new file mode 100644 index 0000000..acfe542 --- /dev/null +++ b/src/abstracttreetensornetwork.jl @@ -0,0 +1,39 @@ +abstract type AbstractTreeTensorNetwork{V} <: Function end + +""" + function evaluate( + ttn::TreeTensorNetwork{V}, + indexset::Union{AbstractVector{Int}, NTuple{N, Int}} + )::V where {V} + +Evaluates the tensor train `tt` at indices given by `indexset`. +""" +function evaluate( + ttn::AbstractTreeTensorNetwork{V}, + indexset::Union{AbstractVector{Int},NTuple{N,Int}}, +)::V where {N,V} + if length(indexset) != length(ttn.sitetensors) + throw( + ArgumentError( + "To evaluate a tt of length $(length(ttn)), you have to provide $(length(ttn)) indices, but there were $(length(indexset)).", + ), + ) + end + sitetensors = IndexedArray[] + for (Tinfo, i) in zip(ttn.sitetensors, indexset) + T, edges = Tinfo + inds = (i, ntuple(_ -> :, ndims(T) - 1)...) + T = T[inds...] + indexs = [ + Index(size(T)[j], "$(src(edges[j]))=>$(dst(edges[j]))") for j = 1:length(edges) + ] + t = IndexedArray(T, indexs) + push!(sitetensors, t) + end + tn = TensorNetwork(sitetensors) + return only(complete_contraction(tn)) +end + +function (ttn::AbstractTreeTensorNetwork{V})(indexset) where {V} + return evaluate(ttn, indexset) +end diff --git a/src/imports.jl b/src/imports.jl new file mode 100644 index 0000000..940df06 --- /dev/null +++ b/src/imports.jl @@ -0,0 +1,22 @@ +using Random +using Graphs: simplecycles_limited_length, has_edge, SimpleGraph, center, steiner_tree +using NamedGraphs: + NamedGraph, + NamedEdge, + is_cyclic, + is_directed, + neighbors, + outneighbors, + has_edge, + edges, + vertices, + namedgraph_dijkstra_shortest_paths +using NamedGraphs.GraphsExtensions: + src, + dst, + is_connected, + degree, + add_vertices!, add_vertex!, rem_vertices!, rem_vertex!, + rem_edge!, add_edge! +import TensorCrossInterpolation as TCI +import SimpleTensorNetworks: TensorNetwork, IndexedArray, Index, complete_contraction, getindex, contract diff --git a/src/newstructureproposer.jl b/src/newstructureproposer.jl new file mode 100644 index 0000000..afee3fe --- /dev/null +++ b/src/newstructureproposer.jl @@ -0,0 +1,36 @@ +""" +Abstract type for structure proposal methods +""" +abstract type AbstractNewStructureProposer end + +struct NewStructureLocalSwap <: AbstractNewStructureProposer end + +struct NewStructureGlobalSwap <: AbstractNewStructureProposer end + +function generate_new_structure( + ::NewStructureLocalSwap, + tci::SimpleTCI{ValueType}, +) where {ValueType} + es = collect(edges(tci.g)) + edge = rand(es) + vs = src(edge) => dst(edge) + return swap_2site(tci.g, vs) +end + +function generate_new_structure( + ::NewStructureGlobalSwap, + tci::SimpleTCI{ValueType}, +) where {ValueType} + vs = collect(vertices(tci.g)) + es = Set(edges(tci.g)) + + all_pairs = Set((v1 => v2) for v1 in vs for v2 in vs if v1 < v2) + + es_symmetric = Set(min(src(e), dst(e)) => max(src(e), dst(e)) for e in es) + + candidate_pairs = setdiff(all_pairs, es_symmetric) + + e = rand(candidate_pairs) + @show e + return swap_2site(tci.g, first(e) => last(e)) +end diff --git a/src/pivotcandidateproposer.jl b/src/pivotcandidateproposer.jl index 5a65ca2..3659493 100644 --- a/src/pivotcandidateproposer.jl +++ b/src/pivotcandidateproposer.jl @@ -44,10 +44,9 @@ function generate_pivot_candidates( Iset = kronecker(Ipivots, Isite_index, tci.localdims[vp]) Jset = kronecker(Jpivots, Jsite_index, tci.localdims[vq]) - extraIJset = if length(tci.IJset_history) > 0 - extraIJset = tci.IJset_history[end] - else - Dict(key => MultiIndex[] for key in keys(tci.IJset)) + extraIJset = tci.IJset + for (key, pivots) in tci.converged_IJset + extraIJset[key] = pivots end Icombined = union(Iset, extraIJset[Ikey]) diff --git a/src/simpletci.jl b/src/simpletci.jl index 963d1dc..4d49f29 100644 --- a/src/simpletci.jl +++ b/src/simpletci.jl @@ -34,12 +34,12 @@ addglobalpivots!(tci, [[1,1,1], [2,1,1]]) """ mutable struct SimpleTCI{ValueType} IJset::Dict{SubTreeVertex,Vector{MultiIndex}} + converged_IJset::Dict{SubTreeVertex,Vector{MultiIndex}} localdims::Vector{Int} g::NamedGraph bonderrors::Dict{NamedEdge,Float64} pivoterrors::Vector{Float64} maxsamplevalue::Float64 - IJset_history::Vector{Dict{SubTreeVertex,Vector{MultiIndex}}} function SimpleTCI{ValueType}(localdims::Vector{Int}, g::NamedGraph) where {ValueType} n = length(localdims) @@ -47,20 +47,23 @@ mutable struct SimpleTCI{ValueType} n == length(vertices(g)) || error( "The number of vertices in the graph must be equal to the length of localdims.", ) - !Graphs.is_cyclic(g) || + !is_cyclic(g) || error("SimpleTCI is not supported for loopy tensor network.") # assign the key for each bond - bonderrors = Dict(e => 0.0 for e in edges(g)) + bonderrors = Dict(e => typemax(Float64) for e in edges(g)) + + !is_cyclic(g) || + error("TreeTensorNetwork is not supported for loopy tensor network.") new{ValueType}( Dict{SubTreeVertex,Vector{MultiIndex}}(), # IJset + Dict{SubTreeVertex,Vector{MultiIndex}}(), # converged_IJset localdims, g, bonderrors, Float64[], 0.0, # maxsamplevalue - Vector{Dict{SubTreeVertex,Vector{MultiIndex}}}(), # IJset_history ) end end @@ -115,6 +118,7 @@ function addglobalpivots!( nothing end + function pushunique!(collection, item) if !(item in collection) push!(collection, item) diff --git a/src/simpletci_optimize.jl b/src/simpletci_optimize.jl index 04a8f3f..3f64b99 100644 --- a/src/simpletci_optimize.jl +++ b/src/simpletci_optimize.jl @@ -48,7 +48,9 @@ function optimize!( loginterval::Int = 10, normalizeerror::Bool = true, ncheckhistory::Int = 3, -) where {ValueType} + ) where {ValueType} + + # Histories of properties for checking convergence. errors = Float64[] ranks = Int[] @@ -62,7 +64,6 @@ function optimize!( ) end - globalpivots = MultiIndex[] for iter = 1:maxiter errornormalization = normalizeerror ? tci.maxsamplevalue : 1.0 abstol = tolerance * errornormalization @@ -74,31 +75,27 @@ function optimize!( sweep2site!( tci, - f, - 2; + f; abstol = abstol, maxbonddim = maxbonddim, verbosity = verbosity, sweepstrategy = sweepstrategy, pivotstrategy = pivotstrategy, ) - if verbosity > 0 && length(globalpivots) > 0 && mod(iter, loginterval) == 0 - abserr = [abs(evaluate(tci, p) - f(p)) for p in globalpivots] - nrejections = length(abserr .> abstol) - if nrejections > 0 - println( - " Rejected $(nrejections) global pivots added in the previous iteration, errors are $(abserr)", - ) - flush(stdout) - end - end - push!(errors, last(pivoterror(tci))) - if verbosity > 1 - println( - " Walltime $(1e-9*(time_ns() - tstart)) sec: start searching global pivots", - ) - flush(stdout) + push!(ranks, rank(tci)) + push!(errors, pivoterror(tci)) + + if convergencecriterion( + ranks, + errors, + maxbonddim, + tolerance, + ncheckhistory + ) + println("Converged at $(iter)th-sweep.") + tci.converged_IJset = deepcopy(tci.IJset) + break end end @@ -111,8 +108,7 @@ end """ function sweep2site!( tci::SimpleTCI{ValueType}, - f, - niter::Int; + f; abstol::Float64 = 1e-8, maxbonddim::Int = typemax(Int), sweepstrategy::AbstractSweep2sitePathProposer = DefaultSweep2sitePathProposer(), @@ -122,24 +118,18 @@ function sweep2site!( edge_path = generate_sweep2site_path(sweepstrategy, tci) - for _ = 1:niter - extraIJset = Dict(key => MultiIndex[] for key in keys(tci.IJset)) - - push!(tci.IJset_history, deepcopy(tci.IJset)) + flushpivoterror!(tci) - flushpivoterror!(tci) - - for edge in edge_path - updatepivots!( - tci, - edge, - f; - abstol = abstol, - maxbonddim = maxbonddim, - pivotstrategy = pivotstrategy, - verbosity = verbosity, - ) - end + for edge in edge_path + updatepivots!( + tci, + edge, + f; + abstol = abstol, + maxbonddim = maxbonddim, + pivotstrategy = pivotstrategy, + verbosity = verbosity, + ) end nothing @@ -224,6 +214,10 @@ function updatepivoterror!(tci::SimpleTCI{T}, errors::AbstractVector{Float64}) w nothing end +function rank(tci::SimpleTCI{ValueType}) where {ValueType} + return maximum(length(IJset) for IJset in values(tci.IJset)) +end + function pivoterror(tci::SimpleTCI{T}) where {T} return maxbonderror(tci) end @@ -231,3 +225,20 @@ end function maxbonderror(tci::SimpleTCI{T}) where {T} return maximum(values(tci.bonderrors)) end + +function convergencecriterion( + ranks::AbstractVector{Int}, + errors::AbstractVector{Float64}, + maxbonddim::Int, + tolerance::Float64, + ncheckhistory::Int, +)::Bool + if length(errors) < ncheckhistory + return false + end + lastranks = last(ranks, ncheckhistory) + return ( + all(last(errors, ncheckhistory) .< tolerance) && + minimum(lastranks) == lastranks[end] + ) || all(lastranks .>= maxbonddim) +end \ No newline at end of file diff --git a/src/simpletci_tensors.jl b/src/simpletci_tensors.jl index a5d9ea7..f5adc98 100644 --- a/src/simpletci_tensors.jl +++ b/src/simpletci_tensors.jl @@ -70,8 +70,8 @@ function sitetensor( return reshape( T, tci.localdims[site], - [length(tci.IJset[key]) for key in Inkeys]..., - [length(tci.IJset[key]) for key in Outkeys]..., + [length(tci.converged_IJset[key]) for key in Inkeys]..., + [length(tci.converged_IJset[key]) for key in Outkeys]..., ) end @@ -85,11 +85,11 @@ function sitetensor( ) where {ValueType} Inkeys, Outkeys = InOutkeys L = length(tci.localdims) - Pi1 = filltensor(ValueType, f, tci.localdims, tci.IJset, Inkeys, Outkeys, Val(1)) + Pi1 = filltensor(ValueType, f, tci.localdims, tci.converged_IJset, Inkeys, Outkeys, Val(1)) Pi1 = reshape( Pi1, - prod(vcat([tci.localdims[site]], [length(tci.IJset[key]) for key in Inkeys])), - prod([length(tci.IJset[key]) for key in Outkeys]), + prod(vcat([tci.localdims[site]], [length(tci.converged_IJset[key]) for key in Inkeys])), + prod([length(tci.converged_IJset[key]) for key in Outkeys]), ) updatemaxsample!(tci, Pi1) @@ -106,17 +106,17 @@ function sitetensor( end P = reshape( - filltensor(ValueType, f, tci.localdims, tci.IJset, [I1key], Outkeys, Val(0)), - length(tci.IJset[I1key]), - prod([length(tci.IJset[key]) for key in Outkeys]), + filltensor(ValueType, f, tci.localdims, tci.converged_IJset, [I1key], Outkeys, Val(0)), + length(tci.converged_IJset[I1key]), + prod([length(tci.converged_IJset[key]) for key in Outkeys]), ) - length(tci.IJset[I1key]) == sum([length(tci.IJset[key]) for key in Outkeys]) || error("Pivot matrix at bond $(site) is not square!") + length(tci.converged_IJset[I1key]) == sum([length(tci.converged_IJset[key]) for key in Outkeys]) || error("Pivot matrix at bond $(site) is not square!") Tmat = transpose(transpose(P) \ transpose(Pi1)) T = reshape( Tmat, tci.localdims[site], - [length(tci.IJset[key]) for key in Inkeys]..., - [length(tci.IJset[key]) for key in Outkeys]..., + [length(tci.converged_IJset[key]) for key in Inkeys]..., + [length(tci.converged_IJset[key]) for key in Outkeys]..., ) return T end diff --git a/src/treegraph_utils.jl b/src/treegraph_utils.jl index e0c4a86..49e1d30 100644 --- a/src/treegraph_utils.jl +++ b/src/treegraph_utils.jl @@ -82,3 +82,45 @@ function distanceedges(g::NamedGraph, edge::NamedEdge)::Dict{NamedEdge,Int} distances[edge] = 0 return distances end + +function swap_2site(g_old::NamedGraph, vs::Pair{Int, Int}) + g = deepcopy(g_old) + p, q = vs + p_neighbors = neighbors(g, p) + q_neighbors = neighbors(g, q) + + rem_vertices!(g, [p, q]) + add_vertex!(g, p) + add_vertex!(g, q) + + for p_i in p_neighbors + p_i == q && continue # avoid self-loop + add_edge!(g, NamedEdge(p_i => q)) + end + + for q_i in q_neighbors + q_i == p && continue # avoid self-loop + add_edge!(g, NamedEdge(q_i => p)) + end + + if has_edge(g_old, NamedEdge(p => q)) + add_edge!(g, NamedEdge(p => q)) + end + + return g +end + +function add_subtree( + g_old::NamedGraph, + v::Int, + edge::NamedEdge, +) + g = deepcopy(g_old) + p, q = separatevertices(g, edge) + p_regions = subtreevertices(g, q => p) + q_regions = subtreevertices(g, p => q) + parent = v in p_regions ? q : p + rem_edge!(g, edge) + add_edge!(g, NamedEdge(parent => v)) + return g +end \ No newline at end of file diff --git a/src/treetensornetwork.jl b/src/treetensornetwork.jl index 1499d90..76982db 100644 --- a/src/treetensornetwork.jl +++ b/src/treetensornetwork.jl @@ -5,7 +5,7 @@ mutable struct TreeTensorNetwork{ValueType} g::NamedGraph, sitetensors::Vector{Pair{Array{ValueType},Vector{NamedEdge}}}, ) where {ValueType} - !Graphs.is_cyclic(g) || + !is_cyclic(g) || error("TreeTensorNetwork is not supported for loopy tensor network.") ttntensors = Vector{IndexedArray}() for (i, (T, edges)) in enumerate(sitetensors) @@ -90,6 +90,36 @@ function crossinterpolate( return TreeTensorNetwork(tci.g, sitetensors), ranks, errors end +function crossinterpolate_adaptivetree( + ::Type{ValueType}, + f, + localdims::Union{Vector{Int},NTuple{N,Int}}, + g::NamedGraph, + nsearch:: Int = 10, + initialpivots::Vector{MultiIndex} = [ones(Int, length(localdims))]; + kwargs..., +) where {ValueType,N} + + tci = SimpleTCI{ValueType}(f, localdims, g, initialpivots) + ranks, errors = optimize!(tci, f; kwargs...) + + for i = 1:nsearch + # Generate a new structure + g_new = generate_new_structure(NewStructureLocalSwap(), tci) + tci_new = SimpleTCI{ValueType}(f, localdims, g_new, initialpivots) + tci_new.converged_IJset = tci.converged_IJset + ranks_new, errors_new = optimize!(tci_new, f; kwargs...) + if sum(errors_new) < sum(errors) + println("Structure is changed") + tci = tci_new + ranks = ranks_new + errors = errors_new + end + end + sitetensors = fillsitetensors(tci, f) + return TreeTensorNetwork(tci.g, sitetensors), ranks, errors +end + function evaluate( ttn::TreeTensorNetwork{ValueType}, indexset::Union{AbstractVector{Int},NTuple{N,Int}},