Skip to content

Commit ebc391f

Browse files
BUG: fix df deep copy (#709)
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent 384e453 commit ebc391f

File tree

7 files changed

+187
-8
lines changed

7 files changed

+187
-8
lines changed

python/xorbits/_mars/core/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __copy__(self):
9494
return self.copy()
9595

9696
def copy(self):
97-
return self.copy_to(type(self)(_key=self.key))
97+
return self.copy_to(type(self)())
9898

9999
def copy_to(self, target: "Base"):
100100
target_fields = target._FIELDS

python/xorbits/_mars/core/entity/tileables.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,14 @@ def __copy__(self):
364364
def _view(self):
365365
return super().copy()
366366

367-
def copy(self: TileableType) -> TileableType:
367+
def copy(self: TileableType, **kw) -> TileableType:
368+
from ...dataframe import Index
369+
from ...deploy.oscar.session import SyncSession
370+
371+
new_name = None
372+
if isinstance(self, Index):
373+
new_name = kw.pop("name", None)
374+
368375
new_op = self.op.copy()
369376
if new_op.create_view:
370377
# if the operand is a view, make it a copy
@@ -378,6 +385,24 @@ def copy(self: TileableType) -> TileableType:
378385
new_outs = new_op.new_tileables(
379386
self.op.inputs, kws=params, output_limit=len(params)
380387
)
388+
389+
sess = self._executed_sessions[-1] if self._executed_sessions else None
390+
to_incref_keys = []
391+
for _out in new_outs:
392+
if sess:
393+
_out._attach_session(sess)
394+
to_incref_keys.append(_out.key)
395+
if self.data in sess._tileable_to_fetch:
396+
sess._tileable_to_fetch[_out.data] = sess._tileable_to_fetch[
397+
self.data
398+
]
399+
if new_name:
400+
_out.name = new_name
401+
402+
if to_incref_keys:
403+
assert sess is not None
404+
SyncSession.from_isolated_session(sess).incref(*to_incref_keys)
405+
381406
pos = -1
382407
for i, out in enumerate(self.op.outputs):
383408
# create a ref to copied one

python/xorbits/_mars/dataframe/base/tests/test_base_execution.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3182,3 +3182,53 @@ def test_nunique(setup, method, chunked, axis):
31823182
raw_df.nunique(axis=axis),
31833183
mdf.nunique(axis=axis, method=method).execute().fetch(),
31843184
)
3185+
3186+
3187+
@pytest.mark.parametrize("chunk_size", [None, 10])
3188+
def test_copy_deep(setup, chunk_size):
3189+
ns = np.random.RandomState(0)
3190+
df = pd.DataFrame(ns.rand(100, 10), columns=["a" + str(i) for i in range(10)])
3191+
mdf = from_pandas_df(df, chunk_size=chunk_size)
3192+
3193+
# test case that there is no other result between copy and origin data
3194+
res = mdf.copy()
3195+
res["a0"] = res["a0"] + 1
3196+
dfc = df.copy(deep=True)
3197+
dfc["a0"] = dfc["a0"] + 1
3198+
pd.testing.assert_frame_equal(res.execute().fetch(), dfc)
3199+
pd.testing.assert_frame_equal(mdf.execute().fetch(), df)
3200+
3201+
s = pd.Series(ns.randint(0, 100, size=(100,)))
3202+
ms = from_pandas_series(s, chunk_size=chunk_size)
3203+
3204+
res = ms.copy()
3205+
res.iloc[0] = 111.0
3206+
sc = s.copy(deep=True)
3207+
sc.iloc[0] = 111.0
3208+
pd.testing.assert_series_equal(res.execute().fetch(), sc)
3209+
pd.testing.assert_series_equal(ms.execute().fetch(), s)
3210+
3211+
index = pd.Index([i for i in range(100)], name="test")
3212+
m_index = from_pandas_index(index, chunk_size=chunk_size)
3213+
3214+
res = m_index.copy()
3215+
assert res is not m_index
3216+
pd.testing.assert_index_equal(res.execute().fetch(), index.copy())
3217+
pd.testing.assert_index_equal(m_index.execute().fetch(), index)
3218+
3219+
res = m_index.copy(name="abc")
3220+
pd.testing.assert_index_equal(res.execute().fetch(), index.copy(name="abc"))
3221+
pd.testing.assert_index_equal(m_index.execute().fetch(), index)
3222+
3223+
# test case that there is other ops between copy and origin data
3224+
xdf = (mdf + 1) * 2 / 7
3225+
expected = (df + 1) * 2 / 7
3226+
pd.testing.assert_frame_equal(xdf.execute().fetch(), expected)
3227+
3228+
xdf_c = xdf.copy()
3229+
expected_c = expected.copy(deep=True)
3230+
pd.testing.assert_frame_equal(xdf_c.execute().fetch(), expected)
3231+
xdf_c["a1"] = xdf_c["a1"] + 0.8
3232+
expected_c["a1"] = expected_c["a1"] + 0.8
3233+
pd.testing.assert_frame_equal(xdf_c.execute().fetch(), expected_c)
3234+
pd.testing.assert_frame_equal(xdf.execute().fetch(), expected)

python/xorbits/_mars/dataframe/core.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,6 +1167,37 @@ def to_series(self, index=None, name=None):
11671167

11681168
return series_from_index(self, index=index, name=name)
11691169

1170+
def copy(self, name=None, deep=False):
1171+
"""
1172+
Make a copy of this object.
1173+
1174+
Name is set on the new object.
1175+
1176+
Parameters
1177+
----------
1178+
name : Label, optional
1179+
Set name for new object.
1180+
deep : bool, default False
1181+
1182+
Returns
1183+
-------
1184+
Index
1185+
Index refer to new object which is a copy of this object.
1186+
1187+
Notes
1188+
-----
1189+
In most cases, there should be no functional difference from using
1190+
``deep``, but if ``deep`` is passed it will attempt to deepcopy.
1191+
1192+
Examples
1193+
--------
1194+
>>> idx = pd.Index(['a', 'b', 'c'])
1195+
>>> new_idx = idx.copy()
1196+
>>> idx is new_idx
1197+
False
1198+
"""
1199+
return super().copy(name=name)
1200+
11701201

11711202
class RangeIndex(Index):
11721203
__slots__ = ()
@@ -1591,10 +1622,9 @@ def copy(self, deep=True): # pylint: disable=arguments-differ
15911622
copy : Series or DataFrame
15921623
Object type matches caller.
15931624
"""
1594-
if deep:
1595-
return super().copy()
1596-
else:
1597-
return super()._view()
1625+
if deep is False:
1626+
raise NotImplementedError("Not support `deep=False` for now")
1627+
return super().copy()
15981628

15991629
def __len__(self):
16001630
return len(self._data)
@@ -2618,6 +2648,11 @@ def apply_if_callable(maybe_callable, obj, **kwargs):
26182648
data[k] = apply_if_callable(v, data)
26192649
return data
26202650

2651+
def copy(self, deep=True):
2652+
if deep is False:
2653+
raise NotImplementedError("Not support `deep=False` for now")
2654+
return super().copy()
2655+
26212656

26222657
class DataFrameGroupByChunkData(BaseDataFrameChunkData):
26232658
type_name = "DataFrameGroupBy"

python/xorbits/_mars/deploy/oscar/session.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,17 @@ def decref(self, *tileables_keys):
502502
Tileables' keys
503503
"""
504504

505+
@abstractmethod
506+
def incref(self, *tileables_keys):
507+
"""
508+
Incref tileables.
509+
510+
Parameters
511+
----------
512+
tileables_keys : list
513+
Tileables' keys
514+
"""
515+
505516
@abstractmethod
506517
def _get_ref_counts(self) -> Dict[str, int]:
507518
"""
@@ -960,10 +971,19 @@ async def execute(self, *tileables, **kwargs) -> ExecutionInfo:
960971
def _get_to_fetch_tileable(
961972
self, tileable: TileableType
962973
) -> Tuple[TileableType, List[Union[slice, Integral]]]:
963-
from ...dataframe.indexing.iloc import DataFrameIlocGetItem, SeriesIlocGetItem
974+
from ...dataframe.indexing.iloc import (
975+
DataFrameIlocGetItem,
976+
IndexIlocGetItem,
977+
SeriesIlocGetItem,
978+
)
964979
from ...tensor.indexing import TensorIndex
965980

966-
slice_op_types = TensorIndex, DataFrameIlocGetItem, SeriesIlocGetItem
981+
slice_op_types = (
982+
TensorIndex,
983+
DataFrameIlocGetItem,
984+
SeriesIlocGetItem,
985+
IndexIlocGetItem,
986+
)
967987

968988
if hasattr(tileable, "data"):
969989
tileable = tileable.data
@@ -1200,6 +1220,10 @@ async def decref(self, *tileable_keys):
12001220
logger.debug("Decref tileables on client: %s", tileable_keys)
12011221
return await self._lifecycle_api.decref_tileables(list(tileable_keys))
12021222

1223+
async def incref(self, *tileable_keys):
1224+
logger.debug("Incref tileables on client: %s", tileable_keys)
1225+
return await self._lifecycle_api.incref_tileables(list(tileable_keys))
1226+
12031227
async def _get_ref_counts(self) -> Dict[str, int]:
12041228
return await self._lifecycle_api.get_all_chunk_ref_counts()
12051229

@@ -1623,6 +1647,11 @@ def fetch_infos(self, *tileables, fields, **kwargs) -> list:
16231647
def decref(self, *tileables_keys):
16241648
pass # pragma: no cover
16251649

1650+
@implements(AbstractSyncSession.incref)
1651+
@_delegate_to_isolated_session
1652+
def incref(self, *tileables_keys):
1653+
pass # pragma: no cover
1654+
16261655
@implements(AbstractSyncSession._get_ref_counts)
16271656
@_delegate_to_isolated_session
16281657
def _get_ref_counts(self) -> Dict[str, int]:

python/xorbits/core/adapter.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,12 @@ def collect_cls_members(
495495
) -> Dict[str, Any]:
496496
cls_members: Dict[str, Any] = {}
497497
for name, cls_member in inspect.getmembers(cls):
498+
# Tileable and TileableData object may have functions that have the same names.
499+
# For example, Index and IndexData both have `copy` function, but they have completely different semantics.
500+
# Therefore, when the Index's `copy` method has been collected,
501+
# the method of the same name on IndexData cannot be collected again.
502+
if cls.__name__.endswith("Data") and name in DATA_MEMBERS[data_type]: # type: ignore
503+
continue
498504
if inspect.isfunction(cls_member) and not name.startswith("_"):
499505
cls_members[name] = wrap_mars_callable(
500506
cls_member,

python/xorbits/pandas/pandas_adapters/tests/test_pandas_adapters.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from .... import pandas as xpd
2424
from ....core.data import DataRef
25+
from ....core.execution import need_to_execute
2526

2627

2728
def test_pandas_dataframe_methods(setup):
@@ -499,3 +500,36 @@ def test_read_pickle(setup):
499500
assert (x == y).all()
500501
finally:
501502
shutil.rmtree(tempdir)
503+
504+
505+
def test_copy(setup):
506+
index = xpd.Index([i for i in range(100)], name="test")
507+
index_iloc = index[:20]
508+
assert need_to_execute(index_iloc) is True
509+
repr(index_iloc)
510+
511+
index_copy = index_iloc.copy()
512+
assert need_to_execute(index_copy) is False
513+
pd.testing.assert_index_equal(index_copy.to_pandas(), index_iloc.to_pandas())
514+
515+
index_copy = index_iloc.copy(name="abc")
516+
assert need_to_execute(index_copy) is True
517+
pd.testing.assert_index_equal(
518+
index_copy.to_pandas(), index_iloc.to_pandas().copy(name="abc")
519+
)
520+
521+
series = xpd.Series([1, 2, 3, 4, np.nan, 6])
522+
series = series + 1
523+
assert need_to_execute(series) is True
524+
repr(series)
525+
526+
sc = series.copy()
527+
assert need_to_execute(sc) is False
528+
expected = series.to_pandas()
529+
pd.testing.assert_series_equal(sc.to_pandas(), expected)
530+
531+
sc[0] = np.nan
532+
assert need_to_execute(sc) is True
533+
ec = expected.copy()
534+
ec[0] = np.nan
535+
pd.testing.assert_series_equal(sc.to_pandas(), ec)

0 commit comments

Comments
 (0)