Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,10 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

=============================================================================

Some parts of src/decl.jl is from Symbolics.jl and SymPy.jl licensed under the
same license with the copyrights

Copyright (c) <2013> <j verzani>
Copyright (c) 2021: Shashi Gowda, Yingbo Ma, Chris Rackauckas, Julia Computing.
1 change: 1 addition & 0 deletions src/SymEngine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ const libversion = get_libversion()
include("exceptions.jl")
include("types.jl")
include("ctypes.jl")
include("decl.jl")
include("display.jl")
include("mathops.jl")
include("mathfuns.jl")
Expand Down
150 changes: 150 additions & 0 deletions src/decl.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# !!! Note:
# Many thanks to `@matthieubulte` for this contribution to `SymPy`.

# The map_subscripts function is stolen from Symbolics.jl
const IndexMap = Dict{Char,Char}(
'-' => '₋',
'0' => '₀',
'1' => '₁',
'2' => '₂',
'3' => '₃',
'4' => '₄',
'5' => '₅',
'6' => '₆',
'7' => '₇',
'8' => '₈',
'9' => '₉')

function map_subscripts(indices)
str = string(indices)
join(IndexMap[c] for c in str)
end

# Define a type hierarchy to describe a variable declaration. This is mainly for convenient pattern matching later.
abstract type VarDecl end

struct SymDecl <: VarDecl
sym :: Symbol
end

struct NamedDecl <: VarDecl
name :: String
rest :: VarDecl
end

struct FunctionDecl <: VarDecl
rest :: VarDecl
end

struct TensorDecl <: VarDecl
ranges :: Vector{AbstractRange}
rest :: VarDecl
end

struct AssumptionsDecl <: VarDecl
assumptions :: Vector{Symbol}
rest :: VarDecl
end

# Transform a Decl struct in an Expression that calls SymPy to declare the corresponding symbol
function gendecl(x::VarDecl)
asstokw(a) = Expr(:kw, esc(a), true)
val = :($(ctor(x))($(name(x, missing)), $(map(asstokw, assumptions(x))...)))
:($(esc(sym(x))) = $(genreshape(val, x)))
end

# Transform an expression in a Decl struct
function parsedecl(expr)
# @vars x
if isa(expr, Symbol)
return SymDecl(expr)

# @vars x::assumptions, where assumption = assumptionkw | (assumptionkw...)
#= no assumptions in SymEngine
elseif isa(expr, Expr) && expr.head == :(::)
symexpr, assumptions = expr.args
assumptions = isa(assumptions, Symbol) ? [assumptions] : assumptions.args
return AssumptionsDecl(assumptions, parsedecl(symexpr))
=#

# @vars x=>"name"
elseif isa(expr, Expr) && expr.head == :call && expr.args[1] == :(=>)
length(expr.args) == 3 || parseerror()
isa(expr.args[3], String) || parseerror()

expr, strname = expr.args[2:end]
return NamedDecl(strname, parsedecl(expr))

# @vars x()
elseif isa(expr, Expr) && expr.head == :call && expr.args[1] != :(=>)
length(expr.args) == 1 || parseerror()
return FunctionDecl(parsedecl(expr.args[1]))

# @vars x[1:5, 3:9]
elseif isa(expr, Expr) && expr.head == :ref
length(expr.args) > 1 || parseerror()
ranges = map(parserange, expr.args[2:end])
return TensorDecl(ranges, parsedecl(expr.args[1]))
else
parseerror()
end
end

function parserange(expr)
range = eval(expr)
isa(range, AbstractRange) || parseerror()
range
end

sym(x::SymDecl) = x.sym
sym(x::NamedDecl) = sym(x.rest)
sym(x::FunctionDecl) = sym(x.rest)
sym(x::TensorDecl) = sym(x.rest)
sym(x::AssumptionsDecl) = sym(x.rest)

ctor(::SymDecl) = :symbols
ctor(x::NamedDecl) = ctor(x.rest)
ctor(::FunctionDecl) = :SymFunction
ctor(x::TensorDecl) = ctor(x.rest)
ctor(x::AssumptionsDecl) = ctor(x.rest)

assumptions(::SymDecl) = []
assumptions(x::NamedDecl) = assumptions(x.rest)
assumptions(x::FunctionDecl) = assumptions(x.rest)
assumptions(x::TensorDecl) = assumptions(x.rest)
assumptions(x::AssumptionsDecl) = x.assumptions

# Reshape is not used by most nodes, but TensorNodes require the output to be given
# the shape matching the specification. For instance if @vars x[1:3, 2:6], we should
# have size(x) = (3, 5)
genreshape(expr, ::SymDecl) = expr
genreshape(expr, x::NamedDecl) = genreshape(expr, x.rest)
genreshape(expr, x::FunctionDecl) = genreshape(expr, x.rest)
genreshape(expr, x::TensorDecl) = let
shape = tuple(length.(x.ranges)...)
:(reshape(collect($(expr)), $(shape)))
end
genreshape(expr, x::AssumptionsDecl) = genreshape(expr, x.rest)

# To find out the name, we need to traverse in both directions to make sure that each node can get
# information from parents and children about possible name.
# This is done because the expr tree will always look like NamedDecl -> ... -> TensorDecl -> ... -> SymDecl
# and the TensorDecl node will need to know if it should create names base on a NamedDecl parent or
# based on the SymDecl leaf.
name(x::SymDecl, parentname) = coalesce(parentname, String(x.sym))
name(x::NamedDecl, parentname) = coalesce(name(x.rest, x.name), x.name)
name(x::FunctionDecl, parentname) = name(x.rest, parentname)
name(x::AssumptionsDecl, parentname) = name(x.rest, parentname)
name(x::TensorDecl, parentname) = let
basename = name(x.rest, parentname)
# we need to double reverse the indices to make sure that we traverse them in the natural order
namestensor = map(Iterators.product(x.ranges...)) do ind
sub = join(map(map_subscripts, ind), "_")
string(basename, sub)
end
join(namestensor[:], ", ")
end

function parseerror()
error("Incorrect @vars syntax. Try `@vars x=>\"x₀\" y() z[0:4]` for instance.")
end
39 changes: 26 additions & 13 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,30 +145,44 @@ end
## Follow, somewhat, the python names: symbols to construct symbols, @vars

"""
Macro to define 1 or more variables in the main workspace.
@vars x y[1:5] z()

Symbolic values are defined with `_symbol`. This is a convenience
Macro to define 1 or more variables or symbolic function

Example
```
@vars x y z
@vars x[1:4]
@vars u(), x
```

"""
macro vars(x...)
q=Expr(:block)
if length(x) == 1 && isa(x[1],Expr)
@assert x[1].head === :tuple "@syms expected a list of symbols"
x = x[1].args
macro vars(xs...)
# If the user separates declaration with commas, the top-level expression is a tuple
if length(xs) == 1 && isa(xs[1], Expr) && xs[1].head == :tuple
_gensyms(xs[1].args...)
elseif length(xs) > 0
_gensyms(xs...)
end
for s in x
@assert isa(s,Symbol) "@syms expected a list of symbols"
push!(q.args, Expr(:(=), esc(s), Expr(:call, :(SymEngine._symbol), Expr(:quote, s))))
end

function _gensyms(xs...)
asstokw(a) = Expr(:kw, esc(a), true)

# Each declaration is parsed and generates a declaration using `symbols`
symdefs = map(xs) do expr
decl = parsedecl(expr)
symname = sym(decl)
symname, gendecl(decl)
end
push!(q.args, Expr(:tuple, map(esc, x)...))
q
syms, defs = collect(zip(symdefs...))

# The macro returns a tuple of Symbols that were declared
Expr(:block, defs..., :(tuple($(map(esc,syms)...))))
end



## We also have a wrapper type that can be used to control dispatch
## pros: wrapping adds overhead, so if possible best to use Basic
## cons: have to write methods meth(x::Basic, ...) = meth(BasicType(x),...)
Expand Down Expand Up @@ -305,4 +319,3 @@ function Serialization.deserialize(s::Serialization.AbstractSerializer, ::Type{B
throw_if_error(res)
return a
end

12 changes: 8 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ end
@test_throws UndefVarError isdefined(w)
@test_throws Exception show(Basic())

# test @vars constructions
@vars a, b[0:4], c(), d=>"D"
@test length(b) == 5
@test isa(c, SymFunction)
@test repr(d) == "D"

a = x^2 + x/2 - x*y*5
b = diff(a, x)
@test b == 2*x + 1//2 - 5*y
Expand Down Expand Up @@ -63,10 +69,8 @@ c = Basic(-5)
@test abs(c) == 5
@test abs(c) != 4

show(a)
println()
show(b)
println()
repr("text/plain", a) == (1/2)*x - 5*x*y + x^2
repr("text/plain", b) == 1/2 + 2*x - 5*y

@test 1 // x == 1 / x
@test x // 2 == (1//2) * x
Expand Down
Loading