Skip to content

Commit

Permalink
Backport PR #3372: Deprecate RandomState (using names only)
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep committed Nov 22, 2024
1 parent 7768360 commit aa33af6
Show file tree
Hide file tree
Showing 23 changed files with 72 additions and 61 deletions.
3 changes: 3 additions & 0 deletions src/scanpy/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
from pathlib import Path
from typing import TYPE_CHECKING

import numpy as np
from packaging.version import Version

if TYPE_CHECKING:
from importlib.metadata import PackageMetadata

_LegacyRandom = int | np.random.RandomState | None


if TYPE_CHECKING:
# type checkers are confused and can only see …core.Array
Expand Down
13 changes: 8 additions & 5 deletions src/scanpy/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
import sys
import warnings
from collections import namedtuple
from collections.abc import Sequence
from contextlib import contextmanager, suppress
from enum import Enum
from functools import partial, singledispatch, wraps
from operator import mul, truediv
from textwrap import dedent
from types import MethodType, ModuleType
from typing import TYPE_CHECKING, overload
from typing import TYPE_CHECKING, Union, overload
from weakref import WeakSet

import h5py
Expand Down Expand Up @@ -49,12 +50,14 @@
from anndata import AnnData
from numpy.typing import ArrayLike, DTypeLike, NDArray

from .._compat import _LegacyRandom
from ..neighbors import NeighborsParams, RPForestDict


# e.g. https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html
# maybe in the future random.Generator
AnyRandom = int | np.random.RandomState | None
SeedLike = int | np.integer | Sequence[int] | np.random.SeedSequence
RNGLike = np.random.Generator | np.random.BitGenerator

LegacyUnionType = type(Union[int, str]) # noqa: UP007


class Empty(Enum):
Expand Down Expand Up @@ -477,7 +480,7 @@ def moving_average(a: np.ndarray, n: int):
return ret[n - 1 :] / n


def get_random_state(seed: AnyRandom) -> np.random.RandomState:
def _get_legacy_random(seed: _LegacyRandom) -> np.random.RandomState:
if isinstance(seed, np.random.RandomState):
return seed
return np.random.RandomState(seed)
Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/datasets/_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
if TYPE_CHECKING:
from typing import Literal

from .._utils import AnyRandom
from .._compat import _LegacyRandom

VisiumSampleID = Literal[
"V1_Breast_Cancer_Block_A_Section_1",
Expand Down Expand Up @@ -63,7 +63,7 @@ def blobs(
n_centers: int = 5,
cluster_std: float = 1.0,
n_observations: int = 640,
random_state: AnyRandom = 0,
random_state: _LegacyRandom = 0,
) -> AnnData:
"""\
Gaussian Blobs.
Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/external/pp/_dca.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from anndata import AnnData

from ..._utils import AnyRandom
from ..._compat import _LegacyRandom

_AEType = Literal["zinb-conddisp", "zinb", "nb-conddisp", "nb"]

Expand Down Expand Up @@ -62,7 +62,7 @@ def dca(
early_stop: int = 15,
batch_size: int = 32,
optimizer: str = "RMSprop",
random_state: AnyRandom = 0,
random_state: _LegacyRandom = 0,
threads: int | None = None,
learning_rate: float | None = None,
verbose: bool = False,
Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/external/pp/_magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from anndata import AnnData

from ..._utils import AnyRandom
from ..._compat import _LegacyRandom

MIN_VERSION = "2.0"

Expand All @@ -36,7 +36,7 @@ def magic(
n_pca: int | None = 100,
solver: Literal["exact", "approximate"] = "exact",
knn_dist: str = "euclidean",
random_state: AnyRandom = None,
random_state: _LegacyRandom = None,
n_jobs: int | None = None,
verbose: bool = False,
copy: bool | None = None,
Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/external/tl/_phate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from anndata import AnnData

from ..._utils import AnyRandom
from ..._compat import _LegacyRandom


@old_positionals(
Expand Down Expand Up @@ -49,7 +49,7 @@ def phate(
mds_dist: str = "euclidean",
mds: Literal["classic", "metric", "nonmetric"] = "metric",
n_jobs: int | None = None,
random_state: AnyRandom = None,
random_state: _LegacyRandom = None,
verbose: bool | int | None = None,
copy: bool = False,
**kwargs,
Expand Down
12 changes: 6 additions & 6 deletions src/scanpy/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from igraph import Graph
from scipy.sparse import csr_matrix

from .._utils import AnyRandom
from .._compat import _LegacyRandom
from ._types import KnnTransformerLike, _Metric, _MetricFn


Expand All @@ -54,13 +54,13 @@ class KwdsForTransformer(TypedDict):
n_neighbors: int
metric: _Metric | _MetricFn
metric_params: Mapping[str, Any]
random_state: AnyRandom
random_state: _LegacyRandom


class NeighborsParams(TypedDict):
n_neighbors: int
method: _Method
random_state: AnyRandom
random_state: _LegacyRandom
metric: _Metric | _MetricFn
metric_kwds: NotRequired[Mapping[str, Any]]
use_rep: NotRequired[str]
Expand All @@ -79,7 +79,7 @@ def neighbors(
transformer: KnnTransformerLike | _KnownTransformer | None = None,
metric: _Metric | _MetricFn = "euclidean",
metric_kwds: Mapping[str, Any] = MappingProxyType({}),
random_state: AnyRandom = 0,
random_state: _LegacyRandom = 0,
key_added: str | None = None,
copy: bool = False,
) -> AnnData | None:
Expand Down Expand Up @@ -521,7 +521,7 @@ def compute_neighbors(
transformer: KnnTransformerLike | _KnownTransformer | None = None,
metric: _Metric | _MetricFn = "euclidean",
metric_kwds: Mapping[str, Any] = MappingProxyType({}),
random_state: AnyRandom = 0,
random_state: _LegacyRandom = 0,
) -> None:
"""\
Compute distances and connectivities of neighbors.
Expand Down Expand Up @@ -755,7 +755,7 @@ def compute_eigen(
n_comps: int = 15,
sym: bool | None = None,
sort: Literal["decrease", "increase"] = "decrease",
random_state: AnyRandom = 0,
random_state: _LegacyRandom = 0,
):
"""\
Compute eigen decomposition of transition matrix.
Expand Down
3 changes: 2 additions & 1 deletion src/scanpy/plotting/_tools/paga.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from matplotlib.colors import Colormap
from scipy.sparse import spmatrix

from ..._compat import _LegacyRandom
from ...tools._draw_graph import _Layout as _LayoutWithoutEqTree
from .._utils import _FontSize, _FontWeight, _LegendLoc

Expand Down Expand Up @@ -210,7 +211,7 @@ def _compute_pos(
adjacency_solid: spmatrix | np.ndarray,
*,
layout: _Layout | None = None,
random_state: _sc_utils.AnyRandom = 0,
random_state: _LegacyRandom = 0,
init_pos: np.ndarray | None = None,
adj_tree=None,
root: int = 0,
Expand Down
7 changes: 4 additions & 3 deletions src/scanpy/preprocessing/_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from scipy.sparse import spmatrix
from sklearn.decomposition import PCA

from .._utils import AnyRandom, Empty
from .._compat import _LegacyRandom
from .._utils import Empty


@_doc_params(
Expand All @@ -39,7 +40,7 @@ def pca(
layer: str | None = None,
zero_center: bool | None = True,
svd_solver: str | None = None,
random_state: AnyRandom = 0,
random_state: _LegacyRandom = 0,
return_info: bool = False,
mask_var: NDArray[np.bool_] | str | None | Empty = _empty,
use_highly_variable: bool | None = None,
Expand Down Expand Up @@ -396,7 +397,7 @@ def _pca_with_sparse(
*,
solver: str = "arpack",
mu: NDArray[np.floating] | None = None,
random_state: AnyRandom = None,
random_state: _LegacyRandom = None,
) -> tuple[NDArray[np.floating], PCA]:
random_state = check_random_state(random_state)
np.random.set_state(random_state.get_state())
Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/preprocessing/_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
if TYPE_CHECKING:
from anndata import AnnData

from .._utils import AnyRandom
from .._compat import _LegacyRandom


@old_positionals(
Expand All @@ -36,7 +36,7 @@ def recipe_weinreb17(
cv_threshold: int = 2,
n_pcs: int = 50,
svd_solver="randomized",
random_state: AnyRandom = 0,
random_state: _LegacyRandom = 0,
copy: bool = False,
) -> AnnData | None:
"""\
Expand Down
8 changes: 4 additions & 4 deletions src/scanpy/preprocessing/_scrublet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .core import Scrublet

if TYPE_CHECKING:
from ..._utils import AnyRandom
from ..._compat import _LegacyRandom
from ...neighbors import _Metric, _MetricFn


Expand Down Expand Up @@ -58,7 +58,7 @@ def scrublet(
threshold: float | None = None,
verbose: bool = True,
copy: bool = False,
random_state: AnyRandom = 0,
random_state: _LegacyRandom = 0,
) -> AnnData | None:
"""\
Predict doublets using Scrublet :cite:p:`Wolock2019`.
Expand Down Expand Up @@ -309,7 +309,7 @@ def _scrublet_call_doublets(
knn_dist_metric: _Metric | _MetricFn = "euclidean",
get_doublet_neighbor_parents: bool = False,
threshold: float | None = None,
random_state: AnyRandom = 0,
random_state: _LegacyRandom = 0,
verbose: bool = True,
) -> AnnData:
"""\
Expand Down Expand Up @@ -503,7 +503,7 @@ def scrublet_simulate_doublets(
layer: str | None = None,
sim_doublet_ratio: float = 2.0,
synthetic_doublet_umi_subsampling: float = 1.0,
random_seed: AnyRandom = 0,
random_seed: _LegacyRandom = 0,
) -> AnnData:
"""\
Simulate doublets by adding the counts of random observed transcriptome pairs.
Expand Down
10 changes: 5 additions & 5 deletions src/scanpy/preprocessing/_scrublet/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from scipy import sparse

from ... import logging as logg
from ..._utils import get_random_state
from ..._utils import _get_legacy_random
from ...neighbors import (
Neighbors,
_get_indices_distances_from_sparse_matrix,
Expand All @@ -21,7 +21,7 @@
from numpy.random import RandomState
from numpy.typing import NDArray

from ..._utils import AnyRandom
from ..._compat import _LegacyRandom
from ...neighbors import _Metric, _MetricFn

__all__ = ["Scrublet"]
Expand Down Expand Up @@ -73,7 +73,7 @@ class Scrublet:
n_neighbors: InitVar[int | None] = None
expected_doublet_rate: float = 0.1
stdev_doublet_rate: float = 0.02
random_state: InitVar[AnyRandom] = 0
random_state: InitVar[_LegacyRandom] = 0

# private fields

Expand Down Expand Up @@ -174,7 +174,7 @@ def __post_init__(
counts_obs: sparse.csr_matrix | sparse.csc_matrix | NDArray[np.integer],
total_counts_obs: NDArray[np.integer] | None,
n_neighbors: int | None,
random_state: AnyRandom,
random_state: _LegacyRandom,
) -> None:
self._counts_obs = sparse.csc_matrix(counts_obs)
self._total_counts_obs = (
Expand All @@ -187,7 +187,7 @@ def __post_init__(
if n_neighbors is None
else n_neighbors
)
self._random_state = get_random_state(random_state)
self._random_state = _get_legacy_random(random_state)

def simulate_doublets(
self,
Expand Down
6 changes: 3 additions & 3 deletions src/scanpy/preprocessing/_scrublet/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
if TYPE_CHECKING:
from typing import Literal

from ..._utils import AnyRandom
from ..._compat import _LegacyRandom
from .core import Scrublet


Expand Down Expand Up @@ -49,7 +49,7 @@ def truncated_svd(
self: Scrublet,
n_prin_comps: int = 30,
*,
random_state: AnyRandom = 0,
random_state: _LegacyRandom = 0,
algorithm: Literal["arpack", "randomized"] = "arpack",
) -> None:
if self._counts_sim_norm is None:
Expand All @@ -68,7 +68,7 @@ def pca(
self: Scrublet,
n_prin_comps: int = 50,
*,
random_state: AnyRandom = 0,
random_state: _LegacyRandom = 0,
svd_solver: Literal["auto", "full", "arpack", "randomized"] = "arpack",
) -> None:
if self._counts_sim_norm is None:
Expand Down
8 changes: 4 additions & 4 deletions src/scanpy/preprocessing/_scrublet/sparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@

from scanpy.preprocessing._utils import _get_mean_var

from ..._utils import get_random_state
from ..._utils import _get_legacy_random

if TYPE_CHECKING:
from numpy.typing import NDArray

from ..._utils import AnyRandom
from .._compat import _LegacyRandom


def sparse_multiply(
Expand Down Expand Up @@ -47,10 +47,10 @@ def subsample_counts(
*,
rate: float,
original_totals,
random_seed: AnyRandom = 0,
random_seed: _LegacyRandom = 0,
) -> tuple[sparse.csr_matrix | sparse.csc_matrix, NDArray[np.int64]]:
if rate < 1:
random_seed = get_random_state(random_seed)
random_seed = _get_legacy_random(random_seed)
E.data = random_seed.binomial(np.round(E.data).astype(int), rate)
current_totals = np.asarray(E.sum(1)).squeeze()
unsampled_orig_totals = original_totals - current_totals
Expand Down
Loading

0 comments on commit aa33af6

Please sign in to comment.