@@ -371,80 +371,120 @@ impl PyObject {
371371 } )
372372 }
373373
374- // Equivalent to check_class. Masks Attribute errors (into TypeErrors) and lets everything
375- // else go through .
376- fn check_cls < F > ( & self , cls : & PyObject , vm : & VirtualMachine , msg : F ) -> PyResult
374+ // Equivalent to CPython's check_class. Returns Ok(()) if cls is a valid class,
375+ // Err with TypeError if not. Uses abstract_get_bases internally .
376+ fn check_class < F > ( & self , vm : & VirtualMachine , msg : F ) -> PyResult < ( ) >
377377 where
378378 F : Fn ( ) -> String ,
379379 {
380- cls. get_attr ( identifier ! ( vm , __bases__ ) , vm ) . map_err ( |e| {
381- // Only mask AttributeErrors.
382- if e . class ( ) . is ( vm . ctx . exceptions . attribute_error ) {
383- vm . new_type_error ( msg ( ) )
384- } else {
385- e
380+ let cls = self ;
381+ match cls . abstract_get_bases ( vm ) ? {
382+ Some ( _bases ) => Ok ( ( ) ) , // Has __bases__, it's a valid class
383+ None => {
384+ // No __bases__ or __bases__ is not a tuple
385+ Err ( vm . new_type_error ( msg ( ) ) )
386386 }
387- } )
387+ }
388388 }
389389
390- fn abstract_issubclass ( & self , cls : & PyObject , vm : & VirtualMachine ) -> PyResult < bool > {
391- let mut derived = self ;
392- let mut first_item: PyObjectRef ;
393- loop {
394- if derived. is ( cls) {
395- return Ok ( true ) ;
390+ /// abstract_get_bases() has logically 4 return states:
391+ /// 1. getattr(cls, '__bases__') could raise an AttributeError
392+ /// 2. getattr(cls, '__bases__') could raise some other exception
393+ /// 3. getattr(cls, '__bases__') could return a tuple
394+ /// 4. getattr(cls, '__bases__') could return something other than a tuple
395+ ///
396+ /// Only state #3 returns Some(tuple). AttributeErrors are masked by returning None.
397+ /// If an object other than a tuple comes out of __bases__, then again, None is returned.
398+ /// Other exceptions are propagated.
399+ fn abstract_get_bases ( & self , vm : & VirtualMachine ) -> PyResult < Option < PyTupleRef > > {
400+ match vm. get_attribute_opt ( self . to_owned ( ) , identifier ! ( vm, __bases__) ) ? {
401+ Some ( bases) => {
402+ // Check if it's a tuple
403+ match PyTupleRef :: try_from_object ( vm, bases) {
404+ Ok ( tuple) => Ok ( Some ( tuple) ) ,
405+ Err ( _) => Ok ( None ) , // Not a tuple, return None
406+ }
396407 }
408+ None => Ok ( None ) , // AttributeError was masked
409+ }
410+ }
397411
398- let bases = derived. get_attr ( identifier ! ( vm, __bases__) , vm) ?;
399- let tuple = PyTupleRef :: try_from_object ( vm, bases) ?;
412+ fn abstract_issubclass ( & self , cls : & PyObject , vm : & VirtualMachine ) -> PyResult < bool > {
413+ // # Safety: The lifetime of `derived` is forced to be ignored
414+ let bases = unsafe {
415+ let mut derived = self ;
416+ // First loop: handle single inheritance without recursion
417+ loop {
418+ if derived. is ( cls) {
419+ return Ok ( true ) ;
420+ }
400421
401- let n = tuple. len ( ) ;
402- match n {
403- 0 => {
422+ let Some ( bases) = derived. abstract_get_bases ( vm) ? else {
404423 return Ok ( false ) ;
405- }
406- 1 => {
407- first_item = tuple[ 0 ] . clone ( ) ;
408- derived = & first_item;
409- continue ;
410- }
411- _ => {
412- for i in 0 ..n {
413- let check = vm. with_recursion ( "in abstract_issubclass" , || {
414- tuple[ i] . abstract_issubclass ( cls, vm)
415- } ) ?;
416- if check {
417- return Ok ( true ) ;
418- }
424+ } ;
425+ let n = bases. len ( ) ;
426+ match n {
427+ 0 => return Ok ( false ) ,
428+ 1 => {
429+ // Avoid recursion in the single inheritance case
430+ // # safety
431+ // Intention:
432+ // ```
433+ // derived = bases.as_slice()[0].as_object();
434+ // ```
435+ // Though type-system cannot guarantee, derived does live long enough in the loop.
436+ derived = & * ( bases. as_slice ( ) [ 0 ] . as_object ( ) as * const _ ) ;
437+ continue ;
438+ }
439+ _ => {
440+ // Multiple inheritance - break out to handle recursively
441+ break bases;
419442 }
420443 }
421444 }
445+ } ;
422446
423- return Ok ( false ) ;
447+ // Second loop: handle multiple inheritance with recursion
448+ // At this point we know n >= 2
449+ let n = bases. len ( ) ;
450+ debug_assert ! ( n >= 2 ) ;
451+
452+ for i in 0 ..n {
453+ let result = vm. with_recursion ( "in __issubclass__" , || {
454+ bases. as_slice ( ) [ i] . abstract_issubclass ( cls, vm)
455+ } ) ?;
456+ if result {
457+ return Ok ( true ) ;
458+ }
424459 }
460+
461+ Ok ( false )
425462 }
426463
427464 fn recursive_issubclass ( & self , cls : & PyObject , vm : & VirtualMachine ) -> PyResult < bool > {
428- if let ( Ok ( obj) , Ok ( cls) ) = ( self . try_to_ref :: < PyType > ( vm) , cls. try_to_ref :: < PyType > ( vm) ) {
429- Ok ( obj. fast_issubclass ( cls) )
430- } else {
431- // Check if derived is a class
432- self . check_cls ( self , vm, || {
433- format ! ( "issubclass() arg 1 must be a class, not {}" , self . class( ) )
465+ // Fast path for both being types (matches CPython's PyType_Check)
466+ if let Some ( cls) = PyType :: check ( cls)
467+ && let Some ( derived) = PyType :: check ( self )
468+ {
469+ // PyType_IsSubtype equivalent
470+ return Ok ( derived. is_subtype ( cls) ) ;
471+ }
472+ // Check if derived is a class
473+ self . check_class ( vm, || {
474+ format ! ( "issubclass() arg 1 must be a class, not {}" , self . class( ) )
475+ } ) ?;
476+
477+ // Check if cls is a class, tuple, or union (matches CPython's order and message)
478+ if !cls. class ( ) . is ( vm. ctx . types . union_type ) {
479+ cls. check_class ( vm, || {
480+ format ! (
481+ "issubclass() arg 2 must be a class, a tuple of classes, or a union, not {}" ,
482+ cls. class( )
483+ )
434484 } ) ?;
435-
436- // Check if cls is a class, tuple, or union
437- if !cls. class ( ) . is ( vm. ctx . types . union_type ) {
438- self . check_cls ( cls, vm, || {
439- format ! (
440- "issubclass() arg 2 must be a class, a tuple of classes, or a union, not {}" ,
441- cls. class( )
442- )
443- } ) ?;
444- }
445-
446- self . abstract_issubclass ( cls, vm)
447485 }
486+
487+ self . abstract_issubclass ( cls, vm)
448488 }
449489
450490 /// Real issubclass check without going through __subclasscheck__
@@ -520,7 +560,7 @@ impl PyObject {
520560 Ok ( retval)
521561 } else {
522562 // Not a type object, check if it's a valid class
523- self . check_cls ( cls , vm, || {
563+ cls . check_class ( vm, || {
524564 format ! (
525565 "isinstance() arg 2 must be a type, a tuple of types, or a union, not {}" ,
526566 cls. class( )
0 commit comments