From ee2d546d85f7c2c1dc9f55a790f817a8aeb08470 Mon Sep 17 00:00:00 2001 From: jverzani Date: Wed, 26 Feb 2025 10:07:20 -0500 Subject: [PATCH] make at vars more expressive --- LICENSE | 7 +++ src/SymEngine.jl | 1 + src/decl.jl | 150 +++++++++++++++++++++++++++++++++++++++++++++++ src/types.jl | 39 ++++++++---- test/runtests.jl | 12 ++-- 5 files changed, 192 insertions(+), 17 deletions(-) create mode 100644 src/decl.jl diff --git a/LICENSE b/LICENSE index 41305ea..64d3d3f 100644 --- a/LICENSE +++ b/LICENSE @@ -18,3 +18,10 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +============================================================================= + +Some parts of src/decl.jl is from Symbolics.jl and SymPy.jl licensed under the +same license with the copyrights + +Copyright (c) <2013> +Copyright (c) 2021: Shashi Gowda, Yingbo Ma, Chris Rackauckas, Julia Computing. \ No newline at end of file diff --git a/src/SymEngine.jl b/src/SymEngine.jl index e5607a9..c5d6693 100644 --- a/src/SymEngine.jl +++ b/src/SymEngine.jl @@ -23,6 +23,7 @@ const libversion = get_libversion() include("exceptions.jl") include("types.jl") include("ctypes.jl") +include("decl.jl") include("display.jl") include("mathops.jl") include("mathfuns.jl") diff --git a/src/decl.jl b/src/decl.jl new file mode 100644 index 0000000..6622b08 --- /dev/null +++ b/src/decl.jl @@ -0,0 +1,150 @@ +# !!! Note: +# Many thanks to `@matthieubulte` for this contribution to `SymPy`. + +# The map_subscripts function is stolen from Symbolics.jl +const IndexMap = Dict{Char,Char}( + '-' => '₋', + '0' => '₀', + '1' => '₁', + '2' => '₂', + '3' => '₃', + '4' => '₄', + '5' => '₅', + '6' => '₆', + '7' => '₇', + '8' => '₈', + '9' => '₉') + +function map_subscripts(indices) + str = string(indices) + join(IndexMap[c] for c in str) +end + +# Define a type hierarchy to describe a variable declaration. This is mainly for convenient pattern matching later. +abstract type VarDecl end + +struct SymDecl <: VarDecl + sym :: Symbol +end + +struct NamedDecl <: VarDecl + name :: String + rest :: VarDecl +end + +struct FunctionDecl <: VarDecl + rest :: VarDecl +end + +struct TensorDecl <: VarDecl + ranges :: Vector{AbstractRange} + rest :: VarDecl +end + +struct AssumptionsDecl <: VarDecl + assumptions :: Vector{Symbol} + rest :: VarDecl +end + +# Transform a Decl struct in an Expression that calls SymPy to declare the corresponding symbol +function gendecl(x::VarDecl) + asstokw(a) = Expr(:kw, esc(a), true) + val = :($(ctor(x))($(name(x, missing)), $(map(asstokw, assumptions(x))...))) + :($(esc(sym(x))) = $(genreshape(val, x))) +end + +# Transform an expression in a Decl struct +function parsedecl(expr) + # @vars x + if isa(expr, Symbol) + return SymDecl(expr) + + # @vars x::assumptions, where assumption = assumptionkw | (assumptionkw...) + #= no assumptions in SymEngine + elseif isa(expr, Expr) && expr.head == :(::) + symexpr, assumptions = expr.args + assumptions = isa(assumptions, Symbol) ? [assumptions] : assumptions.args + return AssumptionsDecl(assumptions, parsedecl(symexpr)) + =# + + # @vars x=>"name" + elseif isa(expr, Expr) && expr.head == :call && expr.args[1] == :(=>) + length(expr.args) == 3 || parseerror() + isa(expr.args[3], String) || parseerror() + + expr, strname = expr.args[2:end] + return NamedDecl(strname, parsedecl(expr)) + + # @vars x() + elseif isa(expr, Expr) && expr.head == :call && expr.args[1] != :(=>) + length(expr.args) == 1 || parseerror() + return FunctionDecl(parsedecl(expr.args[1])) + + # @vars x[1:5, 3:9] + elseif isa(expr, Expr) && expr.head == :ref + length(expr.args) > 1 || parseerror() + ranges = map(parserange, expr.args[2:end]) + return TensorDecl(ranges, parsedecl(expr.args[1])) + else + parseerror() + end +end + +function parserange(expr) + range = eval(expr) + isa(range, AbstractRange) || parseerror() + range +end + +sym(x::SymDecl) = x.sym +sym(x::NamedDecl) = sym(x.rest) +sym(x::FunctionDecl) = sym(x.rest) +sym(x::TensorDecl) = sym(x.rest) +sym(x::AssumptionsDecl) = sym(x.rest) + +ctor(::SymDecl) = :symbols +ctor(x::NamedDecl) = ctor(x.rest) +ctor(::FunctionDecl) = :SymFunction +ctor(x::TensorDecl) = ctor(x.rest) +ctor(x::AssumptionsDecl) = ctor(x.rest) + +assumptions(::SymDecl) = [] +assumptions(x::NamedDecl) = assumptions(x.rest) +assumptions(x::FunctionDecl) = assumptions(x.rest) +assumptions(x::TensorDecl) = assumptions(x.rest) +assumptions(x::AssumptionsDecl) = x.assumptions + +# Reshape is not used by most nodes, but TensorNodes require the output to be given +# the shape matching the specification. For instance if @vars x[1:3, 2:6], we should +# have size(x) = (3, 5) +genreshape(expr, ::SymDecl) = expr +genreshape(expr, x::NamedDecl) = genreshape(expr, x.rest) +genreshape(expr, x::FunctionDecl) = genreshape(expr, x.rest) +genreshape(expr, x::TensorDecl) = let + shape = tuple(length.(x.ranges)...) + :(reshape(collect($(expr)), $(shape))) +end +genreshape(expr, x::AssumptionsDecl) = genreshape(expr, x.rest) + +# To find out the name, we need to traverse in both directions to make sure that each node can get +# information from parents and children about possible name. +# This is done because the expr tree will always look like NamedDecl -> ... -> TensorDecl -> ... -> SymDecl +# and the TensorDecl node will need to know if it should create names base on a NamedDecl parent or +# based on the SymDecl leaf. +name(x::SymDecl, parentname) = coalesce(parentname, String(x.sym)) +name(x::NamedDecl, parentname) = coalesce(name(x.rest, x.name), x.name) +name(x::FunctionDecl, parentname) = name(x.rest, parentname) +name(x::AssumptionsDecl, parentname) = name(x.rest, parentname) +name(x::TensorDecl, parentname) = let + basename = name(x.rest, parentname) + # we need to double reverse the indices to make sure that we traverse them in the natural order + namestensor = map(Iterators.product(x.ranges...)) do ind + sub = join(map(map_subscripts, ind), "_") + string(basename, sub) + end + join(namestensor[:], ", ") +end + +function parseerror() + error("Incorrect @vars syntax. Try `@vars x=>\"x₀\" y() z[0:4]` for instance.") +end diff --git a/src/types.jl b/src/types.jl index d778a69..2b4dced 100644 --- a/src/types.jl +++ b/src/types.jl @@ -145,30 +145,44 @@ end ## Follow, somewhat, the python names: symbols to construct symbols, @vars """ -Macro to define 1 or more variables in the main workspace. + @vars x y[1:5] z() -Symbolic values are defined with `_symbol`. This is a convenience +Macro to define 1 or more variables or symbolic function Example ``` @vars x y z +@vars x[1:4] +@vars u(), x ``` + """ -macro vars(x...) - q=Expr(:block) - if length(x) == 1 && isa(x[1],Expr) - @assert x[1].head === :tuple "@syms expected a list of symbols" - x = x[1].args +macro vars(xs...) + # If the user separates declaration with commas, the top-level expression is a tuple + if length(xs) == 1 && isa(xs[1], Expr) && xs[1].head == :tuple + _gensyms(xs[1].args...) + elseif length(xs) > 0 + _gensyms(xs...) end - for s in x - @assert isa(s,Symbol) "@syms expected a list of symbols" - push!(q.args, Expr(:(=), esc(s), Expr(:call, :(SymEngine._symbol), Expr(:quote, s)))) +end + +function _gensyms(xs...) + asstokw(a) = Expr(:kw, esc(a), true) + + # Each declaration is parsed and generates a declaration using `symbols` + symdefs = map(xs) do expr + decl = parsedecl(expr) + symname = sym(decl) + symname, gendecl(decl) end - push!(q.args, Expr(:tuple, map(esc, x)...)) - q + syms, defs = collect(zip(symdefs...)) + + # The macro returns a tuple of Symbols that were declared + Expr(:block, defs..., :(tuple($(map(esc,syms)...)))) end + ## We also have a wrapper type that can be used to control dispatch ## pros: wrapping adds overhead, so if possible best to use Basic ## cons: have to write methods meth(x::Basic, ...) = meth(BasicType(x),...) @@ -305,4 +319,3 @@ function Serialization.deserialize(s::Serialization.AbstractSerializer, ::Type{B throw_if_error(res) return a end - diff --git a/test/runtests.jl b/test/runtests.jl index 746cd17..f2527ef 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -21,6 +21,12 @@ end @test_throws UndefVarError isdefined(w) @test_throws Exception show(Basic()) +# test @vars constructions +@vars a, b[0:4], c(), d=>"D" +@test length(b) == 5 +@test isa(c, SymFunction) +@test repr(d) == "D" + a = x^2 + x/2 - x*y*5 b = diff(a, x) @test b == 2*x + 1//2 - 5*y @@ -63,10 +69,8 @@ c = Basic(-5) @test abs(c) == 5 @test abs(c) != 4 -show(a) -println() -show(b) -println() +repr("text/plain", a) == (1/2)*x - 5*x*y + x^2 +repr("text/plain", b) == 1/2 + 2*x - 5*y @test 1 // x == 1 / x @test x // 2 == (1//2) * x