Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-357] New Scala API Design (Symbol) #10660

Merged
merged 26 commits into from
May 14, 2018

Conversation

lanking520
Copy link
Member

@lanking520 lanking520 commented Apr 23, 2018

Description

See full design document
@nswamy @yzhliu
This PR is the Addition for new Symbol functions of Scala API

Checklist

Essentials

  • Use QuasiQuote to replace original API Implementation (Reduce lines)
  • MakeAtomicFunction change to support String type parameter type in
  • User New namespace for New API (temporarily Symbol.api.function_name)
  • Default args using None to pass in as default
  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

@lanking520 lanking520 requested a review from yzhliu as a code owner April 23, 2018 22:43
)
)
)
val newFunctionDefs = newSymbolFunctions map { symbolfunction =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove if it is not used currently.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be implemented in the next commit

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know, but remove this for now.


private def argumentCleaner(argType : String) : (String, Boolean) = {
val spaceRemoved = argType.replaceAll("\\s+", "")
var commaRemoved : Array[String] = new Array[String](0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like you can write

val commaRemoved = if ... else ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seemed not applicable as I need to change one of the element in the Array

@@ -174,6 +183,10 @@ private[mxnet] object SymbolImplMacros {
println("Symbol function definition:\n" + docStr)
}
// scalastyle:on println
(aliasName, new SymbolFunction(handle, keyVarNumArgs.value))
val argList = argNames zip argTypes map { case (argName, argType) =>
val tup = argumentCleaner(argType)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name the return tuple.

val endIdx = spaceRemoved.indexOf('}')
commaRemoved = spaceRemoved.substring(endIdx + 1).split(",")
// commaRemoved(0) = spaceRemoved.substring(0, endIdx+1)
commaRemoved(0) = "string"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you explain more about the process logic here? I don't quite get the point.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The input can be in the format:
e.g: stype : {'csr', 'default', 'row_sparse'}
In which case we need to get rid of the "{}" and set the type as string. This part is just a part of the data cleaning

commaRemoved = spaceRemoved.split(",")
}
// Optional Field
if (commaRemoved.length >= 3) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 for name and attrs ?

Copy link
Member Author

@lanking520 lanking520 Apr 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are current different format for this, usually:
arg : Type
arg: Type, required
arg: Type, optional, default = Null
The logic here is trying to handle all of these cases

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd better make the pattern clear. For example, do assertion

if (commaRemoved.length >= 3) {
  // arg: Type, optional, default = Null
  require(commaRemoved[1] == "required")
  require(commaRemoved[2].startsWith("default = ")

so that if some other patterns appear one day, we can fail immediately and have chance to fix it before it leaks to public and causes strange error.

val opNames = ListBuffer.empty[String]
_LIB.mxListAllOpNames(opNames)
opNames.map(opName => {
opNames.filter(!_.startsWith("_")).map(opName => {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why filter _ ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for filter _ is to remove the internal function to be compiled. The Documentation for Internal function has not been updated for a long time.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a list of those internal functions? Since it is a broken for api compatibility, we need to review whether it is safe to remove.

btw, it removes _contrib_ as well, you don't mean that, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, not mean't to remove contrib

@lanking520
Copy link
Member Author

lanking520 commented Apr 25, 2018

The recent commit including the new API functions, currently we call them "New". To access them, build the package and call <Function_Name>New and you will be able to access.

In order to merge this PR, Example is required to prove the API will functioned normally. At least one example will be added in this case

@lanking520 lanking520 changed the title [MXNET-357][WIP] New Scala API Design [MXNET-357] New Scala API Design (Symbol) Apr 26, 2018
val funcName = symbolfunction.name
val tName = TermName(funcName)
q"""
@Deprecated
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest to deprecate them later, when the new api is proved to be stable.

Copy link
Member Author

@lanking520 lanking520 May 1, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will remove it in the next commit

impl += "createSymbolGeneral(\"" + symbolfunction.name + "\", name, attr, Seq(), map.toMap)"
// Combine and build the function string
val returnType = "org.apache.mxnet.Symbol"
var finalStr = s"def ${symbolfunction.name}New"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about postfix 'Ex' for 'Expand', which is also consistent with those in c_api.h

Copy link
Member Author

@lanking520 lanking520 May 1, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding postfix, we decided to call the API as Symbol.api.Function name

if (spaceRemoved.charAt(0)== '{') {
val endIdx = spaceRemoved.indexOf('}')
commaRemoved = spaceRemoved.substring(endIdx + 1).split(",")
commaRemoved(0) = "string"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if argType = {'csr', 'default', 'row_sparse'}, then we will do typeConversion("string", "{'csr', 'default', 'row_sparse'}") ?

if so, then these two lines are pretty confusing. why not simply commaRemoved = Array("string")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously I was thinking adding these into the default field, but just found them unnecessary. The reason not doing Array("string") is we are not sure if this arg contains optional field. We need to do a split "," to make sure of that

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding more comments here to avoid misunderstanding

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gotcha

commaRemoved = spaceRemoved.split(",")
}
// Optional Field
if (commaRemoved.length >= 3) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd better make the pattern clear. For example, do assertion

if (commaRemoved.length >= 3) {
  // arg: Type, optional, default = Null
  require(commaRemoved[1] == "required")
  require(commaRemoved[2].startsWith("default = ")

so that if some other patterns appear one day, we can fail immediately and have chance to fix it before it leaks to public and causes strange error.

Copy link
Member

@nswamy nswamy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great Job! 💯 . This is using advanced features and scala macros are really cryptic, I want to encourage you to add lots of comments so someone else do not have to break their head like you had to.

*/
package org.apache.mxnet
@AddNewSymbolFunctions(false)
object NewSymbol {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NewSymbol=> SymbolAPI
?

@@ -830,6 +830,8 @@ object Symbol {
private val functions: Map[String, SymbolFunction] = initSymbolModule()
private val bindReqMap = Map("null" -> 0, "write" -> 1, "add" -> 3)

val api = NewSymbol
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val api : SymbolBase = SymbolAPI

import org.apache.mxnet.init.Base._
import org.apache.mxnet.utils.OperatorBuildUtils

private[mxnet] class AddSymbolFunctions(isContrib: Boolean) extends StaticAnnotation {
private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.addDefs
}

private[mxnet] class AddNewSymbolFunctions(isContrib: Boolean) extends StaticAnnotation {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AddNewSymbolFunctions-> AddSymbolAPIs ?

* limitations under the License.
*/
package org.apache.mxnet
@AddNewSymbolFunctions(false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AddNewSymbolFunctions->GenerateSymbolAPIs ?

}
// scalastyle:off havetype
def addNewDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
impl(c)(false, true, annottees: _*)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please create a new method for generating the new APIs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should keep using impl to avoid duplicated code, since the new API implementation is just a small component in

argDef += "attr : Map[String, String] = null"
// Construct Implementation field
var impl = ListBuffer[String]()
impl += "val map = scala.collection.mutable.Map[String, Any]()"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you not do this in the above loop

private def typeConversion(in : String, argType : String = "") : String = {
in match {
case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape"
case "Symbol" | "NDArray" | "NDArray-or-Symbol" => "org.apache.mxnet.Symbol"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should NDArray also return Symbol?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, they share the same Function name

case "double" => "Double"
case "string" => "String"
case "boolean" => "Boolean"
case "tupleof<float>" => "Any"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why Any? Any will remove type checks

val opNames = ListBuffer.empty[String]
_LIB.mxListAllOpNames(opNames)
opNames.map(opName => {
opNames.filter(op => !op.startsWith("_") || op.startsWith("_contrib_")).map(opName => {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are filtering all operators that start with _ unfortunately the sparse, linear algebra and other operators also start with _. Look at this https://github.com/apache/incubator-mxnet/blob/4fb5241b47c8147690fd6408b55cb694d544656e/python/mxnet/base.py#L455 We need to revisit this again

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, let's add a TODO in here to make sure adding more functions.




private def argumentCleaner(argType : String) : (String, Boolean) = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add comments of what this method is trying to achieve. Also, please add how the structure of the API looks like when extracted from C++ and at the end when you are done cleaning up.

@lanking520
Copy link
Member Author

Now is turn out to be fun. I removed the underscore filter and adding support for underscore function generation. Please kindly review the code parser section and find if there are possible ways to convert some "Any"s to actual Scala types

@lanking520
Copy link
Member Author

In the latest commit, unit test were added for testing Scala API. In order to help init-native/base find the correct path to import the library, a environment variable called BASEDIR were used to help user customize the base directory they need. Unit test specifically focus on the Argument cleaner to make sure it can correctly handle different argument type descriptions

var baseDir = System.getProperty("user.dir") + "/init-native"
if (System.getenv().containsKey("BASEDIR")) {
baseDir = sys.env("BASEDIR")
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are you expecting BASEDIR variable to be?

I think you should have a else
else { baseDir System.getProperty("user.dir") } baseDir = baseDir + "/init-native"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BASEDIR locate in the pom file (set as an environment variable) which determine the current base directory.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for system environment var, better to specify it is mxnet-related, e.g., MXNET_SCALA_MACRO_BASEDIR or something like it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}
// scalastyle:off havetype
def addNewDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
newAPIImpl(c)(annottees: _*)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have a more meaningful name, for example, typeSafeAPI?


// Construct argument field
var argDef = ListBuffer[String]()
symbolfunction.listOfArgs.foreach(symbolarg => {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we combine this with the next foreach, i.e., line 120?

if (spaceRemoved.charAt(0)== '{') {
val endIdx = spaceRemoved.indexOf('}')
commaRemoved = spaceRemoved.substring(endIdx + 1).split(",")
commaRemoved(0) = "string"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gotcha

// Optional Field
if (commaRemoved.length >= 3) {
// arg: Type, optional, default = Null
require(commaRemoved(1).equals("optional"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just remind, better to use == in scala. == behaves the same as equals in Java, and equals in Scala behaves the same as == in Java... sigh... Since String is immutable, equals here is fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood, will note this.


override def afterAll(): Unit = {

}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be removed?

*/
package org.apache.mxnet
@AddSymbolAPIs(false)
object SymbolAPI {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment a bit for this placeholder.

@@ -37,7 +37,11 @@ object Base {

@throws(classOf[UnsatisfiedLinkError])
private def tryLoadInitLibrary(): Unit = {
val baseDir = System.getProperty("user.dir") + "/init-native"
// val baseDir = System.getProperty("user.dir") + "/init-native"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this line.

var baseDir = System.getProperty("user.dir") + "/init-native"
if (System.getenv().containsKey("BASEDIR")) {
baseDir = sys.env("BASEDIR")
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for system environment var, better to specify it is mxnet-related, e.g., MXNET_SCALA_MACRO_BASEDIR or something like it.

Copy link
Member

@nswamy nswamy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor comments, please address them. we can merge it.

@@ -37,7 +37,10 @@ object Base {

@throws(classOf[UnsatisfiedLinkError])
private def tryLoadInitLibrary(): Unit = {
val baseDir = System.getProperty("user.dir") + "/init-native"
var baseDir = System.getProperty("user.dir") + "/init-native"
if (System.getenv().containsKey("MXNET_SCALA_MACRO_BASEDIR")) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we change this to MXNET_BASEDIR and we can append the macro dir location relative to it. I don't think the user needs to understand what macros are or find out macro dir.

<artifactId>scalatest-maven-plugin</artifactId>
<configuration>
<environmentVariables>
<MXNET_SCALA_MACRO_BASEDIR>${project.parent.basedir}/init-native</MXNET_SCALA_MACRO_BASEDIR>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MXNET_BASEDIR

if (isContrib) symbolFunctions.filter(_._1.startsWith("_contrib_"))
else symbolFunctions.filter(!_._1.startsWith("_contrib_"))
if (isContrib) symbolFunctions.filter(
func => func.name.startsWith("_contrib_") || !func.name.startsWith("_"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking of how this flag would be used, one possible use-case is we pull all the contrib APIs into a separate Object to make it more explicit as in ?
Symbol.api.foo --> Stable APIs
Symbol.contrib.bar--> Contrib APIs

We can do this as a separate PR, thoughts?

if (isContrib) funcName.substring(funcName.indexOf("_contrib_") + "_contrib_".length())
else funcName

val functionDefs = newSymbolFunctions map { symbolfunction =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this code way better than the old, but I think we introducing new way of generating code for old and new APIs, if there is any bug that we haven't caught it will break both old and new APIs. I think its a risk.


val newSymbolFunctions = {
if (isContrib) symbolFunctions.filter(
func => func.name.startsWith("_contrib_") || !func.name.startsWith("_"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above, thoughts? can be done as a separate PR

argDef += "name : String = null"
argDef += "attr : Map[String, String] = null"
// scalastyle:off
impl += "org.apache.mxnet.Symbol.createSymbolGeneral(\"" + symbolfunction.name + "\", name, attr, Seq(), map.toMap)"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make a note to Seq() when we deprecate old APIs

lanking520 and others added 2 commits May 11, 2018 17:15
Add relative path to MXNET_BASEDIR
val baseDir = System.getProperty("user.dir") + "/init-native"
var baseDir = System.getProperty("user.dir") + "/init-native"
if (System.getenv().containsKey("MXNET_BASEDIR")) {
baseDir = sys.env("MXNET_BASEDIR") + "/scala-package/init-native"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just FYI, I updated this line from
baseDir = sys.env("MXNET_BASEDIR")
to
baseDir = sys.env("MXNET_BASEDIR") + "/scala-package/init-native"

@nswamy nswamy merged commit b011ecc into apache:master May 14, 2018
@lanking520 lanking520 deleted the scala-macros-impl branch May 18, 2018 22:47
jinhuang415 pushed a commit to jinhuang415/incubator-mxnet that referenced this pull request May 29, 2018
* Simplfied current Macros impl to Quasiquote
* Change the Symbol Function Field, add SymbolArg
* Fix the Macros problem, disable the hidden function _
* Add Implementation for New API
* Add examples and comments
* Add _contrib_ support
* New namespace for Symbol API
* Change names and add comments
* add TODOs and name changes
* Add relative path to MXNET_BASEDIR
* Update Base.scala
rahul003 pushed a commit to rahul003/mxnet that referenced this pull request Jun 4, 2018
* Simplfied current Macros impl to Quasiquote
* Change the Symbol Function Field, add SymbolArg
* Fix the Macros problem, disable the hidden function _
* Add Implementation for New API
* Add examples and comments
* Add _contrib_ support
* New namespace for Symbol API
* Change names and add comments
* add TODOs and name changes
* Add relative path to MXNET_BASEDIR
* Update Base.scala
zheng-da pushed a commit to zheng-da/incubator-mxnet that referenced this pull request Jun 28, 2018
* Simplfied current Macros impl to Quasiquote
* Change the Symbol Function Field, add SymbolArg
* Fix the Macros problem, disable the hidden function _
* Add Implementation for New API
* Add examples and comments
* Add _contrib_ support
* New namespace for Symbol API
* Change names and add comments
* add TODOs and name changes
* Add relative path to MXNET_BASEDIR
* Update Base.scala
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants