-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-357] New Scala API Design (Symbol) #10660
Conversation
) | ||
) | ||
) | ||
val newFunctionDefs = newSymbolFunctions map { symbolfunction => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove if it is not used currently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will be implemented in the next commit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know, but remove this for now.
|
||
private def argumentCleaner(argType : String) : (String, Boolean) = { | ||
val spaceRemoved = argType.replaceAll("\\s+", "") | ||
var commaRemoved : Array[String] = new Array[String](0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like you can write
val commaRemoved = if ... else ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seemed not applicable as I need to change one of the element in the Array
@@ -174,6 +183,10 @@ private[mxnet] object SymbolImplMacros { | |||
println("Symbol function definition:\n" + docStr) | |||
} | |||
// scalastyle:on println | |||
(aliasName, new SymbolFunction(handle, keyVarNumArgs.value)) | |||
val argList = argNames zip argTypes map { case (argName, argType) => | |||
val tup = argumentCleaner(argType) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
name the return tuple.
val endIdx = spaceRemoved.indexOf('}') | ||
commaRemoved = spaceRemoved.substring(endIdx + 1).split(",") | ||
// commaRemoved(0) = spaceRemoved.substring(0, endIdx+1) | ||
commaRemoved(0) = "string" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you explain more about the process logic here? I don't quite get the point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The input can be in the format:
e.g: stype : {'csr', 'default', 'row_sparse'}
In which case we need to get rid of the "{}" and set the type as string. This part is just a part of the data cleaning
commaRemoved = spaceRemoved.split(",") | ||
} | ||
// Optional Field | ||
if (commaRemoved.length >= 3) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 for name
and attrs
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are current different format for this, usually:
arg : Type
arg: Type, required
arg: Type, optional, default = Null
The logic here is trying to handle all of these cases
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'd better make the pattern clear. For example, do assertion
if (commaRemoved.length >= 3) {
// arg: Type, optional, default = Null
require(commaRemoved[1] == "required")
require(commaRemoved[2].startsWith("default = ")
so that if some other patterns appear one day, we can fail immediately and have chance to fix it before it leaks to public and causes strange error.
val opNames = ListBuffer.empty[String] | ||
_LIB.mxListAllOpNames(opNames) | ||
opNames.map(opName => { | ||
opNames.filter(!_.startsWith("_")).map(opName => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why filter _
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason for filter _ is to remove the internal function to be compiled. The Documentation for Internal function has not been updated for a long time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have a list of those internal functions? Since it is a broken for api compatibility, we need to review whether it is safe to remove.
btw, it removes _contrib_
as well, you don't mean that, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, not mean't to remove contrib
The recent commit including the new API functions, currently we call them "New". To access them, build the package and call <Function_Name>New and you will be able to access. In order to merge this PR, Example is required to prove the API will functioned normally. At least one example will be added in this case |
val funcName = symbolfunction.name | ||
val tName = TermName(funcName) | ||
q""" | ||
@Deprecated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest to deprecate them later, when the new api is proved to be stable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will remove it in the next commit
impl += "createSymbolGeneral(\"" + symbolfunction.name + "\", name, attr, Seq(), map.toMap)" | ||
// Combine and build the function string | ||
val returnType = "org.apache.mxnet.Symbol" | ||
var finalStr = s"def ${symbolfunction.name}New" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about postfix 'Ex' for 'Expand', which is also consistent with those in c_api.h
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of adding postfix, we decided to call the API as Symbol.api.Function name
if (spaceRemoved.charAt(0)== '{') { | ||
val endIdx = spaceRemoved.indexOf('}') | ||
commaRemoved = spaceRemoved.substring(endIdx + 1).split(",") | ||
commaRemoved(0) = "string" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if argType = {'csr', 'default', 'row_sparse'}
, then we will do typeConversion("string", "{'csr', 'default', 'row_sparse'}")
?
if so, then these two lines are pretty confusing. why not simply commaRemoved = Array("string")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously I was thinking adding these into the default field, but just found them unnecessary. The reason not doing Array("string") is we are not sure if this arg contains optional field. We need to do a split "," to make sure of that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding more comments here to avoid misunderstanding
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gotcha
commaRemoved = spaceRemoved.split(",") | ||
} | ||
// Optional Field | ||
if (commaRemoved.length >= 3) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'd better make the pattern clear. For example, do assertion
if (commaRemoved.length >= 3) {
// arg: Type, optional, default = Null
require(commaRemoved[1] == "required")
require(commaRemoved[2].startsWith("default = ")
so that if some other patterns appear one day, we can fail immediately and have chance to fix it before it leaks to public and causes strange error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great Job! 💯 . This is using advanced features and scala macros are really cryptic, I want to encourage you to add lots of comments so someone else do not have to break their head like you had to.
*/ | ||
package org.apache.mxnet | ||
@AddNewSymbolFunctions(false) | ||
object NewSymbol { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NewSymbol=> SymbolAPI
?
@@ -830,6 +830,8 @@ object Symbol { | |||
private val functions: Map[String, SymbolFunction] = initSymbolModule() | |||
private val bindReqMap = Map("null" -> 0, "write" -> 1, "add" -> 3) | |||
|
|||
val api = NewSymbol |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
val api : SymbolBase = SymbolAPI
import org.apache.mxnet.init.Base._ | ||
import org.apache.mxnet.utils.OperatorBuildUtils | ||
|
||
private[mxnet] class AddSymbolFunctions(isContrib: Boolean) extends StaticAnnotation { | ||
private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.addDefs | ||
} | ||
|
||
private[mxnet] class AddNewSymbolFunctions(isContrib: Boolean) extends StaticAnnotation { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AddNewSymbolFunctions-> AddSymbolAPIs ?
* limitations under the License. | ||
*/ | ||
package org.apache.mxnet | ||
@AddNewSymbolFunctions(false) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AddNewSymbolFunctions->GenerateSymbolAPIs ?
} | ||
// scalastyle:off havetype | ||
def addNewDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { | ||
impl(c)(false, true, annottees: _*) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please create a new method for generating the new APIs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should keep using impl to avoid duplicated code, since the new API implementation is just a small component in
argDef += "attr : Map[String, String] = null" | ||
// Construct Implementation field | ||
var impl = ListBuffer[String]() | ||
impl += "val map = scala.collection.mutable.Map[String, Any]()" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you not do this in the above loop
private def typeConversion(in : String, argType : String = "") : String = { | ||
in match { | ||
case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape" | ||
case "Symbol" | "NDArray" | "NDArray-or-Symbol" => "org.apache.mxnet.Symbol" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should NDArray also return Symbol?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, they share the same Function name
case "double" => "Double" | ||
case "string" => "String" | ||
case "boolean" => "Boolean" | ||
case "tupleof<float>" => "Any" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why Any? Any will remove type checks
val opNames = ListBuffer.empty[String] | ||
_LIB.mxListAllOpNames(opNames) | ||
opNames.map(opName => { | ||
opNames.filter(op => !op.startsWith("_") || op.startsWith("_contrib_")).map(opName => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are filtering all operators that start with _
unfortunately the sparse, linear algebra and other operators also start with _
. Look at this https://github.com/apache/incubator-mxnet/blob/4fb5241b47c8147690fd6408b55cb694d544656e/python/mxnet/base.py#L455 We need to revisit this again
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, let's add a TODO in here to make sure adding more functions.
|
||
|
||
|
||
private def argumentCleaner(argType : String) : (String, Boolean) = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add comments of what this method is trying to achieve. Also, please add how the structure of the API looks like when extracted from C++ and at the end when you are done cleaning up.
Now is turn out to be fun. I removed the underscore filter and adding support for underscore function generation. Please kindly review the code parser section and find if there are possible ways to convert some "Any"s to actual Scala types |
In the latest commit, unit test were added for testing Scala API. In order to help init-native/base find the correct path to import the library, a environment variable called BASEDIR were used to help user customize the base directory they need. Unit test specifically focus on the Argument cleaner to make sure it can correctly handle different argument type descriptions |
var baseDir = System.getProperty("user.dir") + "/init-native" | ||
if (System.getenv().containsKey("BASEDIR")) { | ||
baseDir = sys.env("BASEDIR") | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what are you expecting BASEDIR variable to be?
I think you should have a else
else { baseDir System.getProperty("user.dir") } baseDir = baseDir + "/init-native"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BASEDIR locate in the pom file (set as an environment variable) which determine the current base directory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for system environment var, better to specify it is mxnet-related, e.g., MXNET_SCALA_MACRO_BASEDIR or something like it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
} | ||
// scalastyle:off havetype | ||
def addNewDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { | ||
newAPIImpl(c)(annottees: _*) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we have a more meaningful name, for example, typeSafeAPI
?
|
||
// Construct argument field | ||
var argDef = ListBuffer[String]() | ||
symbolfunction.listOfArgs.foreach(symbolarg => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we combine this with the next foreach
, i.e., line 120?
if (spaceRemoved.charAt(0)== '{') { | ||
val endIdx = spaceRemoved.indexOf('}') | ||
commaRemoved = spaceRemoved.substring(endIdx + 1).split(",") | ||
commaRemoved(0) = "string" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gotcha
// Optional Field | ||
if (commaRemoved.length >= 3) { | ||
// arg: Type, optional, default = Null | ||
require(commaRemoved(1).equals("optional")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just remind, better to use ==
in scala. ==
behaves the same as equals
in Java, and equals
in Scala behaves the same as ==
in Java... sigh... Since String is immutable, equals
here is fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understood, will note this.
|
||
override def afterAll(): Unit = { | ||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can be removed?
*/ | ||
package org.apache.mxnet | ||
@AddSymbolAPIs(false) | ||
object SymbolAPI { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment a bit for this placeholder.
@@ -37,7 +37,11 @@ object Base { | |||
|
|||
@throws(classOf[UnsatisfiedLinkError]) | |||
private def tryLoadInitLibrary(): Unit = { | |||
val baseDir = System.getProperty("user.dir") + "/init-native" | |||
// val baseDir = System.getProperty("user.dir") + "/init-native" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this line.
var baseDir = System.getProperty("user.dir") + "/init-native" | ||
if (System.getenv().containsKey("BASEDIR")) { | ||
baseDir = sys.env("BASEDIR") | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for system environment var, better to specify it is mxnet-related, e.g., MXNET_SCALA_MACRO_BASEDIR or something like it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor comments, please address them. we can merge it.
@@ -37,7 +37,10 @@ object Base { | |||
|
|||
@throws(classOf[UnsatisfiedLinkError]) | |||
private def tryLoadInitLibrary(): Unit = { | |||
val baseDir = System.getProperty("user.dir") + "/init-native" | |||
var baseDir = System.getProperty("user.dir") + "/init-native" | |||
if (System.getenv().containsKey("MXNET_SCALA_MACRO_BASEDIR")) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we change this to MXNET_BASEDIR and we can append the macro dir location relative to it. I don't think the user needs to understand what macros are or find out macro dir.
scala-package/macros/pom.xml
Outdated
<artifactId>scalatest-maven-plugin</artifactId> | ||
<configuration> | ||
<environmentVariables> | ||
<MXNET_SCALA_MACRO_BASEDIR>${project.parent.basedir}/init-native</MXNET_SCALA_MACRO_BASEDIR> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MXNET_BASEDIR
if (isContrib) symbolFunctions.filter(_._1.startsWith("_contrib_")) | ||
else symbolFunctions.filter(!_._1.startsWith("_contrib_")) | ||
if (isContrib) symbolFunctions.filter( | ||
func => func.name.startsWith("_contrib_") || !func.name.startsWith("_")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking of how this flag would be used, one possible use-case is we pull all the contrib APIs into a separate Object to make it more explicit as in ?
Symbol.api.foo
--> Stable APIs
Symbol.contrib.bar
--> Contrib APIs
We can do this as a separate PR, thoughts?
if (isContrib) funcName.substring(funcName.indexOf("_contrib_") + "_contrib_".length()) | ||
else funcName | ||
|
||
val functionDefs = newSymbolFunctions map { symbolfunction => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this code way better than the old, but I think we introducing new way of generating code for old and new APIs, if there is any bug that we haven't caught it will break both old and new APIs. I think its a risk.
|
||
val newSymbolFunctions = { | ||
if (isContrib) symbolFunctions.filter( | ||
func => func.name.startsWith("_contrib_") || !func.name.startsWith("_")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above, thoughts? can be done as a separate PR
argDef += "name : String = null" | ||
argDef += "attr : Map[String, String] = null" | ||
// scalastyle:off | ||
impl += "org.apache.mxnet.Symbol.createSymbolGeneral(\"" + symbolfunction.name + "\", name, attr, Seq(), map.toMap)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make a note to Seq() when we deprecate old APIs
Add relative path to MXNET_BASEDIR
val baseDir = System.getProperty("user.dir") + "/init-native" | ||
var baseDir = System.getProperty("user.dir") + "/init-native" | ||
if (System.getenv().containsKey("MXNET_BASEDIR")) { | ||
baseDir = sys.env("MXNET_BASEDIR") + "/scala-package/init-native" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just FYI, I updated this line from
baseDir = sys.env("MXNET_BASEDIR")
to
baseDir = sys.env("MXNET_BASEDIR") + "/scala-package/init-native"
* Simplfied current Macros impl to Quasiquote * Change the Symbol Function Field, add SymbolArg * Fix the Macros problem, disable the hidden function _ * Add Implementation for New API * Add examples and comments * Add _contrib_ support * New namespace for Symbol API * Change names and add comments * add TODOs and name changes * Add relative path to MXNET_BASEDIR * Update Base.scala
* Simplfied current Macros impl to Quasiquote * Change the Symbol Function Field, add SymbolArg * Fix the Macros problem, disable the hidden function _ * Add Implementation for New API * Add examples and comments * Add _contrib_ support * New namespace for Symbol API * Change names and add comments * add TODOs and name changes * Add relative path to MXNET_BASEDIR * Update Base.scala
* Simplfied current Macros impl to Quasiquote * Change the Symbol Function Field, add SymbolArg * Fix the Macros problem, disable the hidden function _ * Add Implementation for New API * Add examples and comments * Add _contrib_ support * New namespace for Symbol API * Change names and add comments * add TODOs and name changes * Add relative path to MXNET_BASEDIR * Update Base.scala
Description
See full design document
@nswamy @yzhliu
This PR is the Addition for new Symbol functions of Scala API
Checklist
Essentials