Skip to content

Commit 47a7a00

Browse files
authored
Rework issubclass (#5867)
* check_exact * check_class * Type compatibility tools * abstract_issubclass * recursive_issubclass
1 parent 7a6e5c4 commit 47a7a00

File tree

2 files changed

+123
-55
lines changed

2 files changed

+123
-55
lines changed

vm/src/builtins/type.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,18 @@ fn downcast_qualname(value: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyRef<
158158
}
159159
}
160160

161+
fn is_subtype_with_mro(a_mro: &[PyTypeRef], a: &Py<PyType>, b: &Py<PyType>) -> bool {
162+
if a.is(b) {
163+
return true;
164+
}
165+
for item in a_mro {
166+
if item.is(b) {
167+
return true;
168+
}
169+
}
170+
false
171+
}
172+
161173
impl PyType {
162174
pub fn new_simple_heap(
163175
name: &str,
@@ -197,6 +209,12 @@ impl PyType {
197209
Self::new_heap_inner(base, bases, attrs, slots, heaptype_ext, metaclass, ctx)
198210
}
199211

212+
/// Equivalent to CPython's PyType_Check macro
213+
/// Checks if obj is an instance of type (or its subclass)
214+
pub(crate) fn check(obj: &PyObject) -> Option<&Py<Self>> {
215+
obj.downcast_ref::<Self>()
216+
}
217+
200218
fn resolve_mro(bases: &[PyRef<Self>]) -> Result<Vec<PyTypeRef>, String> {
201219
// Check for duplicates in bases.
202220
let mut unique_bases = HashSet::new();
@@ -439,6 +457,16 @@ impl PyType {
439457
}
440458

441459
impl Py<PyType> {
460+
pub(crate) fn is_subtype(&self, other: &Py<PyType>) -> bool {
461+
is_subtype_with_mro(&self.mro.read(), self, other)
462+
}
463+
464+
/// Equivalent to CPython's PyType_CheckExact macro
465+
/// Checks if obj is exactly a type (not a subclass)
466+
pub fn check_exact<'a>(obj: &'a PyObject, vm: &VirtualMachine) -> Option<&'a Py<PyType>> {
467+
obj.downcast_ref_if_exact::<PyType>(vm)
468+
}
469+
442470
/// Determines if `subclass` is actually a subclass of `cls`, this doesn't call __subclasscheck__,
443471
/// so only use this if `cls` is known to have not overridden the base __subclasscheck__ magic
444472
/// method.

vm/src/protocol/object.rs

Lines changed: 95 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -371,80 +371,120 @@ impl PyObject {
371371
})
372372
}
373373

374-
// Equivalent to check_class. Masks Attribute errors (into TypeErrors) and lets everything
375-
// else go through.
376-
fn check_cls<F>(&self, cls: &PyObject, vm: &VirtualMachine, msg: F) -> PyResult
374+
// Equivalent to CPython's check_class. Returns Ok(()) if cls is a valid class,
375+
// Err with TypeError if not. Uses abstract_get_bases internally.
376+
fn check_class<F>(&self, vm: &VirtualMachine, msg: F) -> PyResult<()>
377377
where
378378
F: Fn() -> String,
379379
{
380-
cls.get_attr(identifier!(vm, __bases__), vm).map_err(|e| {
381-
// Only mask AttributeErrors.
382-
if e.class().is(vm.ctx.exceptions.attribute_error) {
383-
vm.new_type_error(msg())
384-
} else {
385-
e
380+
let cls = self;
381+
match cls.abstract_get_bases(vm)? {
382+
Some(_bases) => Ok(()), // Has __bases__, it's a valid class
383+
None => {
384+
// No __bases__ or __bases__ is not a tuple
385+
Err(vm.new_type_error(msg()))
386386
}
387-
})
387+
}
388388
}
389389

390-
fn abstract_issubclass(&self, cls: &PyObject, vm: &VirtualMachine) -> PyResult<bool> {
391-
let mut derived = self;
392-
let mut first_item: PyObjectRef;
393-
loop {
394-
if derived.is(cls) {
395-
return Ok(true);
390+
/// abstract_get_bases() has logically 4 return states:
391+
/// 1. getattr(cls, '__bases__') could raise an AttributeError
392+
/// 2. getattr(cls, '__bases__') could raise some other exception
393+
/// 3. getattr(cls, '__bases__') could return a tuple
394+
/// 4. getattr(cls, '__bases__') could return something other than a tuple
395+
///
396+
/// Only state #3 returns Some(tuple). AttributeErrors are masked by returning None.
397+
/// If an object other than a tuple comes out of __bases__, then again, None is returned.
398+
/// Other exceptions are propagated.
399+
fn abstract_get_bases(&self, vm: &VirtualMachine) -> PyResult<Option<PyTupleRef>> {
400+
match vm.get_attribute_opt(self.to_owned(), identifier!(vm, __bases__))? {
401+
Some(bases) => {
402+
// Check if it's a tuple
403+
match PyTupleRef::try_from_object(vm, bases) {
404+
Ok(tuple) => Ok(Some(tuple)),
405+
Err(_) => Ok(None), // Not a tuple, return None
406+
}
396407
}
408+
None => Ok(None), // AttributeError was masked
409+
}
410+
}
397411

398-
let bases = derived.get_attr(identifier!(vm, __bases__), vm)?;
399-
let tuple = PyTupleRef::try_from_object(vm, bases)?;
412+
fn abstract_issubclass(&self, cls: &PyObject, vm: &VirtualMachine) -> PyResult<bool> {
413+
// # Safety: The lifetime of `derived` is forced to be ignored
414+
let bases = unsafe {
415+
let mut derived = self;
416+
// First loop: handle single inheritance without recursion
417+
loop {
418+
if derived.is(cls) {
419+
return Ok(true);
420+
}
400421

401-
let n = tuple.len();
402-
match n {
403-
0 => {
422+
let Some(bases) = derived.abstract_get_bases(vm)? else {
404423
return Ok(false);
405-
}
406-
1 => {
407-
first_item = tuple[0].clone();
408-
derived = &first_item;
409-
continue;
410-
}
411-
_ => {
412-
for i in 0..n {
413-
let check = vm.with_recursion("in abstract_issubclass", || {
414-
tuple[i].abstract_issubclass(cls, vm)
415-
})?;
416-
if check {
417-
return Ok(true);
418-
}
424+
};
425+
let n = bases.len();
426+
match n {
427+
0 => return Ok(false),
428+
1 => {
429+
// Avoid recursion in the single inheritance case
430+
// # safety
431+
// Intention:
432+
// ```
433+
// derived = bases.as_slice()[0].as_object();
434+
// ```
435+
// Though type-system cannot guarantee, derived does live long enough in the loop.
436+
derived = &*(bases.as_slice()[0].as_object() as *const _);
437+
continue;
438+
}
439+
_ => {
440+
// Multiple inheritance - break out to handle recursively
441+
break bases;
419442
}
420443
}
421444
}
445+
};
422446

423-
return Ok(false);
447+
// Second loop: handle multiple inheritance with recursion
448+
// At this point we know n >= 2
449+
let n = bases.len();
450+
debug_assert!(n >= 2);
451+
452+
for i in 0..n {
453+
let result = vm.with_recursion("in __issubclass__", || {
454+
bases.as_slice()[i].abstract_issubclass(cls, vm)
455+
})?;
456+
if result {
457+
return Ok(true);
458+
}
424459
}
460+
461+
Ok(false)
425462
}
426463

427464
fn recursive_issubclass(&self, cls: &PyObject, vm: &VirtualMachine) -> PyResult<bool> {
428-
if let (Ok(obj), Ok(cls)) = (self.try_to_ref::<PyType>(vm), cls.try_to_ref::<PyType>(vm)) {
429-
Ok(obj.fast_issubclass(cls))
430-
} else {
431-
// Check if derived is a class
432-
self.check_cls(self, vm, || {
433-
format!("issubclass() arg 1 must be a class, not {}", self.class())
465+
// Fast path for both being types (matches CPython's PyType_Check)
466+
if let Some(cls) = PyType::check(cls)
467+
&& let Some(derived) = PyType::check(self)
468+
{
469+
// PyType_IsSubtype equivalent
470+
return Ok(derived.is_subtype(cls));
471+
}
472+
// Check if derived is a class
473+
self.check_class(vm, || {
474+
format!("issubclass() arg 1 must be a class, not {}", self.class())
475+
})?;
476+
477+
// Check if cls is a class, tuple, or union (matches CPython's order and message)
478+
if !cls.class().is(vm.ctx.types.union_type) {
479+
cls.check_class(vm, || {
480+
format!(
481+
"issubclass() arg 2 must be a class, a tuple of classes, or a union, not {}",
482+
cls.class()
483+
)
434484
})?;
435-
436-
// Check if cls is a class, tuple, or union
437-
if !cls.class().is(vm.ctx.types.union_type) {
438-
self.check_cls(cls, vm, || {
439-
format!(
440-
"issubclass() arg 2 must be a class, a tuple of classes, or a union, not {}",
441-
cls.class()
442-
)
443-
})?;
444-
}
445-
446-
self.abstract_issubclass(cls, vm)
447485
}
486+
487+
self.abstract_issubclass(cls, vm)
448488
}
449489

450490
/// Real issubclass check without going through __subclasscheck__
@@ -520,7 +560,7 @@ impl PyObject {
520560
Ok(retval)
521561
} else {
522562
// Not a type object, check if it's a valid class
523-
self.check_cls(cls, vm, || {
563+
cls.check_class(vm, || {
524564
format!(
525565
"isinstance() arg 2 must be a type, a tuple of types, or a union, not {}",
526566
cls.class()

0 commit comments

Comments
 (0)