Skip to content

Commit e3f92e9

Browse files
mdesprieeJose Luis Contreras
authored andcommitted
[MXNET-918] Introduce Random module / Refact code generation (apache#13038)
* refactor code gen * remove xxxAPIMacroBase (overkill) * CI errors / scala-style * PR review comments
1 parent 48c9b79 commit e3f92e9

4 files changed

Lines changed: 411 additions & 493 deletions

File tree

scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala

Lines changed: 105 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -17,178 +17,154 @@
1717

1818
package org.apache.mxnet
1919

20-
import org.apache.mxnet.init.Base._
21-
import org.apache.mxnet.utils.CToScalaUtils
2220
import java.io._
2321
import java.security.MessageDigest
2422

25-
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
23+
import scala.collection.mutable.ListBuffer
2624

2725
/**
2826
* This object will generate the Scala documentation of the new Scala API
2927
* Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala
3028
* The code will be executed during Macros stage and file live in Core stage
3129
*/
32-
private[mxnet] object APIDocGenerator{
33-
case class absClassArg(argName : String, argType : String, argDesc : String, isOptional : Boolean)
34-
case class absClassFunction(name : String, desc : String,
35-
listOfArgs: List[absClassArg], returnType : String)
30+
private[mxnet] object APIDocGenerator extends GeneratorBase {
3631

37-
38-
def main(args: Array[String]) : Unit = {
32+
def main(args: Array[String]): Unit = {
3933
val FILE_PATH = args(0)
4034
val hashCollector = ListBuffer[String]()
41-
hashCollector += absClassGen(FILE_PATH, true)
42-
hashCollector += absClassGen(FILE_PATH, false)
35+
hashCollector += typeSafeClassGen(FILE_PATH, true)
36+
hashCollector += typeSafeClassGen(FILE_PATH, false)
4337
hashCollector += nonTypeSafeClassGen(FILE_PATH, true)
4438
hashCollector += nonTypeSafeClassGen(FILE_PATH, false)
4539
val finalHash = hashCollector.mkString("\n")
4640
}
4741

48-
def MD5Generator(input : String) : String = {
42+
def MD5Generator(input: String): String = {
4943
val md = MessageDigest.getInstance("MD5")
5044
md.update(input.getBytes("UTF-8"))
5145
val digest = md.digest()
5246
org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest)
5347
}
5448

55-
def absClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
56-
// scalastyle:off
57-
val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
58-
// Defines Operators that should not generated
59-
val notGenerated = Set("Custom")
60-
// TODO: Add Filter to the same location in case of refactor
61-
val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_"))
62-
.filterNot(ele => notGenerated.contains(ele.name))
63-
.map(absClassFunction => {
64-
val scalaDoc = generateAPIDocFromBackend(absClassFunction)
65-
val defBody = generateAPISignature(absClassFunction, isSymbol)
66-
s"$scalaDoc\n$defBody"
67-
})
68-
val packageName = if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase"
69-
val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n"
70-
val scalaStyle = "// scalastyle:off"
71-
val packageDef = "package org.apache.mxnet"
72-
val imports = "import org.apache.mxnet.annotation.Experimental"
73-
val absClassDef = s"abstract class $packageName"
74-
val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}"
75-
val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
76-
pw.write(finalStr)
77-
pw.close()
78-
MD5Generator(finalStr)
49+
def typeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
50+
val generated = typeSafeFunctionsToGenerate(isSymbol, isContrib = false)
51+
.map { func =>
52+
val scalaDoc = generateAPIDocFromBackend(func)
53+
val decl = generateAPISignature(func, isSymbol)
54+
s"$scalaDoc\n$decl"
55+
}
56+
57+
writeFile(
58+
FILE_PATH,
59+
if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase",
60+
"package org.apache.mxnet",
61+
generated)
7962
}
8063

81-
def nonTypeSafeClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
82-
// scalastyle:off
83-
val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
84-
val absFuncs = absClassFunctions.map(absClassFunction => {
85-
val scalaDoc = generateAPIDocFromBackend(absClassFunction, false)
86-
if (isSymbol) {
87-
val defBody = s"def ${absClassFunction.name}(name : String = null, attr : Map[String, String] = null)(args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null): org.apache.mxnet.Symbol"
88-
s"$scalaDoc\n$defBody"
89-
} else {
90-
val defBodyWithKwargs = s"def ${absClassFunction.name}(kwargs: Map[String, Any] = null)(args: Any*) : org.apache.mxnet.NDArrayFuncReturn"
91-
val defBody = s"def ${absClassFunction.name}(args: Any*) : org.apache.mxnet.NDArrayFuncReturn"
92-
s"$scalaDoc\n$defBodyWithKwargs\n$scalaDoc\n$defBody"
64+
def nonTypeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
65+
val absFuncs = functionsToGenerate(isSymbol, isContrib = false)
66+
.map { func =>
67+
val scalaDoc = generateAPIDocFromBackend(func, false)
68+
if (isSymbol) {
69+
s"""$scalaDoc
70+
|def ${func.name}(name : String = null, attr : Map[String, String] = null)
71+
| (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null):
72+
| org.apache.mxnet.Symbol
73+
""".stripMargin
74+
} else {
75+
s"""$scalaDoc
76+
|def ${func.name}(kwargs: Map[String, Any] = null)
77+
| (args: Any*): org.apache.mxnet.NDArrayFuncReturn
78+
|
79+
|$scalaDoc
80+
|def ${func.name}(args: Any*): org.apache.mxnet.NDArrayFuncReturn
81+
""".stripMargin
82+
}
9383
}
94-
})
95-
val packageName = if (isSymbol) "SymbolBase" else "NDArrayBase"
96-
val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n"
97-
val scalaStyle = "// scalastyle:off"
98-
val packageDef = "package org.apache.mxnet"
99-
val imports = "import org.apache.mxnet.annotation.Experimental"
100-
val absClassDef = s"abstract class $packageName"
101-
val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}"
102-
import java.io._
103-
val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
104-
pw.write(finalStr)
105-
pw.close()
106-
MD5Generator(finalStr)
84+
85+
writeFile(
86+
FILE_PATH,
87+
if (isSymbol) "SymbolBase" else "NDArrayBase",
88+
"package org.apache.mxnet",
89+
absFuncs)
10790
}
10891

109-
// Generate ScalaDoc type
110-
def generateAPIDocFromBackend(func : absClassFunction, withParam : Boolean = true) : String = {
111-
val desc = ArrayBuffer[String]()
112-
desc += " * <pre>"
113-
func.desc.split("\n").foreach({ currStr =>
114-
desc += s" * $currStr"
115-
})
116-
desc += " * </pre>"
117-
val params = func.listOfArgs.map({ absClassArg =>
118-
val currArgName = absClassArg.argName match {
119-
case "var" => "vari"
120-
case "type" => "typeOf"
121-
case _ => absClassArg.argName
122-
}
123-
s" * @param $currArgName\t\t${absClassArg.argDesc}"
124-
})
92+
def generateAPIDocFromBackend(func: Func, withParam: Boolean = true): String = {
93+
val desc = func.desc.split("\n")
94+
.mkString(" * <pre>\n", "\n * ", " * </pre>\n")
95+
96+
val params = func.listOfArgs.map { absClassArg =>
97+
s" * @param ${absClassArg.safeArgName}\t\t${absClassArg.argDesc}"
98+
}
99+
125100
val returnType = s" * @return ${func.returnType}"
101+
126102
if (withParam) {
127-
s" /**\n${desc.mkString("\n")}\n${params.mkString("\n")}\n$returnType\n */"
103+
s""" /**
104+
|$desc
105+
|${params.mkString("\n")}
106+
|$returnType
107+
| */""".stripMargin
128108
} else {
129-
s" /**\n${desc.mkString("\n")}\n$returnType\n */"
109+
s""" /**
110+
|$desc
111+
|$returnType
112+
| */""".stripMargin
130113
}
131114
}
132115

133-
def generateAPISignature(func : absClassFunction, isSymbol : Boolean) : String = {
134-
var argDef = ListBuffer[String]()
135-
func.listOfArgs.foreach(absClassArg => {
136-
val currArgName = absClassArg.argName match {
137-
case "var" => "vari"
138-
case "type" => "typeOf"
139-
case _ => absClassArg.argName
140-
}
141-
if (absClassArg.isOptional) {
142-
argDef += s"$currArgName : Option[${absClassArg.argType}] = None"
143-
}
144-
else {
145-
argDef += s"$currArgName : ${absClassArg.argType}"
146-
}
147-
})
148-
var returnType = func.returnType
116+
def generateAPISignature(func: Func, isSymbol: Boolean): String = {
117+
val argDef = ListBuffer[String]()
118+
119+
argDef ++= typedFunctionCommonArgDef(func)
120+
149121
if (isSymbol) {
150122
argDef += "name : String = null"
151123
argDef += "attr : Map[String, String] = null"
152124
} else {
153125
argDef += "out : Option[NDArray] = None"
154-
returnType = "org.apache.mxnet.NDArrayFuncReturn"
155126
}
156-
val experimentalTag = "@Experimental"
157-
s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", ")}) : $returnType"
158-
}
159127

128+
val returnType = func.returnType
160129

161-
// List and add all the atomic symbol functions to current module.
162-
private def getSymbolNDArrayMethods(isSymbol : Boolean): List[absClassFunction] = {
163-
val opNames = ListBuffer.empty[String]
164-
val returnType = if (isSymbol) "Symbol" else "NDArray"
165-
_LIB.mxListAllOpNames(opNames)
166-
// TODO: Add '_linalg_', '_sparse_', '_image_' support
167-
// TODO: Add Filter to the same location in case of refactor
168-
opNames.map(opName => {
169-
val opHandle = new RefLong
170-
_LIB.nnGetOpHandle(opName, opHandle)
171-
makeAtomicSymbolFunction(opHandle.value, opName, "org.apache.mxnet." + returnType)
172-
}).toList.filterNot(_.name.startsWith("_"))
130+
s"""@Experimental
131+
|def ${func.name} (${argDef.mkString(", ")}): $returnType""".stripMargin
173132
}
174133

175-
// Create an atomic symbol function by handle and function name.
176-
private def makeAtomicSymbolFunction(handle: SymbolHandle, aliasName: String, returnType : String)
177-
: absClassFunction = {
178-
val name = new RefString
179-
val desc = new RefString
180-
val keyVarNumArgs = new RefString
181-
val numArgs = new RefInt
182-
val argNames = ListBuffer.empty[String]
183-
val argTypes = ListBuffer.empty[String]
184-
val argDescs = ListBuffer.empty[String]
185-
186-
_LIB.mxSymbolGetAtomicSymbolInfo(
187-
handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
188-
val argList = argNames zip argTypes zip argDescs map { case ((argName, argType), argDesc) =>
189-
val typeAndOption = CToScalaUtils.argumentCleaner(argName, argType, returnType)
190-
new absClassArg(argName, typeAndOption._1, argDesc, typeAndOption._2)
191-
}
192-
new absClassFunction(aliasName, desc.value, argList.toList, returnType)
134+
def writeFile(FILE_PATH: String, className: String, packageDef: String,
135+
absFuncs: Seq[String]): String = {
136+
137+
val finalStr =
138+
s"""/*
139+
|* Licensed to the Apache Software Foundation (ASF) under one or more
140+
|* contributor license agreements. See the NOTICE file distributed with
141+
|* this work for additional information regarding copyright ownership.
142+
|* The ASF licenses this file to You under the Apache License, Version 2.0
143+
|* (the "License"); you may not use this file except in compliance with
144+
|* the License. You may obtain a copy of the License at
145+
|*
146+
|* http://www.apache.org/licenses/LICENSE-2.0
147+
|*
148+
|* Unless required by applicable law or agreed to in writing, software
149+
|* distributed under the License is distributed on an "AS IS" BASIS,
150+
|* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
151+
|* See the License for the specific language governing permissions and
152+
|* limitations under the License.
153+
|*/
154+
|
155+
|$packageDef
156+
|
157+
|import org.apache.mxnet.annotation.Experimental
158+
|
159+
|// scalastyle:off
160+
|abstract class $className {
161+
|${absFuncs.mkString("\n")}
162+
|}""".stripMargin
163+
164+
val pw = new PrintWriter(new File(FILE_PATH + s"$className.scala"))
165+
pw.write(finalStr)
166+
pw.close()
167+
MD5Generator(finalStr)
193168
}
169+
194170
}

0 commit comments

Comments
 (0)