diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 52950a12..6c3dac09 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -748,6 +748,7 @@ def render_labels( sdata.plotting_tree[f"{n_steps + 1}_render_labels"] = LabelsRenderParams( element=element, color=param_values["color"], + col_for_color=param_values["col_for_color"], groups=param_values["groups"], contour_px=param_values["contour_px"], cmap_params=cmap_params, @@ -1121,14 +1122,13 @@ def _draw_colorbar( if wanted_labels_on_this_cs: table = params_copy.table_name - if table is not None: - assert isinstance(params_copy.color, str) - colors = sc.get.obs_df(sdata[table], [params_copy.color]) - if isinstance(colors[params_copy.color].dtype, pd.CategoricalDtype): + if table is not None and params_copy.col_for_color is not None: + colors = sc.get.obs_df(sdata[table], [params_copy.col_for_color]) + if isinstance(colors[params_copy.col_for_color].dtype, pd.CategoricalDtype): _maybe_set_colors( source=sdata[table], target=sdata[table], - key=params_copy.color, + key=params_copy.col_for_color, palette=params_copy.palette, ) diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 014c3cc5..226b5a5a 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -1262,7 +1262,7 @@ def _render_labels( table_name = render_params.table_name table_layer = render_params.table_layer palette = render_params.palette - color = render_params.color + col_for_color = render_params.col_for_color groups = render_params.groups scale = render_params.scale @@ -1311,23 +1311,25 @@ def _render_labels( _, trans_data = _prepare_transformation(label, coordinate_system, ax) + na_color = render_params.color if render_params.color else render_params.cmap_params.na_color color_source_vector, color_vector, categorical = _set_color_source_vec( sdata=sdata_filt, element=label, element_name=element, - value_to_plot=color, + value_to_plot=col_for_color, groups=groups, palette=palette, - na_color=render_params.cmap_params.na_color, + na_color=na_color, cmap_params=render_params.cmap_params, table_name=table_name, table_layer=table_layer, + render_type="labels", coordinate_system=coordinate_system, ) # rasterize could have removed labels from label # only problematic if color is specified - if rasterize and color is not None: + if rasterize and col_for_color is not None: labels_in_rasterized_image = np.unique(label.values) mask = np.isin(instance_id, labels_in_rasterized_image) instance_id = instance_id[mask] @@ -1405,7 +1407,7 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float) colorbar_requested = _should_request_colorbar( render_params.colorbar, has_mappable=cax is not None, - is_continuous=color is not None and color_source_vector is None and not categorical, + is_continuous=col_for_color is not None and color_source_vector is None and not categorical, ) _ = _decorate_axs( @@ -1413,7 +1415,7 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float) cax=cax, fig_params=fig_params, adata=table, - value_to_plot=color, + value_to_plot=col_for_color, color_source_vector=color_source_vector, color_vector=color_vector, palette=palette, @@ -1429,7 +1431,7 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float) colorbar_requests=colorbar_requests, colorbar_label=_resolve_colorbar_label( render_params.colorbar_params, - color if isinstance(color, str) else None, + col_for_color if isinstance(col_for_color, str) else None, ), scalebar_dx=scalebar_params.scalebar_dx, scalebar_units=scalebar_params.scalebar_units, diff --git a/src/spatialdata_plot/pl/render_params.py b/src/spatialdata_plot/pl/render_params.py index 4936468f..a108e131 100644 --- a/src/spatialdata_plot/pl/render_params.py +++ b/src/spatialdata_plot/pl/render_params.py @@ -278,7 +278,8 @@ class LabelsRenderParams: cmap_params: CmapParams element: str - color: str | None = None + color: Color | None = None + col_for_color: str | None = None groups: str | list[str] | None = None contour_px: int | None = None outline: bool = False diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 11d5d10d..a4f8f2d7 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -981,7 +981,7 @@ def _set_color_source_vec( alpha: float = 1.0, table_name: str | None = None, table_layer: str | None = None, - render_type: Literal["points"] | None = None, + render_type: Literal["points", "labels"] | None = None, coordinate_system: str | None = None, ) -> tuple[ArrayLike | pd.Series | None, ArrayLike, bool]: if value_to_plot is None and element is not None: @@ -1451,7 +1451,7 @@ def _get_categorical_color_mapping( alpha: float = 1, groups: list[str] | str | None = None, palette: list[str] | str | None = None, - render_type: Literal["points"] | None = None, + render_type: Literal["points", "labels"] | None = None, ) -> Mapping[str, str]: if not isinstance(color_source_vector, Categorical): raise TypeError(f"Expected `categories` to be a `Categorical`, but got {type(color_source_vector).__name__}") @@ -2138,7 +2138,7 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st }: if not isinstance(color, str | tuple | list): raise TypeError("Parameter 'color' must be a string or a tuple/list of floats.") - if element_type in {"shapes", "points"}: + if element_type in {"shapes", "points", "labels"}: if _is_color_like(color): logger.info("Value for parameter 'color' appears to be a color, using it as such.") param_dict["col_for_color"] = None @@ -2146,7 +2146,7 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st if param_dict["color"].alpha_is_user_defined(): if element_type == "points" and param_dict.get("alpha") is None: param_dict["alpha"] = param_dict["color"].get_alpha_as_float() - elif element_type == "shapes" and param_dict.get("fill_alpha") is None: + elif element_type in {"shapes", "labels"} and param_dict.get("fill_alpha") is None: param_dict["fill_alpha"] = param_dict["color"].get_alpha_as_float() else: logger.info( @@ -2158,7 +2158,7 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st param_dict["color"] = None else: raise ValueError(f"{color} is not a valid RGB(A) array and therefore can't be used as 'color' value.") - elif "color" in param_dict and element_type != "labels": + elif "color" in param_dict and element_type != "images": param_dict["col_for_color"] = None outline_width = param_dict.get("outline_width") @@ -2455,15 +2455,18 @@ def _validate_label_render_params( element_params[el]["table_layer"] = param_dict["table_layer"] element_params[el]["table_name"] = None - element_params[el]["color"] = None - color = param_dict["color"] - if color is not None: - color, table_name = _validate_col_for_column_table(sdata, el, color, param_dict["table_name"], labels=True) + element_params[el]["color"] = param_dict["color"] # literal Color or None + element_params[el]["col_for_color"] = None + if (col_for_color := param_dict["col_for_color"]) is not None: + col_for_color, table_name = _validate_col_for_column_table( + sdata, el, col_for_color, param_dict["table_name"], labels=True + ) element_params[el]["table_name"] = table_name - element_params[el]["color"] = color + element_params[el]["col_for_color"] = col_for_color - element_params[el]["palette"] = param_dict["palette"] if element_params[el]["table_name"] is not None else None - element_params[el]["groups"] = param_dict["groups"] if element_params[el]["table_name"] is not None else None + has_col = element_params[el]["col_for_color"] is not None + element_params[el]["palette"] = param_dict["palette"] if has_col else None + element_params[el]["groups"] = param_dict["groups"] if has_col else None element_params[el]["colorbar"] = param_dict["colorbar"] element_params[el]["colorbar_params"] = param_dict["colorbar_params"] diff --git a/tests/pl/test_render_labels.py b/tests/pl/test_render_labels.py index a585d4eb..7000d077 100644 --- a/tests/pl/test_render_labels.py +++ b/tests/pl/test_render_labels.py @@ -84,6 +84,9 @@ def test_plot_can_stack_render_labels(self, sdata_blobs: SpatialData): .pl.show() ) + def test_plot_can_color_by_color_name(self, sdata_blobs: SpatialData): + sdata_blobs.pl.render_labels("blobs_labels", color="red").pl.show() + def test_plot_can_color_labels_by_continuous_variable(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum").pl.show()