diff --git a/R/inference-tensorflow.R b/R/inference-tensorflow.R index b19f7e6..c93732d 100644 --- a/R/inference-tensorflow.R +++ b/R/inference-tensorflow.R @@ -91,11 +91,17 @@ inference_tensorflow <- function(Y, tf$constant(Inf, dtype = tf$float64)) }) - beta <- tf$Variable(tf$random_normal(shape(G,P), - mean = 0, - stddev = 1, - seed = random_seed, - dtype = tf$float64), + # beta <- tf$Variable(tf$random_normal(shape(G,P), + # mean = 0, + # stddev = 1, + # seed = random_seed, + # dtype = tf$float64), + # dtype = tf$float64) + + beta_0_init <- scale(colMeans(Y)) + beta_init <- cbind(beta_0_init, + matrix(0, nrow = G, ncol = P-1)) + beta <- tf$Variable(tf$constant(beta_init, dtype = tf$float64), dtype = tf$float64) theta_logit <- tf$Variable(tf$random_normal(shape(C),