diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala index 0a5683aa7ab3..fd8269b6ac8f 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala @@ -194,7 +194,8 @@ private[mxnet] class LibInfo { argNames: ListBuffer[String], argTypes: ListBuffer[String], argDescs: ListBuffer[String], - keyVarNumArgs: RefString): Int + keyVarNumArgs: RefString, + returnType: RefString): Int @native def mxSymbolCreateAtomicSymbol(handle: SymbolHandle, paramKeys: Array[String], paramVals: Array[String], diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 416f2d74e828..561ebc96b558 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -162,13 +162,14 @@ object NDArray { val name = new RefString val desc = new RefString val keyVarNumArgs = new RefString + val returnType = new RefString val numArgs = new RefInt val argNames = ListBuffer.empty[String] val argTypes = ListBuffer.empty[String] val argDescs = ListBuffer.empty[String] checkCall(_LIB.mxSymbolGetAtomicSymbolInfo( - handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)) + handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs, returnType)) val arguments = (argTypes zip argNames).filter { case (dtype, _) => !(dtype.startsWith("NDArray") || dtype.startsWith("Symbol") || dtype.startsWith("NDArray-or-Symbol")) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala index 13f85a731dc4..05abcce18a57 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala @@ -29,7 +29,8 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer} * WARNING: it is your responsibility to clear this object through dispose(). * */ -class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotDisposed { +class Symbol private(private[mxnet] val handle: SymbolHandle) + extends WarnIfNotDisposed { private val logger: Logger = LoggerFactory.getLogger(classOf[Symbol]) private var disposed = false protected def isDisposed = disposed @@ -822,9 +823,8 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotD jsonStr.value } } - @AddSymbolFunctions(false) -object Symbol { +object Symbol extends SymbolBase { private type SymbolCreateNamedFunc = Map[String, Any] => Symbol private val logger = LoggerFactory.getLogger(classOf[Symbol]) private val functions: Map[String, SymbolFunction] = initSymbolModule() @@ -1026,13 +1026,14 @@ object Symbol { val name = new RefString val desc = new RefString val keyVarNumArgs = new RefString + val returnType = new RefString val numArgs = new RefInt val argNames = ListBuffer.empty[String] val argTypes = ListBuffer.empty[String] val argDescs = ListBuffer.empty[String] checkCall(_LIB.mxSymbolGetAtomicSymbolInfo( - handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)) + handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs, returnType)) (aliasName, new SymbolFunction(handle, keyVarNumArgs.value)) } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/SymbolBase.scala b/scala-package/core/src/main/scala/org/apache/mxnet/SymbolBase.scala new file mode 100644 index 000000000000..30a6a4c0d3f6 --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/SymbolBase.scala @@ -0,0 +1,20 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.mxnet + +trait SymbolBase {} \ No newline at end of file diff --git a/scala-package/init-native/src/main/native/org_apache_mxnet_init_native_c_api.cc b/scala-package/init-native/src/main/native/org_apache_mxnet_init_native_c_api.cc index b689521bcad1..2374c5ba3314 100644 --- a/scala-package/init-native/src/main/native/org_apache_mxnet_init_native_c_api.cc +++ b/scala-package/init-native/src/main/native/org_apache_mxnet_init_native_c_api.cc @@ -49,8 +49,8 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_init_LibInfo_mxSymbolListAtomicSymb JNIEXPORT jint JNICALL Java_org_apache_mxnet_init_LibInfo_mxSymbolGetAtomicSymbolInfo (JNIEnv *env, jobject obj, jlong symbolPtr, jobject name, jobject desc, jobject numArgs, - jobject argNames, jobject argTypes, jobject argDescs, jobject keyVarNumArgs) { - + jobject argNames, jobject argTypes, jobject argDescs, jobject keyVarNumArgs, + jobject returnType) { const char *cName; const char *cDesc; mx_uint cNumArgs; @@ -58,11 +58,12 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_init_LibInfo_mxSymbolGetAtomicSymbo const char **cArgTypes; const char **cArgDescs; const char *cKeyVarNumArgs; + const char *cReturnType; int ret = MXSymbolGetAtomicSymbolInfo(reinterpret_cast(symbolPtr), &cName, &cDesc, &cNumArgs, &cArgNames, &cArgTypes, &cArgDescs, - &cKeyVarNumArgs); + &cKeyVarNumArgs, &cReturnType); jclass refIntClass = env->FindClass("org/apache/mxnet/init/Base$RefInt"); jfieldID valueInt = env->GetFieldID(refIntClass, "value", "I"); @@ -78,6 +79,7 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_init_LibInfo_mxSymbolGetAtomicSymbo env->SetObjectField(name, valueStr, env->NewStringUTF(cName)); env->SetObjectField(desc, valueStr, env->NewStringUTF(cDesc)); env->SetObjectField(keyVarNumArgs, valueStr, env->NewStringUTF(cKeyVarNumArgs)); + env->SetObjectField(returnType, valueStr, env->NewStringUTF(cReturnType)); env->SetIntField(numArgs, valueInt, static_cast(cNumArgs)); for (size_t i = 0; i < cNumArgs; ++i) { env->CallObjectMethod(argNames, listAppend, env->NewStringUTF(cArgNames[i])); diff --git a/scala-package/init/src/main/scala/org/apache/mxnet/init/LibInfo.scala b/scala-package/init/src/main/scala/org/apache/mxnet/init/LibInfo.scala index 7bd0c701f872..b50a81e12274 100644 --- a/scala-package/init/src/main/scala/org/apache/mxnet/init/LibInfo.scala +++ b/scala-package/init/src/main/scala/org/apache/mxnet/init/LibInfo.scala @@ -30,7 +30,8 @@ class LibInfo { argNames: ListBuffer[String], argTypes: ListBuffer[String], argDescs: ListBuffer[String], - keyVarNumArgs: RefString): Int + keyVarNumArgs: RefString, + returnType: RefString): Int @native def mxListAllOpNames(names: ListBuffer[String]): Int @native def nnGetOpHandle(opName: String, opHandle: RefLong): Int } diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala index c26d14c12923..1956e6b122ac 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala @@ -162,12 +162,13 @@ private[mxnet] object NDArrayMacro { val desc = new RefString val keyVarNumArgs = new RefString val numArgs = new RefInt + val returnType = new RefString val argNames = ListBuffer.empty[String] val argTypes = ListBuffer.empty[String] val argDescs = ListBuffer.empty[String] _LIB.mxSymbolGetAtomicSymbolInfo( - handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs) + handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs, returnType) val paramStr = OperatorBuildUtils.ctypes2docstring(argNames, argTypes, argDescs) val extraDoc: String = if (keyVarNumArgs.value != null && keyVarNumArgs.value.length > 0) { s"This function support variable length of positional input (${keyVarNumArgs.value})." diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolBaseMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolBaseMacro.scala new file mode 100644 index 000000000000..92ad9d4082d7 --- /dev/null +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolBaseMacro.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet + +import scala.collection.mutable.{HashMap, ListBuffer} +import org.apache.mxnet.init.Base._ + + +private[mxnet] object SymbolDocMacros { + + case class SymbolFunction(handle: SymbolHandle, paramStr: String) + + def addDefs() : Unit = { + val baseDir = System.getProperty("user.dir") + val targetPath = baseDir + "/core/src/main/scala/org/apache/mxnet/SymbolBase.scala" + SEImpl(targetPath) + } + + def SEImpl(FILE_PATH : String) : Unit = { + var symbolFunctions: List[SymbolFunction] = initSymbolModule() + import java.io._ + val pw = new PrintWriter(new File(FILE_PATH)) + // scalastyle:off + pw.write("/*\n* Licensed to the Apache Software Foundation (ASF) under one or more\n* contributor license agreements. See the NOTICE file distributed with\n* this work for additional information regarding copyright ownership.\n* The ASF licenses this file to You under the Apache License, Version 2.0\n* (the \"License\"); you may not use this file except in compliance with\n* the License. You may obtain a copy of the License at\n*\n* http://www.apache.org/licenses/LICENSE-2.0\n*\n* Unless required by applicable law or agreed to in writing, software\n* distributed under the License is distributed on an \"AS IS\" BASIS,\n* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n* See the License for the specific language governing permissions and\n* limitations under the License.\n*/\n\npackage org.apache.mxnet\n") + // scalastyle:on + pw.write(s"\ntrait SymbolBase {\n\n") + pw.write(s" // scalastyle:off\n") + symbolFunctions = symbolFunctions.distinct + for (ele <- symbolFunctions) { + val temp = ele.paramStr + "\n\n" + pw.write(temp) + } + pw.write(s"\n\n}") + pw.close() + } + + + /* + Code copies from the SymbolMacros Class + */ + private def initSymbolModule(): List[SymbolFunction] = { + var opNames = ListBuffer.empty[String] + _LIB.mxListAllOpNames(opNames) + opNames = opNames.distinct + val result : ListBuffer[SymbolFunction] = ListBuffer[SymbolFunction]() + opNames.foreach(opName => { + val opHandle = new RefLong + // printf(opName) + _LIB.nnGetOpHandle(opName, opHandle) + makeAtomicSymbolFunction(opHandle.value, opName, result) + }) + + result.toList + } + + private def makeAtomicSymbolFunction(handle: SymbolHandle, + aliasName: String, result : ListBuffer[SymbolFunction]) + : Unit = { + val name = new RefString + val desc = new RefString + val keyVarNumArgs = new RefString + val returnType = new RefString + val numArgs = new RefInt + val argNames = ListBuffer.empty[String] + val argTypes = ListBuffer.empty[String] + val argDescs = ListBuffer.empty[String] + + _LIB.mxSymbolGetAtomicSymbolInfo( + handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs, returnType) + + if (name.value.charAt(0) == '_') { + // Internal function + } else { + val paramStr = + traitgen(name.value, desc.value, argNames, argTypes, argDescs, returnType.value) + val extraDoc: String = if (keyVarNumArgs.value != null && keyVarNumArgs.value.length > 0) { + s"This function support variable length of positional input (${keyVarNumArgs.value})." + } else { + "" + } + result += SymbolFunction(handle, paramStr) + } + } + + + def traitgen(functionName : String, + functionDesc : String, + argNames : Seq[String], + argTypes : Seq[String], + argDescs : Seq[String], + returnType : String) : String = { + val desc = functionDesc.split("\n") map { currStr => + s" * $currStr" + } + val params = + (argNames zip argTypes zip argDescs) map { case ((argName, argType), argDesc) => + // val desc = if (argDesc.isEmpty) "" else s"\n$argDesc" + s" * @param $argName\t\t$argDesc" + } + val traitsec = + (argNames zip argTypes) map { case ((argName, argType)) => + val currArgType = CodeClean.cleanUp(argType) + var currArgName = "" + if (argName.equals("var")) { + currArgName = "vari" + } else { + currArgName = argName + } + s"$currArgName : $currArgType" + } + // scalastyle:off + val defaultConfig = s"(name : scala.Predef.String, attr : scala.Predef.Map[scala.Predef.String, scala.Predef.String])(args : org.apache.mxnet.Symbol*)(kwargs : scala.Predef.Map[scala.Predef.String, scala.Any]) : org.apache.mxnet.Symbol" + // s"/**\n${desc.mkString("\n")}\n${params.mkString("\n")}\n* @return $returnType\n*/\ndef $functionName(${traitsec.mkString(", ")}) : Any" + s" /**\n${desc.mkString("\n")}\n${params.mkString("\n")}\n * @return $returnType\n */\n def $functionName$defaultConfig" + // scalastyle:on + } +} + +private[mxnet] object CodeClean { + + + val typeMap : HashMap[String, String] = HashMap( + ("Shape(tuple)", "Shape"), + ("Symbol", "Symbol"), + ("NDArray", "Symbol"), + ("NDArray-or-Symbol", "Symbol"), + // TODO: Add def + ("Symbol[]", "Any"), + ("NDArray[]", "Any"), + ("NDArray-or-Symbol[]", "Any"), + ("int(non-negative)", "Any"), + ("long(non-negative)", "Any"), + ("ShapeorNone", "Option[Shape]"), + ("real_t", "Any"), // MXFloat + ("float", "Any"), + ("intorNone", "Option[Int]"), + ("SymbolorSymbol[]", "Any"), + ("tupleof", "Any"), + // End Missing section + ("int", "Int"), + ("long", "Long"), + ("double", "Double"), + ("string", "String"), + ("boolean", "Boolean") + ) + + + def conversion(in : String, optional : String) : String = { + val out = in match { + // deal with [] + case "Shape" => "new Shape()" + // deal with '6000' => 6000 + case "Int" | "Option[Int]" | "Option[Shape]" => optional.replaceAll("'", "") + // deal with string default + case "String" => optional.replaceAll("'", "\"") + // Deal with Boolean + case "Boolean" => { + if (optional.charAt(0) == '0') { + "false" + } else { + "true" + } + } + // Anything else + case _ => optional + } + + out + } + + def cleanUp(in : String) : String = { + val spaceRemoved = in.replaceAll("\\s+", "") + var commaRemoved : Array[String] = new Array[String](0) + // Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'} + if (spaceRemoved.charAt(0)== '{') { + val endIdx = spaceRemoved.indexOf('}') + commaRemoved = spaceRemoved.substring(endIdx + 1).split(",") + // commaRemoved(0) = spaceRemoved.substring(0, endIdx+1) + commaRemoved(0) = "string" + } else { + commaRemoved = spaceRemoved.split(",") + } + var typeConv = "" + var optionalField = "" + // println("Try to find key " + commaRemoved(0)) + if (commaRemoved.length < 1) { + printf("Empty Field Generated\n") + } else if (commaRemoved.length == 3) { + + // Something to do with Optional + typeConv = typeMap(commaRemoved(0)) + optionalField = " = " + conversion(typeConv, commaRemoved(2).split("=")(1)) + } else if (commaRemoved.length > 3) { + // TODO: Field over 3, need to rework + typeConv = "Any" + printf("Field Over 3, please reformat %s", in) + } else { + typeConv = typeMap(commaRemoved(0)) + } + // if (!typeMap.contains(commaRemoved(0))) { + // logger.error("First Field not recognized " + commaRemoved(0)) + // } else { + // typeConv = typeMap(commaRemoved(0)) + // } + val out = typeConv + optionalField + out + } + +} \ No newline at end of file diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala index b6ddaafc7ad7..e23c2edcaadc 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala @@ -34,6 +34,7 @@ private[mxnet] object SymbolImplMacros { // scalastyle:off havetype def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { + SymbolDocMacros.addDefs() impl(c)(false, annottees: _*) } // scalastyle:off havetype @@ -154,12 +155,13 @@ private[mxnet] object SymbolImplMacros { val desc = new RefString val keyVarNumArgs = new RefString val numArgs = new RefInt + val returnType = new RefString val argNames = ListBuffer.empty[String] val argTypes = ListBuffer.empty[String] val argDescs = ListBuffer.empty[String] _LIB.mxSymbolGetAtomicSymbolInfo( - handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs) + handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs, returnType) val paramStr = OperatorBuildUtils.ctypes2docstring(argNames, argTypes, argDescs) val extraDoc: String = if (keyVarNumArgs.value != null && keyVarNumArgs.value.length > 0) { s"This function support variable length of positional input (${keyVarNumArgs.value})." diff --git a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc index caf7e56cd3bd..ebd1e83dcb71 100644 --- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc +++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc @@ -1120,8 +1120,8 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolListAtomicSymbolCre JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolGetAtomicSymbolInfo (JNIEnv *env, jobject obj, jlong symbolPtr, jobject name, jobject desc, jobject numArgs, - jobject argNames, jobject argTypes, jobject argDescs, jobject keyVarNumArgs) { - + jobject argNames, jobject argTypes, jobject argDescs, jobject keyVarNumArgs, + jobject returnType) { const char *cName; const char *cDesc; mx_uint cNumArgs; @@ -1129,11 +1129,12 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolGetAtomicSymbolInfo const char **cArgTypes; const char **cArgDescs; const char *cKeyVarNumArgs; + const char *cReturnType; int ret = MXSymbolGetAtomicSymbolInfo(reinterpret_cast(symbolPtr), &cName, &cDesc, &cNumArgs, &cArgNames, &cArgTypes, &cArgDescs, - &cKeyVarNumArgs); + &cKeyVarNumArgs, &cReturnType); jclass refIntClass = env->FindClass("org/apache/mxnet/Base$RefInt"); jfieldID valueInt = env->GetFieldID(refIntClass, "value", "I"); @@ -1149,6 +1150,7 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolGetAtomicSymbolInfo env->SetObjectField(name, valueStr, env->NewStringUTF(cName)); env->SetObjectField(desc, valueStr, env->NewStringUTF(cDesc)); env->SetObjectField(keyVarNumArgs, valueStr, env->NewStringUTF(cKeyVarNumArgs)); + env->SetObjectField(returnType, valueStr, env->NewStringUTF(cReturnType)); env->SetIntField(numArgs, valueInt, static_cast(cNumArgs)); for (size_t i = 0; i < cNumArgs; ++i) { env->CallObjectMethod(argNames, listAppend, env->NewStringUTF(cArgNames[i]));