Skip to content

Commit

Permalink
feat: added _get_device in config/_device.py for the default device (…
Browse files Browse the repository at this point in the history
…see #524)

refactor: rearranged methods in Image
  • Loading branch information
Marsmaennchen221 committed Jan 10, 2024
1 parent e2234b1 commit b30af81
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 31 deletions.
7 changes: 7 additions & 0 deletions src/safeds/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Configuration for Safe-DS"""

from ._device import _get_device

__all__ = [
"_get_device",
]
6 changes: 6 additions & 0 deletions src/safeds/config/_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import torch
from torch.types import Device


def _get_device() -> Device:
return torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
56 changes: 25 additions & 31 deletions src/safeds/data/image/containers/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from PIL.Image import open as pil_image_open
from torch import Tensor

from safeds.config import _get_device

if TYPE_CHECKING:
from torch.types import Device
import torchvision
Expand All @@ -35,10 +37,9 @@ class Image:
"""

_pil_to_tensor = PILToTensor()
_default_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

@staticmethod
def from_file(path: str | Path, device: Device = _default_device) -> Image:
def from_file(path: str | Path, device: Device = _get_device()) -> Image:
"""
Create an image from a file.
Expand All @@ -57,7 +58,7 @@ def from_file(path: str | Path, device: Device = _default_device) -> Image:
return Image(image_tensor=Image._pil_to_tensor(pil_image_open(path)), device=device)

@staticmethod
def from_bytes(data: bytes, device: Device = _default_device) -> Image:
def from_bytes(data: bytes, device: Device = _get_device()) -> Image:
"""
Create an image from bytes.
Expand All @@ -81,9 +82,29 @@ def from_bytes(data: bytes, device: Device = _default_device) -> Image:
input_tensor = torch.frombuffer(data, dtype=torch.uint8)
return Image(image_tensor=torchvision.io.decode_image(input_tensor), device=device)

def __init__(self, image_tensor: Tensor, device: Device = _default_device) -> None:
def __init__(self, image_tensor: Tensor, device: Device = _get_device()) -> None:
self._image_tensor: Tensor = image_tensor.to(device)

def __eq__(self, other: object) -> bool:
"""
Compare two images.
Parameters
----------
other: The image to compare to.
Returns
-------
equals : bool
Whether the two images contain equal pixel data.
"""
if not isinstance(other, Image):
return NotImplemented
return (
self._image_tensor.size() == other._image_tensor.size()
and torch.all(torch.eq(self._image_tensor, other._set_device(self.device)._image_tensor)).item()
)

def _repr_jpeg_(self) -> bytes | None:
"""
Return a JPEG image as bytes.
Expand Down Expand Up @@ -209,30 +230,6 @@ def to_png_file(self, path: str | Path) -> None:
Path(path).parent.mkdir(parents=True, exist_ok=True)
save_image(self._image_tensor.to(torch.float32) / 255, path, format="png")

# ------------------------------------------------------------------------------------------------------------------
# IPython integration
# ------------------------------------------------------------------------------------------------------------------

def __eq__(self, other: object) -> bool:
"""
Compare two images.
Parameters
----------
other: The image to compare to.
Returns
-------
equals : bool
Whether the two images contain equal pixel data.
"""
if not isinstance(other, Image):
return NotImplemented
return (
self._image_tensor.size() == other._image_tensor.size()
and torch.all(torch.eq(self._image_tensor, other._set_device(self.device)._image_tensor)).item()
)

# ------------------------------------------------------------------------------------------------------------------
# Transformations
# ------------------------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -545,6 +542,3 @@ def rotate_left(self) -> Image:
The image rotated 90 degrees counter-clockwise.
"""
return Image(func2.rotate(self._image_tensor, 90, expand=True), device=self.device)

# def find_edges(self) -> Image:
# pass
Empty file added tests/safeds/config/__init__.py
Empty file.
10 changes: 10 additions & 0 deletions tests/safeds/config/test_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import torch

from safeds.config import _get_device


def test_device() -> None:
if torch.cuda.is_available():
assert _get_device() == torch.device('cuda')
else:
assert _get_device() == torch.device('cpu')

0 comments on commit b30af81

Please sign in to comment.