|
17 | 17 |
|
18 | 18 | package org.apache.mxnet |
19 | 19 |
|
20 | | -import org.apache.mxnet.init.Base._ |
21 | | -import org.apache.mxnet.utils.CToScalaUtils |
22 | 20 | import java.io._ |
23 | 21 | import java.security.MessageDigest |
24 | 22 |
|
25 | | -import scala.collection.mutable.{ArrayBuffer, ListBuffer} |
| 23 | +import scala.collection.mutable.ListBuffer |
26 | 24 |
|
27 | 25 | /** |
28 | 26 | * This object will generate the Scala documentation of the new Scala API |
29 | 27 | * Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala |
30 | 28 | * The code will be executed during Macros stage and file live in Core stage |
31 | 29 | */ |
32 | | -private[mxnet] object APIDocGenerator{ |
33 | | - case class absClassArg(argName : String, argType : String, argDesc : String, isOptional : Boolean) |
34 | | - case class absClassFunction(name : String, desc : String, |
35 | | - listOfArgs: List[absClassArg], returnType : String) |
| 30 | +private[mxnet] object APIDocGenerator extends GeneratorBase { |
36 | 31 |
|
37 | | - |
38 | | - def main(args: Array[String]) : Unit = { |
| 32 | + def main(args: Array[String]): Unit = { |
39 | 33 | val FILE_PATH = args(0) |
40 | 34 | val hashCollector = ListBuffer[String]() |
41 | | - hashCollector += absClassGen(FILE_PATH, true) |
42 | | - hashCollector += absClassGen(FILE_PATH, false) |
| 35 | + hashCollector += typeSafeClassGen(FILE_PATH, true) |
| 36 | + hashCollector += typeSafeClassGen(FILE_PATH, false) |
43 | 37 | hashCollector += nonTypeSafeClassGen(FILE_PATH, true) |
44 | 38 | hashCollector += nonTypeSafeClassGen(FILE_PATH, false) |
45 | 39 | val finalHash = hashCollector.mkString("\n") |
46 | 40 | } |
47 | 41 |
|
48 | | - def MD5Generator(input : String) : String = { |
| 42 | + def MD5Generator(input: String): String = { |
49 | 43 | val md = MessageDigest.getInstance("MD5") |
50 | 44 | md.update(input.getBytes("UTF-8")) |
51 | 45 | val digest = md.digest() |
52 | 46 | org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest) |
53 | 47 | } |
54 | 48 |
|
55 | | - def absClassGen(FILE_PATH : String, isSymbol : Boolean) : String = { |
56 | | - // scalastyle:off |
57 | | - val absClassFunctions = getSymbolNDArrayMethods(isSymbol) |
58 | | - // Defines Operators that should not generated |
59 | | - val notGenerated = Set("Custom") |
60 | | - // TODO: Add Filter to the same location in case of refactor |
61 | | - val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_")) |
62 | | - .filterNot(ele => notGenerated.contains(ele.name)) |
63 | | - .map(absClassFunction => { |
64 | | - val scalaDoc = generateAPIDocFromBackend(absClassFunction) |
65 | | - val defBody = generateAPISignature(absClassFunction, isSymbol) |
66 | | - s"$scalaDoc\n$defBody" |
67 | | - }) |
68 | | - val packageName = if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase" |
69 | | - val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n" |
70 | | - val scalaStyle = "// scalastyle:off" |
71 | | - val packageDef = "package org.apache.mxnet" |
72 | | - val imports = "import org.apache.mxnet.annotation.Experimental" |
73 | | - val absClassDef = s"abstract class $packageName" |
74 | | - val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}" |
75 | | - val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala")) |
76 | | - pw.write(finalStr) |
77 | | - pw.close() |
78 | | - MD5Generator(finalStr) |
| 49 | + def typeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = { |
| 50 | + val generated = typeSafeFunctionsToGenerate(isSymbol, isContrib = false) |
| 51 | + .map { func => |
| 52 | + val scalaDoc = generateAPIDocFromBackend(func) |
| 53 | + val decl = generateAPISignature(func, isSymbol) |
| 54 | + s"$scalaDoc\n$decl" |
| 55 | + } |
| 56 | + |
| 57 | + writeFile( |
| 58 | + FILE_PATH, |
| 59 | + if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase", |
| 60 | + "package org.apache.mxnet", |
| 61 | + generated) |
79 | 62 | } |
80 | 63 |
|
81 | | - def nonTypeSafeClassGen(FILE_PATH : String, isSymbol : Boolean) : String = { |
82 | | - // scalastyle:off |
83 | | - val absClassFunctions = getSymbolNDArrayMethods(isSymbol) |
84 | | - val absFuncs = absClassFunctions.map(absClassFunction => { |
85 | | - val scalaDoc = generateAPIDocFromBackend(absClassFunction, false) |
86 | | - if (isSymbol) { |
87 | | - val defBody = s"def ${absClassFunction.name}(name : String = null, attr : Map[String, String] = null)(args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null): org.apache.mxnet.Symbol" |
88 | | - s"$scalaDoc\n$defBody" |
89 | | - } else { |
90 | | - val defBodyWithKwargs = s"def ${absClassFunction.name}(kwargs: Map[String, Any] = null)(args: Any*) : org.apache.mxnet.NDArrayFuncReturn" |
91 | | - val defBody = s"def ${absClassFunction.name}(args: Any*) : org.apache.mxnet.NDArrayFuncReturn" |
92 | | - s"$scalaDoc\n$defBodyWithKwargs\n$scalaDoc\n$defBody" |
| 64 | + def nonTypeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = { |
| 65 | + val absFuncs = functionsToGenerate(isSymbol, isContrib = false) |
| 66 | + .map { func => |
| 67 | + val scalaDoc = generateAPIDocFromBackend(func, false) |
| 68 | + if (isSymbol) { |
| 69 | + s"""$scalaDoc |
| 70 | + |def ${func.name}(name : String = null, attr : Map[String, String] = null) |
| 71 | + | (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null): |
| 72 | + | org.apache.mxnet.Symbol |
| 73 | + """.stripMargin |
| 74 | + } else { |
| 75 | + s"""$scalaDoc |
| 76 | + |def ${func.name}(kwargs: Map[String, Any] = null) |
| 77 | + | (args: Any*): org.apache.mxnet.NDArrayFuncReturn |
| 78 | + | |
| 79 | + |$scalaDoc |
| 80 | + |def ${func.name}(args: Any*): org.apache.mxnet.NDArrayFuncReturn |
| 81 | + """.stripMargin |
| 82 | + } |
93 | 83 | } |
94 | | - }) |
95 | | - val packageName = if (isSymbol) "SymbolBase" else "NDArrayBase" |
96 | | - val apacheLicence = "/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n" |
97 | | - val scalaStyle = "// scalastyle:off" |
98 | | - val packageDef = "package org.apache.mxnet" |
99 | | - val imports = "import org.apache.mxnet.annotation.Experimental" |
100 | | - val absClassDef = s"abstract class $packageName" |
101 | | - val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef {\n${absFuncs.mkString("\n")}\n}" |
102 | | - import java.io._ |
103 | | - val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala")) |
104 | | - pw.write(finalStr) |
105 | | - pw.close() |
106 | | - MD5Generator(finalStr) |
| 84 | + |
| 85 | + writeFile( |
| 86 | + FILE_PATH, |
| 87 | + if (isSymbol) "SymbolBase" else "NDArrayBase", |
| 88 | + "package org.apache.mxnet", |
| 89 | + absFuncs) |
107 | 90 | } |
108 | 91 |
|
109 | | - // Generate ScalaDoc type |
110 | | - def generateAPIDocFromBackend(func : absClassFunction, withParam : Boolean = true) : String = { |
111 | | - val desc = ArrayBuffer[String]() |
112 | | - desc += " * <pre>" |
113 | | - func.desc.split("\n").foreach({ currStr => |
114 | | - desc += s" * $currStr" |
115 | | - }) |
116 | | - desc += " * </pre>" |
117 | | - val params = func.listOfArgs.map({ absClassArg => |
118 | | - val currArgName = absClassArg.argName match { |
119 | | - case "var" => "vari" |
120 | | - case "type" => "typeOf" |
121 | | - case _ => absClassArg.argName |
122 | | - } |
123 | | - s" * @param $currArgName\t\t${absClassArg.argDesc}" |
124 | | - }) |
| 92 | + def generateAPIDocFromBackend(func: Func, withParam: Boolean = true): String = { |
| 93 | + val desc = func.desc.split("\n") |
| 94 | + .mkString(" * <pre>\n", "\n * ", " * </pre>\n") |
| 95 | + |
| 96 | + val params = func.listOfArgs.map { absClassArg => |
| 97 | + s" * @param ${absClassArg.safeArgName}\t\t${absClassArg.argDesc}" |
| 98 | + } |
| 99 | + |
125 | 100 | val returnType = s" * @return ${func.returnType}" |
| 101 | + |
126 | 102 | if (withParam) { |
127 | | - s" /**\n${desc.mkString("\n")}\n${params.mkString("\n")}\n$returnType\n */" |
| 103 | + s""" /** |
| 104 | + |$desc |
| 105 | + |${params.mkString("\n")} |
| 106 | + |$returnType |
| 107 | + | */""".stripMargin |
128 | 108 | } else { |
129 | | - s" /**\n${desc.mkString("\n")}\n$returnType\n */" |
| 109 | + s""" /** |
| 110 | + |$desc |
| 111 | + |$returnType |
| 112 | + | */""".stripMargin |
130 | 113 | } |
131 | 114 | } |
132 | 115 |
|
133 | | - def generateAPISignature(func : absClassFunction, isSymbol : Boolean) : String = { |
134 | | - var argDef = ListBuffer[String]() |
135 | | - func.listOfArgs.foreach(absClassArg => { |
136 | | - val currArgName = absClassArg.argName match { |
137 | | - case "var" => "vari" |
138 | | - case "type" => "typeOf" |
139 | | - case _ => absClassArg.argName |
140 | | - } |
141 | | - if (absClassArg.isOptional) { |
142 | | - argDef += s"$currArgName : Option[${absClassArg.argType}] = None" |
143 | | - } |
144 | | - else { |
145 | | - argDef += s"$currArgName : ${absClassArg.argType}" |
146 | | - } |
147 | | - }) |
148 | | - var returnType = func.returnType |
| 116 | + def generateAPISignature(func: Func, isSymbol: Boolean): String = { |
| 117 | + val argDef = ListBuffer[String]() |
| 118 | + |
| 119 | + argDef ++= typedFunctionCommonArgDef(func) |
| 120 | + |
149 | 121 | if (isSymbol) { |
150 | 122 | argDef += "name : String = null" |
151 | 123 | argDef += "attr : Map[String, String] = null" |
152 | 124 | } else { |
153 | 125 | argDef += "out : Option[NDArray] = None" |
154 | | - returnType = "org.apache.mxnet.NDArrayFuncReturn" |
155 | 126 | } |
156 | | - val experimentalTag = "@Experimental" |
157 | | - s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", ")}) : $returnType" |
158 | | - } |
159 | 127 |
|
| 128 | + val returnType = func.returnType |
160 | 129 |
|
161 | | - // List and add all the atomic symbol functions to current module. |
162 | | - private def getSymbolNDArrayMethods(isSymbol : Boolean): List[absClassFunction] = { |
163 | | - val opNames = ListBuffer.empty[String] |
164 | | - val returnType = if (isSymbol) "Symbol" else "NDArray" |
165 | | - _LIB.mxListAllOpNames(opNames) |
166 | | - // TODO: Add '_linalg_', '_sparse_', '_image_' support |
167 | | - // TODO: Add Filter to the same location in case of refactor |
168 | | - opNames.map(opName => { |
169 | | - val opHandle = new RefLong |
170 | | - _LIB.nnGetOpHandle(opName, opHandle) |
171 | | - makeAtomicSymbolFunction(opHandle.value, opName, "org.apache.mxnet." + returnType) |
172 | | - }).toList.filterNot(_.name.startsWith("_")) |
| 130 | + s"""@Experimental |
| 131 | + |def ${func.name} (${argDef.mkString(", ")}): $returnType""".stripMargin |
173 | 132 | } |
174 | 133 |
|
175 | | - // Create an atomic symbol function by handle and function name. |
176 | | - private def makeAtomicSymbolFunction(handle: SymbolHandle, aliasName: String, returnType : String) |
177 | | - : absClassFunction = { |
178 | | - val name = new RefString |
179 | | - val desc = new RefString |
180 | | - val keyVarNumArgs = new RefString |
181 | | - val numArgs = new RefInt |
182 | | - val argNames = ListBuffer.empty[String] |
183 | | - val argTypes = ListBuffer.empty[String] |
184 | | - val argDescs = ListBuffer.empty[String] |
185 | | - |
186 | | - _LIB.mxSymbolGetAtomicSymbolInfo( |
187 | | - handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs) |
188 | | - val argList = argNames zip argTypes zip argDescs map { case ((argName, argType), argDesc) => |
189 | | - val typeAndOption = CToScalaUtils.argumentCleaner(argName, argType, returnType) |
190 | | - new absClassArg(argName, typeAndOption._1, argDesc, typeAndOption._2) |
191 | | - } |
192 | | - new absClassFunction(aliasName, desc.value, argList.toList, returnType) |
| 134 | + def writeFile(FILE_PATH: String, className: String, packageDef: String, |
| 135 | + absFuncs: Seq[String]): String = { |
| 136 | + |
| 137 | + val finalStr = |
| 138 | + s"""/* |
| 139 | + |* Licensed to the Apache Software Foundation (ASF) under one or more |
| 140 | + |* contributor license agreements. See the NOTICE file distributed with |
| 141 | + |* this work for additional information regarding copyright ownership. |
| 142 | + |* The ASF licenses this file to You under the Apache License, Version 2.0 |
| 143 | + |* (the "License"); you may not use this file except in compliance with |
| 144 | + |* the License. You may obtain a copy of the License at |
| 145 | + |* |
| 146 | + |* http://www.apache.org/licenses/LICENSE-2.0 |
| 147 | + |* |
| 148 | + |* Unless required by applicable law or agreed to in writing, software |
| 149 | + |* distributed under the License is distributed on an "AS IS" BASIS, |
| 150 | + |* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 151 | + |* See the License for the specific language governing permissions and |
| 152 | + |* limitations under the License. |
| 153 | + |*/ |
| 154 | + | |
| 155 | + |$packageDef |
| 156 | + | |
| 157 | + |import org.apache.mxnet.annotation.Experimental |
| 158 | + | |
| 159 | + |// scalastyle:off |
| 160 | + |abstract class $className { |
| 161 | + |${absFuncs.mkString("\n")} |
| 162 | + |}""".stripMargin |
| 163 | + |
| 164 | + val pw = new PrintWriter(new File(FILE_PATH + s"$className.scala")) |
| 165 | + pw.write(finalStr) |
| 166 | + pw.close() |
| 167 | + MD5Generator(finalStr) |
193 | 168 | } |
| 169 | + |
194 | 170 | } |
0 commit comments