Skip to content

Commit c195473

Browse files
authored
Update copy from 3.13.5 (RustPython#5913)
1 parent d58c500 commit c195473

File tree

2 files changed

+105
-9
lines changed

2 files changed

+105
-9
lines changed

Lib/copy.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
55
import copy
66
7-
x = copy.copy(y) # make a shallow copy of y
8-
x = copy.deepcopy(y) # make a deep copy of y
7+
x = copy.copy(y) # make a shallow copy of y
8+
x = copy.deepcopy(y) # make a deep copy of y
9+
x = copy.replace(y, a=1, b=2) # new object with fields replaced, as defined by `__replace__`
910
1011
For module specific errors, copy.Error is raised.
1112
@@ -56,7 +57,7 @@ class Error(Exception):
5657
pass
5758
error = Error # backward compatibility
5859

59-
__all__ = ["Error", "copy", "deepcopy"]
60+
__all__ = ["Error", "copy", "deepcopy", "replace"]
6061

6162
def copy(x):
6263
"""Shallow copy operation on arbitrary Python objects.
@@ -121,13 +122,13 @@ def deepcopy(x, memo=None, _nil=[]):
121122
See the module's __doc__ string for more info.
122123
"""
123124

125+
d = id(x)
124126
if memo is None:
125127
memo = {}
126-
127-
d = id(x)
128-
y = memo.get(d, _nil)
129-
if y is not _nil:
130-
return y
128+
else:
129+
y = memo.get(d, _nil)
130+
if y is not _nil:
131+
return y
131132

132133
cls = type(x)
133134

@@ -290,3 +291,16 @@ def _reconstruct(x, memo, func, args,
290291
return y
291292

292293
del types, weakref
294+
295+
296+
def replace(obj, /, **changes):
297+
"""Return a new object replacing specified fields with new values.
298+
299+
This is especially useful for immutable objects, like named tuples or
300+
frozen dataclasses.
301+
"""
302+
cls = obj.__class__
303+
func = getattr(cls, '__replace__', None)
304+
if func is None:
305+
raise TypeError(f"replace() does not support {cls.__name__} objects")
306+
return func(obj, **changes)

Lib/test/test_copy.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import copyreg
55
import weakref
66
import abc
7-
from operator import le, lt, ge, gt, eq, ne
7+
from operator import le, lt, ge, gt, eq, ne, attrgetter
88

99
import unittest
1010
from test import support
@@ -903,7 +903,89 @@ def m(self):
903903
g.b()
904904

905905

906+
class TestReplace(unittest.TestCase):
907+
908+
def test_unsupported(self):
909+
self.assertRaises(TypeError, copy.replace, 1)
910+
self.assertRaises(TypeError, copy.replace, [])
911+
self.assertRaises(TypeError, copy.replace, {})
912+
def f(): pass
913+
self.assertRaises(TypeError, copy.replace, f)
914+
class A: pass
915+
self.assertRaises(TypeError, copy.replace, A)
916+
self.assertRaises(TypeError, copy.replace, A())
917+
918+
def test_replace_method(self):
919+
class A:
920+
def __new__(cls, x, y=0):
921+
self = object.__new__(cls)
922+
self.x = x
923+
self.y = y
924+
return self
925+
926+
def __init__(self, *args, **kwargs):
927+
self.z = self.x + self.y
928+
929+
def __replace__(self, **changes):
930+
x = changes.get('x', self.x)
931+
y = changes.get('y', self.y)
932+
return type(self)(x, y)
933+
934+
attrs = attrgetter('x', 'y', 'z')
935+
a = A(11, 22)
936+
self.assertEqual(attrs(copy.replace(a)), (11, 22, 33))
937+
self.assertEqual(attrs(copy.replace(a, x=1)), (1, 22, 23))
938+
self.assertEqual(attrs(copy.replace(a, y=2)), (11, 2, 13))
939+
self.assertEqual(attrs(copy.replace(a, x=1, y=2)), (1, 2, 3))
940+
941+
# TODO: RUSTPYTHON
942+
@unittest.expectedFailure
943+
def test_namedtuple(self):
944+
from collections import namedtuple
945+
from typing import NamedTuple
946+
PointFromCall = namedtuple('Point', 'x y', defaults=(0,))
947+
class PointFromInheritance(PointFromCall):
948+
pass
949+
class PointFromClass(NamedTuple):
950+
x: int
951+
y: int = 0
952+
for Point in (PointFromCall, PointFromInheritance, PointFromClass):
953+
with self.subTest(Point=Point):
954+
p = Point(11, 22)
955+
self.assertIsInstance(p, Point)
956+
self.assertEqual(copy.replace(p), (11, 22))
957+
self.assertIsInstance(copy.replace(p), Point)
958+
self.assertEqual(copy.replace(p, x=1), (1, 22))
959+
self.assertEqual(copy.replace(p, y=2), (11, 2))
960+
self.assertEqual(copy.replace(p, x=1, y=2), (1, 2))
961+
with self.assertRaisesRegex(TypeError, 'unexpected field name'):
962+
copy.replace(p, x=1, error=2)
963+
964+
# TODO: RUSTPYTHON
965+
@unittest.expectedFailure
966+
def test_dataclass(self):
967+
from dataclasses import dataclass
968+
@dataclass
969+
class C:
970+
x: int
971+
y: int = 0
972+
973+
attrs = attrgetter('x', 'y')
974+
c = C(11, 22)
975+
self.assertEqual(attrs(copy.replace(c)), (11, 22))
976+
self.assertEqual(attrs(copy.replace(c, x=1)), (1, 22))
977+
self.assertEqual(attrs(copy.replace(c, y=2)), (11, 2))
978+
self.assertEqual(attrs(copy.replace(c, x=1, y=2)), (1, 2))
979+
with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
980+
copy.replace(c, x=1, error=2)
981+
982+
983+
class MiscTestCase(unittest.TestCase):
984+
def test__all__(self):
985+
support.check__all__(self, copy, not_exported={"dispatch_table", "error"})
986+
906987
def global_foo(x, y): return x+y
907988

989+
908990
if __name__ == "__main__":
909991
unittest.main()

0 commit comments

Comments
 (0)