|
| 1 | +using Test |
| 2 | +using SymEngine |
| 3 | +import SymbolicUtils: simplify, @rule, @acrule, Chain, Fixpoint |
| 4 | + |
| 5 | + |
| 6 | +@testset "SymbolicUtils" begin |
| 7 | + # from SymbolicUtils.jl docs |
| 8 | + # https://symbolicutils.juliasymbolics.org/rewrite/#rule-based_rewriting |
| 9 | + @vars w x y z |
| 10 | + @vars α β |
| 11 | + @vars a b c d |
| 12 | + |
| 13 | + @test simplify(cos(x)^2 + sin(x)^2) == 1 |
| 14 | + |
| 15 | + r1 = @rule sin(2(~x)) => 2sin(~x)*cos(~x) |
| 16 | + @test r1(sin(2z)) == 2*cos(z)*sin(z) |
| 17 | + @test r1(sin(3z)) === nothing |
| 18 | + @test r1(sin(2*(w-z))) == 2cos(w - z)*sin(w - z) |
| 19 | + @test r1(sin(2*(w+z)*(α+β))) === nothing |
| 20 | + |
| 21 | + r2 = @rule sin(~x + ~y) => sin(~x)*cos(~y) + cos(~x)*sin(~y); |
| 22 | + @test r2(sin(α+β)) == sin(α)*cos(β) + cos(α)*sin(β) |
| 23 | + |
| 24 | + xs = @rule(+(~~xs) => ~~xs)(x + y + z) # segment variable |
| 25 | + @test Set(xs) == Set([x,y,z]) |
| 26 | + |
| 27 | + r3 = @rule ~x * +(~~ys) => sum(map(y-> ~x * y, ~~ys)); |
| 28 | + @test r3(2 * (w+w+α+β)) == 4w + 2α + 2β |
| 29 | + |
| 30 | + r4 = @rule ~x + ~~y::(ys->iseven(length(ys))) => "odd terms"; # Predicates for matching |
| 31 | + |
| 32 | + @test r4(a + b + c + d) == nothing |
| 33 | + @test r4(b + c + d) == "odd terms" |
| 34 | + @test r4(b + c + b) == nothing |
| 35 | + @test r4(a + b) == nothing |
| 36 | + |
| 37 | + sqexpand = @rule (~x + ~y)^2 => (~x)^2 + (~y)^2 + 2 * ~x * ~y |
| 38 | + @test sqexpand((cos(x) + sin(x))^2) == cos(x)^2 + sin(x)^2 + 2cos(x)*sin(x) |
| 39 | + |
| 40 | + pyid = @rule sin(~x)^2 + cos(~x)^2 => 1 |
| 41 | + @test_broken pyid(cos(x)^2 + sin(x)^2) === nothing # order should matter, but this works |
| 42 | + |
| 43 | + acpyid = @acrule sin(~x)^2 + cos(~x)^2 => 1 # acrule is commutative |
| 44 | + @test acpyid(cos(x)^2 + sin(x)^2 + 2cos(x)*sin(x)) == 1 + 2cos(x)*sin(x) |
| 45 | + |
| 46 | + csa = Chain([sqexpand, acpyid]) # chain composes rules |
| 47 | + @test csa((cos(x) + sin(x))^2) == 1 + 2cos(x)*sin(x) |
| 48 | + |
| 49 | + cas = Chain([acpyid, sqexpand]) # order matters |
| 50 | + @test cas((cos(x) + sin(x))^2) == cos(x)^2 + sin(x)^2 + 2cos(x)*sin(x) |
| 51 | + |
| 52 | + @test Fixpoint(cas)((cos(x) + sin(x))^2) == 1 + 2cos(x)*sin(x) |
| 53 | + |
| 54 | +end |
0 commit comments