Skip to content

Commit

Permalink
Fix usage of DataFrame.set_axis()
Browse files Browse the repository at this point in the history
  • Loading branch information
jpdunc23 committed Feb 9, 2024
1 parent ad4c4e4 commit 34391b9
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 11 deletions.
11 changes: 2 additions & 9 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,9 @@
import pandas as pd
import pytest
from numpy.testing import assert_equal

from vflow.subkey import Subkey as sm
from vflow.utils import (
PREV_KEY,
apply_vfuncs,
combine_dicts,
dict_to_df,
perturbation_stats,
to_list,
)
from vflow.utils import (PREV_KEY, apply_vfuncs, combine_dicts, dict_to_df,
perturbation_stats, to_list)


@pytest.mark.parametrize(
Expand Down
4 changes: 2 additions & 2 deletions vflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def dict_to_df(d: dict, param_key=None):
cols = [
c if c != "init" else init_step(idx, cols) for idx, c in enumerate(cols)
]
df.set_axis(cols, axis=1, copy=False)
df = df.set_axis(cols, axis=1)
if param_key:
param_keys = df[
param_key
Expand All @@ -235,7 +235,7 @@ def dict_to_df(d: dict, param_key=None):
param_keys = [[s.split("=")[1] for s in t] for t in param_keys]
df = df.join(pd.DataFrame(param_keys)).drop(columns=param_key)
new_cols = df.columns[: len(cols) - 1].tolist() + param_key_cols
df.set_axis(new_cols, axis=1, copy=False)
df = df.set_axis(new_cols, axis=1)
new_idx = list(range(len(new_cols)))
new_idx = (
new_idx[:param_loc]
Expand Down

0 comments on commit 34391b9

Please sign in to comment.