Skip to content

Commit

Permalink
Merge pull request #105 from rfourquet/rf/MT-seed
Browse files Browse the repository at this point in the history
seed explicitly MersenneTwister(0)
  • Loading branch information
denizyuret authored Apr 2, 2017
2 parents 94d9e84 + 219524f commit 89c2026
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 20 deletions.
7 changes: 3 additions & 4 deletions deprecated/src7/deprecated/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ gnorm(m::Model,g=0)=(for p in params(m); g += vnorm(p.dif); end; g) #
item2xy(item)=(isa(item, Tuple) ? (item[1:end-1],item[end]) : item==nothing ? (nothing,nothing) : ((),item))

# So gradient checking does not mess up random seed:
const gradcheck_rng = MersenneTwister()
const gradcheck_rng = MersenneTwister(0)


# train(m, d; seq=false, a...)=(!seq ? train1(m,d;a...) : train2(m,d;a...))
Expand Down Expand Up @@ -48,7 +48,7 @@ const gradcheck_rng = MersenneTwister()
# end

# NO: make the model interface more functional:
# back and loss rely on hidden state info.
# back and loss rely on hidden state info.
# forw has to allocate.
# purely functional models are impossible.
# forw needs to compute intermediate values.
Expand All @@ -71,7 +71,7 @@ const gradcheck_rng = MersenneTwister()
# end

# Use test with percloss instead:
#
#
# function accuracy(m::Model, d) # TODO: this only works if y is a single item
# numcorr = numinst = 0
# z = nothing
Expand Down Expand Up @@ -108,4 +108,3 @@ const gradcheck_rng = MersenneTwister()
# end
# return sumloss/numloss
# end

4 changes: 2 additions & 2 deletions deprecated/src7/gradcheck.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
function gradcheck(m, grad, loss; gcheck=10, _eps=cbrt(eps(eltype(m))), delta=_eps, atol=_eps, rtol=_eps, o...)
# 6e-6 for Float64, 5e-3 for Float32 works best
# rnum = 42 # time_ns() #DBG
# isdefined(data,:rng) && (data_rng_save = data.rng; data.rng=MersenneTwister(); srand(data.rng,rnum))
# isdefined(m,:rng) && (m_rng_save = m.rng; m.rng=MersenneTwister(); srand(m.rng,rnum))
# isdefined(data,:rng) && (data_rng_save = data.rng; data.rng=MersenneTwister(0); srand(data.rng,rnum))
# isdefined(m,:rng) && (m_rng_save = m.rng; m.rng=MersenneTwister(0); srand(m.rng,rnum))
# l = zeros(2)
# train(m, data, loss; gcheck=true, losscnt=fill!(l,0), o...)
# loss0 = l[1]
Expand Down
16 changes: 7 additions & 9 deletions examples/deprecated/adding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ end

function gradloss(f, data, loss; grad=false, seed=42)
data_rng = data.rng
data.rng = MersenneTwister()
srand(data.rng, seed)
data.rng = MersenneTwister(seed)
reset!(f)
myforw = grad ? sforw : forw
loss1 = 0
Expand All @@ -93,7 +92,7 @@ import Base: start, next, done

type Data; len; batchsize; epochsize; batch; sum; cnt; rng;
Data(len, batchsize, epochsize; rng=Base.GLOBAL_RNG) =
new(len, batchsize, epochsize, zeros(Float32,2,batchsize),
new(len, batchsize, epochsize, zeros(Float32,2,batchsize),
zeros(Float32,1,batchsize), zeros(Int,1,batchsize), rng)
end

Expand All @@ -109,7 +108,7 @@ function next(a::Data, s)
fill!(sub(a.batch,2,:),0)
togo = a.len - t
for b=1:a.batchsize
if (a.cnt[b]==0 ? rand(a.rng) <= 2/togo :
if (a.cnt[b]==0 ? rand(a.rng) <= 2/togo :
a.cnt[b]==1 ? rand(a.rng) <= 1/togo : false)
a.batch[2,b] = 1
a.cnt[b] += 1
Expand Down Expand Up @@ -407,7 +406,7 @@ end #module
# return(xx, yy)
# end
# net0 = (args["type"] == "irnn" ? Net(irnn(nh),quadlosslayer(ny)) :
# args["type"] == "lstm" ? Net(lstm(nh),quadlosslayer(ny)) :
# args["type"] == "lstm" ? Net(lstm(nh),quadlosslayer(ny)) :
# error("Unknown network type "*args["type"]))
# setparam!(net; lr=args["lr"], gc=args["gc"]) # do a global gclip instead of per parameter
# "--test"
Expand Down Expand Up @@ -529,7 +528,7 @@ end #module

# OLD GENERATOR FOR COMPARISON:

# type Adding1; len; batchsize; epochsize; b; x; y;
# type Adding1; len; batchsize; epochsize; b; x; y;
# Adding1(len, batchsize, epochsize; o...)=new(len,batchsize,epochsize,Adding0(len,batchsize,epochsize; o...))
# end

Expand All @@ -548,7 +547,7 @@ end #module
# end

# type Adding0; len; batchsize; epochsize; rng;
# Adding0(len, batchsize, epochsize; rng=MersenneTwister())=new(len, batchsize, epochsize, rng)
# Adding0(len, batchsize, epochsize; rng=MersenneTwister(0))=new(len, batchsize, epochsize, rng)
# end

# start(a::Adding0)=0
Expand All @@ -574,7 +573,7 @@ end #module


# p1 = (opts["nettype"] == "irnn" ? Net(irnn; out=opts["hidden"], winit=Gaussian(0,opts["winit"])) :
# opts["nettype"] == "lstm" ? Net(lstm; out=opts["hidden"], fbias=opts["fbias"]) :
# opts["nettype"] == "lstm" ? Net(lstm; out=opts["hidden"], fbias=opts["fbias"]) :
# error("Unknown network type "*opts["nettype"]))
# p2 = Net(wb; out=1, winit=Gaussian(0,opts["winit"]))

Expand All @@ -585,4 +584,3 @@ end #module
# @knet function p2(x; o...)
# y = wb(x; o..., out=1)
# end

9 changes: 4 additions & 5 deletions examples/deprecated/mnistpixels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ end

function gradloss(f, data, loss; grad=false, seed=42)
data_rng = data.rng
data.rng = MersenneTwister()
srand(data.rng, seed)
data.rng = MersenneTwister(seed)
reset!(f)
myforw = grad ? sforw : forw
loss1 = 0
Expand Down Expand Up @@ -118,7 +117,7 @@ import Base: start, next, done
# the last pixel should be served as x(1,batch), y(10,batch)

type Pixels; x; y; rng; datasize; epochsize; batchsize; bootstrap; shuffle; xbatch; ybatch; images;
function Pixels(x, y; rng=MersenneTwister(), epoch=ccount(x), batch=16, bootstrap=false, shuffle=false)
function Pixels(x, y; rng=MersenneTwister(0), epoch=ccount(x), batch=16, bootstrap=false, shuffle=false)
nx = ccount(x)
nx == ccount(y) || error("Item count mismatch")
shuf = (shuffle ? shuffle!(rng,[1:nx;]) : nothing)
Expand Down Expand Up @@ -176,7 +175,7 @@ function parse_commandline(args)
"--lrate"
help = "learning rate"
arg_type = Float64
default = 0.005 # paper says 1e-8?
default = 0.005 # paper says 1e-8?
"--gclip"
help = "gradient clip"
arg_type = Float64
Expand Down Expand Up @@ -277,6 +276,6 @@ end # module

# S2C no longer accepts Net, it expects kfun:
# p1 = (nettype == "irnn" ? Net(irnn; out=hidden, winit=Gaussian(0,winit)) :
# nettype == "lstm" ? Net(lstm; out=hidden, fbias=fbias) :
# nettype == "lstm" ? Net(lstm; out=hidden, fbias=fbias) :
# error("Unknown network type "*nettype))
# p2 = Net(wbf; out=10, winit=Gaussian(0,winit), f=soft)

0 comments on commit 89c2026

Please sign in to comment.