Skip to content

Commit a590555

Browse files
committed
enh: jit gradient and value
1 parent 3003009 commit a590555

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

src/zfit_physics/tfpwa/loss.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import zfit
1010
import zfit.z.numpy as znp
11+
from zfit import z
1112
from zfit.core.interfaces import ZfitParameter
1213
from zfit.util.container import convert_to_container
1314

@@ -32,14 +33,12 @@ def nll_from_fcn(fcn: tf_pwa.model.FCN, *, params: ParamType = None):
3233

3334
# something is off here: for the value, we need to pass the parameters as a dict
3435
# but for the gradient/hesse, we need to pass them as a list
35-
# TODO: activate if https://github.com/jiangyi15/tf-pwa/pull/153 is merged
36-
# @z.function(wraps="loss")
36+
@z.function(wraps="loss")
3737
def eval_func(params):
3838
paramdict = make_paramdict(params)
3939
return fcn(paramdict)
4040

41-
# TODO: activate if https://github.com/jiangyi15/tf-pwa/pull/153 is merged
42-
# @z.function(wraps="loss")
41+
@z.function(wraps="loss")
4342
def eval_grad(params):
4443
return fcn.nll_grad(params)[1]
4544

tests/tfpwa/test_basic_example_tfpwa.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ def test_example1_tfpwa():
6767
fit_result = config.fit(method="BFGS")
6868

6969
kwargs = dict(gradient='zfit', tol=0.01)
70+
# kwargs = dict(gradient=False, tol=0.01)
71+
# kwargs = dict(tol=0.01)
7072
assert pytest.approx(nll.value(), 0.001) == initial_val
7173
v, g, h = fcn.nll_grad_hessian()
7274
vz, gz, hz = nll.value_gradient_hessian()
@@ -79,7 +81,7 @@ def test_example1_tfpwa():
7981
np.testing.assert_allclose(g, gz1, atol=0.001)
8082

8183
minimizer = zfit.minimize.Minuit(verbosity=7, **kwargs)
82-
# minimizer = zfit.minimize.ScipyBFGS(verbosity=7, **kwargs) # performs bestamba
84+
# minimizer = zfit.minimize.ScipyBFGS(verbosity=7, **kwargs) # performs best
8385
# minimizer = zfit.minimize.NLoptMMAV1(verbosity=7, **kwargs)
8486
# minimizer = zfit.minimize.ScipyLBFGSBV1(verbosity=7, **kwargs)
8587
# minimizer = zfit.minimize.NLoptLBFGSV1(verbosity=7, **kwargs)

0 commit comments

Comments
 (0)