From 8c7d1f887966b119b9d210e840f3f07335e0369d Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 25 Jul 2018 10:32:24 -0700 Subject: [PATCH] add one more class to convert Strings to DTypes --- .../core/src/main/scala/org/apache/mxnet/DType.scala | 9 +++++++++ .../core/src/main/scala/org/apache/mxnet/IO.scala | 5 ++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala b/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala index 4458a7c7aeb8..b015bd2169b7 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala @@ -35,4 +35,13 @@ object DType extends Enumeration { case DType.Unknown => 0 } } + private[mxnet] def getType(dtypeStr: String): DType = { + dtypeStr match { + case "UInt8" => DType.UInt8 + case "Int32" => DType.Int32 + case "Float16" => DType.Float16 + case "Float32" => DType.Float32 + case "Float64" => DType.Float64 + } + } } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala index 6d90ece4be18..56cd59a9c24a 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala @@ -25,7 +25,6 @@ import org.slf4j.LoggerFactory import scala.annotation.varargs import scala.collection.immutable.ListMap import scala.collection.mutable.ListBuffer -import scala.reflect.runtime.universe._ /** * IO iterators for loading training & validation data */ @@ -112,8 +111,8 @@ object IO { val labelDType = params.getOrElse("labelDType", "Int32") new MXDataIter(out.value, dataName, labelName, dataLayout = dataLayout, labelLayout = labelLayout, - dataDType = q"DType ${TermName(dataDType)}".asInstanceOf[DType], - labelDType = q"DType ${TermName(labelDType)}".asInstanceOf[DType]) + dataDType = DType.getType(dataDType), + labelDType = DType.getType(labelDType)) } // Convert data into canonical form.