@@ -5,30 +5,64 @@ import Base: diff
55# # what is the rest of the interface. This does:
66# # diff(ex, x, n) f^(n)
77# # diff(ex, x, y, ...) f_{xy...} # also diff(ex, (x,y))
8- # # no support for diff(ex, x,n1, y,n2, ...), but can do diff(ex, (x,y), (n1, n2))
8+ # # Support for diff(ex, x,n1, y,n2, ...),
9+ # # but can also do diff(ex, (x,y), (n1, n2))
910
10- function diff(b1:: SymbolicType , b2:: BasicType{Val{:Symbol}} )
11- a = Basic()
11+
12+ function diff!(a:: Basic , b1:: SymbolicType , b2:: Basic )
13+ is_symbol(b2) || throw(ArgumentError(" Must differentiate with respect to a symbol" ))
1214 ret = ccall((:basic_diff, libsymengine), Int, (Ref{Basic}, Ref{Basic}, Ref{Basic}), a, b1, b2)
1315 return a
1416end
1517
16- diff(b1:: SymbolicType , b2:: BasicType ) =
17- throw(ArgumentError(" Second argument must be of Symbol type" ))
18+ function diff(b1:: SymbolicType , b2:: Basic )
19+ a = Basic()
20+ diff!(a, b1, b2)
21+ a
22+ end
1823
19- function diff(b1:: SymbolicType , b2:: SymbolicType , n:: Integer = 1 )
24+ function diff(b1:: SymbolicType , b2:: SymbolicType , n:: Integer )
2025 n < 0 && throw(DomainError(" n must be non-negative integer" ))
21- n== 0 && return b1
22- n== 1 && return diff(b1, BasicType(b2))
23- n > 1 && return diff(diff(b1, BasicType(b2)), BasicType(b2), n- 1 )
26+ n == 0 && return b1
27+ x = Basic(b2)
28+ out = Basic()
29+ diff!(out, b1, x)
30+ for _ in (n- 1 ): - 1 : 1
31+ diff!(out, out, x)
32+ end
33+ out
34+ end
35+
36+ function diff(b1:: SymbolicType , b2:: SymbolicType , n:: Integer , xs... )
37+ diff(diff(b1,b2,n), xs... )
2438end
2539
2640function diff(b1:: SymbolicType , b2:: SymbolicType , b3:: SymbolicType )
27- isa(BasicType(b3), BasicType{Val{:Integer}}) ? diff(b1, b2, N(b3)) : diff(b1, (b2, b3))
41+ if isinteger(b3)
42+ n = N(b3):: Int
43+ diff(b1, b2, n)
44+ else
45+ ex = diff(b1, b2)
46+ diff(ex, b3)
47+ end
48+ end
49+
50+ function diff(b1:: SymbolicType , b2:: SymbolicType , b3:: SymbolicType , bs... )
51+ diff(diff(b1,b2,b3), bs... )
2852end
2953
30- diff(b1:: SymbolicType , b2:: SymbolicType , b3:: SymbolicType , b4:: SymbolicType , b5... ) =
31- diff(b1, (b2,b3,b4,b5... ))
54+ function diff(b1:: SymbolicType )
55+ xs = free_symbols(b1)
56+ n = length(xs)
57+ n == 0 && return zero(b1)
58+ n > 1 && throw(ArgumentError(" More than one variable; one must be specified" ))
59+ diff(b1, only(xs))
60+ end
61+
62+ # # deprecate
63+ diff(b1:: SymbolicType , b2:: BasicType{Val{:Symbol}} ) = diff(b1, Basic(b2))
64+ diff(b1:: SymbolicType , b2:: BasicType ) =
65+ throw(ArgumentError(" Second argument must be of Symbol type" ))
3266
3367# # mixed partials
3468diff(ex:: SymbolicType , bs:: Tuple ) = reduce((ex, x) -> diff(ex, x), bs, init= ex)
0 commit comments