Skip to content

Commit

Permalink
fix: index tensor on cpu (#961)
Browse files Browse the repository at this point in the history
### Summary of Changes

The index tensor of an `ImageDataset` sometimes ended up on the CPU,
instead of the default device, which led to runtime errors. This PR
fixes this.
  • Loading branch information
lars-reimann authored Nov 26, 2024
1 parent 5b32acc commit afafd43
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
},
{
"cell_type": "code",
"execution_count": null,
"id": "initial_id",
"metadata": {
"collapsed": true
Expand All @@ -57,7 +56,8 @@
"\n",
"images, filepaths = ImageList.from_files(\"data/shapes\", return_filenames=True)"
],
"outputs": []
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
Expand All @@ -84,8 +84,8 @@
"collapsed": false
},
"id": "66dcf95a3fa51f23",
"execution_count": null,
"outputs": []
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
Expand All @@ -108,8 +108,8 @@
"collapsed": false
},
"id": "32056ddf5396e070",
"execution_count": null,
"outputs": []
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
Expand Down Expand Up @@ -149,8 +149,8 @@
"collapsed": false
},
"id": "806a8091249d533a",
"execution_count": null,
"outputs": []
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
Expand All @@ -175,8 +175,8 @@
"collapsed": false
},
"id": "af68cc0d32655d32",
"execution_count": null,
"outputs": []
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
Expand All @@ -198,15 +198,13 @@
},
{
"cell_type": "code",
"source": [
"cnn_fitted = cnn.fit(dataset, epoch_size=32, batch_size=16)"
],
"source": "cnn_fitted = cnn.fit(dataset, epoch_size=8, batch_size=16)",
"metadata": {
"collapsed": false
},
"id": "381627a94d500675",
"execution_count": null,
"outputs": []
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
Expand All @@ -227,8 +225,8 @@
"collapsed": false
},
"id": "62f63dd68362c8b7",
"execution_count": null,
"outputs": []
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
Expand All @@ -249,8 +247,8 @@
"collapsed": false
},
"id": "779277d73e30554d",
"execution_count": null,
"outputs": []
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
Expand All @@ -271,8 +269,8 @@
"collapsed": false
},
"id": "a5ddbbfba41aa7f",
"execution_count": null,
"outputs": []
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
Expand All @@ -293,8 +291,8 @@
"collapsed": false
},
"id": "7081595d7100fb42",
"execution_count": null,
"outputs": []
"outputs": [],
"execution_count": null
}
],
"metadata": {
Expand Down
2 changes: 2 additions & 0 deletions src/safeds/data/labeled/containers/_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,8 @@ def split(
upper_bound=_ClosedBound(1),
)

_init_default_device()

first_dataset: ImageDataset[Out_co] = copy.copy(self)
second_dataset: ImageDataset[Out_co] = copy.copy(self)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from safeds.data.image.containers._single_size_image_list import _SingleSizeImageList
from safeds.data.labeled.containers import ImageDataset
from safeds.data.labeled.containers._image_dataset import _ColumnAsTensor
from safeds.data.tabular.containers import Column

from ._input_converter_image import _InputConversionImage

if TYPE_CHECKING:
from torch import Tensor

from safeds.data.image.containers import ImageList
from safeds.data.tabular.containers import Column
from safeds.data.tabular.transformation import OneHotEncoder


Expand Down Expand Up @@ -43,9 +43,9 @@ def _data_conversion_output(
output = torch.zeros(len(input_data), len(one_hot_encoder._get_names_of_added_columns()))
output[torch.arange(len(input_data)), output_data] = 1

im_dataset: ImageDataset[Column] = ImageDataset[Column].__new__(ImageDataset)
im_dataset: ImageDataset[Column] = object.__new__(ImageDataset)
im_dataset._output = _ColumnAsTensor._from_tensor(output, column_name, one_hot_encoder)
im_dataset._shuffle_tensor_indices = torch.LongTensor(list(range(len(input_data))))
im_dataset._shuffle_tensor_indices = torch.arange(len(input_data))
im_dataset._shuffle_after_epoch = False
im_dataset._batch_size = 1
im_dataset._next_batch_index = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from safeds.data.image.containers._single_size_image_list import _SingleSizeImageList
from safeds.data.labeled.containers import ImageDataset
from safeds.data.labeled.containers._image_dataset import _TableAsTensor
from safeds.data.tabular.containers import Table

from ._input_converter_image import _InputConversionImage

if TYPE_CHECKING:
from torch import Tensor

from safeds.data.image.containers import ImageList
from safeds.data.tabular.containers import Table


class InputConversionImageToTable(_InputConversionImage):
Expand All @@ -33,9 +33,9 @@ def _data_conversion_output(self, input_data: ImageList, output_data: Tensor) ->
output = torch.zeros(len(input_data), len(column_names))
output[torch.arange(len(input_data)), output_data] = 1

im_dataset: ImageDataset[Table] = ImageDataset[Table].__new__(ImageDataset)
im_dataset: ImageDataset[Table] = object.__new__(ImageDataset)
im_dataset._output = _TableAsTensor._from_tensor(output, column_names)
im_dataset._shuffle_tensor_indices = torch.LongTensor(list(range(len(input_data))))
im_dataset._shuffle_tensor_indices = torch.arange(len(input_data))
im_dataset._shuffle_after_epoch = False
im_dataset._batch_size = 1
im_dataset._next_batch_index = 0
Expand Down
4 changes: 0 additions & 4 deletions src/src/resources/to_csv_file.csv

This file was deleted.

5 changes: 0 additions & 5 deletions src/src/resources/to_json_file.json

This file was deleted.

Binary file removed src/src/resources/to_parquet_file.parquet
Binary file not shown.

0 comments on commit afafd43

Please sign in to comment.