Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add more ops
Browse files Browse the repository at this point in the history
  • Loading branch information
Anirudh Acharya committed Jul 23, 2018
1 parent 65fee98 commit 3a17c61
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 22 deletions.
26 changes: 21 additions & 5 deletions python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
"""Operator attributes conversion"""
from ._op_translations import identity, random_uniform, random_normal
from ._op_translations import add, subtract, multiply, divide, absolute, negative, add_n
from ._op_translations import tanh
from ._op_translations import ceil, floor
from ._op_translations import tanh, arccos, arcsin, arctan, _cos, _sin, _tan
from ._op_translations import softplus, softsign, shape, gather, lp_pooling
from ._op_translations import ceil, floor, hardsigmoid, global_lppooling
from ._op_translations import concat
from ._op_translations import leaky_relu, _elu, _prelu, softmax, fully_connected
from ._op_translations import global_avgpooling, global_maxpooling, linalg_gemm
Expand All @@ -30,12 +31,13 @@
from ._op_translations import reshape, cast, split, _slice, transpose, squeeze, flatten
from ._op_translations import reciprocal, squareroot, power, exponent, _log, unsqueeze
from ._op_translations import reduce_max, reduce_mean, reduce_min, reduce_sum
from ._op_translations import reduce_prod, avg_pooling, max_pooling
from ._op_translations import reduce_prod, avg_pooling, max_pooling, instance_norm
from ._op_translations import argmax, argmin, maximum, minimum
from ._op_translations import clip, reduce_log_sum, reduce_log_sum_exp
from ._op_translations import reduce_sum_square, reduce_l2, max_roi_pooling, instance_norm
from ._op_translations import reduce_sum_square, reduce_l1, reduce_l2, max_roi_pooling
from ._op_translations import log_softmax, softsign, lesser, greater, equal
from ._op_translations import logical_and, logical_or, logical_xor, logical_not
from ._op_translations import mean

# convert_map defines maps of ONNX operator names to converter functor(callable)
# defined in the op_translations module.
Expand Down Expand Up @@ -77,6 +79,7 @@
'FC' : fully_connected,
'GlobalAveragePool' : global_avgpooling,
'GlobalMaxPool' : global_maxpooling,
'GlobalLpPool' : global_lppooling,
'Gemm' : linalg_gemm,
'LRN' : local_response_norm,
'Dropout' : dropout,
Expand Down Expand Up @@ -113,6 +116,7 @@
'ReduceLogSum' : reduce_log_sum,
'ReduceLogSumExp' : reduce_log_sum_exp,
'ReduceSumSquare' : reduce_sum_square,
'ReduceL1' : reduce_l1,
'ReduceL2' : reduce_l2,
'MaxRoiPool' : max_roi_pooling,
'InstanceNormalization' : instance_norm,
Expand All @@ -124,5 +128,17 @@
'And' : logical_and,
'Xor' : logical_xor,
'Not' : logical_not,
'Or' : logical_or
'Or' : logical_or,
'Mean' : mean,
'Acos' : arccos,
'Asin' : arcsin,
'Atan' : arctan,
'Cos' : _cos,
'Sin' : _sin,
'Softplus' : softplus,
'Tan' : _tan,
'Shape' : shape,
'Gather' : gather,
'HardSigmoid' : hardsigmoid,
'LpPool' : lp_pooling
}
75 changes: 74 additions & 1 deletion python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ def divide(attrs, inputs, proto_obj):
return op_value, new_attr, inputs
return 'broadcast_div', new_attr, inputs

def mean(attrs, inputs, proto_obj):
"""Mean of two input tensors."""
concat_input = []
for input in inputs:
concat_input.append(symbol.expand_dims(input, axis=0))
concat_sym = symbol.concat(*concat_input, dim=0)
mean_sym = symbol.mean(concat_sym, axis=0)
return mean_sym, attrs, inputs

def logical_and(attrs, inputs, proto_obj):
"""Logical and of two input arrays."""
return 'broadcast_logical_and', attrs, inputs
Expand Down Expand Up @@ -186,6 +195,10 @@ def sigmoid(attrs, inputs, proto_obj):
"""Computes elementwise sigmoid of the input array"""
return 'sigmoid', attrs, inputs

def hardsigmoid(attrs, inputs, proto_obj):
"""Computes elementwise hard sigmoid of the input array"""
return 'hard_sigmoid', attrs, inputs

def relu(attrs, inputs, proto_obj):
"""Computes rectified linear function."""
return 'relu', attrs, inputs
Expand Down Expand Up @@ -348,6 +361,14 @@ def global_avgpooling(attrs, inputs, proto_obj):
'pool_type': 'avg'})
return 'Pooling', new_attrs, inputs

def global_lppooling(attrs, inputs, proto_obj):
"""Performs global lp pooling on the input."""
p_value = attrs['p'] if 'p' in attrs else 2
new_attrs = translation_utils._add_extra_attributes(attrs, {'global_pool': True,
'kernel': (1, 1),
'pool_type': 'lp',
'p_value': p_value})
return 'Pooling', new_attrs, inputs

def linalg_gemm(attrs, inputs, proto_obj):
"""Performs general matrix multiplication and accumulation"""
Expand Down Expand Up @@ -465,7 +486,6 @@ def unsqueeze(attrs, inputs, cls):

return mxnet_op, attrs, inputs


def flatten(attrs, inputs, proto_obj):
"""Flattens the input array into a 2-D array by collapsing the higher dimensions."""
#Mxnet does not have axis support. By default uses axis=1
Expand All @@ -484,6 +504,10 @@ def clip(attrs, inputs, proto_obj):
new_attrs = translation_utils._add_extra_attributes(new_attrs, {'a_min' : -np.inf})
return 'clip', new_attrs, inputs

def gather(attrs, inputs, proto_obj):
"""Gather elements from an input array along the given axis."""
return 'take', attrs, inputs

#Powers
def reciprocal(attrs, inputs, proto_obj):
"""Returns the reciprocal of the argument, element-wise."""
Expand All @@ -505,6 +529,30 @@ def exponent(attrs, inputs, proto_obj):
"""Elementwise exponent of input array."""
return 'exp', attrs, inputs

def _cos(attrs, inputs, proto_obj):
"""Elementwise cosine of input array."""
return 'cos', attrs, inputs

def _sin(attrs, inputs, proto_obj):
"""Elementwise sine of input array."""
return 'sin', attrs, inputs

def _tan(attrs, inputs, proto_obj):
"""Elementwise tan of input array."""
return 'tan', attrs, inputs

def arccos(attrs, inputs, proto_obj):
"""Elementwise inverse cos of input array."""
return 'arccos', attrs, inputs

def arcsin(attrs, inputs, proto_obj):
"""Elementwise inverse sin of input array."""
return 'arcsin', attrs, inputs

def arctan(attrs, inputs, proto_obj):
"""Elementwise inverse tan of input array."""
return 'arctan', attrs, inputs

def _log(attrs, inputs, proto_obj):
"""Elementwise log of input array."""
return 'log', attrs, inputs
Expand Down Expand Up @@ -559,6 +607,17 @@ def reduce_sum_square(attrs, inputs, proto_obj):
keepdims=attrs.get('keepdims'))
return sum_op, attrs, inputs

def reduce_l1(attrs, inputs, proto_obj):
"""Reduce input tensor by l1 normalization."""
new_attrs = translation_utils._fix_attribute_names(attrs, {'axes':'axis'})
new_attrs = translation_utils._add_extra_attributes(new_attrs,
{'ord' : 1})
return 'norm', new_attrs, inputs

def shape(attrs, inputs, proto_obj):
"""Returns shape of input array."""
return 'shape_array', attrs, inputs

def reduce_l2(attrs, inputs, proto_obj):
"""Reduce input tensor by l2 normalization."""
new_attrs = translation_utils._fix_attribute_names(attrs, {'axes':'axis'})
Expand All @@ -578,6 +637,20 @@ def avg_pooling(attrs, inputs, proto_obj):

return new_op, new_attrs, inputs

def lp_pooling(attrs, inputs, proto_obj):
"""LP Pooling"""
p_value = attrs['p'] if 'p' in attrs else 2
new_attrs = translation_utils._fix_attribute_names(attrs,
{'kernel_shape': 'kernel',
'strides': 'stride',
'pads': 'pad',
'p_value': p_value
})
new_attrs = translation_utils._add_extra_attributes(new_attrs,
{'pooling_convention': 'valid'
})
new_op = translation_utils._fix_pooling('lp', inputs, new_attrs)
return new_op, new_attrs, inputs

def max_pooling(attrs, inputs, proto_obj):
""" Average pooling"""
Expand Down
39 changes: 23 additions & 16 deletions tests/python-pytest/onnx/import/test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@

IMPLEMENTED_OPERATORS_TEST = [
'test_split_equal'
'test_random_uniform',
'test_random_normal',
'test_random_',
'test_add',
'test_sub',
'test_mul',
Expand All @@ -38,41 +37,49 @@
'test_constant_pad',
'test_edge_pad',
'test_reflect_pad',
'test_reduce_min',
'test_reduce_max',
'test_reduce_mean',
'test_reduce_prod',
'test_squeeze',
'test_squeeze_',
'test_unsqueeze',
'test_softmax_example',
'test_softmax_large_number',
'test_softmax_axis_2',
'test_transpose',
'test_globalmaxpool',
'test_globalaveragepool',
'test_global_lppooling',
'test_slice_cpu',
'test_slice_neg',
'test_squeeze_',
'test_reciprocal',
'test_sqrt',
'test_pow',
'test_exp',
'test_argmax',
'test_argmin',
'test_min',
'test_logical_and',
'test_logical_xor',
'test_logical_not',
'test_logical_or',
'test_logical_',
# enabling partial test cases for matmul
'test_matmul_3d',
'test_matmul_4d',
'test_clip',
'test_softsign',
'test_reduce_l2',
'test_reduce_log_sum',
'test_reduce_log_sum_exp',
'test_reduce_sum_square'
'test_reduce_',
'test_softplus',
'test_mean',
'test_acos',
'test_asin',
'test_atan',
'test_cos',
'test_sin',
'test_tan',
'test_shape',
'test_hardsigmoid_',
'test_averagepool_1d',
'test_averagepool_2d_pads_count_include_pad',
'test_averagepool_2d_precomputed_pads_count_include_pad',
'test_averagepool_2d_precomputed_strides',
'test_averagepool_2d_strides',
'test_averagepool_3d',
'test_LpPool_',
'test_instancenorm_epsilon',
#pytorch operator tests
'test_operator_exp',
'test_operator_maxpool',
Expand Down

0 comments on commit 3a17c61

Please sign in to comment.