diff --git a/src/palantir/plot.py b/src/palantir/plot.py index 21c08eff..ed35d480 100644 --- a/src/palantir/plot.py +++ b/src/palantir/plot.py @@ -1940,7 +1940,9 @@ def plot_trajectory( arrowprops : dict, optional Properties for the arrowstyle. If None, defaults to black arrow with lw=1. scanpy_kwargs : dict, optional - Keyword arguments for the scanpy.pl.emebdding function to plot the cells. + Keyword arguments for the scanpy.pl.emebdding function to plot the cells + unless `masks_key == "branch_masks"` in which case these arguments are + passed to `matplotlib.pyplot.scatter`. figsize : Tuple[float, float], optional Size of the plot in inches, as (width, height). Defaults to (5, 5). **kwargs @@ -1993,12 +1995,14 @@ def plot_trajectory( umap[~mask, 1], c=config.DESELECTED_COLOR, label="Other Cells", + **scanpy_kwargs ) ax.scatter( umap[mask, 0], umap[mask, 1], c=config.SELECTED_COLOR, label="Selected Cells", + **scanpy_kwargs ) elif cell_color is not None: b = embedding_basis[2:] if embedding_basis.startswith("X_") else embedding_basis