|
| 1 | +import ast |
1 | 2 | import difflib
|
2 | 3 | import inspect
|
3 | 4 |
|
@@ -487,16 +488,60 @@ def test_matshow():
|
487 | 488 |
|
488 | 489 |
|
489 | 490 | def assert_signatures_identical(plt_meth, original_meth, remove_self_param=False):
|
490 |
| - plt_params = inspect.signature(plt_meth).parameters |
491 |
| - original_params = inspect.signature(original_meth).parameters |
| 491 | + def get_src(meth): |
| 492 | + meth_src = Path(inspect.getfile(meth)) |
| 493 | + meth_stub = meth_src.with_suffix(".pyi") |
| 494 | + return meth_stub if meth_stub.exists() else meth_src |
| 495 | + |
| 496 | + def tree_loop(tree, name, class_): |
| 497 | + for item in tree.body: |
| 498 | + if class_ and isinstance(item, ast.ClassDef) and item.name == class_: |
| 499 | + return tree_loop(item, name, None) |
| 500 | + |
| 501 | + if isinstance(item, ast.FunctionDef) and item.name == name: |
| 502 | + return item |
| 503 | + |
| 504 | + raise ValueError(f"Cannot find {class_}.{name} in ast") |
| 505 | + |
| 506 | + def get_signature(meth): |
| 507 | + qualname = meth.__qualname__ |
| 508 | + class_ = None if "." not in qualname else qualname.split(".")[-2] |
| 509 | + path = get_src(meth) |
| 510 | + tree = ast.parse(path.read_text()) |
| 511 | + node = tree_loop(tree, meth.__name__, class_) |
| 512 | + |
| 513 | + params = dict(inspect.signature(meth).parameters) |
| 514 | + args = node.args |
| 515 | + for param in (*args.posonlyargs, *args.args, args.vararg, *args.kwonlyargs, args.kwarg): |
| 516 | + if param is None: |
| 517 | + continue |
| 518 | + if param.annotation is None: |
| 519 | + continue |
| 520 | + annotation = ast.unparse(param.annotation) |
| 521 | + params[param.arg] = params[param.arg].replace(annotation=annotation) |
| 522 | + |
| 523 | + if node.returns is not None: |
| 524 | + return inspect.Signature( |
| 525 | + params.values(), |
| 526 | + return_annotation=ast.unparse(node.returns) |
| 527 | + ) |
| 528 | + else: |
| 529 | + return inspect.Signature(params.values()) |
| 530 | + |
| 531 | + plt_sig = get_signature(plt_meth) |
| 532 | + original_sig = get_signature(original_meth) |
| 533 | + |
| 534 | + assert plt_sig.return_annotation == original_sig.return_annotation |
| 535 | + |
| 536 | + original_params = original_sig.parameters |
492 | 537 | if remove_self_param:
|
493 | 538 | if next(iter(original_params)) not in ["self"]:
|
494 |
| - raise AssertionError(f"{original_params} is not an instance method") |
| 539 | + raise ValueError(f"{original_sig} is not an instance method") |
495 | 540 |
|
496 |
| - original_params = dict(original_params) |
| 541 | + original_params = original_params.copy() |
497 | 542 | del original_params["self"]
|
498 | 543 |
|
499 |
| - assert plt_params == original_params |
| 544 | + assert plt_sig.parameters == original_params |
500 | 545 |
|
501 | 546 |
|
502 | 547 | def test_setloglevel_signature():
|
|
0 commit comments