Skip to content

Commit

Permalink
some tests for presults including compute_gene_trends
Browse files Browse the repository at this point in the history
  • Loading branch information
katosh committed Nov 27, 2023
1 parent 2f33c68 commit 9909bde
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 0 deletions.
35 changes: 35 additions & 0 deletions tests/presults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import numpy as np
import pandas as pd
import palantir

def test_PResults():
# Create some dummy data
pseudotime = np.array([0.1, 0.2, 0.3, 0.4, 0.5])
entropy = None
branch_probs = pd.DataFrame({'branch1': [0.1, 0.2, 0.3, 0.4, 0.5], 'branch2': [0.5, 0.4, 0.3, 0.2, 0.1]})
waypoints = None

# Initialize PResults object
presults = palantir.presults.PResults(pseudotime, entropy, branch_probs, waypoints)

# Asserts to check attributes
assert np.array_equal(presults.pseudotime, pseudotime)
assert presults.entropy is None
assert presults.waypoints is None
assert np.array_equal(presults.branch_probs, branch_probs.values)

def test_gam_fit_predict():
# Create some dummy data
x = np.array([0.1, 0.2, 0.3, 0.4, 0.5])
y = np.array([0.1, 0.2, 0.3, 0.4, 0.5])
weights = None
pred_x = None
n_splines = 4
spline_order = 2

# Call the function
y_pred, stds = palantir.presults.gam_fit_predict(x, y, weights, pred_x, n_splines, spline_order)

# Asserts to check the output
assert isinstance(y_pred, np.ndarray)
assert isinstance(stds, np.ndarray)
84 changes: 84 additions & 0 deletions tests/presults_compute_gene_trends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import pytest
import palantir

@pytest.fixture
def mock_adata():
import pandas as pd
import numpy as np
from anndata import AnnData

n_cells = 10

# Create mock data
adata = AnnData(
X=np.random.rand(n_cells, 3),
obs=pd.DataFrame(
{"palantir_pseudotime": np.random.rand(n_cells)},
index=[f"cell_{i}" for i in range(n_cells)],
),
var=pd.DataFrame(index=[f"gene_{i}" for i in range(3)]),
)

adata.obsm["branch_masks"] = pd.DataFrame(
np.random.randint(2, size=(n_cells, 2)),
columns=["branch_1", "branch_2"],
index=adata.obs_names,
)

return adata

@pytest.fixture
def mock_adata_old():
import pandas as pd
import numpy as np
from anndata import AnnData

n_cells = 10

# Create mock data
adata = AnnData(
X=np.random.rand(n_cells, 3),
obs=pd.DataFrame(
{"palantir_pseudotime": np.random.rand(n_cells)},
index=[f"cell_{i}" for i in range(n_cells)],
),
var=pd.DataFrame(index=[f"gene_{i}" for i in range(3)]),
)

# Create mock branch_masks in obsm
adata.obsm["branch_masks"] = pd.DataFrame(np.random.randint(2, size=(n_cells, 2))
adata.uns["branch_masks_columns"] = ["branch_1", "branch_2"]

return adata

@pytest.mark.parametrize("adata", [mock_adata, mock_adata_old])
def test_compute_gene_trends(adata):
# Call the function with default keys
res = palantir.presults.compute_gene_trends(adata)

# Asserts to check the output
assert isinstance(res, dict)
assert "branch_1" in res
assert "branch_2" in res
assert isinstance(res["branch_1"], dict)
assert isinstance(res["branch_1"]["trends"], pd.DataFrame)
assert "gene_0" in res["branch_1"]["trends"].index
assert adata.varm["gene_trends_branch_1"].shape == (3, 500)

# Call the function with custom keys
res = palantir.presults.compute_gene_trends(
adata,
masks_key="custom_masks",
pseudo_time_key="custom_time",
gene_trend_key="custom_trends",
)

# Asserts to check the output with custom keys
assert isinstance(res, dict)
assert "branch_1" in res
assert "branch_2" in res
assert isinstance(res["branch_1"], dict)
assert isinstance(res["branch_1"]["trends"], pd.DataFrame)
assert "gene_0" in res["branch_1"]["trends"].index
assert adata.varm["custom_trends_branch_1"].shape == (3, 500)

0 comments on commit 9909bde

Please sign in to comment.