Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,7 @@ def render_labels(
sdata.plotting_tree[f"{n_steps + 1}_render_labels"] = LabelsRenderParams(
element=element,
color=param_values["color"],
col_for_color=param_values["col_for_color"],
groups=param_values["groups"],
contour_px=param_values["contour_px"],
cmap_params=cmap_params,
Expand Down Expand Up @@ -1121,14 +1122,13 @@ def _draw_colorbar(

if wanted_labels_on_this_cs:
table = params_copy.table_name
if table is not None:
assert isinstance(params_copy.color, str)
colors = sc.get.obs_df(sdata[table], [params_copy.color])
if isinstance(colors[params_copy.color].dtype, pd.CategoricalDtype):
if table is not None and params_copy.col_for_color is not None:
colors = sc.get.obs_df(sdata[table], [params_copy.col_for_color])
if isinstance(colors[params_copy.col_for_color].dtype, pd.CategoricalDtype):
_maybe_set_colors(
source=sdata[table],
target=sdata[table],
key=params_copy.color,
key=params_copy.col_for_color,
palette=params_copy.palette,
)

Expand Down
16 changes: 9 additions & 7 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,7 +1262,7 @@ def _render_labels(
table_name = render_params.table_name
table_layer = render_params.table_layer
palette = render_params.palette
color = render_params.color
col_for_color = render_params.col_for_color
groups = render_params.groups
scale = render_params.scale

Expand Down Expand Up @@ -1311,23 +1311,25 @@ def _render_labels(

_, trans_data = _prepare_transformation(label, coordinate_system, ax)

na_color = render_params.color if render_params.color else render_params.cmap_params.na_color
color_source_vector, color_vector, categorical = _set_color_source_vec(
sdata=sdata_filt,
element=label,
element_name=element,
value_to_plot=color,
value_to_plot=col_for_color,
groups=groups,
palette=palette,
na_color=render_params.cmap_params.na_color,
na_color=na_color,
cmap_params=render_params.cmap_params,
table_name=table_name,
table_layer=table_layer,
render_type="labels",
coordinate_system=coordinate_system,
)

# rasterize could have removed labels from label
# only problematic if color is specified
if rasterize and color is not None:
if rasterize and col_for_color is not None:
labels_in_rasterized_image = np.unique(label.values)
mask = np.isin(instance_id, labels_in_rasterized_image)
instance_id = instance_id[mask]
Expand Down Expand Up @@ -1405,15 +1407,15 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float)
colorbar_requested = _should_request_colorbar(
render_params.colorbar,
has_mappable=cax is not None,
is_continuous=color is not None and color_source_vector is None and not categorical,
is_continuous=col_for_color is not None and color_source_vector is None and not categorical,
)

_ = _decorate_axs(
ax=ax,
cax=cax,
fig_params=fig_params,
adata=table,
value_to_plot=color,
value_to_plot=col_for_color,
color_source_vector=color_source_vector,
color_vector=color_vector,
palette=palette,
Expand All @@ -1429,7 +1431,7 @@ def _draw_labels(seg_erosionpx: int | None, seg_boundaries: bool, alpha: float)
colorbar_requests=colorbar_requests,
colorbar_label=_resolve_colorbar_label(
render_params.colorbar_params,
color if isinstance(color, str) else None,
col_for_color if isinstance(col_for_color, str) else None,
),
scalebar_dx=scalebar_params.scalebar_dx,
scalebar_units=scalebar_params.scalebar_units,
Expand Down
3 changes: 2 additions & 1 deletion src/spatialdata_plot/pl/render_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ class LabelsRenderParams:

cmap_params: CmapParams
element: str
color: str | None = None
color: Color | None = None
col_for_color: str | None = None
groups: str | list[str] | None = None
contour_px: int | None = None
outline: bool = False
Expand Down
27 changes: 15 additions & 12 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,7 @@ def _set_color_source_vec(
alpha: float = 1.0,
table_name: str | None = None,
table_layer: str | None = None,
render_type: Literal["points"] | None = None,
render_type: Literal["points", "labels"] | None = None,
coordinate_system: str | None = None,
) -> tuple[ArrayLike | pd.Series | None, ArrayLike, bool]:
if value_to_plot is None and element is not None:
Expand Down Expand Up @@ -1451,7 +1451,7 @@ def _get_categorical_color_mapping(
alpha: float = 1,
groups: list[str] | str | None = None,
palette: list[str] | str | None = None,
render_type: Literal["points"] | None = None,
render_type: Literal["points", "labels"] | None = None,
) -> Mapping[str, str]:
if not isinstance(color_source_vector, Categorical):
raise TypeError(f"Expected `categories` to be a `Categorical`, but got {type(color_source_vector).__name__}")
Expand Down Expand Up @@ -2138,15 +2138,15 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
}:
if not isinstance(color, str | tuple | list):
raise TypeError("Parameter 'color' must be a string or a tuple/list of floats.")
if element_type in {"shapes", "points"}:
if element_type in {"shapes", "points", "labels"}:
if _is_color_like(color):
logger.info("Value for parameter 'color' appears to be a color, using it as such.")
param_dict["col_for_color"] = None
param_dict["color"] = Color(color)
if param_dict["color"].alpha_is_user_defined():
if element_type == "points" and param_dict.get("alpha") is None:
param_dict["alpha"] = param_dict["color"].get_alpha_as_float()
elif element_type == "shapes" and param_dict.get("fill_alpha") is None:
elif element_type in {"shapes", "labels"} and param_dict.get("fill_alpha") is None:
param_dict["fill_alpha"] = param_dict["color"].get_alpha_as_float()
else:
logger.info(
Expand All @@ -2158,7 +2158,7 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
param_dict["color"] = None
else:
raise ValueError(f"{color} is not a valid RGB(A) array and therefore can't be used as 'color' value.")
elif "color" in param_dict and element_type != "labels":
elif "color" in param_dict and element_type != "images":
param_dict["col_for_color"] = None

outline_width = param_dict.get("outline_width")
Expand Down Expand Up @@ -2455,15 +2455,18 @@ def _validate_label_render_params(
element_params[el]["table_layer"] = param_dict["table_layer"]

element_params[el]["table_name"] = None
element_params[el]["color"] = None
color = param_dict["color"]
if color is not None:
color, table_name = _validate_col_for_column_table(sdata, el, color, param_dict["table_name"], labels=True)
element_params[el]["color"] = param_dict["color"] # literal Color or None
element_params[el]["col_for_color"] = None
if (col_for_color := param_dict["col_for_color"]) is not None:
col_for_color, table_name = _validate_col_for_column_table(
sdata, el, col_for_color, param_dict["table_name"], labels=True
)
element_params[el]["table_name"] = table_name
element_params[el]["color"] = color
element_params[el]["col_for_color"] = col_for_color

element_params[el]["palette"] = param_dict["palette"] if element_params[el]["table_name"] is not None else None
element_params[el]["groups"] = param_dict["groups"] if element_params[el]["table_name"] is not None else None
has_col = element_params[el]["col_for_color"] is not None
element_params[el]["palette"] = param_dict["palette"] if has_col else None
element_params[el]["groups"] = param_dict["groups"] if has_col else None
element_params[el]["colorbar"] = param_dict["colorbar"]
element_params[el]["colorbar_params"] = param_dict["colorbar_params"]

Expand Down
3 changes: 3 additions & 0 deletions tests/pl/test_render_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def test_plot_can_stack_render_labels(self, sdata_blobs: SpatialData):
.pl.show()
)

def test_plot_can_color_by_color_name(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_labels("blobs_labels", color="red").pl.show()

def test_plot_can_color_labels_by_continuous_variable(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_labels("blobs_labels", color="channel_0_sum").pl.show()

Expand Down
Loading