Skip to content
This repository was archived by the owner on Dec 6, 2023. It is now read-only.

Commit a06c0da

Browse files
committed
Store pyearth version information in the Earth object.
1 parent 4fc94a0 commit a06c0da

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

pyearth/earth.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
check_X_y)
88
import numpy as np
99
from scipy import sparse
10-
10+
from ._version import get_versions
11+
__version__ = get_versions()['version']
1112

1213
class Earth(BaseEstimator, RegressorMixin, TransformerMixin):
1314

@@ -254,6 +255,11 @@ class Earth(BaseEstimator, RegressorMixin, TransformerMixin):
254255
array of shape m. If several feature importance types are
255256
specified, then it is dict where each key is a feature importance type
256257
name and its corresponding value is an array of shape m.
258+
259+
`_version`: string
260+
The version of py-earth in which the Earth object was originally
261+
created. This information may be useful when dealing with
262+
serialized Earth objects.
257263
258264
259265
References
@@ -317,6 +323,7 @@ def __init__(self, max_terms=None, max_degree=None, allow_missing=False,
317323
self.enable_pruning = enable_pruning
318324
self.feature_importance_type = feature_importance_type
319325
self.verbose = verbose
326+
self._version = __version__
320327

321328
def __eq__(self, other):
322329
if self.__class__ is not other.__class__:

pyearth/test/test_earth.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pyearth._basis import (Basis, ConstantBasisFunction,
1919
HingeBasisFunction, LinearBasisFunction)
2020
from pyearth import Earth
21+
import pyearth
2122

2223
numpy.random.seed(0)
2324

@@ -306,6 +307,16 @@ def test_pickle_compatibility():
306307
assert_true(model_copy.basis_[0] is model_copy.basis_[1]._get_root())
307308

308309

310+
def test_pickle_version_storage():
311+
earth = Earth(**default_params)
312+
model = earth.fit(X, y)
313+
assert_equal(model._version, pyearth.__version__)
314+
model._version = 'hello'
315+
assert_equal(model._version,'hello')
316+
model_copy = pickle.loads(pickle.dumps(model))
317+
assert_equal(model_copy._version, model._version)
318+
319+
309320
def test_copy_compatibility():
310321
model = Earth(**default_params).fit(X, y)
311322
model_copy = copy.copy(model)

0 commit comments

Comments
 (0)