From 4b9795191a62aedc142f10a217f4f0bc11c81196 Mon Sep 17 00:00:00 2001 From: gowerc Date: Thu, 27 Jun 2024 17:06:48 +0100 Subject: [PATCH 1/8] Initial --- NAMESPACE | 4 + R/LongitudinalClaretBruno.R | 16 ++- R/LongitudinalGSF.R | 18 ++- R/LongitudinalRandomSlope.R | 8 +- R/LongitudinalSteinFojo.R | 12 +- R/Prior.R | 103 ++++++++++++++---- R/generics.R | 15 +++ man/Prior-Shared.Rd | 2 +- man/Prior-class.Rd | 19 +++- man/set_limits.Rd | 20 ++++ tests/testthat/_snaps/JointModel.md | 6 +- .../_snaps/LongitudinalRandomSlope.md | 4 +- tests/testthat/helper-example_data.R | 20 ++-- tests/testthat/test-JointModelSamples.R | 2 +- tests/testthat/test-LongitudinalClaretBruno.R | 50 +++++++++ tests/testthat/test-LongitudinalGSF.R | 52 +++++++++ tests/testthat/test-LongitudinalQuantiles.R | 6 +- tests/testthat/test-LongitudinalRandomSlope.R | 45 ++++++++ tests/testthat/test-LongitudinalSteinFojo.R | 49 +++++++++ tests/testthat/test-Prior.R | 28 +++++ tests/testthat/test-SurvivalQuantities.R | 4 +- tests/testthat/test-brierScore.R | 4 +- tests/testthat/test-misc_models.R | 2 +- vignettes/model_fitting.Rmd | 3 + 24 files changed, 425 insertions(+), 67 deletions(-) create mode 100644 man/set_limits.Rd diff --git a/NAMESPACE b/NAMESPACE index 62ec805d..e1ad0cb2 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -140,6 +140,7 @@ S3method(linkTTG,LongitudinalGSF) S3method(linkTTG,LongitudinalSteinFojo) S3method(linkTTG,PromiseLongitudinalModel) S3method(linkTTG,default) +S3method(median,Prior) S3method(names,LinkComponent) S3method(names,Parameter) S3method(names,ParameterList) @@ -157,6 +158,7 @@ S3method(sampleSubjects,SimLongitudinalGSF) S3method(sampleSubjects,SimLongitudinalRandomSlope) S3method(sampleSubjects,SimLongitudinalSteinFojo) S3method(sampleSubjects,SimSurvival) +S3method(set_limits,Prior) S3method(size,Parameter) S3method(size,ParameterList) S3method(subset,DataJoint) @@ -247,6 +249,7 @@ export(resolvePromise) export(sampleObservations) export(sampleStanModel) export(sampleSubjects) +export(set_limits) export(show) export(write_stan) exportClasses(DataJoint) @@ -295,6 +298,7 @@ importFrom(stats,.checkMFClasses) importFrom(stats,acf) importFrom(stats,as.formula) importFrom(stats,delete.response) +importFrom(stats,median) importFrom(stats,model.frame) importFrom(stats,model.matrix) importFrom(stats,rbeta) diff --git a/R/LongitudinalClaretBruno.R b/R/LongitudinalClaretBruno.R index c171dae4..1d633fde 100755 --- a/R/LongitudinalClaretBruno.R +++ b/R/LongitudinalClaretBruno.R @@ -63,6 +63,14 @@ LongitudinalClaretBruno <- function( centred = centred )) + # Apply constraints + omega_b <- set_limits(omega_b, lower = 0) + omega_g <- set_limits(omega_g, lower = 0) + omega_c <- set_limits(omega_c, lower = 0) + omega_p <- set_limits(omega_p, lower = 0) + sigma <- set_limits(sigma, lower = 0) + + parameters <- list( Parameter(name = "lm_clbr_mu_b", prior = mu_b, size = "n_studies"), Parameter(name = "lm_clbr_mu_g", prior = mu_g, size = "n_arms"), @@ -82,22 +90,22 @@ LongitudinalClaretBruno <- function( list( Parameter( name = "lm_clbr_ind_b", - prior = prior_init_only(prior_lognormal(mu_b@init, omega_b@init)), + prior = prior_init_only(prior_lognormal(median(mu_b), median(omega_b))), size = "n_subjects" ), Parameter( name = "lm_clbr_ind_g", - prior = prior_init_only(prior_lognormal(mu_g@init, omega_g@init)), + prior = prior_init_only(prior_lognormal(median(mu_g), median(omega_g))), size = "n_subjects" ), Parameter( name = "lm_clbr_ind_c", - prior = prior_init_only(prior_lognormal(mu_c@init, omega_c@init)), + prior = prior_init_only(prior_lognormal(median(mu_c), median(omega_c))), size = "n_subjects" ), Parameter( name = "lm_clbr_ind_p", - prior = prior_init_only(prior_lognormal(mu_p@init, omega_p@init)), + prior = prior_init_only(prior_lognormal(median(mu_p), median(omega_p))), size = "n_subjects" ) ) diff --git a/R/LongitudinalGSF.R b/R/LongitudinalGSF.R index ced7fa0a..ea028929 100755 --- a/R/LongitudinalGSF.R +++ b/R/LongitudinalGSF.R @@ -67,6 +67,15 @@ LongitudinalGSF <- function( centred = centred )) + # Apply constraints + omega_bsld <- set_limits(omega_bsld, lower = 0) + omega_ks <- set_limits(omega_ks, lower = 0) + omega_kg <- set_limits(omega_kg, lower = 0) + a_phi <- set_limits(a_phi, lower = 0) + b_phi <- set_limits(b_phi, lower = 0) + sigma <- set_limits(sigma, lower = 0) + + parameters <- list( Parameter(name = "lm_gsf_mu_bsld", prior = mu_bsld, size = "n_studies"), Parameter(name = "lm_gsf_mu_ks", prior = mu_ks, size = "n_arms"), @@ -78,9 +87,10 @@ LongitudinalGSF <- function( Parameter(name = "lm_gsf_a_phi", prior = a_phi, size = "n_arms"), Parameter(name = "lm_gsf_b_phi", prior = b_phi, size = "n_arms"), + Parameter( name = "lm_gsf_psi_phi", - prior = prior_init_only(prior_beta(a_phi@init, b_phi@init)), + prior = prior_init_only(prior_beta(median(a_phi), median(b_phi))), size = "n_subjects" ), @@ -92,17 +102,17 @@ LongitudinalGSF <- function( list( Parameter( name = "lm_gsf_psi_bsld", - prior = prior_init_only(prior_lognormal(mu_bsld@init, omega_bsld@init)), + prior = prior_init_only(prior_lognormal(median(mu_bsld), median(omega_bsld))), size = "n_subjects" ), Parameter( name = "lm_gsf_psi_ks", - prior = prior_init_only(prior_lognormal(mu_ks@init, omega_ks@init)), + prior = prior_init_only(prior_lognormal(median(mu_ks), median(omega_ks))), size = "n_subjects" ), Parameter( name = "lm_gsf_psi_kg", - prior = prior_init_only(prior_lognormal(mu_kg@init, omega_kg@init)), + prior = prior_init_only(prior_lognormal(median(mu_kg), median(omega_kg))), size = "n_subjects" ) ) diff --git a/R/LongitudinalRandomSlope.R b/R/LongitudinalRandomSlope.R index 57b70a45..b431a661 100755 --- a/R/LongitudinalRandomSlope.R +++ b/R/LongitudinalRandomSlope.R @@ -31,7 +31,7 @@ NULL #' @export LongitudinalRandomSlope <- function( intercept = prior_normal(30, 10), - slope_mu = prior_normal(0, 15), + slope_mu = prior_normal(1, 3), slope_sigma = prior_lognormal(0, 1.5), sigma = prior_lognormal(0, 1.5) ) { @@ -40,6 +40,10 @@ LongitudinalRandomSlope <- function( x = "lm-random-slope/model.stan" ) + # Apply constriants + sigma <- set_limits(sigma, lower = 0) + slope_sigma <- set_limits(slope_sigma, lower = 0) + .LongitudinalRandomSlope( LongitudinalModel( name = "Random Slope", @@ -51,7 +55,7 @@ LongitudinalRandomSlope <- function( Parameter(name = "lm_rs_sigma", prior = sigma, size = 1), Parameter( name = "lm_rs_ind_rnd_slope", - prior = prior_init_only(prior_normal(slope_mu@init, slope_sigma@init)), + prior = prior_init_only(prior_normal(median(slope_mu), median(slope_sigma))), size = "n_subjects" ) ) diff --git a/R/LongitudinalSteinFojo.R b/R/LongitudinalSteinFojo.R index 34aa2919..73cf1cb5 100755 --- a/R/LongitudinalSteinFojo.R +++ b/R/LongitudinalSteinFojo.R @@ -59,6 +59,12 @@ LongitudinalSteinFojo <- function( centred = centred )) + # Apply constriants + omega_bsld <- set_limits(omega_bsld, lower = 0) + omega_ks <- set_limits(omega_ks, lower = 0) + omega_kg <- set_limits(omega_kg, lower = 0) + sigma <- set_limits(sigma, lower = 0) + parameters <- list( Parameter(name = "lm_sf_mu_bsld", prior = mu_bsld, size = "n_studies"), Parameter(name = "lm_sf_mu_ks", prior = mu_ks, size = "n_arms"), @@ -76,17 +82,17 @@ LongitudinalSteinFojo <- function( list( Parameter( name = "lm_sf_psi_bsld", - prior = prior_init_only(prior_lognormal(mu_bsld@init, omega_bsld@init)), + prior = prior_init_only(prior_lognormal(median(mu_bsld), median(omega_bsld))), size = "n_subjects" ), Parameter( name = "lm_sf_psi_ks", - prior = prior_init_only(prior_lognormal(mu_ks@init, omega_ks@init)), + prior = prior_init_only(prior_lognormal(median(mu_ks), median(omega_ks))), size = "n_subjects" ), Parameter( name = "lm_sf_psi_kg", - prior = prior_init_only(prior_lognormal(mu_kg@init, omega_kg@init)), + prior = prior_init_only(prior_lognormal(median(mu_kg), median(omega_kg))), size = "n_subjects" ) ) diff --git a/R/Prior.R b/R/Prior.R index c08ee946..2685e3a0 100755 --- a/R/Prior.R +++ b/R/Prior.R @@ -7,7 +7,7 @@ NULL #' The documentation lists all the conventional arguments for [`Prior`] #' constructors. #' -#' @param init (`number`)\cr initial value. +#' @param centre (`number`)\cr the central point of distribution to shrink sampled values towards #' @param x ([`Prior`])\cr a prior Distribution #' @param object ([`Prior`])\cr a prior Distribution #' @param name (`character`)\cr the name of the parameter the prior distribution is for @@ -26,10 +26,11 @@ NULL #' @slot parameters (`list`)\cr See arguments. #' @slot repr_model (`string`)\cr See arguments. #' @slot repr_data (`string`)\cr See arguments. -#' @slot init (`numeric`)\cr See arguments. +#' @slot centre (`numeric`)\cr See arguments. #' @slot validation (`list`)\cr See arguments. #' @slot display (`string`)\cr See arguments. #' @slot sample (`function`)\cr See arguments. +#' @slot limits (`numeric`)\cr See arguments. #' #' @family Prior-internal #' @export Prior @@ -41,9 +42,10 @@ NULL "display" = "character", "repr_model" = "character", "repr_data" = "character", - "init" = "numeric", + "centre" = "numeric", "validation" = "list", - "sample" = "function" + "sample" = "function", + "limits" = "numeric" ) ) @@ -52,28 +54,31 @@ NULL #' @param repr_model (`string`)\cr the Stan code representation for the model block. #' @param repr_data (`string`)\cr the Stan code representation for the data block. #' @param display (`string`)\cr the string to display when object is printed. -#' @param init (`numeric`)\cr the initial value. +#' @param centre (`numeric`)\cr the central point of distribution to shrink sampled values towards #' @param validation (`list`)\cr the prior distribution parameter validation functions. Must have #' the same names as the `paramaters` slot. #' @param sample (`function`)\cr a function to sample from the prior distribution. +#' @param limits (`numeric`)\cr TODO #' @rdname Prior-class Prior <- function( parameters, display, repr_model, repr_data, - init, + centre, validation, - sample + sample, + limits = c(-Inf, Inf) ) { .Prior( parameters = parameters, repr_model = repr_model, repr_data = repr_data, - init = init, + centre = centre, display = display, validation = validation, - sample = sample + sample = sample, + limits = limits ) } @@ -99,6 +104,15 @@ setValidity( ) + +#' @rdname set_limits +#' @export +set_limits.Prior <- function(object, lower = -Inf, upper = Inf) { + object@limits <- c(lower, upper) + return(object) +} + + #' `Prior` -> `Character` #' #' Converts a [`Prior`] object to a character vector @@ -188,8 +202,18 @@ NULL #' @describeIn Prior-Getter-Methods The prior's initial value #' @export initialValues.Prior <- function(object, ...) { - getOption("jmpost.prior_shrinkage") * object@init + - (1 - getOption("jmpost.prior_shrinkage")) * object@sample(1) + samples <- getOption("jmpost.prior_shrinkage") * object@centre + + (1 - getOption("jmpost.prior_shrinkage")) * object@sample(100) + + valid_samples <- samples[samples >= min(object@limits) & samples <= max(object@limits)] + assert_that( + length(valid_samples) >= 1, + msg = "Unable to generate an initial value that meets the required constraints" + ) + if (length(valid_samples) == 1) { + return(valid_samples) + } + return(sample(valid_samples, 1)) } @@ -210,7 +234,7 @@ prior_normal <- function(mu, sigma) { "real prior_mu_{name};", "real prior_sigma_{name};" ), - init = mu, + centre = mu, sample = \(n) local_rnorm(n, mu, sigma), validation = list( mu = is.numeric, @@ -231,7 +255,7 @@ prior_std_normal <- function() { display = "std_normal()", repr_model = "{name} ~ std_normal();", repr_data = "", - init = 0, + centre = 0, sample = \(n) local_rnorm(n), validation = list() ) @@ -253,7 +277,7 @@ prior_cauchy <- function(mu, sigma) { "real prior_mu_{name};", "real prior_sigma_{name};" ), - init = mu, + centre = mu, sample = \(n) local_rcauchy(n, mu, sigma), validation = list( mu = is.numeric, @@ -262,6 +286,7 @@ prior_cauchy <- function(mu, sigma) { ) } + #' Gamma Prior Distribution #' #' @param alpha (`number`)\cr shape. @@ -278,7 +303,7 @@ prior_gamma <- function(alpha, beta) { "real prior_alpha_{name};", "real prior_beta_{name};" ), - init = alpha / beta, + centre = alpha / beta, sample = \(n) local_rgamma(n, shape = alpha, rate = beta), validation = list( alpha = \(x) x > 0, @@ -303,7 +328,7 @@ prior_lognormal <- function(mu, sigma) { "real prior_mu_{name};", "real prior_sigma_{name};" ), - init = exp(mu + (sigma^2) / 2), + centre = exp(mu + (sigma^2) / 2), sample = \(n) local_rlnorm(n, mu, sigma), validation = list( mu = is.numeric, @@ -328,7 +353,7 @@ prior_beta <- function(a, b) { "real prior_a_{name};", "real prior_b_{name};" ), - init = a / (a + b), + centre = a / (a + b), sample = \(n) local_rbeta(n, a, b), validation = list( a = \(x) x > 0, @@ -356,7 +381,7 @@ prior_init_only <- function(dist) { sample = \(n) { dist@sample(n) }, - init = dist@init, + centre = dist@centre, validation = list() ) } @@ -384,7 +409,7 @@ prior_uniform <- function(alpha, beta) { "real prior_alpha_{name};", "real prior_beta_{name};" ), - init = 0.5 * (alpha + beta), + centre = 0.5 * (alpha + beta), sample = \(n) local_runif(n, alpha, beta), validation = list( alpha = is.numeric, @@ -416,7 +441,7 @@ prior_student_t <- function(nu, mu, sigma) { "real prior_mu_{name};", "real prior_sigma_{name};" ), - init = mu, + centre = mu, sample = \(n) local_rt(n, nu, mu, sigma), validation = list( nu = \(x) x > 0, @@ -447,7 +472,7 @@ prior_logistic <- function(mu, sigma) { "real prior_mu_{name};", "real prior_sigma_{name};" ), - init = mu, + centre = mu, sample = \(n) local_rlogis(n, mu, sigma), validation = list( mu = is.numeric, @@ -476,7 +501,7 @@ prior_loglogistic <- function(alpha, beta) { "real prior_alpha_{name};", "real prior_beta_{name};" ), - init = alpha * pi / (beta * sin(pi / beta)), + centre = alpha * pi / (beta * sin(pi / beta)), sample = \(n) { local_rloglogis(n, alpha, beta) }, @@ -507,7 +532,7 @@ prior_invgamma <- function(alpha, beta) { "real prior_alpha_{name};", "real prior_beta_{name};" ), - init = beta / (alpha - 1), + centre = beta / (alpha - 1), sample = \(n) local_rinvgamma(n, alpha, beta), validation = list( alpha = \(x) x > 0, @@ -517,6 +542,38 @@ prior_invgamma <- function(alpha, beta) { } +# nolint start +# +# Developer Notes +# +# The `median.Prior` function is a rough workaround to help generate initial values for +# hierarchical distributions. The original implementation involved sampling initial values +# for the random effects using the medians of the parent distribution e.g. +# ``` +# random_effect ~ beta(a_prior@centre, b_prior@centre) +# ``` +# A problem came up though when we implemented support for constrained distributions +# as there was no longer any guarantee that the median/centre of the distribution is +# a valid value e.g. `a_prior ~ prior_normal(-200, 400)`. +# +# To resolve this issue the `median.Prior` method was created which simply samples +# multiple observations from the constrained distribution and then takes the median +# of those constrained observations; this then ensures that the value being used +# for the parameters is a valid value +# +# nolint end +#' @importFrom stats median +#' @export +median.Prior <- function(object) { + x <- replicate( + n = 250, + initialValues(object), + simplify = FALSE + ) |> + unlist() + median(x) +} + diff --git a/R/generics.R b/R/generics.R index cb5082ed..8f40c9b5 100755 --- a/R/generics.R +++ b/R/generics.R @@ -470,3 +470,18 @@ as_formula <- function(x, ...) { as_formula.default <- function(x, ...) { as.formula(x, ...) } + + +#' Set Constraints +#' +#' Applies constraints to a prior distribution to ensure any sampled numbers +#' from the distribution fall within the constraints +#' +#' @param object (`Prior`)\cr a prior distribution to apply constraints to +#' @param lower (`numeric`)\cr lower constraint boundary +#' @param lower (`numeric`)\cr upper constraint boundary +#' +#' @export +set_limits <- function(object, lower = -Inf, upper = Inf) { + UseMethod("set_limits") +} diff --git a/man/Prior-Shared.Rd b/man/Prior-Shared.Rd index 2290370e..218ae8e8 100644 --- a/man/Prior-Shared.Rd +++ b/man/Prior-Shared.Rd @@ -4,7 +4,7 @@ \alias{Prior-Shared} \title{\code{Prior} Function Arguments} \arguments{ -\item{init}{(\code{number})\cr initial value.} +\item{centre}{(\code{number})\cr the central point of distribution to shrink sampled values towards} \item{x}{(\code{\link{Prior}})\cr a prior Distribution} diff --git a/man/Prior-class.Rd b/man/Prior-class.Rd index d3e8a522..f06151b6 100644 --- a/man/Prior-class.Rd +++ b/man/Prior-class.Rd @@ -7,7 +7,16 @@ \alias{Prior} \title{Prior Object and Constructor Function} \usage{ -Prior(parameters, display, repr_model, repr_data, init, validation, sample) +Prior( + parameters, + display, + repr_model, + repr_data, + centre, + validation, + sample, + limits = c(-Inf, Inf) +) } \arguments{ \item{parameters}{(\code{list})\cr the prior distribution parameters.} @@ -18,12 +27,14 @@ Prior(parameters, display, repr_model, repr_data, init, validation, sample) \item{repr_data}{(\code{string})\cr the Stan code representation for the data block.} -\item{init}{(\code{numeric})\cr the initial value.} +\item{centre}{(\code{numeric})\cr the central point of distribution to shrink sampled values towards} \item{validation}{(\code{list})\cr the prior distribution parameter validation functions. Must have the same names as the \code{paramaters} slot.} \item{sample}{(\code{function})\cr a function to sample from the prior distribution.} + +\item{limits}{(\code{numeric})\cr TODO} } \description{ Specifies the prior distribution in a Stan Model @@ -37,13 +48,15 @@ Specifies the prior distribution in a Stan Model \item{\code{repr_data}}{(\code{string})\cr See arguments.} -\item{\code{init}}{(\code{numeric})\cr See arguments.} +\item{\code{centre}}{(\code{numeric})\cr See arguments.} \item{\code{validation}}{(\code{list})\cr See arguments.} \item{\code{display}}{(\code{string})\cr See arguments.} \item{\code{sample}}{(\code{function})\cr See arguments.} + +\item{\code{limits}}{(\code{numeric})\cr See arguments.} }} \seealso{ diff --git a/man/set_limits.Rd b/man/set_limits.Rd new file mode 100644 index 00000000..7cbdc6bd --- /dev/null +++ b/man/set_limits.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/generics.R, R/Prior.R +\name{set_limits} +\alias{set_limits} +\alias{set_limits.Prior} +\title{Set Constraints} +\usage{ +set_limits(object, lower = -Inf, upper = Inf) + +\method{set_limits}{Prior}(object, lower = -Inf, upper = Inf) +} +\arguments{ +\item{object}{(\code{Prior})\cr a prior distribution to apply constraints to} + +\item{lower}{(\code{numeric})\cr upper constraint boundary} +} +\description{ +Applies constraints to a prior distribution to ensure any sampled numbers +from the distribution fall within the constraints +} diff --git a/tests/testthat/_snaps/JointModel.md b/tests/testthat/_snaps/JointModel.md index 79135e32..780df2a1 100644 --- a/tests/testthat/_snaps/JointModel.md +++ b/tests/testthat/_snaps/JointModel.md @@ -17,7 +17,7 @@ Longitudinal: Random Slope Longitudinal Model with parameters: lm_rs_intercept ~ normal(mu = 30, sigma = 10) - lm_rs_slope_mu ~ normal(mu = 0, sigma = 15) + lm_rs_slope_mu ~ normal(mu = 1, sigma = 3) lm_rs_slope_sigma ~ lognormal(mu = 0, sigma = 1.5) lm_rs_sigma ~ lognormal(mu = 0, sigma = 1.5) lm_rs_ind_rnd_slope ~ @@ -45,7 +45,7 @@ Longitudinal: Random Slope Longitudinal Model with parameters: lm_rs_intercept ~ normal(mu = 30, sigma = 10) - lm_rs_slope_mu ~ normal(mu = 0, sigma = 15) + lm_rs_slope_mu ~ normal(mu = 1, sigma = 3) lm_rs_slope_sigma ~ lognormal(mu = 0, sigma = 1.5) lm_rs_sigma ~ lognormal(mu = 0, sigma = 1.5) lm_rs_ind_rnd_slope ~ @@ -126,7 +126,7 @@ Longitudinal: Random Slope Longitudinal Model with parameters: lm_rs_intercept ~ normal(mu = 30, sigma = 10) - lm_rs_slope_mu ~ normal(mu = 0, sigma = 15) + lm_rs_slope_mu ~ normal(mu = 1, sigma = 3) lm_rs_slope_sigma ~ lognormal(mu = 0, sigma = 1.5) lm_rs_sigma ~ lognormal(mu = 0, sigma = 1.5) lm_rs_ind_rnd_slope ~ diff --git a/tests/testthat/_snaps/LongitudinalRandomSlope.md b/tests/testthat/_snaps/LongitudinalRandomSlope.md index 49563b6e..997081e9 100644 --- a/tests/testthat/_snaps/LongitudinalRandomSlope.md +++ b/tests/testthat/_snaps/LongitudinalRandomSlope.md @@ -7,7 +7,7 @@ Random Slope Longitudinal Model with parameters: lm_rs_intercept ~ normal(mu = 30, sigma = 10) - lm_rs_slope_mu ~ normal(mu = 0, sigma = 15) + lm_rs_slope_mu ~ normal(mu = 1, sigma = 3) lm_rs_slope_sigma ~ lognormal(mu = 0, sigma = 1.5) lm_rs_sigma ~ lognormal(mu = 0, sigma = 1.5) lm_rs_ind_rnd_slope ~ @@ -23,7 +23,7 @@ Random Slope Longitudinal Model with parameters: lm_rs_intercept ~ normal(mu = 0, sigma = 1) - lm_rs_slope_mu ~ normal(mu = 0, sigma = 15) + lm_rs_slope_mu ~ normal(mu = 1, sigma = 3) lm_rs_slope_sigma ~ lognormal(mu = 0, sigma = 1.5) lm_rs_sigma ~ gamma(alpha = 2, beta = 1) lm_rs_ind_rnd_slope ~ diff --git a/tests/testthat/helper-example_data.R b/tests/testthat/helper-example_data.R index d87ac619..3b28bcb8 100644 --- a/tests/testthat/helper-example_data.R +++ b/tests/testthat/helper-example_data.R @@ -1,12 +1,6 @@ -test_data_1 <- new.env() - ensure_test_data_1 <- function() { - if (!is.null(test_data_1$jsamples)) { - return(invisible(test_data_1)) - } - set.seed(739) simjdat <- SimJointData( design = list( @@ -74,10 +68,12 @@ ensure_test_data_1 <- function() { }) - test_data_1$dat_os <- dat_os - test_data_1$dat_lm <- dat_lm - test_data_1$jmodel <- jm - test_data_1$jdata <- jdat - test_data_1$jsamples <- mp - return(invisible(test_data_1)) + results <- list( + dat_os = dat_os, + dat_lm = dat_lm, + jmodel = jm, + jdata = jdat, + jsamples = mp + ) + return(invisible(results)) } diff --git a/tests/testthat/test-JointModelSamples.R b/tests/testthat/test-JointModelSamples.R index ff183e41..886d2100 100644 --- a/tests/testthat/test-JointModelSamples.R +++ b/tests/testthat/test-JointModelSamples.R @@ -1,7 +1,7 @@ +test_data_1 <- ensure_test_data_1() test_that("print works as expected for JointModelSamples", { - ensure_test_data_1() expect_snapshot({ print(test_data_1$jsamples) }) diff --git a/tests/testthat/test-LongitudinalClaretBruno.R b/tests/testthat/test-LongitudinalClaretBruno.R index de2b35b8..0241a863 100644 --- a/tests/testthat/test-LongitudinalClaretBruno.R +++ b/tests/testthat/test-LongitudinalClaretBruno.R @@ -295,3 +295,53 @@ test_that("Quantity models pass the parser", { ) expect_stan_syntax(stanmod) }) + + +test_that("Can generate valid initial values", { + + pars <- c( + "lm_clbr_omega_b", "lm_clbr_omega_g", "lm_clbr_omega_c", + "lm_clbr_omega_p", "lm_clbr_sigma", "lm_gsf_sigma" + ) + + # Defaults work as expected + mod <- LongitudinalClaretBruno() + vals <- initialValues(mod, n_chains = 1) + vals <- vals[names(vals) %in% pars] + expect_true(all(vals > 0)) + + + # Test all individual parameters throw error if given prior that can't sample + # valid value + args <- list( + omega_b = prior_normal(-200, 1), + omega_g = prior_normal(-200, 1), + omega_c = prior_normal(-200, 1), + omega_p = prior_normal(-200, 1), + sigma = prior_normal(-200, 1) + ) + for (n_arg in names(args)) { + arg <- args[n_arg] + expect_error( + { + mod <- do.call(LongitudinalClaretBruno, arg) + initialValues(mod, n_chains = 1) + }, + regexp = "Unable to generate" + ) + } + + # Test initial values can be found for weird priors that do overlap the valid region + mod <- LongitudinalClaretBruno( + omega_b = prior_normal(-200, 400), + omega_g = prior_gamma(2, 5), + omega_c = prior_uniform(-200, 400), + omega_p = prior_lognormal(-200, 2), + sigma = prior_cauchy(-200, 400) + ) + set.seed(1001) + vals <- unlist(initialValues(mod, n_chains = 200)) + vals <- vals[names(vals) %in% pars] + expect_true(all(vals > 0)) + +}) diff --git a/tests/testthat/test-LongitudinalGSF.R b/tests/testthat/test-LongitudinalGSF.R index 0e2f0553..392aaa89 100644 --- a/tests/testthat/test-LongitudinalGSF.R +++ b/tests/testthat/test-LongitudinalGSF.R @@ -220,3 +220,55 @@ test_that("Quantity models pass the parser", { ) expect_stan_syntax(stanmod) }) + + +test_that("Can generate valid initial values", { + + pars <- c( + "lm_gsf_omega_bsld", "lm_gsf_omega_ks", "lm_gsf_omega_kg", + "lm_gsf_a_phi", "lm_gsf_b_phi", "lm_gsf_sigma" + ) + + # Defaults work as expected + mod <- LongitudinalGSF() + vals <- initialValues(mod, n_chains = 1) + vals <- vals[names(vals) %in% pars] + expect_true(all(vals > 0)) + + + # Test all individual parameters throw error if given prior that can't sample + # valid value + args <- list( + omega_bsld = prior_normal(-200, 1), + omega_ks = prior_normal(-200, 1), + omega_kg = prior_normal(-200, 1), + a_phi = prior_normal(-200, 1), + b_phi = prior_normal(-200, 1), + sigma = prior_normal(-200, 1) + ) + for (n_arg in names(args)) { + arg <- args[n_arg] + expect_error( + { + mod <- do.call(LongitudinalGSF, arg) + initialValues(mod, n_chains = 1) + }, + regexp = "Unable to generate" + ) + } + + # Test initial values can be found for weird priors that do overlap the valid region + mod <- LongitudinalGSF( + omega_bsld = prior_normal(-200, 400), + omega_ks = prior_gamma(2, 5), + omega_kg = prior_uniform(-200, 400), + a_phi = prior_lognormal(-200, 2), + b_phi = prior_cauchy(-200, 400), + sigma = prior_cauchy(-200, 400) + ) + set.seed(1001) + vals <- unlist(initialValues(mod, n_chains = 200)) + vals <- vals[names(vals) %in% pars] + expect_true(all(vals > 0)) + +}) diff --git a/tests/testthat/test-LongitudinalQuantiles.R b/tests/testthat/test-LongitudinalQuantiles.R index b2af0755..2b744e20 100644 --- a/tests/testthat/test-LongitudinalQuantiles.R +++ b/tests/testthat/test-LongitudinalQuantiles.R @@ -1,7 +1,7 @@ +test_data_1 <- ensure_test_data_1() test_that("Test that LongitudinalQuantities works as expected", { - ensure_test_data_1() expected_column_names <- c("group", "time", "median", "lower", "upper") @@ -82,7 +82,7 @@ test_that("autoplot.LongitudinalQuantities works as expected", { test_that("LongitudinalQuantities print method works as expected", { - ensure_test_data_1() + expect_snapshot({ subjectgroups <- c("subject_0011", "subject_0061", "subject_0001", "subject_0002") times <- seq(0, 100, by = 10) @@ -118,7 +118,6 @@ test_that("LongitudinalQuantities print method works as expected", { test_that("LongitudinalQuantities can recover known results", { set.seed(101) - ensure_test_data_1() longsamps <- LongitudinalQuantities( test_data_1$jsamples, grid = GridFixed( @@ -141,7 +140,6 @@ test_that("LongitudinalQuantities can recover known results", { test_that("LongitudinalQuantities correctly subsets subjects and rebuilds correct value for each sample", { set.seed(101) - ensure_test_data_1() times <- c(-100, 0, 1, 100, 200) longsamps <- LongitudinalQuantities( diff --git a/tests/testthat/test-LongitudinalRandomSlope.R b/tests/testthat/test-LongitudinalRandomSlope.R index 4afbb6ff..015ccbfd 100644 --- a/tests/testthat/test-LongitudinalRandomSlope.R +++ b/tests/testthat/test-LongitudinalRandomSlope.R @@ -226,6 +226,7 @@ test_that("Random Slope Model left-censoring works as expected", { jm <- JointModel( longitudinal = LongitudinalRandomSlope( intercept = prior_normal(30, 2), + slope_mu = prior_normal(1, 2), slope_sigma = prior_lognormal(log(0.2), sigma = 0.5), sigma = prior_lognormal(log(3), sigma = 0.5) ) @@ -317,3 +318,47 @@ test_that("Quantity models pass the parser", { ) expect_stan_syntax(stanmod) }) + + + +test_that("Can generate valid initial values", { + + pars <- c( + "lm_rs_slope_sigma", "lm_rs_sigma" + ) + + # Defaults work as expected + mod <- LongitudinalRandomSlope() + vals <- initialValues(mod, n_chains = 1) + vals <- vals[names(vals) %in% pars] + expect_true(all(vals > 0)) + + + # Test all individual parameters throw error if given prior that can't sample + # valid value + args <- list( + slope_sigma = prior_normal(-200, 1), + sigma = prior_normal(-200, 1) + ) + for (n_arg in names(args)) { + arg <- args[n_arg] + expect_error( + { + mod <- do.call(LongitudinalRandomSlope, arg) + initialValues(mod, n_chains = 1) + }, + regexp = "Unable to generate" + ) + } + + # Test initial values can be found for weird priors that do overlap the valid region + mod <- LongitudinalRandomSlope( + slope_sigma = prior_normal(-200, 400), + sigma = prior_normal(-200, 400) + ) + set.seed(1001) + vals <- unlist(initialValues(mod, n_chains = 200)) + vals <- vals[names(vals) %in% pars] + expect_true(all(vals > 0)) + +}) diff --git a/tests/testthat/test-LongitudinalSteinFojo.R b/tests/testthat/test-LongitudinalSteinFojo.R index 03d42b09..17c9b21e 100644 --- a/tests/testthat/test-LongitudinalSteinFojo.R +++ b/tests/testthat/test-LongitudinalSteinFojo.R @@ -411,3 +411,52 @@ test_that("Quantity models pass the parser", { ) expect_stan_syntax(stanmod) }) + + + +test_that("Can generate valid initial values", { + + pars <- c( + "lm_gsf_omega_bsld", "lm_gsf_omega_ks", "lm_gsf_omega_kg", + "lm_gsf_sigma" + ) + + # Defaults work as expected + mod <- LongitudinalSteinFojo() + vals <- initialValues(mod, n_chains = 1) + vals <- vals[names(vals) %in% pars] + expect_true(all(vals > 0)) + + + # Test all individual parameters throw error if given prior that can't sample + # valid value + args <- list( + omega_bsld = prior_normal(-200, 1), + omega_ks = prior_normal(-200, 1), + omega_kg = prior_normal(-200, 1), + sigma = prior_normal(-200, 1) + ) + for (n_arg in names(args)) { + arg <- args[n_arg] + expect_error( + { + mod <- do.call(LongitudinalSteinFojo, arg) + initialValues(mod, n_chains = 1) + }, + regexp = "Unable to generate" + ) + } + + # Test initial values can be found for weird priors that do overlap the valid region + mod <- LongitudinalSteinFojo( + omega_bsld = prior_normal(-200, 400), + omega_ks = prior_gamma(2, 5), + omega_kg = prior_uniform(-200, 400), + sigma = prior_cauchy(-200, 400) + ) + set.seed(1001) + vals <- unlist(initialValues(mod, n_chains = 200)) + vals <- vals[names(vals) %in% pars] + expect_true(all(vals > 0)) + +}) diff --git a/tests/testthat/test-Prior.R b/tests/testthat/test-Prior.R index 3e434193..3a687ade 100644 --- a/tests/testthat/test-Prior.R +++ b/tests/testthat/test-Prior.R @@ -186,3 +186,31 @@ test_that("jmpost.prior_shrinkage works as expected", { local_rnorm = \(...) 4 ) }) + + + +test_that("Limits work as expected", { + x <- prior_normal(0, 1) + x <- set_limits(x, lower = 0, upper = 1) + ivs <- replicate( + n = 100, + initialValues(x) + ) + expect_true(all(ivs > 0)) + expect_true(all(ivs < 1)) + + + x <- prior_cauchy(-200, 150) + x <- set_limits(x, lower = 0) + ivs <- replicate( + n = 100, + initialValues(x) + ) + expect_true(all(ivs > 0)) + + + ## Put an impossible constraint on the distribution + x <- prior_lognormal(0, 1) + x <- set_limits(x, upper = 0) + expect_error(initialValues(x), regex = "Unable to generate") +}) diff --git a/tests/testthat/test-SurvivalQuantities.R b/tests/testthat/test-SurvivalQuantities.R index d6618a10..e8943e42 100644 --- a/tests/testthat/test-SurvivalQuantities.R +++ b/tests/testthat/test-SurvivalQuantities.R @@ -1,8 +1,7 @@ - +test_data_1 <- ensure_test_data_1() test_that("SurvivalQuantities and autoplot.SurvivalQuantities works as expected", { - ensure_test_data_1() expected_column_names <- c("group", "time", "median", "lower", "upper") @@ -173,7 +172,6 @@ test_that("SurvivalQuantities print method works as expected", { test_that("SurvivalQuantities() works with time = 0", { - ensure_test_data_1() expect_error( { diff --git a/tests/testthat/test-brierScore.R b/tests/testthat/test-brierScore.R index ea67a31f..15dff624 100644 --- a/tests/testthat/test-brierScore.R +++ b/tests/testthat/test-brierScore.R @@ -1,4 +1,7 @@ +test_data_1 <- ensure_test_data_1() + + test_that("brierScore(SurvivalQuantities) returns same results as survreg", { # @@ -11,7 +14,6 @@ test_that("brierScore(SurvivalQuantities) returns same results as survreg", { # related to extracting the predicted values & the time & event data are working # as expected. # - ensure_test_data_1() dat_os <- test_data_1$dat_os mp <- test_data_1$jsamples diff --git a/tests/testthat/test-misc_models.R b/tests/testthat/test-misc_models.R index 9e1a2b10..ffcfd8f5 100644 --- a/tests/testthat/test-misc_models.R +++ b/tests/testthat/test-misc_models.R @@ -1,10 +1,10 @@ +test_data_1 <- ensure_test_data_1() test_that("Longitudinal Model doesn't print sampler rejection messages", { # These rejections typically happen when the sampler samples a # 0 value for the variance parameter. Sensible initial values + # setting near 0 limits (as opposed to 0) should avoid this - ensure_test_data_1() mp <- capture_messages({ devnull_out <- capture.output({ diff --git a/vignettes/model_fitting.Rmd b/vignettes/model_fitting.Rmd index c1d8f5df..569cc663 100644 --- a/vignettes/model_fitting.Rmd +++ b/vignettes/model_fitting.Rmd @@ -462,3 +462,6 @@ in a vector of the same length as the number of covariates. That is, if you are using a `prior_cauchy(0, 1)` prior for a parameter that should be `>0` then you will need to manually set the initial value (as described above) to ensure that it is a valid value. +- For constrained parameters (e.g. variance parameters that must be $> 0$) `initialValues()` +will continously sample and discard initial values until it generates one that meet the constraints. +If after 100 attempts no valid initial value has been found then it will throw an error. From cdfab28cd945aef4bce3cb81d31c244ee146852d Mon Sep 17 00:00:00 2001 From: github-actions <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 27 Jun 2024 16:15:46 +0000 Subject: [PATCH 2/8] [skip roxygen] [skip vbump] Roxygen Man Pages Auto Update --- man/LongitudinalRandomSlope-class.Rd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/man/LongitudinalRandomSlope-class.Rd b/man/LongitudinalRandomSlope-class.Rd index 915e1406..a40277a2 100644 --- a/man/LongitudinalRandomSlope-class.Rd +++ b/man/LongitudinalRandomSlope-class.Rd @@ -9,7 +9,7 @@ \usage{ LongitudinalRandomSlope( intercept = prior_normal(30, 10), - slope_mu = prior_normal(0, 15), + slope_mu = prior_normal(1, 3), slope_sigma = prior_lognormal(0, 1.5), sigma = prior_lognormal(0, 1.5) ) From 1615e2c3726197907cd6b0445963c3675bcc64e5 Mon Sep 17 00:00:00 2001 From: gowerc Date: Thu, 27 Jun 2024 17:16:18 +0100 Subject: [PATCH 3/8] fix spelling --- vignettes/model_fitting.Rmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vignettes/model_fitting.Rmd b/vignettes/model_fitting.Rmd index 569cc663..b2ceafb5 100644 --- a/vignettes/model_fitting.Rmd +++ b/vignettes/model_fitting.Rmd @@ -463,5 +463,5 @@ That is, if you are using a `prior_cauchy(0, 1)` prior for a parameter that shou then you will need to manually set the initial value (as described above) to ensure that it is a valid value. - For constrained parameters (e.g. variance parameters that must be $> 0$) `initialValues()` -will continously sample and discard initial values until it generates one that meet the constraints. +will continuously sample and discard initial values until it generates one that meet the constraints. If after 100 attempts no valid initial value has been found then it will throw an error. From a6bb0b8fd39c25b8477a6c25eb631fb134c074a8 Mon Sep 17 00:00:00 2001 From: gowerc Date: Thu, 27 Jun 2024 17:17:37 +0100 Subject: [PATCH 4/8] fix pkgdown --- _pkgdown.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/_pkgdown.yml b/_pkgdown.yml index f22443b4..af180ac5 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -174,6 +174,7 @@ reference: - enableGQ - as_formula - getPredictionNames + - set_limits - title: Promises contents: From dc53a1fa25d5d9a987490b5fb427b426f8f86c65 Mon Sep 17 00:00:00 2001 From: gowerc Date: Thu, 27 Jun 2024 19:16:37 +0100 Subject: [PATCH 5/8] Fix rcmdcheck --- R/Prior.R | 8 ++++---- R/generics.R | 2 +- man/set_limits.Rd | 4 +++- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/R/Prior.R b/R/Prior.R index 2685e3a0..ca73698a 100755 --- a/R/Prior.R +++ b/R/Prior.R @@ -564,14 +564,14 @@ prior_invgamma <- function(alpha, beta) { # nolint end #' @importFrom stats median #' @export -median.Prior <- function(object) { - x <- replicate( +median.Prior <- function(x, na.rm, ...) { + vals <- replicate( n = 250, - initialValues(object), + initialValues(x), simplify = FALSE ) |> unlist() - median(x) + median(vals) } diff --git a/R/generics.R b/R/generics.R index 8f40c9b5..56879d0e 100755 --- a/R/generics.R +++ b/R/generics.R @@ -479,7 +479,7 @@ as_formula.default <- function(x, ...) { #' #' @param object (`Prior`)\cr a prior distribution to apply constraints to #' @param lower (`numeric`)\cr lower constraint boundary -#' @param lower (`numeric`)\cr upper constraint boundary +#' @param upper (`numeric`)\cr upper constraint boundary #' #' @export set_limits <- function(object, lower = -Inf, upper = Inf) { diff --git a/man/set_limits.Rd b/man/set_limits.Rd index 7cbdc6bd..ce05f7bd 100644 --- a/man/set_limits.Rd +++ b/man/set_limits.Rd @@ -12,7 +12,9 @@ set_limits(object, lower = -Inf, upper = Inf) \arguments{ \item{object}{(\code{Prior})\cr a prior distribution to apply constraints to} -\item{lower}{(\code{numeric})\cr upper constraint boundary} +\item{lower}{(\code{numeric})\cr lower constraint boundary} + +\item{upper}{(\code{numeric})\cr upper constraint boundary} } \description{ Applies constraints to a prior distribution to ensure any sampled numbers From 709161b5093531f39ec6f5391ebafe1f038b3873 Mon Sep 17 00:00:00 2001 From: gowerc Date: Fri, 28 Jun 2024 13:09:14 +0100 Subject: [PATCH 6/8] add test for median(Prior) --- R/Prior.R | 2 +- tests/testthat/test-Prior.R | 27 +++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/R/Prior.R b/R/Prior.R index ca73698a..b32f61f3 100755 --- a/R/Prior.R +++ b/R/Prior.R @@ -566,7 +566,7 @@ prior_invgamma <- function(alpha, beta) { #' @export median.Prior <- function(x, na.rm, ...) { vals <- replicate( - n = 250, + n = 500, initialValues(x), simplify = FALSE ) |> diff --git a/tests/testthat/test-Prior.R b/tests/testthat/test-Prior.R index 3a687ade..b48d3701 100644 --- a/tests/testthat/test-Prior.R +++ b/tests/testthat/test-Prior.R @@ -214,3 +214,30 @@ test_that("Limits work as expected", { x <- set_limits(x, upper = 0) expect_error(initialValues(x), regex = "Unable to generate") }) + + + +test_that("median(Prior) works as expected", { + set.seed(2410) + + # Unrestricted + p1 <- prior_normal(-200, 400) + expect_equal( + median(p1), + -200, + tolerance = 0.15 + ) + + + # Constrained + p2 <- set_limits(p1, lower = 0) + + actual <- rnorm(6000, -200, 400) * 0.5 + -200 * 0.5 + actual_red <- actual[actual >= 0] + + expect_equal( + median(p2), + median(actual_red), + tolerance = 0.15 + ) +}) From a9c9f1abc74edff453b5d08fbb34e17355190ce4 Mon Sep 17 00:00:00 2001 From: gowerc Date: Mon, 1 Jul 2024 16:31:06 +0100 Subject: [PATCH 7/8] updated docs --- R/Prior.R | 3 ++- man/Prior-Shared.Rd | 3 ++- man/Prior-class.Rd | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/R/Prior.R b/R/Prior.R index b32f61f3..a978a054 100755 --- a/R/Prior.R +++ b/R/Prior.R @@ -8,6 +8,7 @@ NULL #' constructors. #' #' @param centre (`number`)\cr the central point of distribution to shrink sampled values towards +#' (for most distributions this is the mean or median if the mean is undefined) #' @param x ([`Prior`])\cr a prior Distribution #' @param object ([`Prior`])\cr a prior Distribution #' @param name (`character`)\cr the name of the parameter the prior distribution is for @@ -58,7 +59,7 @@ NULL #' @param validation (`list`)\cr the prior distribution parameter validation functions. Must have #' the same names as the `paramaters` slot. #' @param sample (`function`)\cr a function to sample from the prior distribution. -#' @param limits (`numeric`)\cr TODO +#' @param limits (`numeric`)\cr the lower and upper limits for a truncated distribution #' @rdname Prior-class Prior <- function( parameters, diff --git a/man/Prior-Shared.Rd b/man/Prior-Shared.Rd index 218ae8e8..6f4cb023 100644 --- a/man/Prior-Shared.Rd +++ b/man/Prior-Shared.Rd @@ -4,7 +4,8 @@ \alias{Prior-Shared} \title{\code{Prior} Function Arguments} \arguments{ -\item{centre}{(\code{number})\cr the central point of distribution to shrink sampled values towards} +\item{centre}{(\code{number})\cr the central point of distribution to shrink sampled values towards +(for most distributions this is the mean or median if the mean is undefined)} \item{x}{(\code{\link{Prior}})\cr a prior Distribution} diff --git a/man/Prior-class.Rd b/man/Prior-class.Rd index f06151b6..f666c2d6 100644 --- a/man/Prior-class.Rd +++ b/man/Prior-class.Rd @@ -34,7 +34,7 @@ the same names as the \code{paramaters} slot.} \item{sample}{(\code{function})\cr a function to sample from the prior distribution.} -\item{limits}{(\code{numeric})\cr TODO} +\item{limits}{(\code{numeric})\cr the lower and upper limits for a truncated distribution} } \description{ Specifies the prior distribution in a Stan Model From 2dfe0ac5dd57fd14b1f058230cdc59bd278925eb Mon Sep 17 00:00:00 2001 From: github-actions <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 3 Jul 2024 14:52:37 +0000 Subject: [PATCH 8/8] [skip roxygen] [skip vbump] Roxygen Man Pages Auto Update --- DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index 210545bf..7045a1a3 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -25,7 +25,7 @@ License: Apache License (>= 2) Encoding: UTF-8 Language: en-GB Roxygen: list(markdown = TRUE) -RoxygenNote: 7.3.1 +RoxygenNote: 7.3.2 Depends: R (>= 4.1.0) Imports: