forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request apache#13 from terrytangyuan/terry
Random and Initializer
- Loading branch information
Showing
3 changed files
with
211 additions
and
1 deletion.
There are no files selected for viewing
133 changes: 133 additions & 0 deletions
133
scala-package/core/src/main/scala/ml/dmlc/mxnet/Initializer.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
package ml.dmlc.mxnet | ||
|
||
import ml.dmlc.mxnet.NDArray.{array, zeros, ones} | ||
|
||
|
||
/** | ||
* | ||
* Base class for Initializer. | ||
* | ||
* @author Yuan Tang | ||
*/ | ||
abstract class Initializer { | ||
|
||
/** | ||
* Initialize an Initializer | ||
* | ||
* @param name name of corrosponding ndarray | ||
* @param arr ndarray to be Initialized | ||
*/ | ||
def apply(name: String, arr: NDArray): Unit = { | ||
|
||
if (name.startsWith("upsampling")) { | ||
_initBilinear(name, arr) | ||
} else if (name.endsWith("bias")) { | ||
_initBias(name, arr) | ||
} else if (name.endsWith("gamma")) { | ||
_initGamma(name, arr) | ||
} else if (name.endsWith("beta")) { | ||
_initBeta(name, arr) | ||
} else if (name.endsWith("weight")) { | ||
_initWeight(name, arr) | ||
} else if (name.endsWith("moving_mean")) { | ||
_initZero(name, arr) | ||
} else if (name.endsWith("moving_var")) { | ||
_initZero(name, arr) | ||
} else if (name.endsWith("moving_avg")) { | ||
_initZero(name, arr) | ||
} else { | ||
throw new IllegalArgumentException(s"Unknown initialization pattern for ${name}.") | ||
} | ||
} | ||
|
||
def _initBilinear(name: String, arr: NDArray): Unit = { | ||
val weight = Array.fill[Float](arr.size)(0.0f) | ||
val shape = arr.shape | ||
val f = shape(3) / 2.0f | ||
val c = (2 * f - 1 - f % 2) / (2.0f * f) | ||
|
||
(0 to (arr.size)).foreach { i => | ||
val x = i % shape(3) | ||
val y = (i / shape(3)) % shape(2) | ||
weight(i) = (1 - math.abs(x / f - c)) * (1 - math.abs(y / f - c)) | ||
} | ||
|
||
arr.set(array(weight)) | ||
} | ||
|
||
def _initZero(name: String, arr: NDArray): Unit = { | ||
arr.set(0f) | ||
} | ||
|
||
def _initBias(name: String, arr: NDArray): Unit = { | ||
arr.set(0f) | ||
} | ||
|
||
def _initGamma(name: String, arr: NDArray): Unit = { | ||
arr.set(1f) | ||
} | ||
|
||
def _initBeta(name: String, arr: NDArray): Unit = { | ||
arr.set(0f) | ||
} | ||
|
||
def _initWeight(name: String, arr: NDArray): Unit | ||
} | ||
|
||
|
||
/** | ||
* Initialize the weight with uniform [-scale, scale] | ||
* | ||
* @param scale The scale of uniform distribution | ||
*/ | ||
class Uniform(protected val scale: Float=0.07f) extends Initializer { | ||
override def _initWeight(name: String, arr: NDArray): Unit = { | ||
Random.uniform(-scale, scale, out=arr) | ||
} | ||
} | ||
|
||
|
||
/** | ||
* Initialize the weight with normal(0, sigma) | ||
* | ||
* @param sigma Standard deviation for gaussian distribution. | ||
*/ | ||
class Normal(protected val sigma: Float=0.01f) extends Initializer { | ||
override def _initWeight(name: String, arr: NDArray): Unit = { | ||
Random.normal(0, sigma, out=arr) | ||
} | ||
} | ||
|
||
|
||
/** | ||
* Initialize the weight with Xavier or similar initialization scheme. | ||
* | ||
* @param rndType Options are: "gaussian" or "uniform" | ||
* @param factorType Options are: "avg", "in", "out" | ||
* @param magnitude scale of random number range | ||
*/ | ||
class Xavier(protected val rndType: String ="uniform", | ||
protected val factorType: String ="avg", | ||
protected val magnitude: Int = 3) extends Initializer { | ||
|
||
override def _initWeight(name: String, arr: NDArray): Unit = { | ||
val shape = arr.shape | ||
val fanIn = shape.slice(1, shape.length).product | ||
val fanOut = shape(0) | ||
var factor = 1 | ||
|
||
factor = factorType match { | ||
case "avg" => (fanIn + fanOut) / 2 | ||
case "in" => fanIn | ||
case "out" => fanOut | ||
case _ => throw new IllegalArgumentException("Incorrect factor type") | ||
} | ||
val scale = math.sqrt(magnitude / factor).toFloat | ||
|
||
rndType match { | ||
case "uniform" => Random.uniform(-scale, scale, out=arr) | ||
case "normal" => Random.normal(0, scale, out=arr) | ||
case _ => throw new IllegalArgumentException("Unknown random type") | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
71 changes: 71 additions & 0 deletions
71
scala-package/core/src/main/scala/ml/dmlc/mxnet/Random.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
package ml.dmlc.mxnet | ||
|
||
import ml.dmlc.mxnet.Base._ | ||
import ml.dmlc.mxnet.NDArray.{_randomUniform, _randomGaussian, empty} | ||
|
||
/** | ||
* Random Number interface of mxnet. | ||
* @author Yuan Tang | ||
*/ | ||
object Random { | ||
/** | ||
* Generate uniform distribution in [low, high) with shape. | ||
* | ||
* @param low The lower bound of distribution. | ||
* @param high The upper bound of distribution. | ||
* @param shape Output shape of the NDArray generated. | ||
* @param ctx Context of output NDArray, will use default context if not specified. | ||
* @param out Output place holder | ||
* @return The result NDArray with generated result. | ||
*/ | ||
def uniform(low: Float, high: Float, shape: Array[Int]=null, ctx: Context=null, out: NDArray=null): NDArray = { | ||
var outCopy = out | ||
if (outCopy != null) { | ||
require(shape == null & ctx == null, "shape and ctx is not needed when out is specified.") | ||
} else { | ||
require(shape != null, "shape is required when out is not specified") | ||
outCopy = empty(shape, ctx) | ||
} | ||
return _randomUniform(low, high, outCopy) | ||
} | ||
|
||
|
||
/** | ||
* Generate normal(Gaussian) distribution N(mean, stdvar^^2) with shape. | ||
* | ||
* @param mean The mean of the normal distribution. | ||
* @param stdvar The standard deviation of normal distribution. | ||
* @param shape Output shape of the NDArray generated. | ||
* @param ctx Context of output NDArray, will use default context if not specified. | ||
* @param out Output place holder | ||
* @return The result NDArray with generated result. | ||
*/ | ||
def normal(mean: Float, stdvar: Float, shape: Array[Int]=null, ctx: Context=null, out: NDArray=null): NDArray = { | ||
var outCopy = out | ||
if (outCopy != null) { | ||
require(shape == null & ctx == null, "shape and ctx is not needed when out is specified.") | ||
} else { | ||
require(shape != null, "shape is required when out is not specified") | ||
outCopy = empty(shape, ctx) | ||
} | ||
return _randomGaussian(mean, stdvar, outCopy) | ||
} | ||
|
||
|
||
/** | ||
* Seed the random number generators in mxnet. | ||
* | ||
* This seed will affect behavior of functions in this module, | ||
* as well as results from executors that contains Random number | ||
* such as Dropout operators. | ||
* | ||
* @param seedState The random number seed to set to all devices. | ||
* @note The random number generator of mxnet is by default device specific. | ||
* This means if you set the same seed, the random number sequence | ||
* generated from GPU0 can be different from CPU. | ||
*/ | ||
def seed(seedState: Int) = { | ||
// TODO | ||
// checkCall(_LIB.mxRandomSeed(seedState)) | ||
} | ||
} |