From e46afbe3c6b5e7e5ab86b23c0dd18075f5e0be53 Mon Sep 17 00:00:00 2001 From: Lanking Date: Mon, 14 May 2018 01:54:53 -0700 Subject: [PATCH] [MXNET-357] New Scala API Design (Symbol) (#10660) * Simplfied current Macros impl to Quasiquote * Change the Symbol Function Field, add SymbolArg * Fix the Macros problem, disable the hidden function _ * Add Implementation for New API * Add examples and comments * Add _contrib_ support * New namespace for Symbol API * Change names and add comments * add TODOs and name changes * Add relative path to MXNET_BASEDIR * Update Base.scala --- .../main/scala/org/apache/mxnet/Symbol.scala | 2 + .../scala/org/apache/mxnet/SymbolAPI.scala | 26 ++ .../imclassification/TrainMnist.scala | 42 +-- .../scala/org/apache/mxnet/init/Base.scala | 7 +- scala-package/macros/pom.xml | 38 +++ .../scala/org/apache/mxnet/SymbolMacro.scala | 240 +++++++++++++----- .../src/test/resources/log4j.properties | 24 ++ .../scala/org/apache/mxnet/MacrosSuite.scala | 50 ++++ 8 files changed, 340 insertions(+), 89 deletions(-) create mode 100644 scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala create mode 100644 scala-package/macros/src/test/resources/log4j.properties create mode 100644 scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala index 13f85a731dc4..60efd2ba62bd 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala @@ -830,6 +830,8 @@ object Symbol { private val functions: Map[String, SymbolFunction] = initSymbolModule() private val bindReqMap = Map("null" -> 0, "write" -> 1, "add" -> 3) + val api = SymbolAPI + def pow(sym1: Symbol, sym2: Symbol): Symbol = { Symbol.createFromListedSymbols("_Power")(Array(sym1, sym2)) } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala b/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala new file mode 100644 index 000000000000..49de9ae73218 --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mxnet + + +@AddSymbolAPIs(false) +/** + * typesafe Symbol API: Symbol.api._ + * Main code will be generated during compile time through Macros + */ +object SymbolAPI { +} diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala index d1ec88d67c6b..e9171bd47c28 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala @@ -30,40 +30,40 @@ object TrainMnist { // multi-layer perceptron def getMlp: Symbol = { val data = Symbol.Variable("data") - val fc1 = Symbol.FullyConnected(name = "fc1")()(Map("data" -> data, "num_hidden" -> 128)) - val act1 = Symbol.Activation(name = "relu1")()(Map("data" -> fc1, "act_type" -> "relu")) - val fc2 = Symbol.FullyConnected(name = "fc2")()(Map("data" -> act1, "num_hidden" -> 64)) - val act2 = Symbol.Activation(name = "relu2")()(Map("data" -> fc2, "act_type" -> "relu")) - val fc3 = Symbol.FullyConnected(name = "fc3")()(Map("data" -> act2, "num_hidden" -> 10)) - val mlp = Symbol.SoftmaxOutput(name = "softmax")()(Map("data" -> fc3)) + + val fc1 = Symbol.api.FullyConnected(data = Some(data), num_hidden = 128, name = "fc1") + val act1 = Symbol.api.Activation (data = Some(fc1), "relu", name = "relu") + val fc2 = Symbol.api.FullyConnected(Some(act1), None, None, 64, name = "fc2") + val act2 = Symbol.api.Activation(data = Some(fc2), "relu", name = "relu2") + val fc3 = Symbol.api.FullyConnected(Some(act2), None, None, 10, name = "fc3") + val mlp = Symbol.api.SoftmaxOutput(name = "softmax", data = Some(fc3)) mlp } // LeCun, Yann, Leon Bottou, Yoshua Bengio, and Patrick // Haffner. "Gradient-based learning applied to document recognition." // Proceedings of the IEEE (1998) + def getLenet: Symbol = { val data = Symbol.Variable("data") // first conv - val conv1 = Symbol.Convolution()()( - Map("data" -> data, "kernel" -> "(5, 5)", "num_filter" -> 20)) - val tanh1 = Symbol.Activation()()(Map("data" -> conv1, "act_type" -> "tanh")) - val pool1 = Symbol.Pooling()()(Map("data" -> tanh1, "pool_type" -> "max", - "kernel" -> "(2, 2)", "stride" -> "(2, 2)")) + val conv1 = Symbol.api.Convolution(data = Some(data), kernel = Shape(5, 5), num_filter = 20) + val tanh1 = Symbol.api.tanh(data = Some(conv1)) + val pool1 = Symbol.api.Pooling(data = Some(tanh1), pool_type = Some("max"), + kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2))) // second conv - val conv2 = Symbol.Convolution()()( - Map("data" -> pool1, "kernel" -> "(5, 5)", "num_filter" -> 50)) - val tanh2 = Symbol.Activation()()(Map("data" -> conv2, "act_type" -> "tanh")) - val pool2 = Symbol.Pooling()()(Map("data" -> tanh2, "pool_type" -> "max", - "kernel" -> "(2, 2)", "stride" -> "(2, 2)")) + val conv2 = Symbol.api.Convolution(data = Some(pool1), kernel = Shape(5, 5), num_filter = 50) + val tanh2 = Symbol.api.tanh(data = Some(conv2)) + val pool2 = Symbol.api.Pooling(data = Some(tanh2), pool_type = Some("max"), + kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2))) // first fullc - val flatten = Symbol.Flatten()()(Map("data" -> pool2)) - val fc1 = Symbol.FullyConnected()()(Map("data" -> flatten, "num_hidden" -> 500)) - val tanh3 = Symbol.Activation()()(Map("data" -> fc1, "act_type" -> "tanh")) + val flatten = Symbol.api.Flatten(data = Some(pool2)) + val fc1 = Symbol.api.FullyConnected(data = Some(flatten), num_hidden = 500) + val tanh3 = Symbol.api.tanh(data = Some(fc1)) // second fullc - val fc2 = Symbol.FullyConnected()()(Map("data" -> tanh3, "num_hidden" -> 10)) + val fc2 = Symbol.api.FullyConnected(data = Some(tanh3), num_hidden = 10) // loss - val lenet = Symbol.SoftmaxOutput(name = "softmax")()(Map("data" -> fc2)) + val lenet = Symbol.api.SoftmaxOutput(name = "softmax", data = Some(fc2)) lenet } diff --git a/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala b/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala index 7af2e052255c..7402dbd3bc1d 100644 --- a/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala +++ b/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala @@ -37,7 +37,12 @@ object Base { @throws(classOf[UnsatisfiedLinkError]) private def tryLoadInitLibrary(): Unit = { - val baseDir = System.getProperty("user.dir") + "/init-native" + var baseDir = System.getProperty("user.dir") + "/init-native" + // TODO(lanKing520) Update this to use relative path to the MXNet director. + // TODO(lanking520) baseDir = sys.env("MXNET_BASEDIR") + "/scala-package/init-native" + if (System.getenv().containsKey("MXNET_BASEDIR")) { + baseDir = sys.env("MXNET_BASEDIR") + } val os = System.getProperty("os.name") // ref: http://lopica.sourceforge.net/os.html if (os.startsWith("Linux")) { diff --git a/scala-package/macros/pom.xml b/scala-package/macros/pom.xml index 0aa3030e7ce3..59cc181bd360 100644 --- a/scala-package/macros/pom.xml +++ b/scala-package/macros/pom.xml @@ -52,4 +52,42 @@ ${libtype} + + + + + org.apache.maven.plugins + maven-jar-plugin + + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + org.apache.maven.plugins + maven-compiler-plugin + + + org.scalatest + scalatest-maven-plugin + + + ${project.parent.basedir}/init-native + + + -Djava.library.path=${project.parent.basedir}/native/${platform}/target \ + -Dlog4j.configuration=file://${project.basedir}/src/test/resources/log4j.properties + + + + + org.scalastyle + scalastyle-maven-plugin + + + + diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index b6ddaafc7ad7..234a8604cb91 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -21,7 +21,6 @@ import scala.annotation.StaticAnnotation import scala.collection.mutable.ListBuffer import scala.language.experimental.macros import scala.reflect.macros.blackbox - import org.apache.mxnet.init.Base._ import org.apache.mxnet.utils.OperatorBuildUtils @@ -29,18 +28,29 @@ private[mxnet] class AddSymbolFunctions(isContrib: Boolean) extends StaticAnnota private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.addDefs } +private[mxnet] class AddSymbolAPIs(isContrib: Boolean) extends StaticAnnotation { + private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.typeSafeAPIDefs +} + private[mxnet] object SymbolImplMacros { - case class SymbolFunction(handle: SymbolHandle, keyVarNumArgs: String) + case class SymbolArg(argName: String, argType: String, isOptional : Boolean) + case class SymbolFunction(name: String, listOfArgs: List[SymbolArg]) // scalastyle:off havetype def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { - impl(c)(false, annottees: _*) + impl(c)(annottees: _*) } - // scalastyle:off havetype + def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { + newAPIImpl(c)(annottees: _*) + } + // scalastyle:on havetype - private val symbolFunctions: Map[String, SymbolFunction] = initSymbolModule() + private val symbolFunctions: List[SymbolFunction] = initSymbolModule() - private def impl(c: blackbox.Context)(addSuper: Boolean, annottees: c.Expr[Any]*): c.Expr[Any] = { + /** + * Implementation for fixed input API structure + */ + private def impl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { import c.universe._ val isContrib: Boolean = c.prefix.tree match { @@ -48,74 +58,106 @@ private[mxnet] object SymbolImplMacros { } val newSymbolFunctions = { - if (isContrib) symbolFunctions.filter(_._1.startsWith("_contrib_")) - else symbolFunctions.filter(!_._1.startsWith("_contrib_")) + if (isContrib) symbolFunctions.filter( + func => func.name.startsWith("_contrib_") || !func.name.startsWith("_")) + else symbolFunctions.filter(!_.name.startsWith("_")) } - val AST_TYPE_MAP_STRING_ANY = AppliedTypeTree(Ident(TypeName("Map")), - List(Ident(TypeName("String")), Ident(TypeName("Any")))) - val AST_TYPE_MAP_STRING_STRING = AppliedTypeTree(Ident(TypeName("Map")), - List(Ident(TypeName("String")), Ident(TypeName("String")))) - val AST_TYPE_SYMBOL_VARARG = AppliedTypeTree( - Select( - Select(Ident(termNames.ROOTPKG), TermName("scala")), - TypeName("") - ), - List(Select(Select(Select( - Ident(TermName("org")), TermName("apache")), TermName("mxnet")), TypeName("Symbol"))) - ) - - val functionDefs = newSymbolFunctions map { case (funcName, funcProp) => - val functionScope = { - if (isContrib) Modifiers() - else { - if (funcName.startsWith("_")) Modifiers(Flag.PRIVATE) else Modifiers() - } - } - val newName = { - if (isContrib) funcName.substring(funcName.indexOf("_contrib_") + "_contrib_".length()) - else funcName + + val functionDefs = newSymbolFunctions map { symbolfunction => + val funcName = symbolfunction.name + val tName = TermName(funcName) + q""" + def $tName(name : String = null, attr : Map[String, String] = null) + (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null) + : org.apache.mxnet.Symbol = { + createSymbolGeneral($funcName,name,attr,args,kwargs) + } + """.asInstanceOf[DefDef] } - // It will generate definition something like, - // def Concat(name: String = null, attr: Map[String, String] = null) - // (args: Symbol*)(kwargs: Map[String, Any] = null) - DefDef(functionScope, TermName(newName), List(), - List( - List( - ValDef(Modifiers(Flag.PARAM | Flag.DEFAULTPARAM), TermName("name"), - Ident(TypeName("String")), Literal(Constant(null))), - ValDef(Modifiers(Flag.PARAM | Flag.DEFAULTPARAM), TermName("attr"), - AST_TYPE_MAP_STRING_STRING, Literal(Constant(null))) - ), - List( - ValDef(Modifiers(), TermName("args"), AST_TYPE_SYMBOL_VARARG, EmptyTree) - ), - List( - ValDef(Modifiers(Flag.PARAM | Flag.DEFAULTPARAM), TermName("kwargs"), - AST_TYPE_MAP_STRING_ANY, Literal(Constant(null))) - ) - ), TypeTree(), - Apply( - Ident(TermName("createSymbolGeneral")), - List( - Literal(Constant(funcName)), - Ident(TermName("name")), - Ident(TermName("attr")), - Ident(TermName("args")), - Ident(TermName("kwargs")) - ) - ) - ) + structGeneration(c)(functionDefs, annottees : _*) + } + + /** + * Implementation for Dynamic typed API Symbol.api. + */ + private def newAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : c.Expr[Any] = { + import c.universe._ + + val isContrib: Boolean = c.prefix.tree match { + case q"new AddSymbolAPIs($b)" => c.eval[Boolean](c.Expr(b)) + } + + // TODO: Put Symbol.api.foo --> Stable APIs + // Symbol.contrib.bar--> Contrib APIs + val newSymbolFunctions = { + if (isContrib) symbolFunctions.filter( + func => func.name.startsWith("_contrib_") || !func.name.startsWith("_")) + else symbolFunctions.filter(!_.name.startsWith("_")) + } + + val functionDefs = newSymbolFunctions map { symbolfunction => + + // Construct argument field + var argDef = ListBuffer[String]() + // Construct Implementation field + var impl = ListBuffer[String]() + impl += "val map = scala.collection.mutable.Map[String, Any]()" + symbolfunction.listOfArgs.foreach({ symbolarg => + // var is a special word used to define variable in Scala, + // need to changed to something else in order to make it work + val currArgName = symbolarg.argName match { + case "var" => "vari" + case "type" => "typeOf" + case default => symbolarg.argName + } + if (symbolarg.isOptional) { + argDef += s"${currArgName} : Option[${symbolarg.argType}] = None" + } + else { + argDef += s"${currArgName} : ${symbolarg.argType}" + } + var base = "map(\"" + symbolarg.argName + "\") = " + currArgName + if (symbolarg.isOptional) { + base = "if (!" + currArgName + ".isEmpty)" + base + ".get" + } + impl += base + }) + argDef += "name : String = null" + argDef += "attr : Map[String, String] = null" + // scalastyle:off + // TODO: Seq() here allows user to place Symbols rather than normal arguments to run, need to fix if old API deprecated + impl += "org.apache.mxnet.Symbol.createSymbolGeneral(\"" + symbolfunction.name + "\", name, attr, Seq(), map.toMap)" + // scalastyle:on + // Combine and build the function string + val returnType = "org.apache.mxnet.Symbol" + var finalStr = s"def ${symbolfunction.name}" + finalStr += s" (${argDef.mkString(",")}) : $returnType" + finalStr += s" = {${impl.mkString("\n")}}" + c.parse(finalStr).asInstanceOf[DefDef] } + structGeneration(c)(functionDefs, annottees : _*) + } + /** + * Generate class structure for all function APIs + * @param c + * @param funcDef DefDef type of function definitions + * @param annottees + * @return + */ + private def structGeneration(c: blackbox.Context) + (funcDef : List[c.universe.DefDef], annottees: c.Expr[Any]*) + : c.Expr[Any] = { + import c.universe._ val inputs = annottees.map(_.tree).toList // pattern match on the inputs val modDefs = inputs map { case ClassDef(mods, name, something, template) => val q = template match { case Template(superMaybe, emptyValDef, defs) => - Template(superMaybe, emptyValDef, defs ++ functionDefs) + Template(superMaybe, emptyValDef, defs ++ funcDef) case ex => throw new IllegalArgumentException(s"Invalid template: $ex") } @@ -123,7 +165,7 @@ private[mxnet] object SymbolImplMacros { case ModuleDef(mods, name, template) => val q = template match { case Template(superMaybe, emptyValDef, defs) => - Template(superMaybe, emptyValDef, defs ++ functionDefs) + Template(superMaybe, emptyValDef, defs ++ funcDef) case ex => throw new IllegalArgumentException(s"Invalid template: $ex") } @@ -136,20 +178,80 @@ private[mxnet] object SymbolImplMacros { result } + // Convert C++ Types to Scala Types + def typeConversion(in : String, argType : String = "") : String = { + in match { + case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape" + case "Symbol" | "NDArray" | "NDArray-or-Symbol" => "org.apache.mxnet.Symbol" + case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]" + => "Array[org.apache.mxnet.Symbol]" + case "float" | "real_t" | "floatorNone" => "org.apache.mxnet.Base.MXFloat" + case "int" | "intorNone" | "int(non-negative)" => "Int" + case "long" | "long(non-negative)" => "Long" + case "double" | "doubleorNone" => "Double" + case "string" => "String" + case "boolean" => "Boolean" + case "tupleof" | "tupleof" | "ptr" | "" => "Any" + case default => throw new IllegalArgumentException( + s"Invalid type for args: $default, $argType") + } + } + + + /** + * By default, the argType come from the C++ API is a description more than a single word + * For Example: + * , , + * The three field shown above do not usually come at the same time + * This function used the above format to determine if the argument is + * optional, what is it Scala type and possibly pass in a default value + * @param argType Raw arguement Type description + * @return (Scala_Type, isOptional) + */ + def argumentCleaner(argType : String) : (String, Boolean) = { + val spaceRemoved = argType.replaceAll("\\s+", "") + var commaRemoved : Array[String] = new Array[String](0) + // Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'} + if (spaceRemoved.charAt(0)== '{') { + val endIdx = spaceRemoved.indexOf('}') + commaRemoved = spaceRemoved.substring(endIdx + 1).split(",") + commaRemoved(0) = "string" + } else { + commaRemoved = spaceRemoved.split(",") + } + // Optional Field + if (commaRemoved.length >= 3) { + // arg: Type, optional, default = Null + require(commaRemoved(1).equals("optional")) + require(commaRemoved(2).startsWith("default=")) + (typeConversion(commaRemoved(0), argType), true) + } else if (commaRemoved.length == 2 || commaRemoved.length == 1) { + val tempType = typeConversion(commaRemoved(0), argType) + val tempOptional = tempType.equals("org.apache.mxnet.Symbol") + (tempType, tempOptional) + } else { + throw new IllegalArgumentException( + s"Unrecognized arg field: $argType, ${commaRemoved.length}") + } + + } + + // List and add all the atomic symbol functions to current module. - private def initSymbolModule(): Map[String, SymbolFunction] = { + private def initSymbolModule(): List[SymbolFunction] = { val opNames = ListBuffer.empty[String] _LIB.mxListAllOpNames(opNames) + // TODO: Add '_linalg_', '_sparse_', '_image_' support opNames.map(opName => { val opHandle = new RefLong _LIB.nnGetOpHandle(opName, opHandle) makeAtomicSymbolFunction(opHandle.value, opName) - }).toMap + }).toList } // Create an atomic symbol function by handle and function name. private def makeAtomicSymbolFunction(handle: SymbolHandle, aliasName: String) - : (String, SymbolFunction) = { + : SymbolFunction = { val name = new RefString val desc = new RefString val keyVarNumArgs = new RefString @@ -174,6 +276,10 @@ private[mxnet] object SymbolImplMacros { println("Symbol function definition:\n" + docStr) } // scalastyle:on println - (aliasName, new SymbolFunction(handle, keyVarNumArgs.value)) + val argList = argNames zip argTypes map { case (argName, argType) => + val typeAndOption = argumentCleaner(argType) + new SymbolArg(argName, typeAndOption._1, typeAndOption._2) + } + new SymbolFunction(aliasName, argList.toList) } } diff --git a/scala-package/macros/src/test/resources/log4j.properties b/scala-package/macros/src/test/resources/log4j.properties new file mode 100644 index 000000000000..d82fd7ea4f3d --- /dev/null +++ b/scala-package/macros/src/test/resources/log4j.properties @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# for development debugging +log4j.rootLogger = debug, stdout + +log4j.appender.stdout = org.apache.log4j.ConsoleAppender +log4j.appender.stdout.Target = System.out +log4j.appender.stdout.layout = org.apache.log4j.PatternLayout +log4j.appender.stdout.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss,SSS} [%t] [%c] [%p] - %m%n diff --git a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala new file mode 100644 index 000000000000..bc8be7df5fb1 --- /dev/null +++ b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet + +import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.slf4j.LoggerFactory + +class MacrosSuite extends FunSuite with BeforeAndAfterAll { + + private val logger = LoggerFactory.getLogger(classOf[MacrosSuite]) + + + test("MacrosSuite-testArgumentCleaner") { + val input = List( + "Symbol, optional, default = Null", + "int, required", + "Shape(tuple), optional, default = []", + "{'csr', 'default', 'row_sparse'}, optional, default = 'csr'", + ", required" + ) + val output = List( + ("org.apache.mxnet.Symbol", true), + ("Int", false), + ("org.apache.mxnet.Shape", true), + ("String", true), + ("Any", false) + ) + + for (idx <- input.indices) { + val result = SymbolImplMacros.argumentCleaner(input(idx)) + assert(result._1 === output(idx)._1 && result._2 === output(idx)._2) + } + } + +}