diff --git a/DESCRIPTION b/DESCRIPTION index 210545bf..7045a1a3 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -25,7 +25,7 @@ License: Apache License (>= 2) Encoding: UTF-8 Language: en-GB Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.1 +RoxygenNote: 7.3.2 Depends: R (>= 4.1.0) Imports: diff --git a/NAMESPACE b/NAMESPACE index 62ec805d..e1ad0cb2 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -140,6 +140,7 @@ S3method(linkTTG,LongitudinalGSF) S3method(linkTTG,LongitudinalSteinFojo) S3method(linkTTG,PromiseLongitudinalModel) S3method(linkTTG,default) +S3method(median,Prior) S3method(names,LinkComponent) S3method(names,Parameter) S3method(names,ParameterList) @@ -157,6 +158,7 @@ S3method(sampleSubjects,SimLongitudinalGSF) S3method(sampleSubjects,SimLongitudinalRandomSlope) S3method(sampleSubjects,SimLongitudinalSteinFojo) S3method(sampleSubjects,SimSurvival) +S3method(set_limits,Prior) S3method(size,Parameter) S3method(size,ParameterList) S3method(subset,DataJoint) @@ -247,6 +249,7 @@ export(resolvePromise) export(sampleObservations) export(sampleStanModel) export(sampleSubjects) +export(set_limits) export(show) export(write_stan) exportClasses(DataJoint) @@ -295,6 +298,7 @@ importFrom(stats,.checkMFClasses) importFrom(stats,acf) importFrom(stats,as.formula) importFrom(stats,delete.response) +importFrom(stats,median) importFrom(stats,model.frame) importFrom(stats,model.matrix) importFrom(stats,rbeta) diff --git a/R/LongitudinalClaretBruno.R b/R/LongitudinalClaretBruno.R index c171dae4..1d633fde 100755 --- a/R/LongitudinalClaretBruno.R +++ b/R/LongitudinalClaretBruno.R @@ -63,6 +63,14 @@ LongitudinalClaretBruno <- function( centred = centred )) + # Apply constraints + omega_b <- set_limits(omega_b, lower = 0) + omega_g <- set_limits(omega_g, lower = 0) + omega_c <- set_limits(omega_c, lower = 0) + omega_p <- set_limits(omega_p, lower = 0) + sigma <- set_limits(sigma, lower = 0) + + parameters <- list( Parameter(name = "lm_clbr_mu_b", prior = mu_b, size = "n_studies"), Parameter(name = "lm_clbr_mu_g", prior = mu_g, size = "n_arms"), @@ -82,22 +90,22 @@ LongitudinalClaretBruno <- function( list( Parameter( name = "lm_clbr_ind_b", - prior = prior_init_only(prior_lognormal(mu_b@init, omega_b@init)), + prior = prior_init_only(prior_lognormal(median(mu_b), median(omega_b))), size = "n_subjects" ), Parameter( name = "lm_clbr_ind_g", - prior = prior_init_only(prior_lognormal(mu_g@init, omega_g@init)), + prior = prior_init_only(prior_lognormal(median(mu_g), median(omega_g))), size = "n_subjects" ), Parameter( name = "lm_clbr_ind_c", - prior = prior_init_only(prior_lognormal(mu_c@init, omega_c@init)), + prior = prior_init_only(prior_lognormal(median(mu_c), median(omega_c))), size = "n_subjects" ), Parameter( name = "lm_clbr_ind_p", - prior = prior_init_only(prior_lognormal(mu_p@init, omega_p@init)), + prior = prior_init_only(prior_lognormal(median(mu_p), median(omega_p))), size = "n_subjects" ) ) diff --git a/R/LongitudinalGSF.R b/R/LongitudinalGSF.R index 38d96cba..d468ce1a 100755 --- a/R/LongitudinalGSF.R +++ b/R/LongitudinalGSF.R @@ -66,6 +66,14 @@ LongitudinalGSF <- function( centred = centred )) + # Apply constraints + omega_bsld <- set_limits(omega_bsld, lower = 0) + omega_ks <- set_limits(omega_ks, lower = 0) + omega_kg <- set_limits(omega_kg, lower = 0) + omega_phi <- set_limits(omega_phi, lower = 0) + sigma <- set_limits(sigma, lower = 0) + + parameters <- list( Parameter(name = "lm_gsf_mu_bsld", prior = mu_bsld, size = "n_studies"), Parameter(name = "lm_gsf_mu_ks", prior = mu_ks, size = "n_arms"), @@ -85,22 +93,22 @@ LongitudinalGSF <- function( list( Parameter( name = "lm_gsf_psi_bsld", - prior = prior_init_only(prior_lognormal(mu_bsld@init, omega_bsld@init)), + prior = prior_init_only(prior_lognormal(median(mu_bsld), median(omega_bsld))), size = "n_subjects" ), Parameter( name = "lm_gsf_psi_ks", - prior = prior_init_only(prior_lognormal(mu_ks@init, omega_ks@init)), + prior = prior_init_only(prior_lognormal(median(mu_ks), median(omega_ks))), size = "n_subjects" ), Parameter( name = "lm_gsf_psi_kg", - prior = prior_init_only(prior_lognormal(mu_kg@init, omega_kg@init)), + prior = prior_init_only(prior_lognormal(median(mu_kg), median(omega_kg))), size = "n_subjects" ), Parameter( name = "lm_gsf_psi_phi_logit", - prior = prior_init_only(prior_normal(mu_phi@init, omega_phi@init)), + prior = prior_init_only(prior_normal(median(mu_phi), median(omega_phi))), size = "n_subjects" ) ) diff --git a/R/LongitudinalRandomSlope.R b/R/LongitudinalRandomSlope.R index 57b70a45..b431a661 100755 --- a/R/LongitudinalRandomSlope.R +++ b/R/LongitudinalRandomSlope.R @@ -31,7 +31,7 @@ NULL #' @export LongitudinalRandomSlope <- function( intercept = prior_normal(30, 10), - slope_mu = prior_normal(0, 15), + slope_mu = prior_normal(1, 3), slope_sigma = prior_lognormal(0, 1.5), sigma = prior_lognormal(0, 1.5) ) { @@ -40,6 +40,10 @@ LongitudinalRandomSlope <- function( x = "lm-random-slope/model.stan" ) + # Apply constriants + sigma <- set_limits(sigma, lower = 0) + slope_sigma <- set_limits(slope_sigma, lower = 0) + .LongitudinalRandomSlope( LongitudinalModel( name = "Random Slope", @@ -51,7 +55,7 @@ LongitudinalRandomSlope <- function( Parameter(name = "lm_rs_sigma", prior = sigma, size = 1), Parameter( name = "lm_rs_ind_rnd_slope", - prior = prior_init_only(prior_normal(slope_mu@init, slope_sigma@init)), + prior = prior_init_only(prior_normal(median(slope_mu), median(slope_sigma))), size = "n_subjects" ) ) diff --git a/R/LongitudinalSteinFojo.R b/R/LongitudinalSteinFojo.R index 34aa2919..73cf1cb5 100755 --- a/R/LongitudinalSteinFojo.R +++ b/R/LongitudinalSteinFojo.R @@ -59,6 +59,12 @@ LongitudinalSteinFojo <- function( centred = centred )) + # Apply constriants + omega_bsld <- set_limits(omega_bsld, lower = 0) + omega_ks <- set_limits(omega_ks, lower = 0) + omega_kg <- set_limits(omega_kg, lower = 0) + sigma <- set_limits(sigma, lower = 0) + parameters <- list( Parameter(name = "lm_sf_mu_bsld", prior = mu_bsld, size = "n_studies"), Parameter(name = "lm_sf_mu_ks", prior = mu_ks, size = "n_arms"), @@ -76,17 +82,17 @@ LongitudinalSteinFojo <- function( list( Parameter( name = "lm_sf_psi_bsld", - prior = prior_init_only(prior_lognormal(mu_bsld@init, omega_bsld@init)), + prior = prior_init_only(prior_lognormal(median(mu_bsld), median(omega_bsld))), size = "n_subjects" ), Parameter( name = "lm_sf_psi_ks", - prior = prior_init_only(prior_lognormal(mu_ks@init, omega_ks@init)), + prior = prior_init_only(prior_lognormal(median(mu_ks), median(omega_ks))), size = "n_subjects" ), Parameter( name = "lm_sf_psi_kg", - prior = prior_init_only(prior_lognormal(mu_kg@init, omega_kg@init)), + prior = prior_init_only(prior_lognormal(median(mu_kg), median(omega_kg))), size = "n_subjects" ) ) diff --git a/R/Prior.R b/R/Prior.R index c08ee946..a978a054 100755 --- a/R/Prior.R +++ b/R/Prior.R @@ -7,7 +7,8 @@ NULL #' The documentation lists all the conventional arguments for [`Prior`] #' constructors. #' -#' @param init (`number`)\cr initial value. +#' @param centre (`number`)\cr the central point of distribution to shrink sampled values towards +#' (for most distributions this is the mean or median if the mean is undefined) #' @param x ([`Prior`])\cr a prior Distribution #' @param object ([`Prior`])\cr a prior Distribution #' @param name (`character`)\cr the name of the parameter the prior distribution is for @@ -26,10 +27,11 @@ NULL #' @slot parameters (`list`)\cr See arguments. #' @slot repr_model (`string`)\cr See arguments. #' @slot repr_data (`string`)\cr See arguments. -#' @slot init (`numeric`)\cr See arguments. +#' @slot centre (`numeric`)\cr See arguments. #' @slot validation (`list`)\cr See arguments. #' @slot display (`string`)\cr See arguments. #' @slot sample (`function`)\cr See arguments. +#' @slot limits (`numeric`)\cr See arguments. #' #' @family Prior-internal #' @export Prior @@ -41,9 +43,10 @@ NULL "display" = "character", "repr_model" = "character", "repr_data" = "character", - "init" = "numeric", + "centre" = "numeric", "validation" = "list", - "sample" = "function" + "sample" = "function", + "limits" = "numeric" ) ) @@ -52,28 +55,31 @@ NULL #' @param repr_model (`string`)\cr the Stan code representation for the model block. #' @param repr_data (`string`)\cr the Stan code representation for the data block. #' @param display (`string`)\cr the string to display when object is printed. -#' @param init (`numeric`)\cr the initial value. +#' @param centre (`numeric`)\cr the central point of distribution to shrink sampled values towards #' @param validation (`list`)\cr the prior distribution parameter validation functions. Must have #' the same names as the `paramaters` slot. #' @param sample (`function`)\cr a function to sample from the prior distribution. +#' @param limits (`numeric`)\cr the lower and upper limits for a truncated distribution #' @rdname Prior-class Prior <- function( parameters, display, repr_model, repr_data, - init, + centre, validation, - sample + sample, + limits = c(-Inf, Inf) ) { .Prior( parameters = parameters, repr_model = repr_model, repr_data = repr_data, - init = init, + centre = centre, display = display, validation = validation, - sample = sample + sample = sample, + limits = limits ) } @@ -99,6 +105,15 @@ setValidity( ) + +#' @rdname set_limits +#' @export +set_limits.Prior <- function(object, lower = -Inf, upper = Inf) { + object@limits <- c(lower, upper) + return(object) +} + + #' `Prior` -> `Character` #' #' Converts a [`Prior`] object to a character vector @@ -188,8 +203,18 @@ NULL #' @describeIn Prior-Getter-Methods The prior's initial value #' @export initialValues.Prior <- function(object, ...) { - getOption("jmpost.prior_shrinkage") * object@init + - (1 - getOption("jmpost.prior_shrinkage")) * object@sample(1) + samples <- getOption("jmpost.prior_shrinkage") * object@centre + + (1 - getOption("jmpost.prior_shrinkage")) * object@sample(100) + + valid_samples <- samples[samples >= min(object@limits) & samples <= max(object@limits)] + assert_that( + length(valid_samples) >= 1, + msg = "Unable to generate an initial value that meets the required constraints" + ) + if (length(valid_samples) == 1) { + return(valid_samples) + } + return(sample(valid_samples, 1)) } @@ -210,7 +235,7 @@ prior_normal <- function(mu, sigma) { "real prior_mu_{name};", "real prior_sigma_{name};" ), - init = mu, + centre = mu, sample = \(n) local_rnorm(n, mu, sigma), validation = list( mu = is.numeric, @@ -231,7 +256,7 @@ prior_std_normal <- function() { display = "std_normal()", repr_model = "{name} ~ std_normal();", repr_data = "", - init = 0, + centre = 0, sample = \(n) local_rnorm(n), validation = list() ) @@ -253,7 +278,7 @@ prior_cauchy <- function(mu, sigma) { "real prior_mu_{name};", "real prior_sigma_{name};" ), - init = mu, + centre = mu, sample = \(n) local_rcauchy(n, mu, sigma), validation = list( mu = is.numeric, @@ -262,6 +287,7 @@ prior_cauchy <- function(mu, sigma) { ) } + #' Gamma Prior Distribution #' #' @param alpha (`number`)\cr shape. @@ -278,7 +304,7 @@ prior_gamma <- function(alpha, beta) { "real prior_alpha_{name};", "real prior_beta_{name};" ), - init = alpha / beta, + centre = alpha / beta, sample = \(n) local_rgamma(n, shape = alpha, rate = beta), validation = list( alpha = \(x) x > 0, @@ -303,7 +329,7 @@ prior_lognormal <- function(mu, sigma) { "real prior_mu_{name};", "real prior_sigma_{name};" ), - init = exp(mu + (sigma^2) / 2), + centre = exp(mu + (sigma^2) / 2), sample = \(n) local_rlnorm(n, mu, sigma), validation = list( mu = is.numeric, @@ -328,7 +354,7 @@ prior_beta <- function(a, b) { "real prior_a_{name};", "real prior_b_{name};" ), - init = a / (a + b), + centre = a / (a + b), sample = \(n) local_rbeta(n, a, b), validation = list( a = \(x) x > 0, @@ -356,7 +382,7 @@ prior_init_only <- function(dist) { sample = \(n) { dist@sample(n) }, - init = dist@init, + centre = dist@centre, validation = list() ) } @@ -384,7 +410,7 @@ prior_uniform <- function(alpha, beta) { "real prior_alpha_{name};", "real prior_beta_{name};" ), - init = 0.5 * (alpha + beta), + centre = 0.5 * (alpha + beta), sample = \(n) local_runif(n, alpha, beta), validation = list( alpha = is.numeric, @@ -416,7 +442,7 @@ prior_student_t <- function(nu, mu, sigma) { "real prior_mu_{name};", "real prior_sigma_{name};" ), - init = mu, + centre = mu, sample = \(n) local_rt(n, nu, mu, sigma), validation = list( nu = \(x) x > 0, @@ -447,7 +473,7 @@ prior_logistic <- function(mu, sigma) { "real prior_mu_{name};", "real prior_sigma_{name};" ), - init = mu, + centre = mu, sample = \(n) local_rlogis(n, mu, sigma), validation = list( mu = is.numeric, @@ -476,7 +502,7 @@ prior_loglogistic <- function(alpha, beta) { "real prior_alpha_{name};", "real prior_beta_{name};" ), - init = alpha * pi / (beta * sin(pi / beta)), + centre = alpha * pi / (beta * sin(pi / beta)), sample = \(n) { local_rloglogis(n, alpha, beta) }, @@ -507,7 +533,7 @@ prior_invgamma <- function(alpha, beta) { "real prior_alpha_{name};", "real prior_beta_{name};" ), - init = beta / (alpha - 1), + centre = beta / (alpha - 1), sample = \(n) local_rinvgamma(n, alpha, beta), validation = list( alpha = \(x) x > 0, @@ -517,6 +543,38 @@ prior_invgamma <- function(alpha, beta) { } +# nolint start +# +# Developer Notes +# +# The `median.Prior` function is a rough workaround to help generate initial values for +# hierarchical distributions. The original implementation involved sampling initial values +# for the random effects using the medians of the parent distribution e.g. +# ``` +# random_effect ~ beta(a_prior@centre, b_prior@centre) +# ``` +# A problem came up though when we implemented support for constrained distributions +# as there was no longer any guarantee that the median/centre of the distribution is +# a valid value e.g. `a_prior ~ prior_normal(-200, 400)`. +# +# To resolve this issue the `median.Prior` method was created which simply samples +# multiple observations from the constrained distribution and then takes the median +# of those constrained observations; this then ensures that the value being used +# for the parameters is a valid value +# +# nolint end +#' @importFrom stats median +#' @export +median.Prior <- function(x, na.rm, ...) { + vals <- replicate( + n = 500, + initialValues(x), + simplify = FALSE + ) |> + unlist() + median(vals) +} + diff --git a/R/generics.R b/R/generics.R index cb5082ed..56879d0e 100755 --- a/R/generics.R +++ b/R/generics.R @@ -470,3 +470,18 @@ as_formula <- function(x, ...) { as_formula.default <- function(x, ...) { as.formula(x, ...) } + + +#' Set Constraints +#' +#' Applies constraints to a prior distribution to ensure any sampled numbers +#' from the distribution fall within the constraints +#' +#' @param object (`Prior`)\cr a prior distribution to apply constraints to +#' @param lower (`numeric`)\cr lower constraint boundary +#' @param upper (`numeric`)\cr upper constraint boundary +#' +#' @export +set_limits <- function(object, lower = -Inf, upper = Inf) { + UseMethod("set_limits") +} diff --git a/_pkgdown.yml b/_pkgdown.yml index f22443b4..af180ac5 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -174,6 +174,7 @@ reference: - enableGQ - as_formula - getPredictionNames + - set_limits - title: Promises contents: diff --git a/man/LongitudinalRandomSlope-class.Rd b/man/LongitudinalRandomSlope-class.Rd index 915e1406..a40277a2 100644 --- a/man/LongitudinalRandomSlope-class.Rd +++ b/man/LongitudinalRandomSlope-class.Rd @@ -9,7 +9,7 @@ \usage{ LongitudinalRandomSlope( intercept = prior_normal(30, 10), - slope_mu = prior_normal(0, 15), + slope_mu = prior_normal(1, 3), slope_sigma = prior_lognormal(0, 1.5), sigma = prior_lognormal(0, 1.5) ) diff --git a/man/Prior-Shared.Rd b/man/Prior-Shared.Rd index 2290370e..6f4cb023 100644 --- a/man/Prior-Shared.Rd +++ b/man/Prior-Shared.Rd @@ -4,7 +4,8 @@ \alias{Prior-Shared} \title{\code{Prior} Function Arguments} \arguments{ -\item{init}{(\code{number})\cr initial value.} +\item{centre}{(\code{number})\cr the central point of distribution to shrink sampled values towards +(for most distributions this is the mean or median if the mean is undefined)} \item{x}{(\code{\link{Prior}})\cr a prior Distribution} diff --git a/man/Prior-class.Rd b/man/Prior-class.Rd index d3e8a522..f666c2d6 100644 --- a/man/Prior-class.Rd +++ b/man/Prior-class.Rd @@ -7,7 +7,16 @@ \alias{Prior} \title{Prior Object and Constructor Function} \usage{ -Prior(parameters, display, repr_model, repr_data, init, validation, sample) +Prior( + parameters, + display, + repr_model, + repr_data, + centre, + validation, + sample, + limits = c(-Inf, Inf) +) } \arguments{ \item{parameters}{(\code{list})\cr the prior distribution parameters.} @@ -18,12 +27,14 @@ Prior(parameters, display, repr_model, repr_data, init, validation, sample) \item{repr_data}{(\code{string})\cr the Stan code representation for the data block.} -\item{init}{(\code{numeric})\cr the initial value.} +\item{centre}{(\code{numeric})\cr the central point of distribution to shrink sampled values towards} \item{validation}{(\code{list})\cr the prior distribution parameter validation functions. Must have the same names as the \code{paramaters} slot.} \item{sample}{(\code{function})\cr a function to sample from the prior distribution.} + +\item{limits}{(\code{numeric})\cr the lower and upper limits for a truncated distribution} } \description{ Specifies the prior distribution in a Stan Model @@ -37,13 +48,15 @@ Specifies the prior distribution in a Stan Model \item{\code{repr_data}}{(\code{string})\cr See arguments.} -\item{\code{init}}{(\code{numeric})\cr See arguments.} +\item{\code{centre}}{(\code{numeric})\cr See arguments.} \item{\code{validation}}{(\code{list})\cr See arguments.} \item{\code{display}}{(\code{string})\cr See arguments.} \item{\code{sample}}{(\code{function})\cr See arguments.} + +\item{\code{limits}}{(\code{numeric})\cr See arguments.} }} \seealso{ diff --git a/man/set_limits.Rd b/man/set_limits.Rd new file mode 100644 index 00000000..ce05f7bd --- /dev/null +++ b/man/set_limits.Rd @@ -0,0 +1,22 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/generics.R, R/Prior.R +\name{set_limits} +\alias{set_limits} +\alias{set_limits.Prior} +\title{Set Constraints} +\usage{ +set_limits(object, lower = -Inf, upper = Inf) + +\method{set_limits}{Prior}(object, lower = -Inf, upper = Inf) +} +\arguments{ +\item{object}{(\code{Prior})\cr a prior distribution to apply constraints to} + +\item{lower}{(\code{numeric})\cr lower constraint boundary} + +\item{upper}{(\code{numeric})\cr upper constraint boundary} +} +\description{ +Applies constraints to a prior distribution to ensure any sampled numbers +from the distribution fall within the constraints +} diff --git a/tests/testthat/_snaps/JointModel.md b/tests/testthat/_snaps/JointModel.md index 6323643e..6c737856 100644 --- a/tests/testthat/_snaps/JointModel.md +++ b/tests/testthat/_snaps/JointModel.md @@ -17,7 +17,7 @@ Longitudinal: Random Slope Longitudinal Model with parameters: lm_rs_intercept ~ normal(mu = 30, sigma = 10) - lm_rs_slope_mu ~ normal(mu = 0, sigma = 15) + lm_rs_slope_mu ~ normal(mu = 1, sigma = 3) lm_rs_slope_sigma ~ lognormal(mu = 0, sigma = 1.5) lm_rs_sigma ~ lognormal(mu = 0, sigma = 1.5) lm_rs_ind_rnd_slope ~ @@ -45,7 +45,7 @@ Longitudinal: Random Slope Longitudinal Model with parameters: lm_rs_intercept ~ normal(mu = 30, sigma = 10) - lm_rs_slope_mu ~ normal(mu = 0, sigma = 15) + lm_rs_slope_mu ~ normal(mu = 1, sigma = 3) lm_rs_slope_sigma ~ lognormal(mu = 0, sigma = 1.5) lm_rs_sigma ~ lognormal(mu = 0, sigma = 1.5) lm_rs_ind_rnd_slope ~ @@ -126,7 +126,7 @@ Longitudinal: Random Slope Longitudinal Model with parameters: lm_rs_intercept ~ normal(mu = 30, sigma = 10) - lm_rs_slope_mu ~ normal(mu = 0, sigma = 15) + lm_rs_slope_mu ~ normal(mu = 1, sigma = 3) lm_rs_slope_sigma ~ lognormal(mu = 0, sigma = 1.5) lm_rs_sigma ~ lognormal(mu = 0, sigma = 1.5) lm_rs_ind_rnd_slope ~ diff --git a/tests/testthat/_snaps/LongitudinalRandomSlope.md b/tests/testthat/_snaps/LongitudinalRandomSlope.md index 49563b6e..997081e9 100644 --- a/tests/testthat/_snaps/LongitudinalRandomSlope.md +++ b/tests/testthat/_snaps/LongitudinalRandomSlope.md @@ -7,7 +7,7 @@ Random Slope Longitudinal Model with parameters: lm_rs_intercept ~ normal(mu = 30, sigma = 10) - lm_rs_slope_mu ~ normal(mu = 0, sigma = 15) + lm_rs_slope_mu ~ normal(mu = 1, sigma = 3) lm_rs_slope_sigma ~ lognormal(mu = 0, sigma = 1.5) lm_rs_sigma ~ lognormal(mu = 0, sigma = 1.5) lm_rs_ind_rnd_slope ~ @@ -23,7 +23,7 @@ Random Slope Longitudinal Model with parameters: lm_rs_intercept ~ normal(mu = 0, sigma = 1) - lm_rs_slope_mu ~ normal(mu = 0, sigma = 15) + lm_rs_slope_mu ~ normal(mu = 1, sigma = 3) lm_rs_slope_sigma ~ lognormal(mu = 0, sigma = 1.5) lm_rs_sigma ~ gamma(alpha = 2, beta = 1) lm_rs_ind_rnd_slope ~ diff --git a/tests/testthat/helper-example_data.R b/tests/testthat/helper-example_data.R index d87ac619..3b28bcb8 100644 --- a/tests/testthat/helper-example_data.R +++ b/tests/testthat/helper-example_data.R @@ -1,12 +1,6 @@ -test_data_1 <- new.env() - ensure_test_data_1 <- function() { - if (!is.null(test_data_1$jsamples)) { - return(invisible(test_data_1)) - } - set.seed(739) simjdat <- SimJointData( design = list( @@ -74,10 +68,12 @@ ensure_test_data_1 <- function() { }) - test_data_1$dat_os <- dat_os - test_data_1$dat_lm <- dat_lm - test_data_1$jmodel <- jm - test_data_1$jdata <- jdat - test_data_1$jsamples <- mp - return(invisible(test_data_1)) + results <- list( + dat_os = dat_os, + dat_lm = dat_lm, + jmodel = jm, + jdata = jdat, + jsamples = mp + ) + return(invisible(results)) } diff --git a/tests/testthat/test-JointModelSamples.R b/tests/testthat/test-JointModelSamples.R index ff183e41..886d2100 100644 --- a/tests/testthat/test-JointModelSamples.R +++ b/tests/testthat/test-JointModelSamples.R @@ -1,7 +1,7 @@ +test_data_1 <- ensure_test_data_1() test_that("print works as expected for JointModelSamples", { - ensure_test_data_1() expect_snapshot({ print(test_data_1$jsamples) }) diff --git a/tests/testthat/test-LongitudinalClaretBruno.R b/tests/testthat/test-LongitudinalClaretBruno.R index de2b35b8..0241a863 100644 --- a/tests/testthat/test-LongitudinalClaretBruno.R +++ b/tests/testthat/test-LongitudinalClaretBruno.R @@ -295,3 +295,53 @@ test_that("Quantity models pass the parser", { ) expect_stan_syntax(stanmod) }) + + +test_that("Can generate valid initial values", { + + pars <- c( + "lm_clbr_omega_b", "lm_clbr_omega_g", "lm_clbr_omega_c", + "lm_clbr_omega_p", "lm_clbr_sigma", "lm_gsf_sigma" + ) + + # Defaults work as expected + mod <- LongitudinalClaretBruno() + vals <- initialValues(mod, n_chains = 1) + vals <- vals[names(vals) %in% pars] + expect_true(all(vals > 0)) + + + # Test all individual parameters throw error if given prior that can't sample + # valid value + args <- list( + omega_b = prior_normal(-200, 1), + omega_g = prior_normal(-200, 1), + omega_c = prior_normal(-200, 1), + omega_p = prior_normal(-200, 1), + sigma = prior_normal(-200, 1) + ) + for (n_arg in names(args)) { + arg <- args[n_arg] + expect_error( + { + mod <- do.call(LongitudinalClaretBruno, arg) + initialValues(mod, n_chains = 1) + }, + regexp = "Unable to generate" + ) + } + + # Test initial values can be found for weird priors that do overlap the valid region + mod <- LongitudinalClaretBruno( + omega_b = prior_normal(-200, 400), + omega_g = prior_gamma(2, 5), + omega_c = prior_uniform(-200, 400), + omega_p = prior_lognormal(-200, 2), + sigma = prior_cauchy(-200, 400) + ) + set.seed(1001) + vals <- unlist(initialValues(mod, n_chains = 200)) + vals <- vals[names(vals) %in% pars] + expect_true(all(vals > 0)) + +}) diff --git a/tests/testthat/test-LongitudinalGSF.R b/tests/testthat/test-LongitudinalGSF.R index 490016f8..73a94f71 100644 --- a/tests/testthat/test-LongitudinalGSF.R +++ b/tests/testthat/test-LongitudinalGSF.R @@ -220,3 +220,53 @@ test_that("Quantity models pass the parser", { ) expect_stan_syntax(stanmod) }) + + +test_that("Can generate valid initial values", { + + pars <- c( + "lm_gsf_omega_bsld", "lm_gsf_omega_ks", "lm_gsf_omega_kg", + "lm_gsf_a_phi", "lm_gsf_b_phi", "lm_gsf_sigma" + ) + + # Defaults work as expected + mod <- LongitudinalGSF() + vals <- initialValues(mod, n_chains = 1) + vals <- vals[names(vals) %in% pars] + expect_true(all(vals > 0)) + + + # Test all individual parameters throw error if given prior that can't sample + # valid value + args <- list( + omega_bsld = prior_normal(-200, 1), + omega_ks = prior_normal(-200, 1), + omega_kg = prior_normal(-200, 1), + omega_phi = prior_normal(-200, 1), + sigma = prior_normal(-200, 1) + ) + for (n_arg in names(args)) { + arg <- args[n_arg] + expect_error( + { + mod <- do.call(LongitudinalGSF, arg) + initialValues(mod, n_chains = 1) + }, + regexp = "Unable to generate" + ) + } + + # Test initial values can be found for weird priors that do overlap the valid region + mod <- LongitudinalGSF( + omega_bsld = prior_normal(-200, 400), + omega_ks = prior_gamma(2, 5), + omega_kg = prior_uniform(-200, 400), + omega_phi = prior_lognormal(-200, 2), + sigma = prior_cauchy(-200, 400) + ) + set.seed(1001) + vals <- unlist(initialValues(mod, n_chains = 200)) + vals <- vals[names(vals) %in% pars] + expect_true(all(vals > 0)) + +}) diff --git a/tests/testthat/test-LongitudinalQuantiles.R b/tests/testthat/test-LongitudinalQuantiles.R index b2af0755..2b744e20 100644 --- a/tests/testthat/test-LongitudinalQuantiles.R +++ b/tests/testthat/test-LongitudinalQuantiles.R @@ -1,7 +1,7 @@ +test_data_1 <- ensure_test_data_1() test_that("Test that LongitudinalQuantities works as expected", { - ensure_test_data_1() expected_column_names <- c("group", "time", "median", "lower", "upper") @@ -82,7 +82,7 @@ test_that("autoplot.LongitudinalQuantities works as expected", { test_that("LongitudinalQuantities print method works as expected", { - ensure_test_data_1() + expect_snapshot({ subjectgroups <- c("subject_0011", "subject_0061", "subject_0001", "subject_0002") times <- seq(0, 100, by = 10) @@ -118,7 +118,6 @@ test_that("LongitudinalQuantities print method works as expected", { test_that("LongitudinalQuantities can recover known results", { set.seed(101) - ensure_test_data_1() longsamps <- LongitudinalQuantities( test_data_1$jsamples, grid = GridFixed( @@ -141,7 +140,6 @@ test_that("LongitudinalQuantities can recover known results", { test_that("LongitudinalQuantities correctly subsets subjects and rebuilds correct value for each sample", { set.seed(101) - ensure_test_data_1() times <- c(-100, 0, 1, 100, 200) longsamps <- LongitudinalQuantities( diff --git a/tests/testthat/test-LongitudinalRandomSlope.R b/tests/testthat/test-LongitudinalRandomSlope.R index 4afbb6ff..015ccbfd 100644 --- a/tests/testthat/test-LongitudinalRandomSlope.R +++ b/tests/testthat/test-LongitudinalRandomSlope.R @@ -226,6 +226,7 @@ test_that("Random Slope Model left-censoring works as expected", { jm <- JointModel( longitudinal = LongitudinalRandomSlope( intercept = prior_normal(30, 2), + slope_mu = prior_normal(1, 2), slope_sigma = prior_lognormal(log(0.2), sigma = 0.5), sigma = prior_lognormal(log(3), sigma = 0.5) ) @@ -317,3 +318,47 @@ test_that("Quantity models pass the parser", { ) expect_stan_syntax(stanmod) }) + + + +test_that("Can generate valid initial values", { + + pars <- c( + "lm_rs_slope_sigma", "lm_rs_sigma" + ) + + # Defaults work as expected + mod <- LongitudinalRandomSlope() + vals <- initialValues(mod, n_chains = 1) + vals <- vals[names(vals) %in% pars] + expect_true(all(vals > 0)) + + + # Test all individual parameters throw error if given prior that can't sample + # valid value + args <- list( + slope_sigma = prior_normal(-200, 1), + sigma = prior_normal(-200, 1) + ) + for (n_arg in names(args)) { + arg <- args[n_arg] + expect_error( + { + mod <- do.call(LongitudinalRandomSlope, arg) + initialValues(mod, n_chains = 1) + }, + regexp = "Unable to generate" + ) + } + + # Test initial values can be found for weird priors that do overlap the valid region + mod <- LongitudinalRandomSlope( + slope_sigma = prior_normal(-200, 400), + sigma = prior_normal(-200, 400) + ) + set.seed(1001) + vals <- unlist(initialValues(mod, n_chains = 200)) + vals <- vals[names(vals) %in% pars] + expect_true(all(vals > 0)) + +}) diff --git a/tests/testthat/test-LongitudinalSteinFojo.R b/tests/testthat/test-LongitudinalSteinFojo.R index 03d42b09..17c9b21e 100644 --- a/tests/testthat/test-LongitudinalSteinFojo.R +++ b/tests/testthat/test-LongitudinalSteinFojo.R @@ -411,3 +411,52 @@ test_that("Quantity models pass the parser", { ) expect_stan_syntax(stanmod) }) + + + +test_that("Can generate valid initial values", { + + pars <- c( + "lm_gsf_omega_bsld", "lm_gsf_omega_ks", "lm_gsf_omega_kg", + "lm_gsf_sigma" + ) + + # Defaults work as expected + mod <- LongitudinalSteinFojo() + vals <- initialValues(mod, n_chains = 1) + vals <- vals[names(vals) %in% pars] + expect_true(all(vals > 0)) + + + # Test all individual parameters throw error if given prior that can't sample + # valid value + args <- list( + omega_bsld = prior_normal(-200, 1), + omega_ks = prior_normal(-200, 1), + omega_kg = prior_normal(-200, 1), + sigma = prior_normal(-200, 1) + ) + for (n_arg in names(args)) { + arg <- args[n_arg] + expect_error( + { + mod <- do.call(LongitudinalSteinFojo, arg) + initialValues(mod, n_chains = 1) + }, + regexp = "Unable to generate" + ) + } + + # Test initial values can be found for weird priors that do overlap the valid region + mod <- LongitudinalSteinFojo( + omega_bsld = prior_normal(-200, 400), + omega_ks = prior_gamma(2, 5), + omega_kg = prior_uniform(-200, 400), + sigma = prior_cauchy(-200, 400) + ) + set.seed(1001) + vals <- unlist(initialValues(mod, n_chains = 200)) + vals <- vals[names(vals) %in% pars] + expect_true(all(vals > 0)) + +}) diff --git a/tests/testthat/test-Prior.R b/tests/testthat/test-Prior.R index 3e434193..b48d3701 100644 --- a/tests/testthat/test-Prior.R +++ b/tests/testthat/test-Prior.R @@ -186,3 +186,58 @@ test_that("jmpost.prior_shrinkage works as expected", { local_rnorm = \(...) 4 ) }) + + + +test_that("Limits work as expected", { + x <- prior_normal(0, 1) + x <- set_limits(x, lower = 0, upper = 1) + ivs <- replicate( + n = 100, + initialValues(x) + ) + expect_true(all(ivs > 0)) + expect_true(all(ivs < 1)) + + + x <- prior_cauchy(-200, 150) + x <- set_limits(x, lower = 0) + ivs <- replicate( + n = 100, + initialValues(x) + ) + expect_true(all(ivs > 0)) + + + ## Put an impossible constraint on the distribution + x <- prior_lognormal(0, 1) + x <- set_limits(x, upper = 0) + expect_error(initialValues(x), regex = "Unable to generate") +}) + + + +test_that("median(Prior) works as expected", { + set.seed(2410) + + # Unrestricted + p1 <- prior_normal(-200, 400) + expect_equal( + median(p1), + -200, + tolerance = 0.15 + ) + + + # Constrained + p2 <- set_limits(p1, lower = 0) + + actual <- rnorm(6000, -200, 400) * 0.5 + -200 * 0.5 + actual_red <- actual[actual >= 0] + + expect_equal( + median(p2), + median(actual_red), + tolerance = 0.15 + ) +}) diff --git a/tests/testthat/test-SurvivalQuantities.R b/tests/testthat/test-SurvivalQuantities.R index d6618a10..e8943e42 100644 --- a/tests/testthat/test-SurvivalQuantities.R +++ b/tests/testthat/test-SurvivalQuantities.R @@ -1,8 +1,7 @@ - +test_data_1 <- ensure_test_data_1() test_that("SurvivalQuantities and autoplot.SurvivalQuantities works as expected", { - ensure_test_data_1() expected_column_names <- c("group", "time", "median", "lower", "upper") @@ -173,7 +172,6 @@ test_that("SurvivalQuantities print method works as expected", { test_that("SurvivalQuantities() works with time = 0", { - ensure_test_data_1() expect_error( { diff --git a/tests/testthat/test-brierScore.R b/tests/testthat/test-brierScore.R index ea67a31f..15dff624 100644 --- a/tests/testthat/test-brierScore.R +++ b/tests/testthat/test-brierScore.R @@ -1,4 +1,7 @@ +test_data_1 <- ensure_test_data_1() + + test_that("brierScore(SurvivalQuantities) returns same results as survreg", { # @@ -11,7 +14,6 @@ test_that("brierScore(SurvivalQuantities) returns same results as survreg", { # related to extracting the predicted values & the time & event data are working # as expected. # - ensure_test_data_1() dat_os <- test_data_1$dat_os mp <- test_data_1$jsamples diff --git a/tests/testthat/test-misc_models.R b/tests/testthat/test-misc_models.R index 9e1a2b10..ffcfd8f5 100644 --- a/tests/testthat/test-misc_models.R +++ b/tests/testthat/test-misc_models.R @@ -1,10 +1,10 @@ +test_data_1 <- ensure_test_data_1() test_that("Longitudinal Model doesn't print sampler rejection messages", { # These rejections typically happen when the sampler samples a # 0 value for the variance parameter. Sensible initial values + # setting near 0 limits (as opposed to 0) should avoid this - ensure_test_data_1() mp <- capture_messages({ devnull_out <- capture.output({ diff --git a/vignettes/model_fitting.Rmd b/vignettes/model_fitting.Rmd index c1d8f5df..b2ceafb5 100644 --- a/vignettes/model_fitting.Rmd +++ b/vignettes/model_fitting.Rmd @@ -462,3 +462,6 @@ in a vector of the same length as the number of covariates. That is, if you are using a `prior_cauchy(0, 1)` prior for a parameter that should be `>0` then you will need to manually set the initial value (as described above) to ensure that it is a valid value. +- For constrained parameters (e.g. variance parameters that must be $> 0$) `initialValues()` +will continuously sample and discard initial values until it generates one that meet the constraints. +If after 100 attempts no valid initial value has been found then it will throw an error.