Skip to content

Commit

Permalink
fix pytest general nn error
Browse files Browse the repository at this point in the history
  • Loading branch information
SunsetWolf committed Dec 14, 2024
1 parent 27813eb commit fd1ba4c
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ prerequisite:
echo "No shared library files found, building..."; \
pip install --upgrade setuptools wheel; \
python -m pip install cython; \
python -m pip install "numpy>=1.24.0"; \
python -m pip install "numpy<2.0.0"; \
python -c "from setuptools import setup, Extension; from Cython.Build import cythonize; import numpy; extensions = [Extension('qlib.data._libs.rolling', ['qlib/data/_libs/rolling.pyx'], language='c++', include_dirs=[numpy.get_include()]), Extension('qlib.data._libs.expanding', ['qlib/data/_libs/expanding.pyx'], language='c++', include_dirs=[numpy.get_include()])]; setup(ext_modules=cythonize(extensions, language_level='3'), script_args=['build_ext', '--inplace'])"; \
fi

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = ["setuptools", "cython", "numpy>=1.24.0"]
requires = ["setuptools", "cython", "numpy<2.0.0>"]
build-backend = "setuptools.build_meta"

[project]
Expand Down
31 changes: 25 additions & 6 deletions qlib/contrib/model/pytorch_general_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,15 @@ def fit(
evals_result=dict(),
save_path=None,
reweighter=None,
batch_size=None,
n_jobs=None,
):
if batch_size is None:
batch_size = self.batch_size

if n_jobs is None:
n_jobs = self.n_jobs

ists = isinstance(dataset, TSDatasetH) # is this time series dataset

dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
Expand Down Expand Up @@ -261,16 +269,16 @@ def fit(

train_loader = DataLoader(
ConcatDataset(dl_train, wl_train),
batch_size=self.batch_size,
batch_size=batch_size,
shuffle=True,
num_workers=self.n_jobs,
num_workers=n_jobs,
drop_last=True,
)
valid_loader = DataLoader(
ConcatDataset(dl_valid, wl_valid),
batch_size=self.batch_size,
batch_size=batch_size,
shuffle=False,
num_workers=self.n_jobs,
num_workers=n_jobs,
drop_last=True,
)
del dl_train, dl_valid, wl_train, wl_valid
Expand Down Expand Up @@ -319,7 +327,18 @@ def fit(
if self.use_gpu:
torch.cuda.empty_cache()

def predict(self, dataset: Union[DatasetH, TSDatasetH]):
def predict(
self,
dataset: Union[DatasetH, TSDatasetH],
batch_size=None,
n_jobs=None,
):
if batch_size is None:
batch_size = self.batch_size

if n_jobs is None:
n_jobs = self.n_jobs

if not self.fitted:
raise ValueError("model is not fitted yet!")

Expand All @@ -333,7 +352,7 @@ def predict(self, dataset: Union[DatasetH, TSDatasetH]):
index = dl_test.index
dl_test = dl_test.values

test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
test_loader = DataLoader(dl_test, batch_size=batch_size, num_workers=n_jobs)
self.dnn_model.eval()
preds = []

Expand Down
4 changes: 2 additions & 2 deletions tests/model/test_general_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def test_both_dataset(self):
]

for ds, model in list(zip((tsds, tbds), model_l)):
model.fit(ds) # It works
model.predict(ds) # It works
model.fit(ds, batch_size=32, n_jobs=0) # It works
model.predict(ds, batch_size=32, n_jobs=0) # It works


if __name__ == "__main__":
Expand Down

0 comments on commit fd1ba4c

Please sign in to comment.