Skip to content

Commit 43f243d

Browse files
committed
sympy_compat: Add a (fake) functions sub-module that resembles
sympy.functions.
1 parent 653a3fc commit 43f243d

File tree

2 files changed

+49
-22
lines changed

2 files changed

+49
-22
lines changed

symengine/sympy_compat.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -64,158 +64,182 @@ def __new__(cls, name):
6464
return symengine.UndefFunction(name)
6565

6666

67-
class log(Function):
67+
from types import ModuleType
68+
69+
functions = ModuleType(__name__ + ".functions")
70+
import sys
71+
sys.modules[functions.__name__] = functions
72+
73+
functions.sqrt = sqrt
74+
functions.exp = exp
75+
76+
77+
class _FunctionRegistrarMeta(BasicMeta):
78+
79+
def __new__(mcls, name, bases, dict):
80+
cls = BasicMeta.__new__(mcls, name, bases, dict)
81+
if not name.startswith("_"):
82+
setattr(functions, name, cls)
83+
return cls
84+
85+
86+
class _RegisteredFunction(with_metaclass(_FunctionRegistrarMeta, Function)):
87+
pass
88+
89+
90+
class log(_RegisteredFunction):
6891
_classes = (symengine.Log,)
6992

7093
def __new__(cls, a, b = E):
7194
return symengine.log(a, b)
7295

7396

74-
class sin(Function):
97+
class sin(_RegisteredFunction):
7598
_classes = (symengine.Sin,)
7699

77100
def __new__(cls, a):
78101
return symengine.sin(a)
79102

80103

81-
class cos(Function):
104+
class cos(_RegisteredFunction):
82105
_classes = (symengine.Cos,)
83106

84107
def __new__(cls, a):
85108
return symengine.cos(a)
86109

87110

88-
class tan(Function):
111+
class tan(_RegisteredFunction):
89112
_classes = (symengine.Tan,)
90113

91114
def __new__(cls, a):
92115
return symengine.tan(a)
93116

94-
class gamma(Function):
117+
class gamma(_RegisteredFunction):
95118
_classes = (symengine.Gamma,)
96119

97120
def __new__(cls, a):
98121
return symengine.gamma(a)
99122

100123

101-
class cot(Function):
124+
class cot(_RegisteredFunction):
102125
_classes = (symengine.Cot,)
103126

104127
def __new__(cls, a):
105128
return symengine.cot(a)
106129

107130

108-
class csc(Function):
131+
class csc(_RegisteredFunction):
109132
_classes = (symengine.Csc,)
110133

111134
def __new__(cls, a):
112135
return symengine.csc(a)
113136

114137

115-
class sec(Function):
138+
class sec(_RegisteredFunction):
116139
_classes = (symengine.Sec,)
117140

118141
def __new__(cls, a):
119142
return symengine.sec(a)
120143

121144

122-
class asin(Function):
145+
class asin(_RegisteredFunction):
123146
_classes = (symengine.ASin,)
124147

125148
def __new__(cls, a):
126149
return symengine.asin(a)
127150

128151

129-
class acos(Function):
152+
class acos(_RegisteredFunction):
130153
_classes = (symengine.ACos,)
131154

132155
def __new__(cls, a):
133156
return symengine.acos(a)
134157

135158

136-
class atan(Function):
159+
class atan(_RegisteredFunction):
137160
_classes = (symengine.ATan,)
138161

139162
def __new__(cls, a):
140163
return symengine.atan(a)
141164

142165

143-
class acot(Function):
166+
class acot(_RegisteredFunction):
144167
_classes = (symengine.ACot,)
145168

146169
def __new__(cls, a):
147170
return symengine.acot(a)
148171

149172

150-
class acsc(Function):
173+
class acsc(_RegisteredFunction):
151174
_classes = (symengine.ACsc,)
152175

153176
def __new__(cls, a):
154177
return symengine.acsc(a)
155178

156179

157-
class asec(Function):
180+
class asec(_RegisteredFunction):
158181
_classes = (symengine.ASec,)
159182

160183
def __new__(cls, a):
161184
return symengine.asec(a)
162185

163186

164-
class sinh(Function):
187+
class sinh(_RegisteredFunction):
165188
_classes = (symengine.Sinh,)
166189

167190
def __new__(cls, a):
168191
return symengine.sinh(a)
169192

170193

171-
class cosh(Function):
194+
class cosh(_RegisteredFunction):
172195
_classes = (symengine.Cosh,)
173196

174197
def __new__(cls, a):
175198
return symengine.cosh(a)
176199

177200

178-
class tanh(Function):
201+
class tanh(_RegisteredFunction):
179202
_classes = (symengine.Tanh,)
180203

181204
def __new__(cls, a):
182205
return symengine.tanh(a)
183206

184207

185-
class coth(Function):
208+
class coth(_RegisteredFunction):
186209
_classes = (symengine.Coth,)
187210

188211
def __new__(cls, a):
189212
return symengine.coth(a)
190213

191214

192-
class asinh(Function):
215+
class asinh(_RegisteredFunction):
193216
_classes = (symengine.ASinh,)
194217

195218
def __new__(cls, a):
196219
return symengine.asinh(a)
197220

198221

199-
class acosh(Function):
222+
class acosh(_RegisteredFunction):
200223
_classes = (symengine.ACosh,)
201224

202225
def __new__(cls, a):
203226
return symengine.acosh(a)
204227

205228

206-
class atanh(Function):
229+
class atanh(_RegisteredFunction):
207230
_classes = (symengine.ATanh,)
208231

209232
def __new__(cls, a):
210233
return symengine.atanh(a)
211234

212235

213-
class acoth(Function):
236+
class acoth(_RegisteredFunction):
214237
_classes = (symengine.ACoth,)
215238

216239
def __new__(cls, a):
217240
return symengine.acoth(a)
218241

242+
219243
'''
220244
for i in ("""Sin Cos Tan Gamma Cot Csc Sec ASin ACos ATan
221245
ACot ACsc ASec Sinh Cosh Tanh Coth ASinh ACosh ATanh

symengine/tests/test_sympy_compat.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,6 @@ def test_log():
5353
def test_zeros():
5454
assert zeros(3, c=2).shape == (3, 2)
5555

56+
def test_has_functions_module():
57+
import symengine.sympy_compat as sp
58+
assert sp.functions.sin(0) == 0

0 commit comments

Comments
 (0)