diff --git a/tests/plot.py b/tests/plot.py index 46b1b920..74515fba 100644 --- a/tests/plot.py +++ b/tests/plot.py @@ -179,6 +179,7 @@ def test_plot_molecules_per_cell_and_gene(): assert ax.get_xlabel() == "Molecules per gene (log10 scale)" assert ax.get_ylabel() == "Frequency" + plt.close() def test_cell_types_default_colors(mock_tsne, mock_clusters): @@ -274,23 +275,27 @@ def test_plot_tsne_by_cell_sizes(mock_data, mock_tsne): 0.2, 0.8, ), "Color limits should be set to vmin and vmax" + plt.close() def test_plot_gene_expression(mock_gene_data, mock_tsne): genes = ["gene_0", "gene_1"] fig, axs = plot_gene_expression(mock_gene_data, mock_tsne, genes, plot_scale=True) assert isinstance(fig, plt.Figure) + plt.close() def test_plot_gene_expression_missing_genes(mock_gene_data, mock_tsne): genes = ["gene_0", "nonexistent_gene"] fig, axs = plot_gene_expression(mock_gene_data, mock_tsne, genes) assert isinstance(fig, plt.Figure) # Expect a warning but still a plot + plt.close() def test_plot_gene_expression_no_genes(mock_gene_data, mock_tsne): with pytest.raises(ValueError): plot_gene_expression(mock_gene_data, mock_tsne, ["nonexistent_gene"]) + plt.close() def test_plot_diffusion_components_with_anndata(mock_anndata, mock_dm_res): @@ -298,6 +303,7 @@ def test_plot_diffusion_components_with_anndata(mock_anndata, mock_dm_res): assert isinstance(fig, plt.Figure) for ax in axs.values(): assert isinstance(ax, plt.Axes) + plt.close() def test_plot_diffusion_components_with_dataframe(mock_tsne, mock_dm_res): @@ -306,16 +312,19 @@ def test_plot_diffusion_components_with_dataframe(mock_tsne, mock_dm_res): assert isinstance(fig, plt.Figure) for ax in axs.values(): assert isinstance(ax, plt.Axes) + plt.close() def test_plot_diffusion_components_key_error_embedding(mock_anndata): with pytest.raises(KeyError): plot_diffusion_components(mock_anndata, embedding_basis="NonexistentKey") + plt.close() def test_plot_diffusion_components_key_error_dm_res(mock_anndata): with pytest.raises(KeyError): plot_diffusion_components(mock_anndata, dm_res="NonexistentKey") + plt.close() def test_plot_diffusion_components_default_args(mock_anndata): @@ -324,6 +333,7 @@ def test_plot_diffusion_components_default_args(mock_anndata): assert ( ax.collections[0].get_array().data.shape[0] == 100 ) # Checking data points + plt.close() def test_plot_diffusion_components_custom_args(mock_anndata): @@ -331,24 +341,28 @@ def test_plot_diffusion_components_custom_args(mock_anndata): for ax in axs.values(): assert ax.collections[0].get_edgecolors().all() == np.array([1, 0, 0, 1]).all() assert ax.collections[0].get_sizes()[0] == 10 + plt.close() # Test with AnnData and all keys available def test_plot_palantir_results_anndata(mock_anndata): fig = plot_palantir_results(mock_anndata) assert isinstance(fig, plt.Figure) + plt.close() # Test with DataFrame and PResults def test_plot_palantir_results_dataframe(mock_tsne, mock_presults): fig = plot_palantir_results(mock_tsne, pr_res=mock_presults) assert isinstance(fig, plt.Figure) + plt.close() # Test KeyError for missing embedding_basis def test_plot_palantir_results_key_error_embedding(mock_anndata): with pytest.raises(KeyError): plot_palantir_results(mock_anndata, embedding_basis="NonexistentKey") + plt.close() # Test KeyError for missing Palantir results in AnnData @@ -356,6 +370,7 @@ def test_plot_palantir_results_key_error_palantir(mock_anndata): mock_anndata.obs = pd.DataFrame(index=mock_anndata.obs_names) # Clearing obs with pytest.raises(KeyError): plot_palantir_results(mock_anndata) + plt.close() # Test plotting with custom arguments @@ -364,24 +379,28 @@ def test_plot_palantir_results_custom_args(mock_anndata): ax = fig.axes[0] # Assuming first subplot holds the first scatter plot assert np.all(ax.collections[0].get_edgecolors() == [1, 0, 0, 1]) assert ax.collections[0].get_sizes()[0] == 10 + plt.close() # Test with AnnData and all keys available def test_plot_terminal_state_probs_anndata(mock_anndata, mock_cells): fig = plot_terminal_state_probs(mock_anndata, mock_cells) assert isinstance(fig, plt.Figure) + plt.close() # Test with DataFrame and PResults def test_plot_terminal_state_probs_dataframe(mock_data, mock_presults, mock_cells): fig = plot_terminal_state_probs(mock_data, mock_cells, pr_res=mock_presults) assert isinstance(fig, plt.Figure) + plt.close() # Test ValueError for missing pr_res in DataFrame input def test_plot_terminal_state_probs_value_error(mock_data, mock_cells): with pytest.raises(ValueError): plot_terminal_state_probs(mock_data, mock_cells) + plt.close() # Test plotting with custom arguments @@ -389,6 +408,7 @@ def test_plot_terminal_state_probs_custom_args(mock_anndata, mock_cells): fig = plot_terminal_state_probs(mock_anndata, mock_cells, linewidth=2.0) ax = fig.axes[0] # Assuming first subplot holds the first bar plot assert ax.patches[0].get_linewidth() == 2.0 + plt.close() # Test if the function uses the correct keys and raises appropriate errors @@ -396,12 +416,15 @@ def test_plot_branch_selection_keys(mock_anndata): # This will depend on how your mock_anndata is structured with pytest.raises(KeyError): plot_branch_selection(mock_anndata, pseudo_time_key="invalid_key") + plt.close() with pytest.raises(KeyError): plot_branch_selection(mock_anndata, fate_prob_key="invalid_key") + plt.close() with pytest.raises(KeyError): plot_branch_selection(mock_anndata, embedding_basis="invalid_basis") + plt.close() # Test the scatter custom arguments @@ -417,6 +440,7 @@ def test_plot_branch_selection_custom_args(mock_anndata): alpha1 = scatter1.get_alpha() assert alpha1 == 0.5 + plt.close() # Test 1: Basic functionality @@ -425,7 +449,7 @@ def test_plot_gene_trends_legacy_basic(mock_gene_trends): axes = fig.axes # Check if the number of subplots matches the number of genes assert len(axes) == 2 - # Perform additional checks on axes content if needed + plt.close() # Test 2: Custom gene list @@ -436,6 +460,7 @@ def test_plot_gene_trends_legacy_custom_genes(mock_gene_trends): assert len(axes) == 1 # Check if the title of the subplot matches the custom gene assert axes[0].get_title() == "Gene1" + plt.close() # Test 3: Color consistency @@ -446,6 +471,7 @@ def test_plot_gene_trends_legacy_color_consistency(mock_gene_trends): colors_2 = [line.get_color() for line in axes[1].lines] # Check if the colors are consistent across different genes assert colors_1 == colors_2 + plt.close() # Test 1: Basic Functionality with AnnData @@ -453,6 +479,7 @@ def test_plot_gene_trends_basic_anndata(mock_anndata): fig = plot_gene_trends(mock_anndata) axes = fig.axes assert len(axes) == mock_anndata.n_vars + plt.close() # Test 2: Basic Functionality with Dictionary @@ -460,6 +487,7 @@ def test_plot_gene_trends_basic_dict(mock_gene_trends): fig = plot_gene_trends(mock_gene_trends) axes = fig.axes assert len(axes) == 2 # Mock data contains 2 genes + plt.close() # Test 3: Custom Genes @@ -468,6 +496,7 @@ def test_plot_gene_trends_custom_genes(mock_anndata): axes = fig.axes assert len(axes) == 1 assert axes[0].get_title() == "gene_1" + plt.close() # Test 4: Custom Branch Names @@ -475,12 +504,14 @@ def test_plot_gene_trends_custom_branch_names(mock_anndata): fig = plot_gene_trends(mock_anndata, branch_names=["a", "b"]) axes = fig.axes assert len(axes) == mock_anndata.n_vars + plt.close() # Test 5: Error Handling - Invalid Data Type def test_plot_gene_trends_invalid_data_type(): with pytest.raises(ValueError): plot_gene_trends("invalid_data_type") + plt.close() # Test 6: Error Handling - Missing Key @@ -489,12 +520,14 @@ def test_plot_gene_trends_missing_key(mock_anndata): plot_gene_trends( mock_anndata, gene_trend_key="missing_key", branch_names="missing_branch" ) + plt.close() @pytest.mark.parametrize("wrong_type", [123, True, 1.23, "unknown_key"]) def test_plot_stats_key_errors(mock_anndata, wrong_type): with pytest.raises(KeyError): plot_stats(mock_anndata, x=wrong_type, y="palantir_pseudotime") + plt.close() def test_plot_stats_basic(mock_anndata): @@ -510,6 +543,7 @@ def test_plot_stats_optional_parameters(mock_anndata): y="palantir_entropy", color="palantir_entropy", ) + plt.close() def test_plot_stats_masking(mock_anndata): @@ -522,6 +556,7 @@ def test_plot_stats_masking(mock_anndata): y="palantir_entropy", masks_key="branch_masks", ) + plt.close() @pytest.mark.parametrize( @@ -555,19 +590,23 @@ def test_plot_branch_functionality(mock_anndata): def test_plot_trend_type_validation(mock_anndata): with pytest.raises(TypeError): plot_trend("string_instead_of_anndata", "a", "gene_1") + plt.close() with pytest.raises(TypeError): plot_trend(mock_anndata, 123, "gene_1") + plt.close() def test_plot_trend_value_validation(mock_anndata): with pytest.raises((ValueError, KeyError)): plot_trend(mock_anndata, "nonexistent_branch", "gene_1") + plt.close() def test_plot_trend_plotting(mock_anndata): fig, ax = plot_trend(mock_anndata, "a", "gene_1") assert isinstance(fig, plt.Figure) assert isinstance(ax, plt.Axes) + plt.close() def test_plot_gene_trend_heatmaps(mock_anndata):