Skip to content

Commit 4ea0ff8

Browse files
committed
Update test_pyplot.py to include type on signature
1 parent f4693e3 commit 4ea0ff8

File tree

1 file changed

+50
-5
lines changed

1 file changed

+50
-5
lines changed

lib/matplotlib/tests/test_pyplot.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import ast
12
import difflib
23
import inspect
34

@@ -487,16 +488,60 @@ def test_matshow():
487488

488489

489490
def assert_signatures_identical(plt_meth, original_meth, remove_self_param=False):
490-
plt_params = inspect.signature(plt_meth).parameters
491-
original_params = inspect.signature(original_meth).parameters
491+
def get_src(meth):
492+
meth_src = Path(inspect.getfile(meth))
493+
meth_stub = meth_src.with_suffix(".pyi")
494+
return meth_stub if meth_stub.exists() else meth_src
495+
496+
def tree_loop(tree, name, class_):
497+
for item in tree.body:
498+
if class_ and isinstance(item, ast.ClassDef) and item.name == class_:
499+
return tree_loop(item, name, None)
500+
501+
if isinstance(item, ast.FunctionDef) and item.name == name:
502+
return item
503+
504+
raise ValueError(f"Cannot find {class_}.{name} in ast")
505+
506+
def get_signature(meth):
507+
qualname = meth.__qualname__
508+
class_ = None if "." not in qualname else qualname.split(".")[-2]
509+
path = get_src(meth)
510+
tree = ast.parse(path.read_text())
511+
node = tree_loop(tree, meth.__name__, class_)
512+
513+
params = dict(inspect.signature(meth).parameters)
514+
args = node.args
515+
for param in (*args.posonlyargs, *args.args, args.vararg, *args.kwonlyargs, args.kwarg):
516+
if param is None:
517+
continue
518+
if param.annotation is None:
519+
continue
520+
annotation = ast.unparse(param.annotation)
521+
params[param.arg] = params[param.arg].replace(annotation=annotation)
522+
523+
if node.returns is not None:
524+
return inspect.Signature(
525+
params.values(),
526+
return_annotation=ast.unparse(node.returns)
527+
)
528+
else:
529+
return inspect.Signature(params.values())
530+
531+
plt_sig = get_signature(plt_meth)
532+
original_sig = get_signature(original_meth)
533+
534+
assert plt_sig.return_annotation == original_sig.return_annotation
535+
536+
original_params = original_sig.parameters
492537
if remove_self_param:
493538
if next(iter(original_params)) not in ["self"]:
494-
raise AssertionError(f"{original_params} is not an instance method")
539+
raise ValueError(f"{original_sig} is not an instance method")
495540

496-
original_params = dict(original_params)
541+
original_params = original_params.copy()
497542
del original_params["self"]
498543

499-
assert plt_params == original_params
544+
assert plt_sig.parameters == original_params
500545

501546

502547
def test_setloglevel_signature():

0 commit comments

Comments
 (0)