Skip to content

Commit 672e9a4

Browse files
committed
Fix mro representation to contain self
1 parent 6cf5c50 commit 672e9a4

File tree

5 files changed

+27
-29
lines changed

5 files changed

+27
-29
lines changed

vm/src/builtins/tuple.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -537,9 +537,9 @@ impl<T: TransmuteFromObject> AsRef<[T]> for PyTupleTyped<T> {
537537
}
538538

539539
impl<T: TransmuteFromObject> PyTupleTyped<T> {
540-
pub fn empty(vm: &VirtualMachine) -> Self {
540+
pub fn empty(ctx: &Context) -> Self {
541541
Self {
542-
tuple: vm.ctx.empty_tuple.clone(),
542+
tuple: ctx.empty_tuple.clone(),
543543
_marker: PhantomData,
544544
}
545545
}

vm/src/builtins/type.rs

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ use std::{borrow::Borrow, collections::HashSet, ops::Deref, pin::Pin, ptr::NonNu
3737
pub struct PyType {
3838
pub base: Option<PyTypeRef>,
3939
pub bases: PyRwLock<Vec<PyTypeRef>>,
40-
pub mro: PyRwLock<Vec<PyTypeRef>>,
40+
pub mro: PyRwLock<Vec<PyTypeRef>>, // TODO: PyTypedTuple<PyTypeRef>
4141
pub subclasses: PyRwLock<Vec<PyRef<PyWeak>>>,
4242
pub attributes: PyRwLock<PyAttributes>,
4343
pub slots: PyTypeSlots,
@@ -48,7 +48,7 @@ unsafe impl crate::object::Traverse for PyType {
4848
fn traverse(&self, tracer_fn: &mut crate::object::TraverseFn<'_>) {
4949
self.base.traverse(tracer_fn);
5050
self.bases.traverse(tracer_fn);
51-
self.mro.traverse(tracer_fn);
51+
// self.mro.traverse(tracer_fn);
5252
self.subclasses.traverse(tracer_fn);
5353
self.attributes
5454
.read_recursive()
@@ -238,8 +238,6 @@ impl PyType {
238238
metaclass: PyRef<Self>,
239239
ctx: &Context,
240240
) -> Result<PyRef<Self>, String> {
241-
let mro = Self::resolve_mro(&bases)?;
242-
243241
if base.slots.flags.has_feature(PyTypeFlags::HAS_DICT) {
244242
slots.flags |= PyTypeFlags::HAS_DICT
245243
}
@@ -256,6 +254,7 @@ impl PyType {
256254
}
257255
}
258256

257+
let mro = Self::resolve_mro(&bases)?;
259258
let new_type = PyRef::new_ref(
260259
PyType {
261260
base: Some(base),
@@ -269,6 +268,7 @@ impl PyType {
269268
metaclass,
270269
None,
271270
);
271+
new_type.mro.write().insert(0, new_type.clone());
272272

273273
new_type.init_slots(ctx);
274274

@@ -300,7 +300,6 @@ impl PyType {
300300

301301
let bases = PyRwLock::new(vec![base.clone()]);
302302
let mro = base.mro_map_collect(|x| x.to_owned());
303-
304303
let new_type = PyRef::new_ref(
305304
PyType {
306305
base: Some(base),
@@ -314,6 +313,7 @@ impl PyType {
314313
metaclass,
315314
None,
316315
);
316+
new_type.mro.write().insert(0, new_type.clone());
317317

318318
let weakref_type = super::PyWeak::static_type();
319319
for base in new_type.bases.read().iter() {
@@ -332,7 +332,7 @@ impl PyType {
332332
#[allow(clippy::mutable_key_type)]
333333
let mut slot_name_set = std::collections::HashSet::new();
334334

335-
for cls in self.mro.read().iter() {
335+
for cls in self.mro.read()[1..].iter() {
336336
for &name in cls.attributes.read().keys() {
337337
if name == identifier!(ctx, __new__) {
338338
continue;
@@ -381,18 +381,15 @@ impl PyType {
381381
}
382382

383383
pub fn get_super_attr(&self, attr_name: &'static PyStrInterned) -> Option<PyObjectRef> {
384-
self.mro
385-
.read()
384+
self.mro.read()[1..]
386385
.iter()
387386
.find_map(|class| class.attributes.read().get(attr_name).cloned())
388387
}
389388

390389
// This is the internal has_attr implementation for fast lookup on a class.
391390
pub fn has_attr(&self, attr_name: &'static PyStrInterned) -> bool {
392391
self.attributes.read().contains_key(attr_name)
393-
|| self
394-
.mro
395-
.read()
392+
|| self.mro.read()[1..]
396393
.iter()
397394
.any(|c| c.attributes.read().contains_key(attr_name))
398395
}
@@ -401,10 +398,7 @@ impl PyType {
401398
// Gather all members here:
402399
let mut attributes = PyAttributes::default();
403400

404-
for bc in std::iter::once(self)
405-
.chain(self.mro.read().iter().map(|cls| -> &PyType { cls }))
406-
.rev()
407-
{
401+
for bc in self.mro.read().iter().map(|cls| -> &PyType { cls }).rev() {
408402
for (name, value) in bc.attributes.read().iter() {
409403
attributes.insert(name.to_owned(), value.clone());
410404
}
@@ -468,22 +462,21 @@ impl Py<PyType> {
468462
/// so only use this if `cls` is known to have not overridden the base __subclasscheck__ magic
469463
/// method.
470464
pub fn fast_issubclass(&self, cls: &impl Borrow<PyObject>) -> bool {
471-
self.as_object().is(cls.borrow()) || self.mro.read().iter().any(|c| c.is(cls.borrow()))
465+
self.as_object().is(cls.borrow()) || self.mro.read()[1..].iter().any(|c| c.is(cls.borrow()))
472466
}
473467

474468
pub fn mro_map_collect<F, R>(&self, f: F) -> Vec<R>
475469
where
476470
F: Fn(&Self) -> R,
477471
{
478-
std::iter::once(self)
479-
.chain(self.mro.read().iter().map(|x| x.deref()))
480-
.map(f)
481-
.collect()
472+
self.mro.read().iter().map(|x| x.deref()).map(f).collect()
482473
}
483474

484475
pub fn mro_collect(&self) -> Vec<PyRef<PyType>> {
485-
std::iter::once(self)
486-
.chain(self.mro.read().iter().map(|x| x.deref()))
476+
self.mro
477+
.read()
478+
.iter()
479+
.map(|x| x.deref())
487480
.map(|x| x.to_owned())
488481
.collect()
489482
}
@@ -497,7 +490,7 @@ impl Py<PyType> {
497490
if let Some(r) = f(self) {
498491
Some(r)
499492
} else {
500-
self.mro.read().iter().find_map(|cls| f(cls))
493+
self.mro.read()[1..].iter().find_map(|cls| f(cls))
501494
}
502495
}
503496

@@ -556,8 +549,10 @@ impl PyType {
556549
*zelf.bases.write() = bases;
557550
// Recursively update the mros of this class and all subclasses
558551
fn update_mro_recursively(cls: &PyType, vm: &VirtualMachine) -> PyResult<()> {
559-
*cls.mro.write() =
552+
let mut mro =
560553
PyType::resolve_mro(&cls.bases.read()).map_err(|msg| vm.new_type_error(msg))?;
554+
mro.insert(0, cls.mro.read()[0].to_owned());
555+
*cls.mro.write() = mro;
561556
for subclass in cls.subclasses.write().iter() {
562557
let subclass = subclass.upgrade().unwrap();
563558
let subclass: &PyType = subclass.payload().unwrap();

vm/src/frame.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1384,7 +1384,7 @@ impl ExecutingFrame<'_> {
13841384
fn import(&mut self, vm: &VirtualMachine, module_name: Option<&Py<PyStr>>) -> PyResult<()> {
13851385
let module_name = module_name.unwrap_or(vm.ctx.empty_str);
13861386
let from_list = <Option<PyTupleTyped<PyStrRef>>>::try_from_object(vm, self.pop_value())?
1387-
.unwrap_or_else(|| PyTupleTyped::empty(vm));
1387+
.unwrap_or_else(|| PyTupleTyped::empty(&vm.ctx));
13881388
let level = usize::try_from_object(vm, self.pop_value())?;
13891389

13901390
let module = vm.import_from(module_name, from_list, level)?;

vm/src/object/core.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1252,12 +1252,14 @@ pub(crate) fn init_type_hierarchy() -> (PyTypeRef, PyTypeRef, PyTypeRef) {
12521252
ptr::write(&mut (*type_type_ptr).typ, PyAtomicRef::from(type_type));
12531253

12541254
let object_type = PyTypeRef::from_raw(object_type_ptr.cast());
1255+
(*object_type_ptr).payload.mro = PyRwLock::new(vec![object_type.clone()]);
12551256

1256-
(*type_type_ptr).payload.mro = PyRwLock::new(vec![object_type.clone()]);
12571257
(*type_type_ptr).payload.bases = PyRwLock::new(vec![object_type.clone()]);
12581258
(*type_type_ptr).payload.base = Some(object_type.clone());
12591259

12601260
let type_type = PyTypeRef::from_raw(type_type_ptr.cast());
1261+
(*type_type_ptr).payload.mro =
1262+
PyRwLock::new(vec![type_type.clone(), object_type.clone()]);
12611263

12621264
(type_type, object_type)
12631265
}
@@ -1273,6 +1275,7 @@ pub(crate) fn init_type_hierarchy() -> (PyTypeRef, PyTypeRef, PyTypeRef) {
12731275
heaptype_ext: None,
12741276
};
12751277
let weakref_type = PyRef::new_ref(weakref_type, type_type.clone(), None);
1278+
weakref_type.mro.write().insert(0, weakref_type.clone());
12761279

12771280
object_type.subclasses.write().push(
12781281
type_type

vm/src/vm/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,7 @@ impl VirtualMachine {
580580
#[inline]
581581
pub fn import<'a>(&self, module_name: impl AsPyStr<'a>, level: usize) -> PyResult {
582582
let module_name = module_name.as_pystr(&self.ctx);
583-
let from_list = PyTupleTyped::empty(self);
583+
let from_list = PyTupleTyped::empty(&self.ctx);
584584
self.import_inner(module_name, from_list, level)
585585
}
586586

0 commit comments

Comments
 (0)