Skip to content

Commit c78c2f4

Browse files
vagnermcjtimhoffmQuLogic
authored
Refactoring: Removing axis parameter from scales (#29988)
* Adding a decorator and Refactoring functions * Fixing Ruff Errors * Update scale.pyi * Adding new line to the end of scale.pyi * Update in docstring * Fixing Handle Function * Support optional axis in scales Updated my refactor based on the feedbacks received * Fixing ruff error * change in parameters and in decorator * parameter fix * minor change in pyi * Update lib/matplotlib/scale.py Co-authored-by: Tim Hoffmann <2836374+timhoffm@users.noreply.github.com> * Update lib/matplotlib/scale.py Co-authored-by: Tim Hoffmann <2836374+timhoffm@users.noreply.github.com> * Update lib/matplotlib/scale.py Co-authored-by: Tim Hoffmann <2836374+timhoffm@users.noreply.github.com> * Update lib/matplotlib/scale.pyi Co-authored-by: Tim Hoffmann <2836374+timhoffm@users.noreply.github.com> * Updating self and axis in pyi * returning scale_factory to default * Ruff checks * description fix * Update lib/matplotlib/scale.pyi Co-authored-by: Tim Hoffmann <2836374+timhoffm@users.noreply.github.com> * Undoing Unrelated Modifications * fixing mypy tests * Update lib/matplotlib/scale.pyi Co-authored-by: Tim Hoffmann <2836374+timhoffm@users.noreply.github.com> * keyword-argument suggestion Co-authored-by: Elliott Sales de Andrade <quantum.analyst@gmail.com> * kwargs pop before function call Co-authored-by: Elliott Sales de Andrade <quantum.analyst@gmail.com> --------- Co-authored-by: Tim Hoffmann <2836374+timhoffm@users.noreply.github.com> Co-authored-by: Elliott Sales de Andrade <quantum.analyst@gmail.com>
1 parent d05b43d commit c78c2f4

File tree

2 files changed

+68
-8
lines changed

2 files changed

+68
-8
lines changed

lib/matplotlib/scale.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
import inspect
3333
import textwrap
34+
from functools import wraps
3435

3536
import numpy as np
3637

@@ -103,13 +104,61 @@ def limit_range_for_scale(self, vmin, vmax, minpos):
103104
return vmin, vmax
104105

105106

107+
def _make_axis_parameter_optional(init_func):
108+
"""
109+
Decorator to allow leaving out the *axis* parameter in scale constructors.
110+
111+
This decorator ensures backward compatibility for scale classes that
112+
previously required an *axis* parameter. It allows constructors to be
113+
callerd with or without the *axis* parameter.
114+
115+
For simplicity, this does not handle the case when *axis*
116+
is passed as a keyword. However,
117+
scanning GitHub, there's no evidence that that is used anywhere.
118+
119+
Parameters
120+
----------
121+
init_func : callable
122+
The original __init__ method of a scale class.
123+
124+
Returns
125+
-------
126+
callable
127+
A wrapped version of *init_func* that handles the optional *axis*.
128+
129+
Notes
130+
-----
131+
If the wrapped constructor defines *axis* as its first argument, the
132+
parameter is preserved when present. Otherwise, the value `None` is injected
133+
as the first argument.
134+
135+
Examples
136+
--------
137+
>>> from matplotlib.scale import ScaleBase
138+
>>> class CustomScale(ScaleBase):
139+
... @_make_axis_parameter_optional
140+
... def __init__(self, axis, custom_param=1):
141+
... self.custom_param = custom_param
142+
"""
143+
@wraps(init_func)
144+
def wrapper(self, *args, **kwargs):
145+
if args and isinstance(args[0], mpl.axis.Axis):
146+
return init_func(self, *args, **kwargs)
147+
else:
148+
# Remove 'axis' from kwargs to avoid double assignment
149+
axis = kwargs.pop('axis', None)
150+
return init_func(self, axis, *args, **kwargs)
151+
return wrapper
152+
153+
106154
class LinearScale(ScaleBase):
107155
"""
108156
The default linear scale.
109157
"""
110158

111159
name = 'linear'
112160

161+
@_make_axis_parameter_optional
113162
def __init__(self, axis):
114163
# This method is present only to prevent inheritance of the base class'
115164
# constructor docstring, which would otherwise end up interpolated into
@@ -180,6 +229,7 @@ class FuncScale(ScaleBase):
180229

181230
name = 'function'
182231

232+
@_make_axis_parameter_optional
183233
def __init__(self, axis, functions):
184234
"""
185235
Parameters
@@ -279,7 +329,8 @@ class LogScale(ScaleBase):
279329
"""
280330
name = 'log'
281331

282-
def __init__(self, axis, *, base=10, subs=None, nonpositive="clip"):
332+
@_make_axis_parameter_optional
333+
def __init__(self, axis=None, *, base=10, subs=None, nonpositive="clip"):
283334
"""
284335
Parameters
285336
----------
@@ -330,6 +381,7 @@ class FuncScaleLog(LogScale):
330381

331382
name = 'functionlog'
332383

384+
@_make_axis_parameter_optional
333385
def __init__(self, axis, functions, base=10):
334386
"""
335387
Parameters
@@ -455,7 +507,8 @@ class SymmetricalLogScale(ScaleBase):
455507
"""
456508
name = 'symlog'
457509

458-
def __init__(self, axis, *, base=10, linthresh=2, subs=None, linscale=1):
510+
@_make_axis_parameter_optional
511+
def __init__(self, axis=None, *, base=10, linthresh=2, subs=None, linscale=1):
459512
self._transform = SymmetricalLogTransform(base, linthresh, linscale)
460513
self.subs = subs
461514

@@ -547,7 +600,8 @@ class AsinhScale(ScaleBase):
547600
1024: (256, 512)
548601
}
549602

550-
def __init__(self, axis, *, linear_width=1.0,
603+
@_make_axis_parameter_optional
604+
def __init__(self, axis=None, *, linear_width=1.0,
551605
base=10, subs='auto', **kwargs):
552606
"""
553607
Parameters
@@ -645,7 +699,8 @@ class LogitScale(ScaleBase):
645699
"""
646700
name = 'logit'
647701

648-
def __init__(self, axis, nonpositive='mask', *,
702+
@_make_axis_parameter_optional
703+
def __init__(self, axis=None, nonpositive='mask', *,
649704
one_half=r"\frac{1}{2}", use_overline=False):
650705
r"""
651706
Parameters

lib/matplotlib/scale.pyi

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ class ScaleBase:
1515

1616
class LinearScale(ScaleBase):
1717
name: str
18+
def __init__(
19+
self,
20+
axis: Axis | None,
21+
) -> None: ...
1822

1923
class FuncTransform(Transform):
2024
input_dims: int
@@ -57,7 +61,7 @@ class LogScale(ScaleBase):
5761
subs: Iterable[int] | None
5862
def __init__(
5963
self,
60-
axis: Axis | None,
64+
axis: Axis | None = ...,
6165
*,
6266
base: float = ...,
6367
subs: Iterable[int] | None = ...,
@@ -104,7 +108,7 @@ class SymmetricalLogScale(ScaleBase):
104108
subs: Iterable[int] | None
105109
def __init__(
106110
self,
107-
axis: Axis | None,
111+
axis: Axis | None = ...,
108112
*,
109113
base: float = ...,
110114
linthresh: float = ...,
@@ -138,7 +142,7 @@ class AsinhScale(ScaleBase):
138142
auto_tick_multipliers: dict[int, tuple[int, ...]]
139143
def __init__(
140144
self,
141-
axis: Axis | None,
145+
axis: Axis | None = ...,
142146
*,
143147
linear_width: float = ...,
144148
base: float = ...,
@@ -165,7 +169,7 @@ class LogitScale(ScaleBase):
165169
name: str
166170
def __init__(
167171
self,
168-
axis: Axis | None,
172+
axis: Axis | None = ...,
169173
nonpositive: Literal["mask", "clip"] = ...,
170174
*,
171175
one_half: str = ...,
@@ -176,3 +180,4 @@ class LogitScale(ScaleBase):
176180
def get_scale_names() -> list[str]: ...
177181
def scale_factory(scale: str, axis: Axis, **kwargs) -> ScaleBase: ...
178182
def register_scale(scale_class: type[ScaleBase]) -> None: ...
183+
def _make_axis_parameter_optional(init_func: Callable[..., None]) -> Callable[..., None]: ...

0 commit comments

Comments
 (0)