diff --git a/docs/source/api_doc/collection/sequence.rst b/docs/source/api_doc/collection/sequence.rst index 27199746eba..19b9b538768 100644 --- a/docs/source/api_doc/collection/sequence.rst +++ b/docs/source/api_doc/collection/sequence.rst @@ -12,4 +12,9 @@ unique .. autofunction:: unique +group_by +--------------------------- + +.. autofunction:: group_by + diff --git a/hbutils/collection/sequence.py b/hbutils/collection/sequence.py index 2dfd6336508..318c8dc141b 100644 --- a/hbutils/collection/sequence.py +++ b/hbutils/collection/sequence.py @@ -1,7 +1,8 @@ -from typing import Union, TypeVar, Sequence +from typing import Union, TypeVar, Sequence, Callable, Optional, Dict, List, Iterable __all__ = [ 'unique', + 'group_by', ] _ElementType = TypeVar('_ElementType') @@ -32,3 +33,52 @@ def unique(s: Union[Sequence[_ElementType]]) -> Sequence[_ElementType]: _set.add(element) return type(s)(_result) + + +_GroupType = TypeVar('_GroupType') +_ResultType = TypeVar('_ResultType') + + +def group_by(s: Iterable[_ElementType], + key: Callable[[_ElementType], _GroupType], + gfunc: Optional[Callable[[List[_ElementType]], _ResultType]] = None) -> Dict[_GroupType, _ResultType]: + """ + Overview: + Divide the elements into groups. + + :param s: Elements. + :param key: Group key, should be a callable object. + :param gfunc: Post-process function for groups, should be a callable object. Default is ``None`` which means \ + no post-processing will be performed. + :return: Grouping result. + + Examples:: + >>> from hbutils.collection import group_by + >>> + >>> foods = [ + ... 'apple', 'orange', 'pear', + ... 'banana', 'fish', 'pork', 'milk', + ... ] + >>> group_by(foods, len) # group by length + {5: ['apple'], 6: ['orange', 'banana'], 4: ['pear', 'fish', 'pork', 'milk']} + >>> group_by(foods, len, len) # group and get length + {5: 1, 6: 2, 4: 4} + >>> group_by(foods, lambda x: x[0]) # group by first letter + {'a': ['apple'], 'o': ['orange'], 'p': ['pear', 'pork'], 'b': ['banana'], 'f': ['fish'], 'm': ['milk']} + >>> group_by(foods, lambda x: x[0], len) # group and get length + {'a': 1, 'o': 1, 'p': 2, 'b': 1, 'f': 1, 'm': 1} + """ + + gfunc = gfunc or (lambda x: x) + + _result_dict: Dict[_GroupType, List[_ElementType]] = {} + for item in s: + _item_key = key(item) + if _item_key not in _result_dict: + _result_dict[_item_key] = [] + _result_dict[_item_key].append(item) + + return { + key: gfunc(grps) + for key, grps in _result_dict.items() + } diff --git a/test/collection/test_sequence.py b/test/collection/test_sequence.py index b23bb2c1b38..ee9aee80fe9 100644 --- a/test/collection/test_sequence.py +++ b/test/collection/test_sequence.py @@ -1,6 +1,6 @@ import pytest -from hbutils.collection import unique +from hbutils.collection import unique, group_by @pytest.mark.unittest @@ -24,3 +24,33 @@ class MyList(list): r4 = unique(MyList([3, 1, 2, 1, 4, 3])) assert type(r4) == MyList assert r4 == MyList([3, 1, 2, 4]) + + def test_group_by(self): + foods = [ + 'apple', + 'orange', + 'pear', + 'banana', + 'fish', + 'pork', + 'milk' + ] + assert group_by(foods, len) == { + 4: ['pear', 'fish', 'pork', 'milk'], + 5: ['apple'], + 6: ['orange', 'banana'] + } + assert group_by(foods, len, len) == {4: 4, 5: 1, 6: 2} + + assert group_by(foods, lambda x: x[0]) == { + 'a': ['apple'], + 'b': ['banana'], + 'f': ['fish'], + 'm': ['milk'], + 'o': ['orange'], + 'p': ['pear', 'pork'] + } + assert group_by(foods, lambda x: x[0], len) == { + 'a': 1, 'b': 1, 'f': 1, + 'm': 1, 'o': 1, 'p': 2, + }