Skip to content

Commit 334a5a7

Browse files
authored
Iterable for PyGenericAlias (RustPython#5876)
* Iterable for PyGenericAlias * GenericAlias works * typevar.rs
1 parent e7c18f1 commit 334a5a7

File tree

8 files changed

+1101
-960
lines changed

8 files changed

+1101
-960
lines changed

Lib/test/test_typing.py

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -680,8 +680,6 @@ def test_typevartuple(self):
680680
class A(Generic[Unpack[Ts]]): ...
681681
Alias = Optional[Unpack[Ts]]
682682

683-
# TODO: RUSTPYTHON
684-
@unittest.expectedFailure
685683
def test_typevartuple_specialization(self):
686684
T = TypeVar("T")
687685
Ts = TypeVarTuple('Ts', default=Unpack[Tuple[str, int]])
@@ -691,8 +689,6 @@ class A(Generic[T, Unpack[Ts]]): ...
691689
self.assertEqual(A[float, range].__args__, (float, range))
692690
self.assertEqual(A[float, *tuple[int, ...]].__args__, (float, *tuple[int, ...]))
693691

694-
# TODO: RUSTPYTHON
695-
@unittest.expectedFailure
696692
def test_typevar_and_typevartuple_specialization(self):
697693
T = TypeVar("T")
698694
U = TypeVar("U", default=float)
@@ -740,8 +736,6 @@ class A(Generic[T, P]): ...
740736
self.assertEqual(A[float].__args__, (float, (str, int)))
741737
self.assertEqual(A[float, [range]].__args__, (float, (range,)))
742738

743-
# TODO: RUSTPYTHON
744-
@unittest.expectedFailure
745739
def test_typevar_and_paramspec_specialization(self):
746740
T = TypeVar("T")
747741
U = TypeVar("U", default=float)
@@ -752,8 +746,6 @@ class A(Generic[T, U, P]): ...
752746
self.assertEqual(A[float, int].__args__, (float, int, (str, int)))
753747
self.assertEqual(A[float, int, [range]].__args__, (float, int, (range,)))
754748

755-
# TODO: RUSTPYTHON
756-
@unittest.expectedFailure
757749
def test_paramspec_and_typevar_specialization(self):
758750
T = TypeVar("T")
759751
P = ParamSpec('P', default=[str, int])
@@ -1049,8 +1041,6 @@ class C(Generic[T1, T2]): pass
10491041
eval(expected_str)
10501042
)
10511043

1052-
# TODO: RUSTPYTHON
1053-
@unittest.expectedFailure
10541044
def test_three_parameters(self):
10551045
T1 = TypeVar('T1')
10561046
T2 = TypeVar('T2')
@@ -2543,8 +2533,6 @@ def __call__(self):
25432533
self.assertIs(a().__class__, C1)
25442534
self.assertEqual(a().__orig_class__, C1[[int], T])
25452535

2546-
# TODO: RUSTPYTHON
2547-
@unittest.expectedFailure
25482536
def test_paramspec(self):
25492537
Callable = self.Callable
25502538
fullname = f"{Callable.__module__}.Callable"
@@ -2579,8 +2567,6 @@ def test_paramspec(self):
25792567
self.assertEqual(repr(C2), f"{fullname}[~P, int]")
25802568
self.assertEqual(repr(C2[int, str]), f"{fullname}[[int, str], int]")
25812569

2582-
# TODO: RUSTPYTHON
2583-
@unittest.expectedFailure
25842570
def test_concatenate(self):
25852571
Callable = self.Callable
25862572
fullname = f"{Callable.__module__}.Callable"
@@ -2608,8 +2594,6 @@ def test_concatenate(self):
26082594
Callable[Concatenate[int, str, P2], int])
26092595
self.assertEqual(C[...], Callable[Concatenate[int, ...], int])
26102596

2611-
# TODO: RUSTPYTHON
2612-
@unittest.expectedFailure
26132597
def test_nested_paramspec(self):
26142598
# Since Callable has some special treatment, we want to be sure
26152599
# that substituion works correctly, see gh-103054
@@ -2652,8 +2636,6 @@ class My(Generic[P, T]):
26522636
self.assertEqual(C4[bool, bytes, float],
26532637
My[[Callable[[int, bool, bytes, str], float], float], float])
26542638

2655-
# TODO: RUSTPYTHON
2656-
@unittest.expectedFailure
26572639
def test_errors(self):
26582640
Callable = self.Callable
26592641
alias = Callable[[int, str], float]
@@ -2682,6 +2664,11 @@ def test_consistency(self):
26822664
class CollectionsCallableTests(BaseCallableTests, BaseTestCase):
26832665
Callable = collections.abc.Callable
26842666

2667+
# TODO: RUSTPYTHON
2668+
@unittest.expectedFailure
2669+
def test_errors(self):
2670+
super().test_errors()
2671+
26852672

26862673
class LiteralTests(BaseTestCase):
26872674
def test_basics(self):
@@ -4631,8 +4618,6 @@ class Base(Generic[T_co]):
46314618
class Sub(Base, Generic[T]):
46324619
...
46334620

4634-
# TODO: RUSTPYTHON
4635-
@unittest.expectedFailure
46364621
def test_parameter_detection(self):
46374622
self.assertEqual(List[T].__parameters__, (T,))
46384623
self.assertEqual(List[List[T]].__parameters__, (T,))
@@ -4650,8 +4635,6 @@ class A:
46504635
# C version of GenericAlias
46514636
self.assertEqual(list[A()].__parameters__, (T,))
46524637

4653-
# TODO: RUSTPYTHON
4654-
@unittest.expectedFailure
46554638
def test_non_generic_subscript(self):
46564639
T = TypeVar('T')
46574640
class G(Generic[T]):
@@ -8858,8 +8841,6 @@ def test_bad_var_substitution(self):
88588841
with self.assertRaises(TypeError):
88598842
collections.abc.Callable[P, T][arg, str]
88608843

8861-
# TODO: RUSTPYTHON
8862-
@unittest.expectedFailure
88638844
def test_type_var_subst_for_other_type_vars(self):
88648845
T = TypeVar('T')
88658846
T2 = TypeVar('T2')
@@ -8981,8 +8962,6 @@ class PandT(Generic[P, T]):
89818962
self.assertEqual(C3.__args__, ((int, *Ts), T))
89828963
self.assertEqual(C3[str, bool, bytes], PandT[[int, str, bool], bytes])
89838964

8984-
# TODO: RUSTPYTHON
8985-
@unittest.expectedFailure
89868965
def test_paramspec_in_nested_generics(self):
89878966
# Although ParamSpec should not be found in __parameters__ of most
89888967
# generics, they probably should be found when nested in
@@ -9001,8 +8980,6 @@ def test_paramspec_in_nested_generics(self):
90018980
self.assertEqual(G2[[int, str], float], list[C])
90028981
self.assertEqual(G3[[int, str], float], list[C] | int)
90038982

9004-
# TODO: RUSTPYTHON
9005-
@unittest.expectedFailure
90068983
def test_paramspec_gets_copied(self):
90078984
# bpo-46581
90088985
P = ParamSpec('P')
@@ -9090,8 +9067,6 @@ def test_invalid_uses(self):
90909067
):
90919068
Concatenate[int]
90929069

9093-
# TODO: RUSTPYTHON
9094-
@unittest.expectedFailure
90959070
def test_var_substitution(self):
90969071
T = TypeVar('T')
90979072
P = ParamSpec('P')

vm/src/builtins/genericalias.rs

Lines changed: 113 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ use crate::{
1212
function::{FuncArgs, PyComparisonValue},
1313
protocol::{PyMappingMethods, PyNumberMethods},
1414
types::{
15-
AsMapping, AsNumber, Callable, Comparable, Constructor, GetAttr, Hashable, PyComparisonOp,
16-
Representable,
15+
AsMapping, AsNumber, Callable, Comparable, Constructor, GetAttr, Hashable, Iterable,
16+
PyComparisonOp, Representable,
1717
},
1818
};
1919
use std::fmt;
@@ -78,6 +78,7 @@ impl Constructor for PyGenericAlias {
7878
Constructor,
7979
GetAttr,
8080
Hashable,
81+
Iterable,
8182
Representable
8283
),
8384
flags(BASETYPE)
@@ -166,17 +167,17 @@ impl PyGenericAlias {
166167
}
167168

168169
#[pymethod]
169-
fn __getitem__(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult {
170+
fn __getitem__(zelf: PyRef<Self>, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult {
170171
let new_args = subs_parameters(
171-
|vm| self.repr(vm),
172-
self.args.clone(),
173-
self.parameters.clone(),
172+
zelf.to_owned().into(),
173+
zelf.args.clone(),
174+
zelf.parameters.clone(),
174175
needle,
175176
vm,
176177
)?;
177178

178179
Ok(
179-
PyGenericAlias::new(self.origin.clone(), new_args.to_pyobject(vm), vm)
180+
PyGenericAlias::new(zelf.origin.clone(), new_args.to_pyobject(vm), vm)
180181
.into_pyobject(vm),
181182
)
182183
}
@@ -277,6 +278,18 @@ fn tuple_index(vec: &[PyObjectRef], item: &PyObjectRef) -> Option<usize> {
277278
vec.iter().position(|element| element.is(item))
278279
}
279280

281+
fn is_unpacked_typevartuple(arg: &PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
282+
if arg.class().is(vm.ctx.types.type_type) {
283+
return Ok(false);
284+
}
285+
286+
if let Ok(attr) = arg.get_attr(identifier!(vm, __typing_is_unpacked_typevartuple__), vm) {
287+
attr.try_to_bool(vm)
288+
} else {
289+
Ok(false)
290+
}
291+
}
292+
280293
fn subs_tvars(
281294
obj: PyObjectRef,
282295
params: &PyTupleRef,
@@ -324,22 +337,40 @@ fn subs_tvars(
324337
}
325338

326339
// _Py_subs_parameters
327-
pub fn subs_parameters<F: Fn(&VirtualMachine) -> PyResult<String>>(
328-
repr: F,
340+
pub fn subs_parameters(
341+
alias: PyObjectRef, // The GenericAlias object itself
329342
args: PyTupleRef,
330343
parameters: PyTupleRef,
331344
needle: PyObjectRef,
332345
vm: &VirtualMachine,
333346
) -> PyResult<PyTupleRef> {
334347
let num_params = parameters.len();
335348
if num_params == 0 {
336-
return Err(vm.new_type_error(format!("There are no type variables left in {}", repr(vm)?)));
349+
return Err(vm.new_type_error(format!("{} is not a generic class", alias.repr(vm)?)));
337350
}
338351

339-
let items = needle.try_to_ref::<PyTuple>(vm);
352+
// Handle __typing_prepare_subst__ for each parameter
353+
// Following CPython: each prepare function transforms the args
354+
let mut prepared_args = needle.clone();
355+
356+
// Ensure args is a tuple
357+
if prepared_args.try_to_ref::<PyTuple>(vm).is_err() {
358+
prepared_args = PyTuple::new_ref(vec![prepared_args], &vm.ctx).into();
359+
}
360+
361+
for param in parameters.iter() {
362+
if let Ok(prepare) = param.get_attr(identifier!(vm, __typing_prepare_subst__), vm) {
363+
if !prepare.is(&vm.ctx.none) {
364+
// Call prepare(cls, args) where cls is the GenericAlias
365+
prepared_args = prepare.call((alias.clone(), prepared_args), vm)?;
366+
}
367+
}
368+
}
369+
370+
let items = prepared_args.try_to_ref::<PyTuple>(vm);
340371
let arg_items = match items {
341372
Ok(tuple) => tuple.as_slice(),
342-
Err(_) => std::slice::from_ref(&needle),
373+
Err(_) => std::slice::from_ref(&prepared_args),
343374
};
344375

345376
let num_items = arg_items.len();
@@ -362,40 +393,82 @@ pub fn subs_parameters<F: Fn(&VirtualMachine) -> PyResult<String>>(
362393

363394
let min_required = num_params - params_with_defaults;
364395
if num_items < min_required {
396+
let repr_str = alias.repr(vm)?;
365397
return Err(vm.new_type_error(format!(
366-
"Too few arguments for {}; actual {}, expected at least {}",
367-
repr(vm)?,
368-
num_items,
369-
min_required
398+
"Too few arguments for {repr_str}; actual {num_items}, expected at least {min_required}"
370399
)));
371400
}
372401
} else if num_items > num_params {
402+
let repr_str = alias.repr(vm)?;
373403
return Err(vm.new_type_error(format!(
374-
"Too many arguments for {}; actual {}, expected {}",
375-
repr(vm)?,
376-
num_items,
377-
num_params
404+
"Too many arguments for {repr_str}; actual {num_items}, expected {num_params}"
378405
)));
379406
}
380407

381-
let mut new_args = Vec::new();
408+
let mut new_args = Vec::with_capacity(args.len());
382409

383410
for arg in args.iter() {
411+
// Skip bare Python classes
412+
if arg.class().is(vm.ctx.types.type_type) {
413+
new_args.push(arg.clone());
414+
continue;
415+
}
416+
417+
// Check if this is an unpacked TypeVarTuple
418+
let unpack = is_unpacked_typevartuple(arg, vm)?;
419+
384420
// Check for __typing_subst__ attribute directly (like CPython)
385421
if let Ok(subst) = arg.get_attr(identifier!(vm, __typing_subst__), vm) {
386-
let idx = tuple_index(parameters.as_slice(), arg).unwrap();
387-
if idx < num_items {
388-
// Call __typing_subst__ with the argument
389-
let substituted = subst.call((arg_items[idx].clone(),), vm)?;
390-
new_args.push(substituted);
422+
if let Some(idx) = tuple_index(parameters.as_slice(), arg) {
423+
if idx < num_items {
424+
// Call __typing_subst__ with the argument
425+
let substituted = subst.call((arg_items[idx].clone(),), vm)?;
426+
427+
if unpack {
428+
// Unpack the tuple if it's a TypeVarTuple
429+
if let Ok(tuple) = substituted.try_to_ref::<PyTuple>(vm) {
430+
for elem in tuple.iter() {
431+
new_args.push(elem.clone());
432+
}
433+
} else {
434+
new_args.push(substituted);
435+
}
436+
} else {
437+
new_args.push(substituted);
438+
}
439+
} else {
440+
// Use default value if available
441+
if let Ok(default_val) = vm.call_method(arg, "__default__", ()) {
442+
if !default_val.is(&vm.ctx.typing_no_default) {
443+
new_args.push(default_val);
444+
} else {
445+
return Err(vm.new_type_error(format!(
446+
"No argument provided for parameter at index {idx}"
447+
)));
448+
}
449+
} else {
450+
return Err(vm.new_type_error(format!(
451+
"No argument provided for parameter at index {idx}"
452+
)));
453+
}
454+
}
391455
} else {
392-
// CPython doesn't support default values in this context
393-
return Err(
394-
vm.new_type_error(format!("No argument provided for parameter at index {idx}"))
395-
);
456+
new_args.push(arg.clone());
396457
}
397458
} else {
398-
new_args.push(subs_tvars(arg.clone(), &parameters, arg_items, vm)?);
459+
let subst_arg = subs_tvars(arg.clone(), &parameters, arg_items, vm)?;
460+
if unpack {
461+
// Unpack the tuple if it's a TypeVarTuple
462+
if let Ok(tuple) = subst_arg.try_to_ref::<PyTuple>(vm) {
463+
for elem in tuple.iter() {
464+
new_args.push(elem.clone());
465+
}
466+
} else {
467+
new_args.push(subst_arg);
468+
}
469+
} else {
470+
new_args.push(subst_arg);
471+
}
399472
}
400473
}
401474

@@ -406,7 +479,8 @@ impl AsMapping for PyGenericAlias {
406479
fn as_mapping() -> &'static PyMappingMethods {
407480
static AS_MAPPING: LazyLock<PyMappingMethods> = LazyLock::new(|| PyMappingMethods {
408481
subscript: atomic_func!(|mapping, needle, vm| {
409-
PyGenericAlias::mapping_downcast(mapping).__getitem__(needle.to_owned(), vm)
482+
let zelf = PyGenericAlias::mapping_downcast(mapping);
483+
PyGenericAlias::__getitem__(zelf.to_owned(), needle.to_owned(), vm)
410484
}),
411485
..PyMappingMethods::NOT_IMPLEMENTED
412486
});
@@ -490,6 +564,13 @@ impl Representable for PyGenericAlias {
490564
}
491565
}
492566

567+
impl Iterable for PyGenericAlias {
568+
fn iter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
569+
// Return an iterator over the args tuple
570+
Ok(zelf.args.clone().to_pyobject(vm).get_iter(vm)?.into())
571+
}
572+
}
573+
493574
pub fn init(context: &Context) {
494575
let generic_alias_type = &context.types.generic_alias_type;
495576
PyGenericAlias::extend_class(context, generic_alias_type);

0 commit comments

Comments
 (0)