Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query fixes #1487

Merged
merged 4 commits into from
Feb 14, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 32 additions & 20 deletions hub/core/query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from hub.core.tensor import Tensor


import numpy
import numpy as np

NP_RESULT = Union[numpy.ndarray, List[numpy.ndarray]]

NP_RESULT = Union[np.ndarray, List[np.ndarray]]
NP_ACCESS = Callable[[str], NP_RESULT]


Expand Down Expand Up @@ -127,17 +128,17 @@ def __getitem__(self, item):

@property
def min(self):
"""Returns numpy.min() for the tensor"""
return numpy.amin(self.val)
"""Returns np.min() for the tensor"""
return np.amin(self.val)

@property
def max(self):
"""Returns numpy.max() for the tensor"""
return numpy.amax(self.val)
"""Returns np.max() for the tensor"""
return np.amax(self.val)

@property
def mean(self):
"""Returns numpy.mean() for the tensor"""
"""Returns np.mean() for the tensor"""
return self.val.mean()

@property
Expand All @@ -151,7 +152,13 @@ def size(self):
return self.val.size # type: ignore

def __eq__(self, o: object) -> bool:
return self.val == o
if isinstance(self.val, (list, np.ndarray)):
if isinstance(o, (list, tuple)):
return set(o) == set(self.val)
else:
return o in self.val
else:
return self.val == o

def __lt__(self, o: object) -> bool:
return self.val < o
Expand All @@ -165,9 +172,6 @@ def __gt__(self, o: object) -> bool:
def __ge__(self, o: object) -> bool:
return self.val >= o

def __ne__(self, o: object) -> bool:
return self.val != o

def __mod__(self, o: object):
return self.val % o

Expand All @@ -189,6 +193,9 @@ def __mul__(self, o: object):
def __pow__(self, o: object):
return self.val**o

def __contains__(self, o: object):
return self.contains(o)


class GroupTensor:
def __init__(self, dataset: Dataset, wrappers, prefix: str) -> None:
Expand Down Expand Up @@ -228,11 +235,17 @@ def __init__(self, tensor: Tensor) -> None:
_classes = tensor.info["class_names"] # type: ignore
self._classes_dict = {v: idx for idx, v in enumerate(_classes)}

def __eq__(self, o: object) -> bool:
def _norm_labels(self, o: object):
if isinstance(o, str):
return self.val == self._classes_dict[o]
else:
return self.val == o
return self._classes_dict[o]
elif isinstance(o, int):
return o
elif isinstance(o, (list, tuple)):
return o.__class__(map(self._norm_labels, o))

def __eq__(self, o: object) -> bool:
o = self._norm_labels(o)
return super(ClassLabelsTensor, self).__eq__(o)

def __lt__(self, o: object) -> bool:
if isinstance(o, str):
Expand All @@ -254,8 +267,7 @@ def __ge__(self, o: object) -> bool:
raise ValueError("label class is not comparable")
return self.val >= o

def __ne__(self, o: object) -> bool:
if isinstance(o, str):
return self.val != self._classes_dict[o]
else:
return self.val != o
def contains(self, v: Any):
if isinstance(v, str):
v = self._classes_dict[v]
return super(ClassLabelsTensor, self).contains(v)
25 changes: 25 additions & 0 deletions hub/core/query/test/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,28 @@ def test_group(local_ds):

result = local_ds.filter("labels.t2 == 1", progressbar=False)
assert len(result) == 1


def test_multi_category_labels(local_ds):
ds = local_ds
with ds:
ds.create_tensor("image", htype="image", sample_compression="png")
ds.create_tensor(
"label", htype="class_label", class_names=["cat", "dog", "tree"]
)
r = np.random.randint(50, 100, (32, 32, 3), dtype=np.uint8)
ds.image.append(r)
ds.label.append([0, 1])
ds.image.append(r + 2)
ds.label.append([1, 2])
ds.image.append(r * 2)
ds.label.append([0, 2])
view1 = ds.filter("label == 0")
view2 = ds.filter("label == 'cat'")
view3 = ds.filter("'cat' in label")
view4 = ds.filter("label.contains('cat')")
exp_images = np.array([r, r * 2])
exp_labels = np.array([[0, 1], [0, 2]], dtype=np.uint8)
for v in (view1, view2, view3, view4):
np.testing.assert_array_equal(v.image.numpy(), exp_images)
np.testing.assert_array_equal(v.label.numpy(), exp_labels)