@@ -371,80 +371,120 @@ impl PyObject {
371
371
} )
372
372
}
373
373
374
- // Equivalent to check_class. Masks Attribute errors (into TypeErrors) and lets everything
375
- // else go through .
376
- fn check_cls < F > ( & self , cls : & PyObject , vm : & VirtualMachine , msg : F ) -> PyResult
374
+ // Equivalent to CPython's check_class. Returns Ok(()) if cls is a valid class,
375
+ // Err with TypeError if not. Uses abstract_get_bases internally .
376
+ fn check_class < F > ( & self , vm : & VirtualMachine , msg : F ) -> PyResult < ( ) >
377
377
where
378
378
F : Fn ( ) -> String ,
379
379
{
380
- cls. get_attr ( identifier ! ( vm , __bases__ ) , vm ) . map_err ( |e| {
381
- // Only mask AttributeErrors.
382
- if e . class ( ) . is ( vm . ctx . exceptions . attribute_error ) {
383
- vm . new_type_error ( msg ( ) )
384
- } else {
385
- e
380
+ let cls = self ;
381
+ match cls . abstract_get_bases ( vm ) ? {
382
+ Some ( _bases ) => Ok ( ( ) ) , // Has __bases__, it's a valid class
383
+ None => {
384
+ // No __bases__ or __bases__ is not a tuple
385
+ Err ( vm . new_type_error ( msg ( ) ) )
386
386
}
387
- } )
387
+ }
388
388
}
389
389
390
- fn abstract_issubclass ( & self , cls : & PyObject , vm : & VirtualMachine ) -> PyResult < bool > {
391
- let mut derived = self ;
392
- let mut first_item: PyObjectRef ;
393
- loop {
394
- if derived. is ( cls) {
395
- return Ok ( true ) ;
390
+ /// abstract_get_bases() has logically 4 return states:
391
+ /// 1. getattr(cls, '__bases__') could raise an AttributeError
392
+ /// 2. getattr(cls, '__bases__') could raise some other exception
393
+ /// 3. getattr(cls, '__bases__') could return a tuple
394
+ /// 4. getattr(cls, '__bases__') could return something other than a tuple
395
+ ///
396
+ /// Only state #3 returns Some(tuple). AttributeErrors are masked by returning None.
397
+ /// If an object other than a tuple comes out of __bases__, then again, None is returned.
398
+ /// Other exceptions are propagated.
399
+ fn abstract_get_bases ( & self , vm : & VirtualMachine ) -> PyResult < Option < PyTupleRef > > {
400
+ match vm. get_attribute_opt ( self . to_owned ( ) , identifier ! ( vm, __bases__) ) ? {
401
+ Some ( bases) => {
402
+ // Check if it's a tuple
403
+ match PyTupleRef :: try_from_object ( vm, bases) {
404
+ Ok ( tuple) => Ok ( Some ( tuple) ) ,
405
+ Err ( _) => Ok ( None ) , // Not a tuple, return None
406
+ }
396
407
}
408
+ None => Ok ( None ) , // AttributeError was masked
409
+ }
410
+ }
397
411
398
- let bases = derived. get_attr ( identifier ! ( vm, __bases__) , vm) ?;
399
- let tuple = PyTupleRef :: try_from_object ( vm, bases) ?;
412
+ fn abstract_issubclass ( & self , cls : & PyObject , vm : & VirtualMachine ) -> PyResult < bool > {
413
+ // # Safety: The lifetime of `derived` is forced to be ignored
414
+ let bases = unsafe {
415
+ let mut derived = self ;
416
+ // First loop: handle single inheritance without recursion
417
+ loop {
418
+ if derived. is ( cls) {
419
+ return Ok ( true ) ;
420
+ }
400
421
401
- let n = tuple. len ( ) ;
402
- match n {
403
- 0 => {
422
+ let Some ( bases) = derived. abstract_get_bases ( vm) ? else {
404
423
return Ok ( false ) ;
405
- }
406
- 1 => {
407
- first_item = tuple[ 0 ] . clone ( ) ;
408
- derived = & first_item;
409
- continue ;
410
- }
411
- _ => {
412
- for i in 0 ..n {
413
- let check = vm. with_recursion ( "in abstract_issubclass" , || {
414
- tuple[ i] . abstract_issubclass ( cls, vm)
415
- } ) ?;
416
- if check {
417
- return Ok ( true ) ;
418
- }
424
+ } ;
425
+ let n = bases. len ( ) ;
426
+ match n {
427
+ 0 => return Ok ( false ) ,
428
+ 1 => {
429
+ // Avoid recursion in the single inheritance case
430
+ // # safety
431
+ // Intention:
432
+ // ```
433
+ // derived = bases.as_slice()[0].as_object();
434
+ // ```
435
+ // Though type-system cannot guarantee, derived does live long enough in the loop.
436
+ derived = & * ( bases. as_slice ( ) [ 0 ] . as_object ( ) as * const _ ) ;
437
+ continue ;
438
+ }
439
+ _ => {
440
+ // Multiple inheritance - break out to handle recursively
441
+ break bases;
419
442
}
420
443
}
421
444
}
445
+ } ;
422
446
423
- return Ok ( false ) ;
447
+ // Second loop: handle multiple inheritance with recursion
448
+ // At this point we know n >= 2
449
+ let n = bases. len ( ) ;
450
+ debug_assert ! ( n >= 2 ) ;
451
+
452
+ for i in 0 ..n {
453
+ let result = vm. with_recursion ( "in __issubclass__" , || {
454
+ bases. as_slice ( ) [ i] . abstract_issubclass ( cls, vm)
455
+ } ) ?;
456
+ if result {
457
+ return Ok ( true ) ;
458
+ }
424
459
}
460
+
461
+ Ok ( false )
425
462
}
426
463
427
464
fn recursive_issubclass ( & self , cls : & PyObject , vm : & VirtualMachine ) -> PyResult < bool > {
428
- if let ( Ok ( obj) , Ok ( cls) ) = ( self . try_to_ref :: < PyType > ( vm) , cls. try_to_ref :: < PyType > ( vm) ) {
429
- Ok ( obj. fast_issubclass ( cls) )
430
- } else {
431
- // Check if derived is a class
432
- self . check_cls ( self , vm, || {
433
- format ! ( "issubclass() arg 1 must be a class, not {}" , self . class( ) )
465
+ // Fast path for both being types (matches CPython's PyType_Check)
466
+ if let Some ( cls) = PyType :: check ( cls)
467
+ && let Some ( derived) = PyType :: check ( self )
468
+ {
469
+ // PyType_IsSubtype equivalent
470
+ return Ok ( derived. is_subtype ( cls) ) ;
471
+ }
472
+ // Check if derived is a class
473
+ self . check_class ( vm, || {
474
+ format ! ( "issubclass() arg 1 must be a class, not {}" , self . class( ) )
475
+ } ) ?;
476
+
477
+ // Check if cls is a class, tuple, or union (matches CPython's order and message)
478
+ if !cls. class ( ) . is ( vm. ctx . types . union_type ) {
479
+ cls. check_class ( vm, || {
480
+ format ! (
481
+ "issubclass() arg 2 must be a class, a tuple of classes, or a union, not {}" ,
482
+ cls. class( )
483
+ )
434
484
} ) ?;
435
-
436
- // Check if cls is a class, tuple, or union
437
- if !cls. class ( ) . is ( vm. ctx . types . union_type ) {
438
- self . check_cls ( cls, vm, || {
439
- format ! (
440
- "issubclass() arg 2 must be a class, a tuple of classes, or a union, not {}" ,
441
- cls. class( )
442
- )
443
- } ) ?;
444
- }
445
-
446
- self . abstract_issubclass ( cls, vm)
447
485
}
486
+
487
+ self . abstract_issubclass ( cls, vm)
448
488
}
449
489
450
490
/// Real issubclass check without going through __subclasscheck__
@@ -520,7 +560,7 @@ impl PyObject {
520
560
Ok ( retval)
521
561
} else {
522
562
// Not a type object, check if it's a valid class
523
- self . check_cls ( cls , vm, || {
563
+ cls . check_class ( vm, || {
524
564
format ! (
525
565
"isinstance() arg 2 must be a type, a tuple of types, or a union, not {}" ,
526
566
cls. class( )
0 commit comments