From b30af815c45f76fcef192b6e87d6ce64f93dd1a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Gr=C3=A9us?= Date: Wed, 10 Jan 2024 14:51:04 +0100 Subject: [PATCH] feat: added _get_device in config/_device.py for the default device (see #524) refactor: rearranged methods in Image --- src/safeds/config/__init__.py | 7 +++ src/safeds/config/_device.py | 6 +++ src/safeds/data/image/containers/_image.py | 56 ++++++++++------------ tests/safeds/config/__init__.py | 0 tests/safeds/config/test_device.py | 10 ++++ 5 files changed, 48 insertions(+), 31 deletions(-) create mode 100644 src/safeds/config/__init__.py create mode 100644 src/safeds/config/_device.py create mode 100644 tests/safeds/config/__init__.py create mode 100644 tests/safeds/config/test_device.py diff --git a/src/safeds/config/__init__.py b/src/safeds/config/__init__.py new file mode 100644 index 000000000..06e9e5ca2 --- /dev/null +++ b/src/safeds/config/__init__.py @@ -0,0 +1,7 @@ +"""Configuration for Safe-DS""" + +from ._device import _get_device + +__all__ = [ + "_get_device", +] diff --git a/src/safeds/config/_device.py b/src/safeds/config/_device.py new file mode 100644 index 000000000..c3fbb3f4b --- /dev/null +++ b/src/safeds/config/_device.py @@ -0,0 +1,6 @@ +import torch +from torch.types import Device + + +def _get_device() -> Device: + return torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") diff --git a/src/safeds/data/image/containers/_image.py b/src/safeds/data/image/containers/_image.py index d948d3920..c0daf03ac 100644 --- a/src/safeds/data/image/containers/_image.py +++ b/src/safeds/data/image/containers/_image.py @@ -10,6 +10,8 @@ from PIL.Image import open as pil_image_open from torch import Tensor +from safeds.config import _get_device + if TYPE_CHECKING: from torch.types import Device import torchvision @@ -35,10 +37,9 @@ class Image: """ _pil_to_tensor = PILToTensor() - _default_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") @staticmethod - def from_file(path: str | Path, device: Device = _default_device) -> Image: + def from_file(path: str | Path, device: Device = _get_device()) -> Image: """ Create an image from a file. @@ -57,7 +58,7 @@ def from_file(path: str | Path, device: Device = _default_device) -> Image: return Image(image_tensor=Image._pil_to_tensor(pil_image_open(path)), device=device) @staticmethod - def from_bytes(data: bytes, device: Device = _default_device) -> Image: + def from_bytes(data: bytes, device: Device = _get_device()) -> Image: """ Create an image from bytes. @@ -81,9 +82,29 @@ def from_bytes(data: bytes, device: Device = _default_device) -> Image: input_tensor = torch.frombuffer(data, dtype=torch.uint8) return Image(image_tensor=torchvision.io.decode_image(input_tensor), device=device) - def __init__(self, image_tensor: Tensor, device: Device = _default_device) -> None: + def __init__(self, image_tensor: Tensor, device: Device = _get_device()) -> None: self._image_tensor: Tensor = image_tensor.to(device) + def __eq__(self, other: object) -> bool: + """ + Compare two images. + + Parameters + ---------- + other: The image to compare to. + + Returns + ------- + equals : bool + Whether the two images contain equal pixel data. + """ + if not isinstance(other, Image): + return NotImplemented + return ( + self._image_tensor.size() == other._image_tensor.size() + and torch.all(torch.eq(self._image_tensor, other._set_device(self.device)._image_tensor)).item() + ) + def _repr_jpeg_(self) -> bytes | None: """ Return a JPEG image as bytes. @@ -209,30 +230,6 @@ def to_png_file(self, path: str | Path) -> None: Path(path).parent.mkdir(parents=True, exist_ok=True) save_image(self._image_tensor.to(torch.float32) / 255, path, format="png") - # ------------------------------------------------------------------------------------------------------------------ - # IPython integration - # ------------------------------------------------------------------------------------------------------------------ - - def __eq__(self, other: object) -> bool: - """ - Compare two images. - - Parameters - ---------- - other: The image to compare to. - - Returns - ------- - equals : bool - Whether the two images contain equal pixel data. - """ - if not isinstance(other, Image): - return NotImplemented - return ( - self._image_tensor.size() == other._image_tensor.size() - and torch.all(torch.eq(self._image_tensor, other._set_device(self.device)._image_tensor)).item() - ) - # ------------------------------------------------------------------------------------------------------------------ # Transformations # ------------------------------------------------------------------------------------------------------------------ @@ -545,6 +542,3 @@ def rotate_left(self) -> Image: The image rotated 90 degrees counter-clockwise. """ return Image(func2.rotate(self._image_tensor, 90, expand=True), device=self.device) - - # def find_edges(self) -> Image: - # pass diff --git a/tests/safeds/config/__init__.py b/tests/safeds/config/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/safeds/config/test_device.py b/tests/safeds/config/test_device.py new file mode 100644 index 000000000..2adeea1c4 --- /dev/null +++ b/tests/safeds/config/test_device.py @@ -0,0 +1,10 @@ +import torch + +from safeds.config import _get_device + + +def test_device() -> None: + if torch.cuda.is_available(): + assert _get_device() == torch.device('cuda') + else: + assert _get_device() == torch.device('cpu')