Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the GSF model implementation #249

Merged
merged 17 commits into from
Feb 1, 2024
31 changes: 31 additions & 0 deletions .github/workflows/cron.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@



on:
schedule:
- cron: '0 4 1,15 * *'
workflow_dispatch:

name: CRON

jobs:
r-cmd:
name: R CMD Check 🧬
uses: insightsengineering/r.pkg.template/.github/workflows/build-check-install.yaml@main
secrets:
REPO_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
additional-env-vars: |
CMDSTAN=/root/.cmdstan
CMDSTAN_PATH=/root/.cmdstan
CMDSTANR_NO_VER_CHECK=true
JMPOST_CACHE_DIR=${{ github.workspace }}/.cache
JMPOST_FULL_TEST=TRUE
additional-caches: |
${{ github.workspace }}/.cache






6 changes: 3 additions & 3 deletions R/DataLongitudinal.R
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ harmonise.DataLongitudinal <- function(object, subject_var, subject_ord, ...) {
)
assert_that(
all(data[[subject_var]] %in% subject_ord),
msg = "There are subjects `longitudinal` that are not present in `subjects`"
msg = "There are subjects in `longitudinal` that are not present in `subjects`"
)
assert_that(
all(subject_ord %in% data[[subject_var]]),
msg = "There are subjects `subjects` that are not present in `longitudinal`"
msg = "There are subjects in `subjects` that are not present in `longitudinal`"
)
data[[subject_var]] <- factor(
as.character(data[[subject_var]]),
Expand Down Expand Up @@ -169,7 +169,7 @@ as_stan_list.DataLongitudinal <- function(object, subject_var, ...) {
assert_factor(df[[subject_var]])

mat_sld_index <- stats::model.matrix(
stats::as.formula(paste("~", subject_var)),
stats::as.formula(paste("~ -1 + ", subject_var)),
data = df
) |>
t()
Expand Down
4 changes: 2 additions & 2 deletions R/DataSurvival.R
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,11 @@ harmonise.DataSurvival <- function(object, subject_var, subject_ord, ...) {
)
assert_that(
all(data[[subject_var]] %in% subject_ord),
msg = "There are subjects `survival` that are not present in `subjects`"
msg = "There are subjects in `survival` that are not present in `subjects`"
)
assert_that(
all(subject_ord %in% data[[subject_var]]),
msg = "There are subjects `subjects` that are not present in `survival`"
msg = "There are subjects in `subjects` that are not present in `survival`"
)

data[[subject_var]] <- factor(
Expand Down
8 changes: 4 additions & 4 deletions R/JointModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ write_stan.JointModel <- function(object, file_path) {
#' @rdname compileStanModel
#' @export
compileStanModel.JointModel <- function(object) {
x <- compileStanModel(object@stan)
stanObject <- object@stan
stanObject@generated_quantities <- ""
x <- compileStanModel(stanObject)
invisible(x)
}

Expand Down Expand Up @@ -161,9 +163,7 @@ sampleStanModel.JointModel <- function(object, data, ...) {
args[["init"]] <- function() values_initial_expanded
}

stanObject <- object@stan
stanObject@generated_quantities <- ""
model <- compileStanModel(stanObject)
model <- compileStanModel(object)
results <- do.call(model$sample, args)

.JointModelSamples(
Expand Down
2 changes: 1 addition & 1 deletion R/JointModelSamples.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ as_print_string.JointModelSamples <- function(object, indent = 1, ...) {
sizes <- vapply(
object@results$metadata()[["stan_variable_sizes"]],
\(x) {
if (length(x) == 1 & x == 1) return("")
if (length(x) == 1 && x == 1) return("")
paste0("[", paste(x, collapse = ", "), "]")
},
character(1)
Expand Down
104 changes: 76 additions & 28 deletions R/LongitudinalGSF.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,47 +26,95 @@ NULL
#' @param mu_bsld (`Prior`)\cr for the mean baseline value `mu_bsld`.
#' @param mu_ks (`Prior`)\cr for the mean shrinkage rate `mu_ks`.
#' @param mu_kg (`Prior`)\cr for the mean growth rate `mu_kg`.
#' @param mu_phi (`Prior`)\cr for the mean shrinkage proportion `mu_phi`.
#'
#' @param omega_bsld (`Prior`)\cr for the baseline value standard deviation `omega_bsld`.
#' @param omega_ks (`Prior`)\cr for the shrinkage rate standard deviation `omega_ks`.
#' @param omega_kg (`Prior`)\cr for the growth rate standard deviation `omega_kg`.
#' @param omega_phi (`Prior`)\cr for the shrinkage proportion standard deviation `omega_phi`.
#'
#' @param sigma (`Prior`)\cr for the variance of the longitudinal values `sigma`.
#'
#' @param a_phi (`Prior`)\cr for the alpha parameter for the fraction of cells that respond to treatment.
#' @param b_phi (`Prior`)\cr for the beta parameter for the fraction of cells that respond to treatment.
#'
#' @param psi_bsld (`Prior`)\cr for the baseline value random effect `psi_bsld`. Only used in the
#' centered parameterization to set the initial value.
#' @param psi_ks (`Prior`)\cr for the shrinkage rate random effect `psi_ks`. Only used in the
#' centered parameterization to set the initial value.
#' @param psi_kg (`Prior`)\cr for the growth rate random effect `psi_kg`. Only used in the
#' centered parameterization to set the initial value.
#' @param psi_phi (`Prior`)\cr for the shrinkage proportion random effect `psi_phi`. Only used in the
#' centered parameterization to set the initial value.
#'
#' @param centered (`logical`)\cr whether to use the centered parameterization.
#'
#' @export
LongitudinalGSF <- function(
mu_bsld = prior_lognormal(log(55), 5, init = 55),
mu_ks = prior_lognormal(log(0.1), 0.5, init = 0.1),
mu_kg = prior_lognormal(log(0.1), 1, init = 0.1),
mu_phi = prior_beta(2, 8, init = 0.2),
omega_bsld = prior_lognormal(log(0.1), 1, init = 0.1),
omega_ks = prior_lognormal(log(0.1), 1, init = 0.1),
omega_kg = prior_lognormal(log(0.1), 1, init = 0.1),
omega_phi = prior_lognormal(log(0.1), 1, init = 0.1),
sigma = prior_lognormal(log(0.1), 0.8, init = 0.1)

mu_bsld = prior_normal(log(60), 1, init = 60),
mu_ks = prior_normal(log(0.5), 1, init = 0.5),
mu_kg = prior_normal(log(0.3), 1, init = 0.3),

omega_bsld = prior_lognormal(log(0.2), 1, init = 0.2),
omega_ks = prior_lognormal(log(0.2), 1, init = 0.2),
omega_kg = prior_lognormal(log(0.2), 1, init = 0.2),

a_phi = prior_lognormal(log(5), 1, init = 5),
b_phi = prior_lognormal(log(5), 1, init = 5),

sigma = prior_lognormal(log(0.1), 1, init = 0.1),

psi_bsld = prior_none(init = 60),
psi_ks = prior_none(init = 0.5),
psi_kg = prior_none(init = 0.5),
psi_phi = prior_none(init = 0.5),

centered = FALSE
) {
eta_prior <- prior_std_normal()

gsf_model <- StanModule(decorated_render(
.x = paste0(read_stan("lm-gsf/model.stan"), collapse = "\n"),
centered = centered
))

parameters <- list(
Parameter(name = "lm_gsf_mu_bsld", prior = mu_bsld, size = "n_studies"),
Parameter(name = "lm_gsf_mu_ks", prior = mu_ks, size = "n_arms"),
Parameter(name = "lm_gsf_mu_kg", prior = mu_kg, size = "n_arms"),

Parameter(name = "lm_gsf_omega_bsld", prior = omega_bsld, size = 1),
Parameter(name = "lm_gsf_omega_ks", prior = omega_ks, size = 1),
Parameter(name = "lm_gsf_omega_kg", prior = omega_kg, size = 1),

Parameter(name = "lm_gsf_a_phi", prior = a_phi, size = "n_arms"),
Parameter(name = "lm_gsf_b_phi", prior = b_phi, size = "n_arms"),
Parameter(name = "lm_gsf_psi_phi", prior = psi_phi, size = "Nind"),

Parameter(name = "lm_gsf_sigma", prior = sigma, size = 1)
)

assert_flag(centered)
parameters_extra <- if (centered) {
list(
Parameter(name = "lm_gsf_psi_bsld", prior = psi_bsld, size = "Nind"),
Parameter(name = "lm_gsf_psi_ks", prior = psi_ks, size = "Nind"),
Parameter(name = "lm_gsf_psi_kg", prior = psi_kg, size = "Nind")
)
} else {
list(
Parameter(name = "lm_gsf_eta_tilde_bsld", prior = prior_std_normal(), size = "Nind"),
Parameter(name = "lm_gsf_eta_tilde_ks", prior = prior_std_normal(), size = "Nind"),
Parameter(name = "lm_gsf_eta_tilde_kg", prior = prior_std_normal(), size = "Nind")
)
}
parameters <- append(parameters, parameters_extra)

x <- LongitudinalModel(
name = "Generalized Stein-Fojo",
stan = merge(
StanModule("lm-gsf/model.stan"),
gsf_model,
StanModule("lm-gsf/functions.stan")
),
parameters = ParameterList(
Parameter(name = "lm_gsf_mu_bsld", prior = mu_bsld, size = "n_studies"),
Parameter(name = "lm_gsf_mu_ks", prior = mu_ks, size = "n_arms"),
Parameter(name = "lm_gsf_mu_kg", prior = mu_kg, size = "n_arms"),
Parameter(name = "lm_gsf_mu_phi", prior = mu_phi, size = "n_arms"),
Parameter(name = "lm_gsf_omega_bsld", prior = omega_bsld, size = 1),
Parameter(name = "lm_gsf_omega_ks", prior = omega_ks, size = 1),
Parameter(name = "lm_gsf_omega_kg", prior = omega_kg, size = 1),
Parameter(name = "lm_gsf_omega_phi", prior = omega_phi, size = 1),
Parameter(name = "lm_gsf_sigma", prior = sigma, size = 1),
Parameter(name = "lm_gsf_eta_tilde_bsld", prior = eta_prior, size = "Nind"),
Parameter(name = "lm_gsf_eta_tilde_ks", prior = eta_prior, size = "Nind"),
Parameter(name = "lm_gsf_eta_tilde_kg", prior = eta_prior, size = "Nind"),
Parameter(name = "lm_gsf_eta_tilde_phi", prior = eta_prior, size = "Nind")
)
parameters = do.call(ParameterList, parameters)
)
.LongitudinalGSF(x)
}
5 changes: 4 additions & 1 deletion R/Prior.R
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,12 @@ setValidity(
#' @family Prior-internal
#' @export
as.character.Prior <- function(x, ...) {

parameters_rounded <- lapply(x@parameters, round, 5)

do.call(
glue::glue,
append(x@display, x@parameters)
append(x@display, parameters_rounded)
)
}

Expand Down
7 changes: 3 additions & 4 deletions R/StanModule.R
Original file line number Diff line number Diff line change
Expand Up @@ -433,10 +433,9 @@ as_stan_fragments <- function(x, stan_blocks = STAN_BLOCKS) {
#' @keywords internal
#' @export
as_print_string.StanModule <- function(object, indent = 1, ...) {
slots <- getSlots("StanModule")
slots <- slots[!names(slots) %in% c("priors", "inits")]
components <- Filter(\(block) all(slot(object, block) != ""), names(slots))

slots <- names(getSlots("StanModule"))
slots <- slots[!slots %in% c("priors", "inits")]
components <- Filter(\(block) paste(slot(object, block), collapse = "") != "", slots)
template <- c(
"StanModule Object with components:",
paste(" ", components)
Expand Down
42 changes: 20 additions & 22 deletions R/simulations_gsf.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@ gsf_dsld <- function(time, b, s, g, phi) {
#' @param sigma (`number`)\cr the variance of the longitudinal values.
#' @param mu_s (`numeric`)\cr the mean shrinkage rates for the two treatment arms.
#' @param mu_g (`numeric`)\cr the mean growth rates for the two treatment arms.
#' @param mu_phi (`numeric`)\cr the mean shrinkage proportions for the two treatment arms.
#' @param mu_b (`numeric`)\cr the mean baseline values for the two treatment arms.
#' @param omega_b (`number`)\cr the baseline value standard deviation.
#' @param omega_s (`number`)\cr the shrinkage rate standard deviation.
#' @param omega_g (`number`)\cr the growth rate standard deviation.
#' @param omega_phi (`number`)\cr the shrinkage proportion standard deviation.
#' @param a_phi (`number`)\cr the alpha parameter for the fraction of cells that respond to treatment.
#' @param b_phi (`number`)\cr the beta parameter for the fraction of cells that respond to treatment.
#' @param link_dsld (`number`)\cr the link coefficient for the derivative contribution.
#' @param link_ttg (`number`)\cr the link coefficient for the time-to-growth contribution.
#' @param link_identity (`number`)\cr the link coefficient for the SLD Identity contribution.
Expand All @@ -65,41 +65,39 @@ gsf_dsld <- function(time, b, s, g, phi) {
#' @export
sim_lm_gsf <- function(
sigma = 0.01,
mu_s = c(3, 4),
mu_g = c(0.2, 0.3),
mu_phi = c(0.1, 0.2),
mu_b = 50,
omega_b = 0.135,
omega_s = 0.15,
omega_g = 0.225,
omega_phi = 0.75,
mu_s = c(0.6, 0.4),
mu_g = c(0.25, 0.35),
mu_b = 60,
a_phi = c(4, 6),
b_phi = c(4, 6),
omega_b = 0.2,
omega_s = 0.2,
omega_g = 0.2,
link_dsld = 0,
link_ttg = 0,
link_identity = 0
) {
function(lm_base) {

assert_that(
length(unique(lm_base$study)) == 1,
length(unique(lm_base$study)) == length(mu_b),
length(mu_b) == 1,
length(sigma) == 1,
length(mu_s) == length(unique(lm_base$arm)),
length(mu_s) == length(mu_g),
length(mu_s) == length(mu_phi),
length(c(omega_b, omega_s, omega_g, omega_phi)) == 4
length(mu_s) == length(a_phi),
length(mu_s) == length(b_phi),
length(c(omega_b, omega_s, omega_g)) == 3
)

baseline_covs <- lm_base |>
dplyr::distinct(.data$pt, .data$arm, .data$study) |>
dplyr::mutate(arm_n = as.numeric(factor(as.character(.data$arm)))) |>
dplyr::mutate(eta_b = stats::rnorm(dplyr::n(), 0, 1)) |>
dplyr::mutate(eta_s = stats::rnorm(dplyr::n(), 0, 1)) |>
dplyr::mutate(eta_g = stats::rnorm(dplyr::n(), 0, 1)) |>
dplyr::mutate(eta_phi = stats::rnorm(dplyr::n(), 0, 1)) |>
dplyr::mutate(psi_b = exp(log(mu_b) + .data$eta_b * omega_b)) |>
dplyr::mutate(psi_s = exp(log(mu_s[.data$arm_n]) + .data$eta_s * omega_s)) |>
dplyr::mutate(psi_g = exp(log(mu_g[.data$arm_n]) + .data$eta_g * omega_g)) |>
dplyr::mutate(psi_phi = stats::plogis(stats::qlogis(mu_phi[.data$arm_n]) + .data$eta_phi * omega_phi))
dplyr::mutate(study_idx = as.numeric(factor(as.character(.data$study)))) |>
dplyr::mutate(arm_idx = as.numeric(factor(as.character(.data$arm)))) |>
dplyr::mutate(psi_b = stats::rlnorm(dplyr::n(), log(mu_b[.data$study_idx]), omega_b)) |>
dplyr::mutate(psi_s = stats::rlnorm(dplyr::n(), log(mu_s[.data$arm_idx]), omega_s)) |>
dplyr::mutate(psi_g = stats::rlnorm(dplyr::n(), log(mu_g[.data$arm_idx]), omega_g)) |>
dplyr::mutate(psi_phi = stats::rbeta(dplyr::n(), a_phi[.data$arm_idx], b_phi[.data$arm_idx]))

lm_dat <- lm_base |>
dplyr::select(!dplyr::all_of(c("study", "arm"))) |>
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ library(jmpost)
#> Registered S3 method overwritten by 'GGally':
#> method from
#> +.gg ggplot2
#> Registered S3 methods overwritten by 'ggpp':
#> method from
#> heightDetails.titleGrob ggplot2
#> widthDetails.titleGrob ggplot2
set.seed(321)
sim_data <- simulate_joint_data(
lm_fun = sim_lm_random_slope(),
Expand Down
Loading
Loading