diff --git a/src/spatialdata_plot/pl/_color.py b/src/spatialdata_plot/pl/_color.py index b3e3425a..b2aae827 100644 --- a/src/spatialdata_plot/pl/_color.py +++ b/src/spatialdata_plot/pl/_color.py @@ -1160,6 +1160,29 @@ def _modify_categorical_color_mapping( return modified_mapping +def _default_categorical_palette(n: int) -> list[str]: + """Return the scanpy default categorical palette sized for ``n`` categories (grey beyond 103).""" + if n <= 20: + return list(default_20) + if n <= 28: + return list(default_28) + if n <= len(default_102): + return list(default_102) + logger.info("input has more than 103 categories. Uniform 'grey' color will be used for all categories.") + return ["grey"] * n + + +def _next_palette_colors(used_colors: set[str], n: int) -> list[str]: + """Pick ``n`` default-palette colors skipping ``used_colors``, keeping a 2nd categorical render distinct (#364).""" + used_norm = {to_hex(to_rgba(c)) for c in used_colors} + pool = _default_categorical_palette(n + len(used_norm)) + unused = [c for c in pool if to_hex(to_rgba(c)) not in used_norm] + if len(unused) < n: # palette exhausted; some colors will repeat an earlier render's + logger.warning("Not enough distinct default colors left; stacked legends may share colors.") + return pool[:n] + return unused[:n] + + def _get_default_categorial_color_mapping( color_source_vector: ArrayLike | pd.Series[CategoricalDtype], cmap_params: CmapParams | None = None, @@ -1179,17 +1202,8 @@ def _get_default_categorial_color_mapping( else: palette = None - # Fall back to default palettes if needed if palette is None: - if len_cat <= 20: - palette = default_20 - elif len_cat <= 28: - palette = default_28 - elif len_cat <= len(default_102): # 103 colors - palette = default_102 - else: - palette = ["grey"] * len_cat - logger.info("input has more than 103 categories. Uniform 'grey' color will be used for all categories.") + palette = _default_categorical_palette(len_cat) return dict(zip(color_source_vector.categories, palette[:len_cat], strict=True)) diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 3001a317..cbeacfcb 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -22,6 +22,7 @@ from matplotlib.backend_bases import RendererBase from matplotlib.colors import Colormap, LogNorm, Normalize from matplotlib.figure import Figure +from matplotlib.legend import Legend from mpl_toolkits.axes_grid1.inset_locator import inset_axes from spatialdata._utils import _deprecation_alias from spatialdata.transformations.operations import get_transformation @@ -31,6 +32,7 @@ from spatialdata_plot._logging import _log_context, logger from spatialdata_plot.pl._color import ( _maybe_set_colors, + _next_palette_colors, _prepare_cmap_norm, _set_outline, ) @@ -1584,6 +1586,10 @@ def show( _draw_scalebar(ax, scalebar_params_obj, panel_idx=i) + if fig_params.fig is not None: + candidate_axes = fig_params.axs if fig_params.axs is not None else [fig_params.ax] + _setup_stacked_legends(fig_params.fig, [a for a in candidate_axes if isinstance(a, Axes)]) + _layout_pending_colorbars(pending_colorbars, fig_params, colorbar_params) if fig_params.fig is not None and save is not None: @@ -1870,6 +1876,48 @@ def _draw_colorbar( trackers_axes[location] = pad_axes + (bbox_axes.width if vertical else bbox_axes.height) +def _stacked_legends(ax: Axes) -> list[Legend]: + """Return the per-render categorical legends (#364) this code tagged on ``ax``.""" + return [c for c in ax.get_children() if isinstance(c, Legend) and hasattr(c, "_sdata_column")] + + +def _reposition_stacked_legends(fig: Figure, renderer: object, gap_px: float = 10.0) -> None: + """Lay each axis' 2+ tagged legends left-to-right along its right edge, measured at ``renderer``. + + A legend is fixed-pixel (unlike a colorbar, which scales to fill its inset), so its axes-fraction + width shifts whenever the axes is rescaled; recompute the offsets per draw at the actual geometry. + """ + for ax in fig.axes: + legends = _stacked_legends(ax) + if len(legends) < 2: + continue + ax_w = ax.get_window_extent().width or 1.0 + x = 1.02 + for leg in legends: + leg.set_bbox_to_anchor((x, 1.0), transform=ax.transAxes) + x += (leg.get_window_extent(renderer).width + gap_px) / ax_w + + +def _setup_stacked_legends(fig: Figure, panel_axes: list[Axes]) -> None: + """Title 2+ same-axis categorical legends and keep them laid out side-by-side across redraws. + + A single figure-level ``draw_event`` handler (connected once) repositions every panel using the + event's renderer, so it works on any backend and isn't re-registered when a figure is reused. + """ + if not any(len(_stacked_legends(ax)) >= 2 for ax in panel_axes): + return + for ax in panel_axes: + for leg in _stacked_legends(ax): + if not leg.get_title().get_text(): # explicit title wins + leg.set_title(leg._sdata_column) + if hasattr(leg, "set_loc"): # mpl >= 3.8 + leg.set_loc("upper left") + if not getattr(fig, "_sdata_legend_cb", False): + fig._sdata_legend_cb = True # type: ignore[attr-defined] + fig.canvas.mpl_connect("draw_event", lambda e: _reposition_stacked_legends(fig, e.renderer)) + fig.canvas.draw() # eager placement; the handler keeps it correct across later resizes/redraws + + def _layout_pending_colorbars( pending_colorbars: list[tuple[Axes, list[ColorbarSpec]]], fig_params: FigParams, @@ -1944,19 +1992,32 @@ def _should_rasterize( return scale is None or (isinstance(scale, str) and scale != "full" and (dpi is not None or figsize is not None)) -def _maybe_set_label_colors(sdata: sd.SpatialData, render_params: LabelsRenderParams) -> None: - """Materialize a categorical palette on the table annotating a labels element, if applicable.""" +def _maybe_set_label_colors( + sdata: sd.SpatialData, + render_params: LabelsRenderParams, + used_colors: set[str] | None = None, +) -> None: + """Materialize a categorical palette on the table annotating a labels element, if applicable. + + ``used_colors`` accumulates the colors already taken by earlier categorical label renders on + the same panel. When a column's colors are auto-generated (no user palette, not already in + ``.uns``), they are shifted to skip ``used_colors`` so stacked legends stay distinct (#364). + """ table = render_params.table_name - if table is None or render_params.col_for_color is None: + col = render_params.col_for_color + if table is None or col is None: return - colors = sc.get.obs_df(sdata[table], [render_params.col_for_color]) - if isinstance(colors[render_params.col_for_color].dtype, pd.CategoricalDtype): - _maybe_set_colors( - source=sdata[table], - target=sdata[table], - key=render_params.col_for_color, - palette=render_params.palette, - ) + colors = sc.get.obs_df(sdata[table], [col]) + if not isinstance(colors[col].dtype, pd.CategoricalDtype): + return + adata = sdata[table] + color_key = f"{col}_colors" + if render_params.palette is None and used_colors and color_key not in adata.uns: + adata.uns[color_key] = _next_palette_colors(used_colors, len(colors[col].cat.categories)) + else: + _maybe_set_colors(source=adata, target=adata, key=col, palette=render_params.palette) + if used_colors is not None and color_key in adata.uns: + used_colors.update(adata.uns[color_key]) # _next_palette_colors normalizes for comparison def _render_panel( @@ -1985,6 +2046,9 @@ def _render_panel( """ wants = dict.fromkeys(("images", "labels", "points", "shapes"), False) wanted_elements: list[str] = [] + # Colors already taken by categorical label renders on this panel, so later renders can + # avoid reusing them and their stacked legends stay distinct (#364). + used_label_colors: set[str] = set() for cmd, params in render_cmds: # Skip render entries that belong to a different color panel. Entries with no @@ -2033,7 +2097,7 @@ def _render_panel( cast("ImageRenderParams | LabelsRenderParams", element_params), dpi, figsize ) if cmd == "render_labels": - _maybe_set_label_colors(sdata, cast(LabelsRenderParams, element_params)) + _maybe_set_label_colors(sdata, cast(LabelsRenderParams, element_params), used_label_colors) _RENDERERS[cmd](**kwargs) # Panel finalization depends only on per-panel values, so run it once after the loop. diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 1ab4825a..28598fd4 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -78,9 +78,11 @@ colormap_with_alpha, ) from spatialdata_plot.pl.utils import ( + _categorical_legend_handles, _decorate_axs, _fast_extent, _join_table_for_element, + _legend_ncol, _mpl_ax_contains_elements, _multiscale_to_spatial_image, _pixel_to_coord, @@ -505,7 +507,7 @@ def _add_outline_legend( ) color_map = mapping_df.drop_duplicates("cats").set_index("cats")["color"].to_dict() - outline_handles = [ax.scatter([], [], c=color_map[c], label=str(c)) for c in cats] + outline_handles = _categorical_legend_handles(ax, {c: color_map[c] for c in cats}) anchor_y: float | None = None if fill_has_legend: @@ -548,7 +550,7 @@ def _add_outline_legend( loc=loc, bbox_to_anchor=anchor, fontsize=legend_params.legend_fontsize, - ncol=(1 if len(outline_handles) <= 14 else 2 if len(outline_handles) <= 30 else 3), + ncol=_legend_ncol(len(outline_handles)), ) @@ -697,8 +699,7 @@ def _render_shapes( nan_count = int(pd.isna(cv).sum()) if nan_count: logger.warning( - f"Found {nan_count} NaN values in color data. " - "These observations will be colored with the 'na_color'." + f"Found {nan_count} NaN values in color data. These observations will be colored with the 'na_color'." ) color_spec = color_spec.evolve(color_vector=cv) diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 323d6c3e..f61615c8 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -26,6 +26,7 @@ ) from matplotlib.figure import Figure from matplotlib.gridspec import GridSpec +from matplotlib.legend import Legend from matplotlib_scalebar.scalebar import ScaleBar from pandas.api.types import CategoricalDtype, is_numeric_dtype from pandas.core.arrays.categorical import Categorical @@ -405,6 +406,47 @@ def _build_alignment_dtype_hint( return "" +def _legend_ncol(n: int) -> int: + """Column count for a categorical legend with ``n`` entries.""" + return 1 if n <= 14 else 2 if n <= 30 else 3 + + +def _categorical_legend_handles(ax: Axes, color_map: Mapping[Any, Any], na_hex: str | None = None) -> list[Any]: + """Empty-scatter handles (colored dots) for a categorical legend, with an optional NA entry.""" + handles = [ax.scatter([], [], c=color, label=str(cat)) for cat, color in color_map.items()] + if na_hex is not None: + handles.append(ax.scatter([], [], c=na_hex, label="NA")) + return handles + + +def _stack_categorical_legend( + ax: Axes, + color_mapping: Mapping[Any, Any], + *, + na_hex: str | None, + title: str | None, + column: str | None, + legend_fontsize: int | float | _FontSize | None, +) -> None: + """Build the 2nd+ categorical legend on a shared axes without dropping existing ones (#364). + + Placement and the column auto-title are finalized later by ``_setup_stacked_legends``. + """ + handles = _categorical_legend_handles(ax, color_mapping, na_hex) + if (cur := ax.get_legend()) is not None: + ax.add_artist(cur) # else ax.legend() below drops it + new_leg = ax.legend( + handles=handles, + title=title, + frameon=False, + loc="upper left", + bbox_to_anchor=(1.02, 1.0), + fontsize=legend_fontsize, + ncol=_legend_ncol(len(handles)), + ) + new_leg._sdata_column = column # type: ignore[attr-defined] + + def _decorate_axs( ax: Axes, cax: PatchCollection, @@ -449,22 +491,43 @@ def _decorate_axs( } ) color_mapping = group_to_color_matching.drop_duplicates("cats").set_index("cats")["color"].to_dict() - _add_categorical_legend( - ax, - pd.Categorical(values=color_source_vector, categories=clusters), - palette=color_mapping, - legend_loc=legend_loc, - legend_fontweight=legend_fontweight, - legend_fontsize=legend_fontsize, - legend_fontoutline=path_effect, - na_color=[na_color.get_hex()], - na_in_legend=na_in_legend, - multi_panel=fig_params.axs is not None, - ) - # scanpy's helper doesn't accept a title; set it post-hoc so the user can - # disambiguate fill vs outline when both legends are drawn. - if legend_title is not None and (legend := ax.get_legend()) is not None: - legend.set_title(legend_title) + color_mapping = {k: v for k, v in color_mapping.items() if not pd.isnull(k)} # NA handled separately + # A 2nd categorical render would make scanpy's bare `ax.legend()` merge every labeled + # artist into one legend and drop the first (#364), so route 2nd+ legends (i.e. when a + # tagged legend already exists) through a helper that keeps them separate. + tagged = (getattr(c, "_sdata_column", None) is not None for c in ax.get_children() if isinstance(c, Legend)) + already = any(tagged) + if legend_loc in (None, "none"): + pass # legend suppressed + elif already: + na_hex = na_color.get_hex() if (na_in_legend and pd.isnull(color_source_vector).any()) else None + _stack_categorical_legend( + ax, + color_mapping, + na_hex=na_hex, + title=legend_title, + column=value_to_plot, + legend_fontsize=legend_fontsize, + ) + else: + _add_categorical_legend( + ax, + pd.Categorical(values=color_source_vector, categories=clusters), + palette=color_mapping, + legend_loc=legend_loc, + legend_fontweight=legend_fontweight, + legend_fontsize=legend_fontsize, + legend_fontoutline=path_effect, + na_color=[na_color.get_hex()], + na_in_legend=na_in_legend, + multi_panel=fig_params.axs is not None, + ) + # Tag with the column; the column auto-title (when 2+ legends) is applied in + # `_setup_stacked_legends`. An explicit title wins now. + if (legend := ax.get_legend()) is not None: + legend._sdata_column = value_to_plot # type: ignore[attr-defined] + if legend_title is not None: + legend.set_title(legend_title) elif colorbar and colorbar_requests is not None and cax is not None: colorbar_requests.append( ColorbarSpec( diff --git a/tests/_images/Shapes_can_color_two_queried_shapes_elements_by_annotation.png b/tests/_images/Shapes_can_color_two_queried_shapes_elements_by_annotation.png index 77d14762..93e476c7 100644 Binary files a/tests/_images/Shapes_can_color_two_queried_shapes_elements_by_annotation.png and b/tests/_images/Shapes_can_color_two_queried_shapes_elements_by_annotation.png differ diff --git a/tests/_images/Shapes_can_color_two_shapes_elements_by_annotation.png b/tests/_images/Shapes_can_color_two_shapes_elements_by_annotation.png index e99f972d..2fe75b07 100644 Binary files a/tests/_images/Shapes_can_color_two_shapes_elements_by_annotation.png and b/tests/_images/Shapes_can_color_two_shapes_elements_by_annotation.png differ diff --git a/tests/pl/test_render_labels.py b/tests/pl/test_render_labels.py index 2048d11c..4f8a9551 100644 --- a/tests/pl/test_render_labels.py +++ b/tests/pl/test_render_labels.py @@ -7,6 +7,7 @@ import scanpy as sc from anndata import AnnData from matplotlib.colors import Normalize +from matplotlib.legend import Legend from spatial_image import to_spatial_image from spatialdata import SpatialData, deepcopy, get_element_instances from spatialdata.models import Labels2DModel, Labels3DModel, TableModel @@ -96,6 +97,99 @@ def test_plot_can_stack_render_labels(self, sdata_blobs: SpatialData): .pl.show() ) + def test_two_categorical_label_renders_make_two_distinct_legends(self, sdata_blobs: SpatialData): + # Regression test for #364: two render_labels calls coloring by two categorical columns + # must produce two separate, column-titled legends (not one merged/replaced legend), and + # the second render must not reuse the first's colors. State-based (no image comparison). + n = sdata_blobs["table"].n_obs + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_labels"] * n) + sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_labels" + sdata_blobs["table"].obs["cat0"] = pd.Categorical((["A", "B"] * ((n + 1) // 2))[:n]) + sdata_blobs["table"].obs["cat1"] = pd.Categorical((["C", "D"] * ((n + 1) // 2))[:n]) + + ( + sdata_blobs.pl.render_labels("blobs_labels", color="cat0") + .pl.render_labels("blobs_labels", color="cat1") + .pl.show() + ) + + fig = plt.gcf() + ax = fig.axes[0] + legends = [c for c in ax.get_children() if isinstance(c, Legend)] + assert len(legends) == 2 + entries = {leg.get_title().get_text(): {t.get_text() for t in leg.get_texts()} for leg in legends} + assert entries == {"cat0": {"A", "B"}, "cat1": {"C", "D"}} + + # palette offset: the second categorical render must not reuse the first's colors + c0 = set(sdata_blobs["table"].uns["cat0_colors"]) + c1 = set(sdata_blobs["table"].uns["cat1_colors"]) + assert c0.isdisjoint(c1) + + # legends are placed side-by-side in the right margin: top-aligned, non-overlapping in x (#364) + fig.canvas.draw() + inv = ax.transAxes.inverted() + boxes = sorted((leg.get_window_extent().transformed(inv) for leg in legends), key=lambda b: b.x0) + assert boxes[0].x1 <= boxes[1].x0 # no horizontal overlap + assert abs(boxes[0].y1 - boxes[1].y1) < 0.01 # tops aligned + plt.close() + + def test_three_categorical_label_renders_make_three_legends(self, sdata_blobs: SpatialData): + # Regression test for #364: re-adding prior legends must not duplicate them; three renders + # yield exactly three distinct legends (not four with a repeat). + n = sdata_blobs["table"].n_obs + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_labels"] * n) + sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_labels" + for col, (a, b) in {"cat0": ("A", "B"), "cat1": ("C", "D"), "cat2": ("E", "F")}.items(): + sdata_blobs["table"].obs[col] = pd.Categorical(([a, b] * ((n + 1) // 2))[:n]) + + ( + sdata_blobs.pl.render_labels("blobs_labels", color="cat0") + .pl.render_labels("blobs_labels", color="cat1") + .pl.render_labels("blobs_labels", color="cat2") + .pl.show() + ) + + ax = plt.gcf().axes[0] + titles = sorted(c.get_title().get_text() for c in ax.get_children() if isinstance(c, Legend)) + assert titles == ["cat0", "cat1", "cat2"] + + # palette offset accumulates: every render skips all colors used by earlier ones, so all + # three palettes are mutually disjoint (not just cat0 vs cat1). + c0, c1, c2 = (set(sdata_blobs["table"].uns[f"cat{i}_colors"]) for i in range(3)) + assert c0.isdisjoint(c1) and c1.isdisjoint(c2) and c0.isdisjoint(c2) + plt.close() + + def test_single_categorical_label_render_legend_has_no_title(self, sdata_blobs: SpatialData): + # A lone categorical render produces exactly one, untitled legend: the column title is only + # added to disambiguate 2+ legends on an axis (#364). + n = sdata_blobs["table"].n_obs + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_labels"] * n) + sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_labels" + sdata_blobs["table"].obs["cat0"] = pd.Categorical((["A", "B"] * ((n + 1) // 2))[:n]) + + sdata_blobs.pl.render_labels("blobs_labels", color="cat0").pl.show() + + ax = plt.gcf().axes[0] + legends = [c for c in ax.get_children() if isinstance(c, Legend)] + assert len(legends) == 1 + assert legends[0].get_title().get_text() == "" + plt.close() + + def test_two_legend_plot_saves_to_vector_backend(self, sdata_blobs: SpatialData, tmp_path): + # Regression for #364: the side-by-side legend layout runs on every draw, so it must use the + # draw event's renderer (valid on PDF/SVG) — not the Agg-only fig.canvas.get_renderer(). + n = sdata_blobs["table"].n_obs + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_labels"] * n) + sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_labels" + sdata_blobs["table"].obs["cat0"] = pd.Categorical((["A", "B"] * ((n + 1) // 2))[:n]) + sdata_blobs["table"].obs["cat1"] = pd.Categorical((["C", "D"] * ((n + 1) // 2))[:n]) + + sdata_blobs.pl.render_labels("blobs_labels", color="cat0").pl.render_labels( + "blobs_labels", color="cat1" + ).pl.show() + plt.gcf().savefig(tmp_path / "out.pdf") # must not raise on the Pdf canvas + plt.close() + def test_plot_can_color_by_rgba_array(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_labels("blobs_labels", color=[0.5, 0.5, 1.0, 0.5]).pl.show()