Skip to content

Commit

Permalink
Query fixes (#1487)
Browse files Browse the repository at this point in the history
  • Loading branch information
farizrahman4u authored Feb 14, 2022
1 parent 10a21d5 commit 18cf7d4
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 20 deletions.
1 change: 1 addition & 0 deletions hub/core/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def _lock_lost_handler(self):
always_warn(
"Unable to update dataset lock as another machine has locked it for writing. Switching to read only mode."
)
self._locked_out = True

def __enter__(self):
self._initial_autoflush.append(self.storage.autoflush)
Expand Down
52 changes: 32 additions & 20 deletions hub/core/query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from hub.core.tensor import Tensor


import numpy
import numpy as np

NP_RESULT = Union[numpy.ndarray, List[numpy.ndarray]]

NP_RESULT = Union[np.ndarray, List[np.ndarray]]
NP_ACCESS = Callable[[str], NP_RESULT]


Expand Down Expand Up @@ -127,17 +128,17 @@ def __getitem__(self, item):

@property
def min(self):
"""Returns numpy.min() for the tensor"""
return numpy.amin(self.val)
"""Returns np.min() for the tensor"""
return np.amin(self.val)

@property
def max(self):
"""Returns numpy.max() for the tensor"""
return numpy.amax(self.val)
"""Returns np.max() for the tensor"""
return np.amax(self.val)

@property
def mean(self):
"""Returns numpy.mean() for the tensor"""
"""Returns np.mean() for the tensor"""
return self.val.mean()

@property
Expand All @@ -151,7 +152,13 @@ def size(self):
return self.val.size # type: ignore

def __eq__(self, o: object) -> bool:
return self.val == o
if isinstance(self.val, (list, np.ndarray)):
if isinstance(o, (list, tuple)):
return set(o) == set(self.val)
else:
return o in self.val
else:
return self.val == o

def __lt__(self, o: object) -> bool:
return self.val < o
Expand All @@ -165,9 +172,6 @@ def __gt__(self, o: object) -> bool:
def __ge__(self, o: object) -> bool:
return self.val >= o

def __ne__(self, o: object) -> bool:
return self.val != o

def __mod__(self, o: object):
return self.val % o

Expand All @@ -189,6 +193,9 @@ def __mul__(self, o: object):
def __pow__(self, o: object):
return self.val**o

def __contains__(self, o: object):
return self.contains(o)


class GroupTensor:
def __init__(self, dataset: Dataset, wrappers, prefix: str) -> None:
Expand Down Expand Up @@ -228,11 +235,17 @@ def __init__(self, tensor: Tensor) -> None:
_classes = tensor.info["class_names"] # type: ignore
self._classes_dict = {v: idx for idx, v in enumerate(_classes)}

def __eq__(self, o: object) -> bool:
def _norm_labels(self, o: object):
if isinstance(o, str):
return self.val == self._classes_dict[o]
else:
return self.val == o
return self._classes_dict[o]
elif isinstance(o, int):
return o
elif isinstance(o, (list, tuple)):
return o.__class__(map(self._norm_labels, o))

def __eq__(self, o: object) -> bool:
o = self._norm_labels(o)
return super(ClassLabelsTensor, self).__eq__(o)

def __lt__(self, o: object) -> bool:
if isinstance(o, str):
Expand All @@ -254,8 +267,7 @@ def __ge__(self, o: object) -> bool:
raise ValueError("label class is not comparable")
return self.val >= o

def __ne__(self, o: object) -> bool:
if isinstance(o, str):
return self.val != self._classes_dict[o]
else:
return self.val != o
def contains(self, v: Any):
if isinstance(v, str):
v = self._classes_dict[v]
return super(ClassLabelsTensor, self).contains(v)
25 changes: 25 additions & 0 deletions hub/core/query/test/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,28 @@ def test_group(local_ds):

result = local_ds.filter("labels.t2 == 1", progressbar=False)
assert len(result) == 1


def test_multi_category_labels(local_ds):
ds = local_ds
with ds:
ds.create_tensor("image", htype="image", sample_compression="png")
ds.create_tensor(
"label", htype="class_label", class_names=["cat", "dog", "tree"]
)
r = np.random.randint(50, 100, (32, 32, 3), dtype=np.uint8)
ds.image.append(r)
ds.label.append([0, 1])
ds.image.append(r + 2)
ds.label.append([1, 2])
ds.image.append(r * 2)
ds.label.append([0, 2])
view1 = ds.filter("label == 0")
view2 = ds.filter("label == 'cat'")
view3 = ds.filter("'cat' in label")
view4 = ds.filter("label.contains('cat')")
exp_images = np.array([r, r * 2])
exp_labels = np.array([[0, 1], [0, 2]], dtype=np.uint8)
for v in (view1, view2, view3, view4):
np.testing.assert_array_equal(v.image.numpy(), exp_images)
np.testing.assert_array_equal(v.label.numpy(), exp_labels)

0 comments on commit 18cf7d4

Please sign in to comment.