Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ private[mxnet] class LibInfo {
argNames: ListBuffer[String],
argTypes: ListBuffer[String],
argDescs: ListBuffer[String],
keyVarNumArgs: RefString): Int
keyVarNumArgs: RefString,
returnType: RefString): Int
@native def mxSymbolCreateAtomicSymbol(handle: SymbolHandle,
paramKeys: Array[String],
paramVals: Array[String],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,14 @@ object NDArray {
val name = new RefString
val desc = new RefString
val keyVarNumArgs = new RefString
val returnType = new RefString
val numArgs = new RefInt
val argNames = ListBuffer.empty[String]
val argTypes = ListBuffer.empty[String]
val argDescs = ListBuffer.empty[String]

checkCall(_LIB.mxSymbolGetAtomicSymbolInfo(
handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs))
handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs, returnType))
val arguments = (argTypes zip argNames).filter { case (dtype, _) =>
!(dtype.startsWith("NDArray") || dtype.startsWith("Symbol")
|| dtype.startsWith("NDArray-or-Symbol"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer}
* WARNING: it is your responsibility to clear this object through dispose().
* </b>
*/
class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotDisposed {
class Symbol private(private[mxnet] val handle: SymbolHandle)
extends WarnIfNotDisposed {
private val logger: Logger = LoggerFactory.getLogger(classOf[Symbol])
private var disposed = false
protected def isDisposed = disposed
Expand Down Expand Up @@ -822,9 +823,8 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotD
jsonStr.value
}
}

@AddSymbolFunctions(false)
object Symbol {
object Symbol extends SymbolBase {
private type SymbolCreateNamedFunc = Map[String, Any] => Symbol
private val logger = LoggerFactory.getLogger(classOf[Symbol])
private val functions: Map[String, SymbolFunction] = initSymbolModule()
Expand Down Expand Up @@ -1026,13 +1026,14 @@ object Symbol {
val name = new RefString
val desc = new RefString
val keyVarNumArgs = new RefString
val returnType = new RefString
val numArgs = new RefInt
val argNames = ListBuffer.empty[String]
val argTypes = ListBuffer.empty[String]
val argDescs = ListBuffer.empty[String]

checkCall(_LIB.mxSymbolGetAtomicSymbolInfo(
handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs))
handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs, returnType))
(aliasName, new SymbolFunction(handle, keyVarNumArgs.value))
}

Expand Down
Loading