diff --git a/src/operator/tensor/histogram-inl.h b/src/operator/tensor/histogram-inl.h index 18ae86249d9e..08620e841725 100644 --- a/src/operator/tensor/histogram-inl.h +++ b/src/operator/tensor/histogram-inl.h @@ -112,18 +112,14 @@ inline bool HistogramOpType(const nnvm::NodeAttrs& attrs, } template -void HistogramForwardImpl(mshadow::Stream* s, - const OpContext& ctx, - const nnvm::NodeAttrs& attrs, +void HistogramForwardImpl(const OpContext& ctx, const TBlob& in_data, const TBlob& bin_bounds, const TBlob& out_data, const TBlob& out_bins); template -void HistogramForwardImpl(mshadow::Stream* s, - const OpContext& ctx, - const nnvm::NodeAttrs& attrs, +void HistogramForwardImpl(const OpContext& ctx, const TBlob& in_data, const TBlob& out_data, const TBlob& out_bins, @@ -146,7 +142,6 @@ void HistogramOpForward(const nnvm::NodeAttrs& attrs, const bool legal_params = (has_cnt && has_range) || (!has_cnt && !has_range); CHECK(legal_params) << "width and range should both or neither be specified"; - mshadow::Stream *s = ctx.get_stream(); const TBlob& in_data = inputs[0]; const TBlob& out_data = outputs[0]; const TBlob& out_bins = outputs[1]; @@ -164,10 +159,10 @@ void HistogramOpForward(const nnvm::NodeAttrs& attrs, max += 0.5f; LOG(INFO) << min << " " << max; } - HistogramForwardImpl(s, ctx, attrs, in_data, out_data, out_bins, bin_cnt, min, max); + HistogramForwardImpl(ctx, in_data, out_data, out_bins, bin_cnt, min, max); } else { const TBlob& bin_bounds = inputs[1]; - HistogramForwardImpl(s, ctx, attrs, in_data, bin_bounds, out_data, out_bins); + HistogramForwardImpl(ctx, in_data, bin_bounds, out_data, out_bins); } } diff --git a/src/operator/tensor/histogram.cc b/src/operator/tensor/histogram.cc index e3bda5926c31..855441847799 100644 --- a/src/operator/tensor/histogram.cc +++ b/src/operator/tensor/histogram.cc @@ -63,16 +63,15 @@ void ComputeHistogram(const int* bin_indices, CType* out_data, size_t input_size } } -template -void HistogramForwardImpl(mshadow::Stream* s, - const OpContext& ctx, - const nnvm::NodeAttrs& attrs, - const TBlob& in_data, - const TBlob& bin_bounds, - const TBlob& out_data, - const TBlob& out_bins) { +template<> +void HistogramForwardImpl(const OpContext& ctx, + const TBlob& in_data, + const TBlob& bin_bounds, + const TBlob& out_data, + const TBlob& out_bins) { using namespace mshadow; using namespace mxnet_op; + mshadow::Stream *s = ctx.get_stream(); Tensor bin_indices = ctx.requested[0].get_space_typed(Shape1(in_data.Size()), s); const int bin_cnt = out_data.Size(); @@ -90,18 +89,17 @@ void HistogramForwardImpl(mshadow::Stream* s, }); } -template -void HistogramForwardImpl(mshadow::Stream* s, - const OpContext& ctx, - const nnvm::NodeAttrs& attrs, - const TBlob& in_data, - const TBlob& out_data, - const TBlob& out_bins, - const int bin_cnt, - const double min, - const double max) { +template<> +void HistogramForwardImpl(const OpContext& ctx, + const TBlob& in_data, + const TBlob& out_data, + const TBlob& out_bins, + const int bin_cnt, + const double min, + const double max) { using namespace mshadow; using namespace mxnet_op; + mshadow::Stream *s = ctx.get_stream(); Tensor bin_indices = ctx.requested[0].get_space_typed(Shape1(in_data.Size()), s); @@ -149,10 +147,6 @@ Example:: .set_attr("FInferShape", HistogramOpShape) .set_attr("FInferType", HistogramOpType) .set_attr("FCompute", HistogramOpForward) -.set_attr("FInplaceOption", - [](const NodeAttrs& attrs) { - return std::vector >{}; - }) .add_argument("data", "NDArray-or-Symbol", "Input ndarray") .add_argument("bins", "NDArray-or-Symbol", "Input ndarray") .add_arguments(HistogramParam::__FIELDS__()); diff --git a/src/operator/tensor/histogram.cu b/src/operator/tensor/histogram.cu index 58174227875a..c3c836a4b498 100644 --- a/src/operator/tensor/histogram.cu +++ b/src/operator/tensor/histogram.cu @@ -58,16 +58,15 @@ struct HistogramFusedKernel { } }; -template -void HistogramForwardImpl(mshadow::Stream* s, - const OpContext& ctx, - const nnvm::NodeAttrs& attrs, - const TBlob& in_data, - const TBlob& bin_bounds, - const TBlob& out_data, - const TBlob& out_bins) { +template<> +void HistogramForwardImpl(const OpContext& ctx, + const TBlob& in_data, + const TBlob& bin_bounds, + const TBlob& out_data, + const TBlob& out_bins) { using namespace mshadow; using namespace mxnet_op; + mshadow::Stream *s = ctx.get_stream(); MSHADOW_TYPE_SWITCH(in_data.type_flag_, DType, { MSHADOW_IDX_TYPE_SWITCH(out_data.type_flag_, CType, { int bin_cnt = out_bins.Size() - 1; @@ -81,18 +80,17 @@ void HistogramForwardImpl(mshadow::Stream* s, }); } -template -void HistogramForwardImpl(mshadow::Stream* s, - const OpContext& ctx, - const nnvm::NodeAttrs& attrs, - const TBlob& in_data, - const TBlob& out_data, - const TBlob& out_bins, - const int bin_cnt, - const double min, - const double max) { +template<> +void HistogramForwardImpl(const OpContext& ctx, + const TBlob& in_data, + const TBlob& out_data, + const TBlob& out_bins, + const int bin_cnt, + const double min, + const double max) { using namespace mshadow; using namespace mxnet_op; + mshadow::Stream *s = ctx.get_stream(); MSHADOW_TYPE_SWITCH(in_data.type_flag_, DType, { MSHADOW_IDX_TYPE_SWITCH(out_data.type_flag_, CType, { Kernel::Launch(s, bin_cnt, out_data.dptr());