Skip to content

Manipulation of Trees #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
Manifest.toml
samples/
Manifest.toml
48 changes: 48 additions & 0 deletions samples/sample_structural_opt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
using Test
using TreeTCI: crossinterpolate_adaptivetree, crossinterpolate
using NamedGraphs: NamedGraph, add_edge!, vertices
using Random

function evaluate_error_sampled(f, ttn1, ttn2, localdims::Vector{Int}, nsamples::Int; rng=Random.default_rng())
total_error1 = 0.0
total_error2 = 0.0

for _ in 1:nsamples
xvec = [rand(rng, 1:d) for d in localdims]
fx = f(xvec)
total_error1 += abs(fx - ttn1(xvec))
total_error2 += abs(fx - ttn2(xvec))
end

mean_error1 = total_error1 / nsamples
mean_error2 = total_error2 / nsamples

println("Sampled mean |f(x) - original(x)| = ", mean_error1)
println("Sampled mean |f(x) - optimized(x)| = ", mean_error2)

return mean_error1, mean_error2
end


function main()
# make graph
g = NamedGraph(8)
add_edge!(g, 1, 2)
add_edge!(g, 2, 3)
add_edge!(g, 3, 4)
add_edge!(g, 4, 5)
add_edge!(g, 5, 6)
add_edge!(g, 6, 7)
add_edge!(g, 7, 8)

localdims = fill(10, length(vertices(g)))
f(v) = 1 / (1 + v' * v)
kwargs = (maxbonddim = 20, maxiter = 10)
tt, ranks, errors = crossinterpolate(Float64, f, localdims, g; kwargs...)
optimized_tt, ranks, errors = crossinterpolate_adaptivetree(Float64, f, localdims, g, 20; kwargs...)

evaluate_error_sampled(f, tt, optimized_tt, localdims, 1000)
return 0
end

main()
27 changes: 27 additions & 0 deletions samples/sample_treetci.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
using Test
using TreeTCI
import NamedGraphs: NamedGraph, NamedEdge, add_edge!, vertices, edges, has_edge

function main()
# make graph
g = NamedGraph(10)
add_edge!(g, 1, 3)
add_edge!(g, 2, 3)
add_edge!(g, 3, 5)

add_edge!(g, 4, 5)
add_edge!(g, 5, 7)
add_edge!(g, 6, 7)
add_edge!(g, 7, 8)
add_edge!(g, 8, 9)
add_edge!(g, 8, 10)

localdims = fill(2, length(vertices(g)))
f(v) = 1 / (1 + v' * v)
kwargs = (maxbonddim = 20, maxiter = 10)
ttn, ranks, errors = TreeTCI.crossinterpolate(Float64, f, localdims, g; kwargs...)
@show ttn([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), f([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
@show ttn([1, 2, 1, 2, 1, 2, 1, 2, 1, 2]), f([1, 2, 1, 2, 1, 2, 1, 2, 1, 2])
end

main()
32 changes: 32 additions & 0 deletions samples/sample_treetci2.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
using Test
using TreeTCI
import NamedGraphs: NamedGraph, NamedEdge, add_edge!, vertices, edges, has_edge

function f_pairwise(v::Vector{Int})
s = 0.0
for i in 1:2:length(v)-1
s += v[i] * v[i+1]
end
return 1 / (1+s)
end


function main()
# make graph
g = NamedGraph(8)
add_edge!(g, 1, 2)
add_edge!(g, 2, 3)
add_edge!(g, 3, 4)
add_edge!(g, 4, 5)
add_edge!(g, 5, 6)
add_edge!(g, 6, 7)
add_edge!(g, 7, 8)

localdims = fill(8, length(vertices(g)))
f(v) = f_pairwise(v)
kwargs = (maxbonddim = 64, maxiter = 10)
ttn, ranks, errors = TreeTCI.crossinterpolate(Float64, f, localdims, g; kwargs...)

end

main()
18 changes: 2 additions & 16 deletions src/TreeTCI.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,9 @@

module TreeTCI
import Graphs
import NamedGraphs:
NamedGraph,
NamedEdge,
is_directed,
outneighbors,
has_edge,
edges,
vertices,
src,
dst,
namedgraph_dijkstra_shortest_paths
import TensorCrossInterpolation as TCI
import SimpleTensorNetworks:
TensorNetwork, IndexedArray, Index, complete_contraction, getindex, contract
import Random: shuffle
include("imports.jl")
include("treegraph_utils.jl")
include("simpletci.jl")
include("newstructureproposer.jl")
include("pivotcandidateproposer.jl")
include("sweep2sitepathproposer.jl")
include("simpletci_optimize.jl")
Expand Down
39 changes: 39 additions & 0 deletions src/abstracttreetensornetwork.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
abstract type AbstractTreeTensorNetwork{V} <: Function end

"""
function evaluate(
ttn::TreeTensorNetwork{V},
indexset::Union{AbstractVector{Int}, NTuple{N, Int}}
)::V where {V}

Evaluates the tensor train `tt` at indices given by `indexset`.
"""
function evaluate(
ttn::AbstractTreeTensorNetwork{V},
indexset::Union{AbstractVector{Int},NTuple{N,Int}},
)::V where {N,V}
if length(indexset) != length(ttn.sitetensors)
throw(
ArgumentError(
"To evaluate a tt of length $(length(ttn)), you have to provide $(length(ttn)) indices, but there were $(length(indexset)).",
),
)
end
sitetensors = IndexedArray[]
for (Tinfo, i) in zip(ttn.sitetensors, indexset)
T, edges = Tinfo
inds = (i, ntuple(_ -> :, ndims(T) - 1)...)
T = T[inds...]
indexs = [
Index(size(T)[j], "$(src(edges[j]))=>$(dst(edges[j]))") for j = 1:length(edges)
]
t = IndexedArray(T, indexs)
push!(sitetensors, t)
end
tn = TensorNetwork(sitetensors)
return only(complete_contraction(tn))
end

function (ttn::AbstractTreeTensorNetwork{V})(indexset) where {V}
return evaluate(ttn, indexset)
end
22 changes: 22 additions & 0 deletions src/imports.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using Random
using Graphs: simplecycles_limited_length, has_edge, SimpleGraph, center, steiner_tree
using NamedGraphs:
NamedGraph,
NamedEdge,
is_cyclic,
is_directed,
neighbors,
outneighbors,
has_edge,
edges,
vertices,
namedgraph_dijkstra_shortest_paths
using NamedGraphs.GraphsExtensions:
src,
dst,
is_connected,
degree,
add_vertices!, add_vertex!, rem_vertices!, rem_vertex!,
rem_edge!, add_edge!
import TensorCrossInterpolation as TCI
import SimpleTensorNetworks: TensorNetwork, IndexedArray, Index, complete_contraction, getindex, contract
36 changes: 36 additions & 0 deletions src/newstructureproposer.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
Abstract type for structure proposal methods
"""
abstract type AbstractNewStructureProposer end

struct NewStructureLocalSwap <: AbstractNewStructureProposer end

struct NewStructureGlobalSwap <: AbstractNewStructureProposer end

function generate_new_structure(
::NewStructureLocalSwap,
tci::SimpleTCI{ValueType},
) where {ValueType}
es = collect(edges(tci.g))
edge = rand(es)
vs = src(edge) => dst(edge)
return swap_2site(tci.g, vs)
end

function generate_new_structure(
::NewStructureGlobalSwap,
tci::SimpleTCI{ValueType},
) where {ValueType}
vs = collect(vertices(tci.g))
es = Set(edges(tci.g))

all_pairs = Set((v1 => v2) for v1 in vs for v2 in vs if v1 < v2)

es_symmetric = Set(min(src(e), dst(e)) => max(src(e), dst(e)) for e in es)

candidate_pairs = setdiff(all_pairs, es_symmetric)

e = rand(candidate_pairs)
@show e
return swap_2site(tci.g, first(e) => last(e))
end
7 changes: 3 additions & 4 deletions src/pivotcandidateproposer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,9 @@ function generate_pivot_candidates(
Iset = kronecker(Ipivots, Isite_index, tci.localdims[vp])
Jset = kronecker(Jpivots, Jsite_index, tci.localdims[vq])

extraIJset = if length(tci.IJset_history) > 0
extraIJset = tci.IJset_history[end]
else
Dict(key => MultiIndex[] for key in keys(tci.IJset))
extraIJset = tci.IJset
for (key, pivots) in tci.converged_IJset
extraIJset[key] = pivots
end

Icombined = union(Iset, extraIJset[Ikey])
Expand Down
12 changes: 8 additions & 4 deletions src/simpletci.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,33 +34,36 @@ addglobalpivots!(tci, [[1,1,1], [2,1,1]])
"""
mutable struct SimpleTCI{ValueType}
IJset::Dict{SubTreeVertex,Vector{MultiIndex}}
converged_IJset::Dict{SubTreeVertex,Vector{MultiIndex}}
localdims::Vector{Int}
g::NamedGraph
bonderrors::Dict{NamedEdge,Float64}
pivoterrors::Vector{Float64}
maxsamplevalue::Float64
IJset_history::Vector{Dict{SubTreeVertex,Vector{MultiIndex}}}

function SimpleTCI{ValueType}(localdims::Vector{Int}, g::NamedGraph) where {ValueType}
n = length(localdims)
n > 1 || error("localdims should have at least 2 elements!")
n == length(vertices(g)) || error(
"The number of vertices in the graph must be equal to the length of localdims.",
)
!Graphs.is_cyclic(g) ||
!is_cyclic(g) ||
error("SimpleTCI is not supported for loopy tensor network.")

# assign the key for each bond
bonderrors = Dict(e => 0.0 for e in edges(g))
bonderrors = Dict(e => typemax(Float64) for e in edges(g))

!is_cyclic(g) ||
error("TreeTensorNetwork is not supported for loopy tensor network.")

new{ValueType}(
Dict{SubTreeVertex,Vector{MultiIndex}}(), # IJset
Dict{SubTreeVertex,Vector{MultiIndex}}(), # converged_IJset
localdims,
g,
bonderrors,
Float64[],
0.0, # maxsamplevalue
Vector{Dict{SubTreeVertex,Vector{MultiIndex}}}(), # IJset_history
)
end
end
Expand Down Expand Up @@ -115,6 +118,7 @@ function addglobalpivots!(
nothing
end


function pushunique!(collection, item)
if !(item in collection)
push!(collection, item)
Expand Down
Loading
Loading