Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add symbol and symbolic forward check for histogram
Browse files Browse the repository at this point in the history
  • Loading branch information
Hao Jin committed Jun 22, 2018
1 parent 3b68540 commit 8bc8e18
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 5 deletions.
4 changes: 2 additions & 2 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3754,7 +3754,7 @@ def histogram(a, bins=10, range=None):
If bins is an int, it defines the number of equal-width bins in the
given range (10, by default). If bins is a sequence, it defines the bin edges,
including the rightmost edge, allowing for non-uniform bin widths.
range_ : (float, float), optional
range : (float, float), optional
The lower and upper range of the bins. If not provided, range is simply (a.min(), a.max()).
Values outside the range are ignored. The first element of the range must be less than or
equal to the second. range affects the automatic bin computation as well, the range will
Expand All @@ -3766,7 +3766,7 @@ def histogram(a, bins=10, range=None):
return _internal._histogram(data=a, bins=bins)
elif isinstance(bins, integer_types):
if range is None:
warnings.warn("range_ is not specified, using numpy's result "
warnings.warn("range is not specified, using numpy's result "
"to ensure consistency with numpy")
res, bin_bounds = np.histogram(a.asnumpy(), bins=bins)
return array(res), array(bin_bounds)
Expand Down
30 changes: 28 additions & 2 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from ..attribute import AttrScope
from ..base import _LIB, numeric_types, c_array, c_array_buf, c_str, c_str_array, c_handle_array
from ..base import mx_uint, py_str, string_types
from ..base import mx_uint, py_str, string_types, integer_types
from ..base import NDArrayHandle, ExecutorHandle, SymbolHandle
from ..base import check_call, MXNetError, NotImplementedForSymbol
from ..context import Context, current_context
Expand All @@ -47,7 +47,8 @@
from ._internal import SymbolBase, _set_symbol_class

__all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json",
"pow", "maximum", "minimum", "hypot", "eye", "zeros", "ones", "full", "arange"]
"pow", "maximum", "minimum", "hypot", "eye", "zeros", "ones", "full", "arange",
"histogram"]


class Symbol(SymbolBase):
Expand Down Expand Up @@ -2862,4 +2863,29 @@ def arange(start, stop=None, step=1.0, repeat=1, name=None, dtype=None):
return _internal._arange(start=start, stop=stop, step=step, repeat=repeat,
name=name, dtype=dtype)

def histogram(a, bins=10, range=None, **kwargs):
"""Compute the histogram of the input data.
Parameters
----------
a : NDArray
Input data. The histogram is computed over the flattened array.
bins : int or sequence of scalars
If bins is an int, it defines the number of equal-width bins in the
given range (10, by default). If bins is a sequence, it defines the bin edges,
including the rightmost edge, allowing for non-uniform bin widths.
range : (float, float), required if bins is an integer
The lower and upper range of the bins. If not provided, range is simply (a.min(), a.max()).
Values outside the range are ignored. The first element of the range must be less than or
equal to the second. range affects the automatic bin computation as well, the range will
be equally divided by the number of bins.
"""
if isinstance(bins, Symbol):
return _internal._histogram(data=a, bins=bins, **kwargs)
elif isinstance(bins, integer_types):
if range is None:
raise ValueError("null range is not supported in symbol mode")
return _internal._histogram(data=a, bin_cnt=bins, range=range, **kwargs)
raise ValueError("bins argument should be either an integer or an NDArray")

_set_symbol_class(Symbol)
5 changes: 4 additions & 1 deletion src/operator/tensor/histogram.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,10 @@ Example::
.set_num_outputs(2)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "bins"};
const HistogramParam& params = nnvm::get<HistogramParam>(attrs.parsed);
return params.bin_cnt.has_value() ?
std::vector<std::string>{"data"} :
std::vector<std::string>{"data", "bins"};
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
Expand Down
12 changes: 12 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6024,6 +6024,18 @@ def f(x, bins=10, range=None):
assert_almost_equal(mx_histo2.asnumpy(), np_histo2, rtol=1e-3, atol=1e-5)
assert_almost_equal(mx_bins2.asnumpy(), np_bins2, rtol=1e-3, atol=1e-5)

data = mx.sym.Variable("data")

bins = mx.sym.Variable("bins")
histo1 = mx.sym.histogram(a=data, bins=bin_cnt, range=bin_range)
histo2 = mx.sym.histogram(a=data, bins=bins)
executor1 = histo1.bind(ctx=default_context(), args={"data" : x})
executor1.forward(is_train=False)
assert_almost_equal(np_histo1, executor1.outputs[0].asnumpy(), 0, 0, ("EXPECTED_histo1", "FORWARD_histo1"), equal_nan=False)
executor2 = histo2.bind(ctx=default_context(), args={"data" : x, "bins" : mx_bins})
executor2.forward(is_train=False)
assert_almost_equal(np_histo2, executor2.outputs[0].asnumpy(), 0, 0, ("EXPECTED_histo2", "FORWARD_histo2"), equal_nan=False)


def test_op_output_names_monitor():
def check_name(op_sym, expected_names):
Expand Down

0 comments on commit 8bc8e18

Please sign in to comment.