Skip to content

Commit d6e04aa

Browse files
committed
Comparable for typing
1 parent c2f8d5b commit d6e04aa

File tree

2 files changed

+85
-31
lines changed

2 files changed

+85
-31
lines changed

Lib/test/test_typing.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -387,8 +387,6 @@ def test_basic_plain(self):
387387
self.assertIs(T.__infer_variance__, False)
388388
self.assertEqual(T.__module__, __name__)
389389

390-
# TODO: RUSTPYTHON
391-
@unittest.expectedFailure
392390
def test_basic_with_exec(self):
393391
ns = {}
394392
exec('from typing import TypeVar; T = TypeVar("T", bound=float)', ns, ns)
@@ -1286,8 +1284,6 @@ def test_module(self):
12861284
Ts = TypeVarTuple('Ts')
12871285
self.assertEqual(Ts.__module__, __name__)
12881286

1289-
# TODO: RUSTPYTHON
1290-
@unittest.expectedFailure
12911287
def test_exec(self):
12921288
ns = {}
12931289
exec('from typing import TypeVarTuple; Ts = TypeVarTuple("Ts")', ns)
@@ -8788,8 +8784,6 @@ def test_basic_plain(self):
87888784
self.assertEqual(P.__name__, 'P')
87898785
self.assertEqual(P.__module__, __name__)
87908786

8791-
# TODO: RUSTPYTHON
8792-
@unittest.expectedFailure
87938787
def test_basic_with_exec(self):
87948788
ns = {}
87958789
exec('from typing import ParamSpec; P = ParamSpec("P")', ns, ns)
@@ -9177,8 +9171,6 @@ def test_paramspec_gets_copied(self):
91779171
self.assertEqual(C2[Concatenate[str, P2]].__parameters__, (P2,))
91789172
self.assertEqual(C2[Concatenate[T, P2]].__parameters__, (T, P2))
91799173

9180-
# TODO: RUSTPYTHON
9181-
@unittest.expectedFailure
91829174
def test_cannot_subclass(self):
91839175
with self.assertRaisesRegex(TypeError, NOT_A_BASE_TYPE % 'ParamSpec'):
91849176
class C(ParamSpec): pass

vm/src/stdlib/typing.rs

Lines changed: 85 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef<PyModule> {
1313
#[pymodule(name = "_typing")]
1414
pub(crate) mod decl {
1515
use crate::{
16-
AsObject, PyObjectRef, PyPayload, PyResult, VirtualMachine,
16+
AsObject, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine,
1717
builtins::{PyGenericAlias, PyTupleRef, PyTypeRef, pystr::AsPyStr},
18-
function::{FuncArgs, IntoFuncArgs},
18+
function::{FuncArgs, IntoFuncArgs, PyComparisonValue},
1919
protocol::PyNumberMethods,
20-
types::{AsNumber, Constructor, Representable},
20+
types::{AsNumber, Comparable, Constructor, PyComparisonOp, Representable},
2121
};
2222

2323
pub(crate) fn _call_typing_func_object<'a>(
@@ -751,7 +751,7 @@ pub(crate) mod decl {
751751
pub(crate) struct ParamSpecArgs {
752752
__origin__: PyObjectRef,
753753
}
754-
#[pyclass(flags(BASETYPE), with(Constructor, Representable))]
754+
#[pyclass(with(Constructor, Representable, Comparable))]
755755
impl ParamSpecArgs {
756756
#[pymethod(magic)]
757757
fn mro_entries(&self, _bases: PyObjectRef, vm: &VirtualMachine) -> PyResult {
@@ -762,15 +762,6 @@ pub(crate) mod decl {
762762
fn origin(&self) -> PyObjectRef {
763763
self.__origin__.clone()
764764
}
765-
766-
#[pymethod(magic)]
767-
fn eq(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
768-
// Check if other has __origin__ attribute
769-
if let Ok(other_origin) = other.get_attr("__origin__", vm) {
770-
return Ok(self.__origin__.is(&other_origin));
771-
}
772-
Ok(false)
773-
}
774765
}
775766

776767
impl Constructor for ParamSpecArgs {
@@ -794,14 +785,52 @@ pub(crate) mod decl {
794785
}
795786
}
796787

788+
impl Comparable for ParamSpecArgs {
789+
fn cmp(
790+
zelf: &crate::Py<Self>,
791+
other: &PyObject,
792+
op: PyComparisonOp,
793+
vm: &VirtualMachine,
794+
) -> PyResult<PyComparisonValue> {
795+
fn eq(
796+
zelf: &crate::Py<ParamSpecArgs>,
797+
other: PyObjectRef,
798+
vm: &VirtualMachine,
799+
) -> PyResult<bool> {
800+
// Check if other has __origin__ attribute
801+
if let Ok(other_origin) = other.get_attr("__origin__", vm) {
802+
return Ok(zelf.__origin__.is(&other_origin));
803+
}
804+
Ok(false)
805+
}
806+
match op {
807+
PyComparisonOp::Eq => {
808+
if let Ok(result) = eq(zelf, other.to_owned(), vm) {
809+
Ok(result.into())
810+
} else {
811+
Ok(PyComparisonValue::NotImplemented)
812+
}
813+
}
814+
PyComparisonOp::Ne => {
815+
if let Ok(result) = eq(zelf, other.to_owned(), vm) {
816+
Ok((!result).into())
817+
} else {
818+
Ok(PyComparisonValue::NotImplemented)
819+
}
820+
}
821+
_ => Ok(PyComparisonValue::NotImplemented),
822+
}
823+
}
824+
}
825+
797826
#[pyattr]
798827
#[pyclass(name = "ParamSpecKwargs", module = "typing")]
799828
#[derive(Debug, PyPayload)]
800829
#[allow(dead_code)]
801830
pub(crate) struct ParamSpecKwargs {
802831
__origin__: PyObjectRef,
803832
}
804-
#[pyclass(flags(BASETYPE), with(Constructor, Representable))]
833+
#[pyclass(with(Constructor, Representable, Comparable))]
805834
impl ParamSpecKwargs {
806835
#[pymethod(magic)]
807836
fn mro_entries(&self, _bases: PyObjectRef, vm: &VirtualMachine) -> PyResult {
@@ -812,15 +841,6 @@ pub(crate) mod decl {
812841
fn origin(&self) -> PyObjectRef {
813842
self.__origin__.clone()
814843
}
815-
816-
#[pymethod(magic)]
817-
fn eq(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
818-
// Check if other has __origin__ attribute
819-
if let Ok(other_origin) = other.get_attr("__origin__", vm) {
820-
return Ok(self.__origin__.is(&other_origin));
821-
}
822-
Ok(false)
823-
}
824844
}
825845

826846
impl Constructor for ParamSpecKwargs {
@@ -844,6 +864,44 @@ pub(crate) mod decl {
844864
}
845865
}
846866

867+
impl Comparable for ParamSpecKwargs {
868+
fn cmp(
869+
zelf: &crate::Py<Self>,
870+
other: &PyObject,
871+
op: PyComparisonOp,
872+
vm: &VirtualMachine,
873+
) -> PyResult<PyComparisonValue> {
874+
fn eq(
875+
zelf: &crate::Py<ParamSpecKwargs>,
876+
other: PyObjectRef,
877+
vm: &VirtualMachine,
878+
) -> PyResult<bool> {
879+
// Check if other has __origin__ attribute
880+
if let Ok(other_origin) = other.get_attr("__origin__", vm) {
881+
return Ok(zelf.__origin__.is(&other_origin));
882+
}
883+
Ok(false)
884+
}
885+
match op {
886+
PyComparisonOp::Eq => {
887+
if let Ok(result) = eq(zelf, other.to_owned(), vm) {
888+
Ok(result.into())
889+
} else {
890+
Ok(PyComparisonValue::NotImplemented)
891+
}
892+
}
893+
PyComparisonOp::Ne => {
894+
if let Ok(result) = eq(zelf, other.to_owned(), vm) {
895+
Ok((!result).into())
896+
} else {
897+
Ok(PyComparisonValue::NotImplemented)
898+
}
899+
}
900+
_ => Ok(PyComparisonValue::NotImplemented),
901+
}
902+
}
903+
}
904+
847905
#[pyattr]
848906
#[pyclass(name)]
849907
#[derive(Debug, PyPayload)]
@@ -922,13 +980,17 @@ pub(crate) mod decl {
922980
if let Ok(name_str) = module_name.str(vm) {
923981
let name = name_str.as_str();
924982
// CPython sets __module__ to None for builtins and <...> modules
983+
// Also set to None for exec contexts (no __name__ in globals means exec)
925984
if name == "builtins" || name.starts_with('<') {
926985
// Don't set __module__ attribute at all (CPython behavior)
927986
// This allows the typing module to handle it
928987
return Ok(());
929988
}
930989
}
931990
obj.set_attr("__module__", module_name, vm)?;
991+
} else {
992+
// If no module name is found (e.g., in exec context), set __module__ to None
993+
obj.set_attr("__module__", vm.ctx.none(), vm)?;
932994
}
933995
Ok(())
934996
}

0 commit comments

Comments
 (0)