Skip to content
41 changes: 41 additions & 0 deletions docs/colorbars_legends.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,3 +469,44 @@
ax = axs[1]
ax.legend(hs2, loc="b", ncols=3, center=True, title="centered rows")
axs.format(xlabel="xlabel", ylabel="ylabel", suptitle="Legend formatting demo")
# %% [raw] raw_mimetype="text/restructuredtext"
# .. _ug_guides_decouple:
#
# Decoupling legend content and location
# --------------------------------------
#
# Sometimes you may want to generate a legend using handles from specific axes
# but place it relative to other axes. In UltraPlot, you can achieve this by passing
# both the `ax` and `ref` keywords to :func:`~ultraplot.figure.Figure.legend`
# (or :func:`~ultraplot.figure.Figure.colorbar`). The `ax` keyword specifies the
# axes used to generate the legend handles, while the `ref` keyword specifies the
# reference axes used to determine the legend location.
#
# For example, to draw a legend based on the handles in the second row of subplots
# but place it below the first row of subplots, you can use
# ``fig.legend(ax=axs[1, :], ref=axs[0, :], loc='bottom')``. If ``ref`` is a list
# of axes, UltraPlot intelligently infers the span (width or height) and anchors
# the legend to the appropriate outer edge (e.g., the bottom-most axis for ``loc='bottom'``
# or the right-most axis for ``loc='right'``).

# %%
import numpy as np

import ultraplot as uplt

fig, axs = uplt.subplots(nrows=2, ncols=2, refwidth=2, share=False)
axs.format(abc="A.", suptitle="Decoupled legend location demo")

# Plot data on all axes
state = np.random.RandomState(51423)
data = (state.rand(20, 4) - 0.5).cumsum(axis=0)
for ax in axs:
ax.plot(data, cycle="mplotcolors", labels=list("abcd"))

# Legend 1: Content from Row 2 (ax=axs[1, :]), Location below Row 1 (ref=axs[0, :])
# This places a legend describing the bottom row data underneath the top row.
fig.legend(ax=axs[1, :], ref=axs[0, :], loc="bottom", title="Data from Row 2")

# Legend 2: Content from Row 1 (ax=axs[0, :]), Location below Row 2 (ref=axs[1, :])
# This places a legend describing the top row data underneath the bottom row.
fig.legend(ax=axs[0, :], ref=axs[1, :], loc="bottom", title="Data from Row 1")
235 changes: 214 additions & 21 deletions ultraplot/figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2594,6 +2594,8 @@ def colorbar(
"""
# Backwards compatibility
ax = kwargs.pop("ax", None)
ref = kwargs.pop("ref", None)
loc_ax = ref if ref is not None else ax
cax = kwargs.pop("cax", None)
if isinstance(values, maxes.Axes):
cax = _not_none(cax_positional=values, cax=cax)
Expand All @@ -2613,20 +2615,102 @@ def colorbar(
with context._state_context(cax, _internal_call=True): # do not wrap pcolor
cb = super().colorbar(mappable, cax=cax, **kwargs)
# Axes panel colorbar
elif ax is not None:
elif loc_ax is not None:
# Check if span parameters are provided
has_span = _not_none(span, row, col, rows, cols) is not None

# Infer span from loc_ax if it is a list and no span provided
if (
not has_span
and np.iterable(loc_ax)
and not isinstance(loc_ax, (str, maxes.Axes))
):
loc_trans = _translate_loc(loc, "colorbar", default=rc["colorbar.loc"])
side = (
loc_trans
if loc_trans in ("left", "right", "top", "bottom")
else None
)

if side:
r_min, r_max = float("inf"), float("-inf")
c_min, c_max = float("inf"), float("-inf")
valid_ax = False
for axi in loc_ax:
if not hasattr(axi, "get_subplotspec"):
continue
ss = axi.get_subplotspec()
if ss is None:
continue
ss = ss.get_topmost_subplotspec()
r1, r2, c1, c2 = ss._get_rows_columns()
r_min = min(r_min, r1)
r_max = max(r_max, r2)
c_min = min(c_min, c1)
c_max = max(c_max, c2)
valid_ax = True

if valid_ax:
if side in ("left", "right"):
rows = (r_min + 1, r_max + 1)
else:
cols = (c_min + 1, c_max + 1)
has_span = True

# Extract a single axes from array if span is provided
# Otherwise, pass the array as-is for normal colorbar behavior
if has_span and np.iterable(ax) and not isinstance(ax, (str, maxes.Axes)):
try:
ax_single = next(iter(ax))
if (
has_span
and np.iterable(loc_ax)
and not isinstance(loc_ax, (str, maxes.Axes))
):
# Pick the best axis to anchor to based on the colorbar side
loc_trans = _translate_loc(loc, "colorbar", default=rc["colorbar.loc"])
side = (
loc_trans
if loc_trans in ("left", "right", "top", "bottom")
else None
)

except (TypeError, StopIteration):
ax_single = ax
best_ax = None
best_coord = float("-inf")

# If side is determined, search for the edge axis
if side:
for axi in loc_ax:
if not hasattr(axi, "get_subplotspec"):
continue
ss = axi.get_subplotspec()
if ss is None:
continue
ss = ss.get_topmost_subplotspec()
r1, r2, c1, c2 = ss._get_rows_columns()

if side == "right":
val = c2 # Maximize column index
elif side == "left":
val = -c1 # Minimize column index
elif side == "bottom":
val = r2 # Maximize row index
elif side == "top":
val = -r1 # Minimize row index
else:
val = 0

if val > best_coord:
best_coord = val
best_ax = axi

# Fallback to first axis
if best_ax is None:
try:
ax_single = next(iter(loc_ax))
except (TypeError, StopIteration):
ax_single = loc_ax
else:
ax_single = best_ax
else:
ax_single = ax
ax_single = loc_ax

# Pass span parameters through to axes colorbar
cb = ax_single.colorbar(
Expand Down Expand Up @@ -2700,27 +2784,136 @@ def legend(
matplotlib.axes.Axes.legend
"""
ax = kwargs.pop("ax", None)
ref = kwargs.pop("ref", None)
loc_ax = ref if ref is not None else ax

# Axes panel legend
if ax is not None:
if loc_ax is not None:
content_ax = ax if ax is not None else loc_ax
# Check if span parameters are provided
has_span = _not_none(span, row, col, rows, cols) is not None
# Extract a single axes from array if span is provided
# Otherwise, pass the array as-is for normal legend behavior
# Automatically collect handles and labels from spanned axes if not provided
if has_span and np.iterable(ax) and not isinstance(ax, (str, maxes.Axes)):
# Auto-collect handles and labels if not explicitly provided
if handles is None and labels is None:
handles, labels = [], []
for axi in ax:

# Automatically collect handles and labels from content axes if not provided
# Case 1: content_ax is a list (we must auto-collect)
# Case 2: content_ax != loc_ax (we must auto-collect because loc_ax.legend won't find content_ax handles)
must_collect = (
np.iterable(content_ax)
and not isinstance(content_ax, (str, maxes.Axes))
) or (content_ax is not loc_ax)

if must_collect and handles is None and labels is None:
handles, labels = [], []
# Handle list of axes
if np.iterable(content_ax) and not isinstance(
content_ax, (str, maxes.Axes)
):
for axi in content_ax:
h, l = axi.get_legend_handles_labels()
handles.extend(h)
labels.extend(l)
try:
ax_single = next(iter(ax))
except (TypeError, StopIteration):
ax_single = ax
# Handle single axis
else:
handles, labels = content_ax.get_legend_handles_labels()

# Infer span from loc_ax if it is a list and no span provided
if (
not has_span
and np.iterable(loc_ax)
and not isinstance(loc_ax, (str, maxes.Axes))
):
loc_trans = _translate_loc(loc, "legend", default=rc["legend.loc"])
side = (
loc_trans
if loc_trans in ("left", "right", "top", "bottom")
else None
)

if side:
r_min, r_max = float("inf"), float("-inf")
c_min, c_max = float("inf"), float("-inf")
valid_ax = False
for axi in loc_ax:
if not hasattr(axi, "get_subplotspec"):
continue
ss = axi.get_subplotspec()
if ss is None:
continue
ss = ss.get_topmost_subplotspec()
r1, r2, c1, c2 = ss._get_rows_columns()
r_min = min(r_min, r1)
r_max = max(r_max, r2)
c_min = min(c_min, c1)
c_max = max(c_max, c2)
valid_ax = True

if valid_ax:
if side in ("left", "right"):
rows = (r_min + 1, r_max + 1)
else:
cols = (c_min + 1, c_max + 1)
has_span = True

# Extract a single axes from array if span is provided (or if ref is a list)
# Otherwise, pass the array as-is for normal legend behavior (only if loc_ax is list)
if (
has_span
and np.iterable(loc_ax)
and not isinstance(loc_ax, (str, maxes.Axes))
):
# Pick the best axis to anchor to based on the legend side
loc_trans = _translate_loc(loc, "legend", default=rc["legend.loc"])
side = (
loc_trans
if loc_trans in ("left", "right", "top", "bottom")
else None
)

best_ax = None
best_coord = float("-inf")

# If side is determined, search for the edge axis
if side:
for axi in loc_ax:
if not hasattr(axi, "get_subplotspec"):
continue
ss = axi.get_subplotspec()
if ss is None:
continue
ss = ss.get_topmost_subplotspec()
r1, r2, c1, c2 = ss._get_rows_columns()

if side == "right":
val = c2 # Maximize column index
elif side == "left":
val = -c1 # Minimize column index
elif side == "bottom":
val = r2 # Maximize row index
elif side == "top":
val = -r1 # Minimize row index
else:
val = 0

if val > best_coord:
best_coord = val
best_ax = axi

# Fallback to first axis if no best axis found (or side is None)
if best_ax is None:
try:
ax_single = next(iter(loc_ax))
except (TypeError, StopIteration):
ax_single = loc_ax
else:
ax_single = best_ax

else:
ax_single = ax
ax_single = loc_ax
if isinstance(ax_single, list):
try:
ax_single = pgridspec.SubplotGrid(ax_single)
except ValueError:
ax_single = ax_single[0]

leg = ax_single.legend(
handles,
labels,
Expand Down
11 changes: 10 additions & 1 deletion ultraplot/gridspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,12 @@ def _encode_indices(self, *args, which=None, panel=False):
nums = []
idxs = self._get_indices(which=which, panel=panel)
for arg in args:
if isinstance(arg, (list, np.ndarray)):
try:
nums.append([idxs[int(i)] for i in arg])
except (IndexError, TypeError):
raise ValueError(f"Invalid gridspec index {arg}.")
continue
try:
nums.append(idxs[arg])
except (IndexError, TypeError):
Expand Down Expand Up @@ -1612,10 +1618,13 @@ def __getitem__(self, key):
>>> axs[:, 0] # a SubplotGrid containing the subplots in the first column
"""
# Allow 1D list-like indexing
if isinstance(key, int):
if isinstance(key, (Integral, np.integer)):
return list.__getitem__(self, key)
elif isinstance(key, slice):
return SubplotGrid(list.__getitem__(self, key))
elif isinstance(key, (list, np.ndarray)):
# NOTE: list.__getitem__ does not support numpy integers
return SubplotGrid([list.__getitem__(self, int(i)) for i in key])

# Allow 2D array-like indexing
# NOTE: We assume this is a 2D array of subplots, because this is
Expand Down
Loading
Loading