diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index 216e886180ca..93e839933c7a 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -643,7 +643,7 @@ unittest_ubuntu_cpu_scala() {
unittest_ubuntu_gpu_scala() {
set -ex
- make scalapkg USE_OPENCV=1 USE_BLAS=openblas USE_CUDA=1 USE_CUDA_PATH=/usr/local/cuda USE_CUDNN=1 USE_DIST_KVSTORE=1
+ make scalapkg USE_OPENCV=1 USE_BLAS=openblas USE_CUDA=1 USE_CUDA_PATH=/usr/local/cuda USE_CUDNN=1 USE_DIST_KVSTORE=1 SCALA_ON_GPU=1
make scalatest USE_OPENCV=1 USE_BLAS=openblas USE_CUDA=1 USE_CUDA_PATH=/usr/local/cuda USE_CUDNN=1 SCALA_TEST_ON_GPU=1 USE_DIST_KVSTORE=1
}
diff --git a/scala-package/core/pom.xml b/scala-package/core/pom.xml
index 361bfab5d611..3b1b051f60b1 100644
--- a/scala-package/core/pom.xml
+++ b/scala-package/core/pom.xml
@@ -65,6 +65,26 @@
org.scalastyle
scalastyle-maven-plugin
+
+ org.scalastyle
+ scalastyle-maven-plugin
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+ 3.3.2
+
+
+
+
+ package
+ attach-javadocs
+
+ doc-jar
+
+
+
+
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index 49f4d35136f8..c2de6ea43f2c 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -28,10 +28,11 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.ref.WeakReference
/**
- * NDArray API of mxnet
- */
+ * NDArray Object extends from NDArrayBase for abstract function signatures
+ * Main code will be generated during compile time through Macros
+ */
@AddNDArrayFunctions(false)
-object NDArray {
+object NDArray extends NDArrayBase {
implicit def getFirstResult(ret: NDArrayFuncReturn): NDArray = ret(0)
private val logger = LoggerFactory.getLogger(classOf[NDArray])
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
index a17fe57dde65..194d3681523f 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
@@ -822,8 +822,12 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotD
}
}
+/**
+ * Symbol Object extends from SymbolBase for abstract function signatures
+ * Main code will be generated during compile time through Macros
+ */
@AddSymbolFunctions(false)
-object Symbol {
+object Symbol extends SymbolBase {
private type SymbolCreateNamedFunc = Map[String, Any] => Symbol
private val logger = LoggerFactory.getLogger(classOf[Symbol])
private val functions: Map[String, SymbolFunction] = initSymbolModule()
diff --git a/scala-package/infer/pom.xml b/scala-package/infer/pom.xml
index 13d3cc1387e0..208d19ee9ce8 100644
--- a/scala-package/infer/pom.xml
+++ b/scala-package/infer/pom.xml
@@ -65,6 +65,22 @@
org.scalastyle
scalastyle-maven-plugin
+
+ net.alchim31.maven
+ scala-maven-plugin
+ 3.3.2
+
+
+
+
+ package
+ attach-javadocs
+
+ doc-jar
+
+
+
+
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
index 3bbc7fd6a90b..9a8ec645f272 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
@@ -19,8 +19,11 @@ package org.apache.mxnet
import org.apache.mxnet.init.Base._
import org.apache.mxnet.utils.CToScalaUtils
+import java.io._
+import java.security.MessageDigest
import scala.collection.mutable.ListBuffer
+import scala.io.Source
/**
* This object will generate the Scala documentation of the new Scala API
@@ -35,15 +38,25 @@ private[mxnet] object APIDocGenerator{
def main(args: Array[String]) : Unit = {
val FILE_PATH = args(0)
- absClassGen(FILE_PATH, true)
- absClassGen(FILE_PATH, false)
+ val hashCollector = ListBuffer[String]()
+ hashCollector += absClassGen(FILE_PATH, true)
+ hashCollector += absClassGen(FILE_PATH, false)
+ hashCollector += nonTypeSafeClassGen(FILE_PATH, true)
+ hashCollector += nonTypeSafeClassGen(FILE_PATH, false)
+ val finalHash = hashCollector.mkString("\n")
}
- def absClassGen(FILE_PATH : String, isSymbol : Boolean) : Unit = {
+ def MD5Generator(input : String) : String = {
+ val md = MessageDigest.getInstance("MD5")
+ md.update(input.getBytes("UTF-8"))
+ val digest = md.digest()
+ org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest)
+ }
+
+ def absClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
// scalastyle:off
val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
- // TODO: Add Filter to the same location in case of refactor
- val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_")).map(absClassFunction => {
+ val absFuncs = absClassFunctions.map(absClassFunction => {
val scalaDoc = generateAPIDocFromBackend(absClassFunction)
val defBody = generateAPISignature(absClassFunction, isSymbol)
s"$scalaDoc\n$defBody"
@@ -55,16 +68,44 @@ private[mxnet] object APIDocGenerator{
val imports = "import org.apache.mxnet.annotation.Experimental"
val absClassDef = s"abstract class $packageName"
val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}"
+ val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
+ pw.write(finalStr)
+ pw.close()
+ MD5Generator(finalStr)
+ }
+
+ def nonTypeSafeClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
+ // scalastyle:off
+ val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
+ val absFuncs = absClassFunctions.map(absClassFunction => {
+ val scalaDoc = generateAPIDocFromBackend(absClassFunction, false)
+ if (isSymbol) {
+ val defBody = s"def ${absClassFunction.name}(name : String = null, attr : Map[String, String] = null)(args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null): org.apache.mxnet.Symbol"
+ s"$scalaDoc\n$defBody"
+ } else {
+ val defBodyWithKwargs = s"def ${absClassFunction.name}(kwargs: Map[String, Any] = null)(args: Any*) : org.apache.mxnet.NDArrayFuncReturn"
+ val defBody = s"def ${absClassFunction.name}(args: Any*) : org.apache.mxnet.NDArrayFuncReturn"
+ s"$scalaDoc\n$defBodyWithKwargs\n$scalaDoc\n$defBody"
+ }
+ })
+ val packageName = if (isSymbol) "SymbolBase" else "NDArrayBase"
+ val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n"
+ val scalaStyle = "// scalastyle:off"
+ val packageDef = "package org.apache.mxnet"
+ val imports = "import org.apache.mxnet.annotation.Experimental"
+ val absClassDef = s"abstract class $packageName"
+ val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}"
import java.io._
val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
pw.write(finalStr)
pw.close()
+ MD5Generator(finalStr)
}
// Generate ScalaDoc type
- def generateAPIDocFromBackend(func : absClassFunction) : String = {
+ def generateAPIDocFromBackend(func : absClassFunction, withParam : Boolean = true) : String = {
val desc = func.desc.split("\n").map({ currStr =>
- s" * $currStr"
+ s" * $currStr
"
})
val params = func.listOfArgs.map({ absClassArg =>
val currArgName = absClassArg.argName match {
@@ -75,7 +116,11 @@ private[mxnet] object APIDocGenerator{
s" * @param $currArgName\t\t${absClassArg.argDesc}"
})
val returnType = s" * @return ${func.returnType}"
- s" /**\n${desc.mkString("\n")}\n${params.mkString("\n")}\n$returnType\n */"
+ if (withParam) {
+ s" /**\n${desc.mkString("\n")}\n${params.mkString("\n")}\n$returnType\n */"
+ } else {
+ s" /**\n${desc.mkString("\n")}\n$returnType\n */"
+ }
}
def generateAPISignature(func : absClassFunction, isSymbol : Boolean) : String = {
@@ -112,11 +157,12 @@ private[mxnet] object APIDocGenerator{
val returnType = if (isSymbol) "Symbol" else "NDArray"
_LIB.mxListAllOpNames(opNames)
// TODO: Add '_linalg_', '_sparse_', '_image_' support
+ // TODO: Add Filter to the same location in case of refactor
opNames.map(opName => {
val opHandle = new RefLong
_LIB.nnGetOpHandle(opName, opHandle)
makeAtomicSymbolFunction(opHandle.value, opName, "org.apache.mxnet." + returnType)
- }).toList
+ }).toList.filterNot(_.name.startsWith("_"))
}
// Create an atomic symbol function by handle and function name.
@@ -136,7 +182,7 @@ private[mxnet] object APIDocGenerator{
val realName = if (aliasName == name.value) "" else s"(a.k.a., ${name.value})"
val argList = argNames zip argTypes zip argDescs map { case ((argName, argType), argDesc) =>
- val typeAndOption = CToScalaUtils.argumentCleaner(argType, returnType)
+ val typeAndOption = CToScalaUtils.argumentCleaner(argName, argType, returnType)
new absClassArg(argName, typeAndOption._1, argDesc, typeAndOption._2)
}
new absClassFunction(aliasName, desc.value, argList.toList, returnType)
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
index 082c64a609c3..644bc5c4489d 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
@@ -225,7 +225,8 @@ private[mxnet] object NDArrayMacro {
}
// scalastyle:on println
val argList = argNames zip argTypes map { case (argName, argType) =>
- val typeAndOption = CToScalaUtils.argumentCleaner(argType, "org.apache.mxnet.NDArray")
+ val typeAndOption =
+ CToScalaUtils.argumentCleaner(argName, argType, "org.apache.mxnet.NDArray")
new NDArrayArg(argName, typeAndOption._1, typeAndOption._2)
}
new NDArrayFunction(aliasName, argList.toList)
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
index 81430c2ab263..3e790ef4126b 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
@@ -230,7 +230,8 @@ private[mxnet] object SymbolImplMacros {
}
// scalastyle:on println
val argList = argNames zip argTypes map { case (argName, argType) =>
- val typeAndOption = CToScalaUtils.argumentCleaner(argType, "org.apache.mxnet.Symbol")
+ val typeAndOption =
+ CToScalaUtils.argumentCleaner(argName, argType, "org.apache.mxnet.Symbol")
new SymbolArg(argName, typeAndOption._1, typeAndOption._2)
}
new SymbolFunction(aliasName, argList.toList)
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
index ca50a741012b..b07e6f97eee5 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
@@ -21,7 +21,8 @@ private[mxnet] object CToScalaUtils {
// Convert C++ Types to Scala Types
- def typeConversion(in : String, argType : String = "", returnType : String) : String = {
+ def typeConversion(in : String, argType : String = "",
+ argName : String, returnType : String) : String = {
in match {
case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape"
case "Symbol" | "NDArray" | "NDArray-or-Symbol" => returnType
@@ -35,7 +36,7 @@ private[mxnet] object CToScalaUtils {
case "boolean" | "booleanorNone" => "Boolean"
case "tupleof" | "tupleof" | "tupleof<>" | "ptr" | "" => "Any"
case default => throw new IllegalArgumentException(
- s"Invalid type for args: $default, $argType")
+ s"Invalid type for args: $default\nString argType: $argType\nargName: $argName")
}
}
@@ -47,10 +48,12 @@ private[mxnet] object CToScalaUtils {
* The three field shown above do not usually come at the same time
* This function used the above format to determine if the argument is
* optional, what is it Scala type and possibly pass in a default value
+ * @param argName The name of the argument
* @param argType Raw arguement Type description
* @return (Scala_Type, isOptional)
*/
- def argumentCleaner(argType : String, returnType : String) : (String, Boolean) = {
+ def argumentCleaner(argName: String,
+ argType : String, returnType : String) : (String, Boolean) = {
val spaceRemoved = argType.replaceAll("\\s+", "")
var commaRemoved : Array[String] = new Array[String](0)
// Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'}
@@ -66,9 +69,9 @@ private[mxnet] object CToScalaUtils {
// arg: Type, optional, default = Null
require(commaRemoved(1).equals("optional"))
require(commaRemoved(2).startsWith("default="))
- (typeConversion(commaRemoved(0), argType, returnType), true)
+ (typeConversion(commaRemoved(0), argType, argName, returnType), true)
} else if (commaRemoved.length == 2 || commaRemoved.length == 1) {
- val tempType = typeConversion(commaRemoved(0), argType, returnType)
+ val tempType = typeConversion(commaRemoved(0), argType, argName, returnType)
val tempOptional = tempType.equals("org.apache.mxnet.Symbol")
(tempType, tempOptional)
} else {
diff --git a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala
index 5883a00c3315..c3a7c58c1afc 100644
--- a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala
+++ b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala
@@ -43,7 +43,7 @@ class MacrosSuite extends FunSuite with BeforeAndAfterAll {
)
for (idx <- input.indices) {
- val result = CToScalaUtils.argumentCleaner(input(idx), "org.apache.mxnet.Symbol")
+ val result = CToScalaUtils.argumentCleaner("Sample", input(idx), "org.apache.mxnet.Symbol")
assert(result._1 === output(idx)._1 && result._2 === output(idx)._2)
}
}
diff --git a/scala-package/pom.xml b/scala-package/pom.xml
index cd5dba85dfd5..043aaae5e9e3 100644
--- a/scala-package/pom.xml
+++ b/scala-package/pom.xml
@@ -340,6 +340,22 @@
+
+ net.alchim31.maven
+ scala-maven-plugin
+ 3.3.2
+
+
+
+
+ package
+ attach-javadocs
+
+ doc-jar
+
+
+
+