Skip to content
34 changes: 24 additions & 10 deletions src/spatialdata_plot/pl/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,29 @@ def _modify_categorical_color_mapping(
return modified_mapping


def _default_categorical_palette(n: int) -> list[str]:
"""Return the scanpy default categorical palette sized for ``n`` categories (grey beyond 103)."""
if n <= 20:
return list(default_20)
if n <= 28:
return list(default_28)
if n <= len(default_102):
return list(default_102)
logger.info("input has more than 103 categories. Uniform 'grey' color will be used for all categories.")
return ["grey"] * n


def _next_palette_colors(used_colors: set[str], n: int) -> list[str]:
"""Pick ``n`` default-palette colors skipping ``used_colors``, keeping a 2nd categorical render distinct (#364)."""
used_norm = {to_hex(to_rgba(c)) for c in used_colors}
pool = _default_categorical_palette(n + len(used_norm))
unused = [c for c in pool if to_hex(to_rgba(c)) not in used_norm]
if len(unused) < n: # palette exhausted; some colors will repeat an earlier render's
logger.warning("Not enough distinct default colors left; stacked legends may share colors.")
return pool[:n]
return unused[:n]


def _get_default_categorial_color_mapping(
color_source_vector: ArrayLike | pd.Series[CategoricalDtype],
cmap_params: CmapParams | None = None,
Expand All @@ -1179,17 +1202,8 @@ def _get_default_categorial_color_mapping(
else:
palette = None

# Fall back to default palettes if needed
if palette is None:
if len_cat <= 20:
palette = default_20
elif len_cat <= 28:
palette = default_28
elif len_cat <= len(default_102): # 103 colors
palette = default_102
else:
palette = ["grey"] * len_cat
logger.info("input has more than 103 categories. Uniform 'grey' color will be used for all categories.")
palette = _default_categorical_palette(len_cat)

return dict(zip(color_source_vector.categories, palette[:len_cat], strict=True))

Expand Down
88 changes: 76 additions & 12 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from matplotlib.backend_bases import RendererBase
from matplotlib.colors import Colormap, LogNorm, Normalize
from matplotlib.figure import Figure
from matplotlib.legend import Legend
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from spatialdata._utils import _deprecation_alias
from spatialdata.transformations.operations import get_transformation
Expand All @@ -31,6 +32,7 @@
from spatialdata_plot._logging import _log_context, logger
from spatialdata_plot.pl._color import (
_maybe_set_colors,
_next_palette_colors,
_prepare_cmap_norm,
_set_outline,
)
Expand Down Expand Up @@ -1584,6 +1586,10 @@ def show(

_draw_scalebar(ax, scalebar_params_obj, panel_idx=i)

if fig_params.fig is not None:
candidate_axes = fig_params.axs if fig_params.axs is not None else [fig_params.ax]
_setup_stacked_legends(fig_params.fig, [a for a in candidate_axes if isinstance(a, Axes)])

_layout_pending_colorbars(pending_colorbars, fig_params, colorbar_params)

if fig_params.fig is not None and save is not None:
Expand Down Expand Up @@ -1870,6 +1876,48 @@ def _draw_colorbar(
trackers_axes[location] = pad_axes + (bbox_axes.width if vertical else bbox_axes.height)


def _stacked_legends(ax: Axes) -> list[Legend]:
"""Return the per-render categorical legends (#364) this code tagged on ``ax``."""
return [c for c in ax.get_children() if isinstance(c, Legend) and hasattr(c, "_sdata_column")]


def _reposition_stacked_legends(fig: Figure, renderer: object, gap_px: float = 10.0) -> None:
"""Lay each axis' 2+ tagged legends left-to-right along its right edge, measured at ``renderer``.

A legend is fixed-pixel (unlike a colorbar, which scales to fill its inset), so its axes-fraction
width shifts whenever the axes is rescaled; recompute the offsets per draw at the actual geometry.
"""
for ax in fig.axes:
legends = _stacked_legends(ax)
if len(legends) < 2:
continue
ax_w = ax.get_window_extent().width or 1.0
x = 1.02
for leg in legends:
leg.set_bbox_to_anchor((x, 1.0), transform=ax.transAxes)
x += (leg.get_window_extent(renderer).width + gap_px) / ax_w


def _setup_stacked_legends(fig: Figure, panel_axes: list[Axes]) -> None:
"""Title 2+ same-axis categorical legends and keep them laid out side-by-side across redraws.

A single figure-level ``draw_event`` handler (connected once) repositions every panel using the
event's renderer, so it works on any backend and isn't re-registered when a figure is reused.
"""
if not any(len(_stacked_legends(ax)) >= 2 for ax in panel_axes):
return
for ax in panel_axes:
for leg in _stacked_legends(ax):
if not leg.get_title().get_text(): # explicit title wins
leg.set_title(leg._sdata_column)
if hasattr(leg, "set_loc"): # mpl >= 3.8
leg.set_loc("upper left")
if not getattr(fig, "_sdata_legend_cb", False):
fig._sdata_legend_cb = True # type: ignore[attr-defined]
fig.canvas.mpl_connect("draw_event", lambda e: _reposition_stacked_legends(fig, e.renderer))
fig.canvas.draw() # eager placement; the handler keeps it correct across later resizes/redraws


def _layout_pending_colorbars(
pending_colorbars: list[tuple[Axes, list[ColorbarSpec]]],
fig_params: FigParams,
Expand Down Expand Up @@ -1944,19 +1992,32 @@ def _should_rasterize(
return scale is None or (isinstance(scale, str) and scale != "full" and (dpi is not None or figsize is not None))


def _maybe_set_label_colors(sdata: sd.SpatialData, render_params: LabelsRenderParams) -> None:
"""Materialize a categorical palette on the table annotating a labels element, if applicable."""
def _maybe_set_label_colors(
sdata: sd.SpatialData,
render_params: LabelsRenderParams,
used_colors: set[str] | None = None,
) -> None:
"""Materialize a categorical palette on the table annotating a labels element, if applicable.

``used_colors`` accumulates the colors already taken by earlier categorical label renders on
the same panel. When a column's colors are auto-generated (no user palette, not already in
``.uns``), they are shifted to skip ``used_colors`` so stacked legends stay distinct (#364).
"""
table = render_params.table_name
if table is None or render_params.col_for_color is None:
col = render_params.col_for_color
if table is None or col is None:
return
colors = sc.get.obs_df(sdata[table], [render_params.col_for_color])
if isinstance(colors[render_params.col_for_color].dtype, pd.CategoricalDtype):
_maybe_set_colors(
source=sdata[table],
target=sdata[table],
key=render_params.col_for_color,
palette=render_params.palette,
)
colors = sc.get.obs_df(sdata[table], [col])
if not isinstance(colors[col].dtype, pd.CategoricalDtype):
return
adata = sdata[table]
color_key = f"{col}_colors"
if render_params.palette is None and used_colors and color_key not in adata.uns:
adata.uns[color_key] = _next_palette_colors(used_colors, len(colors[col].cat.categories))
else:
_maybe_set_colors(source=adata, target=adata, key=col, palette=render_params.palette)
if used_colors is not None and color_key in adata.uns:
used_colors.update(adata.uns[color_key]) # _next_palette_colors normalizes for comparison


def _render_panel(
Expand Down Expand Up @@ -1985,6 +2046,9 @@ def _render_panel(
"""
wants = dict.fromkeys(("images", "labels", "points", "shapes"), False)
wanted_elements: list[str] = []
# Colors already taken by categorical label renders on this panel, so later renders can
# avoid reusing them and their stacked legends stay distinct (#364).
used_label_colors: set[str] = set()

for cmd, params in render_cmds:
# Skip render entries that belong to a different color panel. Entries with no
Expand Down Expand Up @@ -2033,7 +2097,7 @@ def _render_panel(
cast("ImageRenderParams | LabelsRenderParams", element_params), dpi, figsize
)
if cmd == "render_labels":
_maybe_set_label_colors(sdata, cast(LabelsRenderParams, element_params))
_maybe_set_label_colors(sdata, cast(LabelsRenderParams, element_params), used_label_colors)
_RENDERERS[cmd](**kwargs)

# Panel finalization depends only on per-panel values, so run it once after the loop.
Expand Down
9 changes: 5 additions & 4 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,11 @@
colormap_with_alpha,
)
from spatialdata_plot.pl.utils import (
_categorical_legend_handles,
_decorate_axs,
_fast_extent,
_join_table_for_element,
_legend_ncol,
_mpl_ax_contains_elements,
_multiscale_to_spatial_image,
_pixel_to_coord,
Expand Down Expand Up @@ -505,7 +507,7 @@ def _add_outline_legend(
)
color_map = mapping_df.drop_duplicates("cats").set_index("cats")["color"].to_dict()

outline_handles = [ax.scatter([], [], c=color_map[c], label=str(c)) for c in cats]
outline_handles = _categorical_legend_handles(ax, {c: color_map[c] for c in cats})

anchor_y: float | None = None
if fill_has_legend:
Expand Down Expand Up @@ -548,7 +550,7 @@ def _add_outline_legend(
loc=loc,
bbox_to_anchor=anchor,
fontsize=legend_params.legend_fontsize,
ncol=(1 if len(outline_handles) <= 14 else 2 if len(outline_handles) <= 30 else 3),
ncol=_legend_ncol(len(outline_handles)),
)


Expand Down Expand Up @@ -697,8 +699,7 @@ def _render_shapes(
nan_count = int(pd.isna(cv).sum())
if nan_count:
logger.warning(
f"Found {nan_count} NaN values in color data. "
"These observations will be colored with the 'na_color'."
f"Found {nan_count} NaN values in color data. These observations will be colored with the 'na_color'."
)
color_spec = color_spec.evolve(color_vector=cv)

Expand Down
95 changes: 79 additions & 16 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
from matplotlib.figure import Figure
from matplotlib.gridspec import GridSpec
from matplotlib.legend import Legend
from matplotlib_scalebar.scalebar import ScaleBar
from pandas.api.types import CategoricalDtype, is_numeric_dtype
from pandas.core.arrays.categorical import Categorical
Expand Down Expand Up @@ -405,6 +406,47 @@ def _build_alignment_dtype_hint(
return ""


def _legend_ncol(n: int) -> int:
"""Column count for a categorical legend with ``n`` entries."""
return 1 if n <= 14 else 2 if n <= 30 else 3


def _categorical_legend_handles(ax: Axes, color_map: Mapping[Any, Any], na_hex: str | None = None) -> list[Any]:
"""Empty-scatter handles (colored dots) for a categorical legend, with an optional NA entry."""
handles = [ax.scatter([], [], c=color, label=str(cat)) for cat, color in color_map.items()]
if na_hex is not None:
handles.append(ax.scatter([], [], c=na_hex, label="NA"))
return handles


def _stack_categorical_legend(
ax: Axes,
color_mapping: Mapping[Any, Any],
*,
na_hex: str | None,
title: str | None,
column: str | None,
legend_fontsize: int | float | _FontSize | None,
) -> None:
"""Build the 2nd+ categorical legend on a shared axes without dropping existing ones (#364).

Placement and the column auto-title are finalized later by ``_setup_stacked_legends``.
"""
handles = _categorical_legend_handles(ax, color_mapping, na_hex)
if (cur := ax.get_legend()) is not None:
ax.add_artist(cur) # else ax.legend() below drops it
new_leg = ax.legend(
handles=handles,
title=title,
frameon=False,
loc="upper left",
bbox_to_anchor=(1.02, 1.0),
fontsize=legend_fontsize,
ncol=_legend_ncol(len(handles)),
)
new_leg._sdata_column = column # type: ignore[attr-defined]


def _decorate_axs(
ax: Axes,
cax: PatchCollection,
Expand Down Expand Up @@ -449,22 +491,43 @@ def _decorate_axs(
}
)
color_mapping = group_to_color_matching.drop_duplicates("cats").set_index("cats")["color"].to_dict()
_add_categorical_legend(
ax,
pd.Categorical(values=color_source_vector, categories=clusters),
palette=color_mapping,
legend_loc=legend_loc,
legend_fontweight=legend_fontweight,
legend_fontsize=legend_fontsize,
legend_fontoutline=path_effect,
na_color=[na_color.get_hex()],
na_in_legend=na_in_legend,
multi_panel=fig_params.axs is not None,
)
# scanpy's helper doesn't accept a title; set it post-hoc so the user can
# disambiguate fill vs outline when both legends are drawn.
if legend_title is not None and (legend := ax.get_legend()) is not None:
legend.set_title(legend_title)
color_mapping = {k: v for k, v in color_mapping.items() if not pd.isnull(k)} # NA handled separately
# A 2nd categorical render would make scanpy's bare `ax.legend()` merge every labeled
# artist into one legend and drop the first (#364), so route 2nd+ legends (i.e. when a
# tagged legend already exists) through a helper that keeps them separate.
tagged = (getattr(c, "_sdata_column", None) is not None for c in ax.get_children() if isinstance(c, Legend))
already = any(tagged)
if legend_loc in (None, "none"):
pass # legend suppressed
elif already:
na_hex = na_color.get_hex() if (na_in_legend and pd.isnull(color_source_vector).any()) else None
_stack_categorical_legend(
ax,
color_mapping,
na_hex=na_hex,
title=legend_title,
column=value_to_plot,
legend_fontsize=legend_fontsize,
)
else:
_add_categorical_legend(
ax,
pd.Categorical(values=color_source_vector, categories=clusters),
palette=color_mapping,
legend_loc=legend_loc,
legend_fontweight=legend_fontweight,
legend_fontsize=legend_fontsize,
legend_fontoutline=path_effect,
na_color=[na_color.get_hex()],
na_in_legend=na_in_legend,
multi_panel=fig_params.axs is not None,
)
# Tag with the column; the column auto-title (when 2+ legends) is applied in
# `_setup_stacked_legends`. An explicit title wins now.
if (legend := ax.get_legend()) is not None:
legend._sdata_column = value_to_plot # type: ignore[attr-defined]
if legend_title is not None:
legend.set_title(legend_title)
elif colorbar and colorbar_requests is not None and cax is not None:
colorbar_requests.append(
ColorbarSpec(
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading