Skip to content

Commit

Permalink
Merge branch 'main' into feature/easy_install_cmdstanr
Browse files Browse the repository at this point in the history
  • Loading branch information
gowerc committed Mar 25, 2024
2 parents 52b8caf + 2c10109 commit 89b508a
Show file tree
Hide file tree
Showing 13 changed files with 62 additions and 18 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Generated by roxygen2: do not edit by hand

S3method(as.CmdStanMCMC,JointModelSamples)
S3method(as.StanModule,JointModel)
S3method(as.StanModule,Link)
S3method(as.StanModule,LinkComponent)
Expand Down Expand Up @@ -134,6 +135,7 @@ export(SurvivalLogLogistic)
export(SurvivalModel)
export(SurvivalQuantities)
export(SurvivalWeibullPH)
export(as.CmdStanMCMC)
export(as_stan_list)
export(autoplot)
export(brierScore)
Expand Down
17 changes: 12 additions & 5 deletions R/JointModelSamples.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ generateQuantities.JointModelSamples <- function(object, patients, time_grid_lm,
devnull <- utils::capture.output(
results <- model$generate_quantities(
data = data,
fitted_params = object@results$draws()
fitted_params = as.CmdStanMCMC(object)$draws()
)
)
return(results)
Expand All @@ -74,7 +74,7 @@ generateQuantities.JointModelSamples <- function(object, patients, time_grid_lm,
#' @export
as_print_string.JointModelSamples <- function(object, indent = 1, ...) {
sizes <- vapply(
object@results$metadata()[["stan_variable_sizes"]],
as.CmdStanMCMC(object)$metadata()[["stan_variable_sizes"]],
\(x) {
if (length(x) == 1 && x == 1) return("")
paste0("[", paste(x, collapse = ", "), "]")
Expand All @@ -83,7 +83,7 @@ as_print_string.JointModelSamples <- function(object, indent = 1, ...) {
)
variable_string <- paste0(
" ",
object@results$metadata()[["stan_variables"]],
as.CmdStanMCMC(object)$metadata()[["stan_variables"]],
sizes
)
template <- c(
Expand All @@ -99,8 +99,8 @@ as_print_string.JointModelSamples <- function(object, indent = 1, ...) {
template_padded <- paste(pad, template)
sprintf(
paste(template_padded, collapse = "\n"),
object@results$metadata()$iter_sampling,
object@results$metadata()$num_chains
as.CmdStanMCMC(object)$metadata()$iter_sampling,
as.CmdStanMCMC(object)$metadata()$num_chains
)
}

Expand All @@ -114,3 +114,10 @@ setMethod(
cat("\n", string, "\n\n")
}
)


#' @rdname as.CmdStanMCMC
#' @export
as.CmdStanMCMC.JointModelSamples <- function(object, ...) {
return(object@results)
}
14 changes: 14 additions & 0 deletions R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -380,3 +380,17 @@ sampleSubjects <- function(object, subjects_df) {
hazardWindows <- function(object, ...) {
UseMethod("hazardWindows")
}


#' Coerce to `CmdStanMCMC`
#'
#' @param object to be converted
#' @param ... additional options
#'
#' @description
#' Coerces an object to a [`cmdstanr::CmdStanMCMC`] object
#'
#' @export
as.CmdStanMCMC <- function(object, ...) {
UseMethod("as.CmdStanMCMC")
}
2 changes: 2 additions & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ reference:
- merge
- show
- as_stan_list
- as.CmdStanMCMC
- as.CmdStanMCMC.JointModelSamples
- as.character.JointModel
- as.StanModule.Parameter
- as.StanModule.ParameterList
Expand Down
19 changes: 19 additions & 0 deletions man/as.CmdStanMCMC.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/testthat/test-LongitudinalGSF.R
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ test_that("Can recover known distributional parameters from a full GSF joint mod
}

dat <- summary_post(
mp@results,
as.CmdStanMCMC(mp),
c("lm_gsf_mu_bsld", "lm_gsf_mu_ks", "lm_gsf_mu_kg"),
TRUE
)
Expand All @@ -180,7 +180,7 @@ test_that("Can recover known distributional parameters from a full GSF joint mod
expect_true(all(dat$ess_bulk > 100))

dat <- summary_post(
mp@results,
as.CmdStanMCMC(mp),
c("link_dsld", "link_ttg", "lm_gsf_a_phi", "lm_gsf_b_phi", "sm_exp_lambda")
)

Expand Down
6 changes: 3 additions & 3 deletions tests/testthat/test-LongitudinalRandomSlope.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ test_that("LongitudinalRandomSlope correctly generates an intercept per study",
})
})

samples <- mp@results$draws(
samples <- as.CmdStanMCMC(mp)$draws(
c("lm_rs_intercept", "lm_rs_slope_mu", "lm_rs_slope_sigma", "lm_rs_sigma"),
format = "draws_matrix"
)
Expand Down Expand Up @@ -163,7 +163,7 @@ test_that("Random Slope Model can recover known parameter values", {
"lm_rs_sigma" # 3
)

pars <- mp@results$summary(vars)
pars <- as.CmdStanMCMC(mp)$summary(vars)


## Check that we can recover main effects parameters
Expand All @@ -173,7 +173,7 @@ test_that("Random Slope Model can recover known parameter values", {

## Check that we can recover random effects parameters
pars <- suppressWarnings({
mp@results$summary("lm_rs_ind_rnd_slope")$mean
as.CmdStanMCMC(mp)$summary("lm_rs_ind_rnd_slope")$mean
})

## Extract real random effects per patient
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test-LongitudinalSteinFojo.R
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ test_that("Can recover known distributional parameters from a SF joint model", {
}

dat <- summary_post(
mp@results,
as.CmdStanMCMC(mp),
c("lm_sf_mu_bsld", "lm_sf_mu_ks", "lm_sf_mu_kg"),
TRUE
)
Expand All @@ -177,7 +177,7 @@ test_that("Can recover known distributional parameters from a SF joint model", {
expect_true(all(dat$ess_bulk > 100))

dat <- summary_post(
mp@results,
as.CmdStanMCMC(mp),
c("link_dsld", "link_ttg", "sm_exp_lambda")
)

Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-SurvivalExponential.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ test_that("SurvivalExponential can recover true parameter (including covariates)

# Variables to extract (order important)
vars <- c("sm_exp_lambda", "beta_os_cov")
results_summary <- mp@results$summary(vars)
results_summary <- as.CmdStanMCMC(mp)$summary(vars)

# calculate Z-scores
par_mean <- results_summary$mean
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-SurvivalLoglogistic.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ test_that("SurvivalLogLogistic can recover known values", {

# Variables to extract (order important)
vars <- c("sm_loglogis_a", "sm_loglogis_b", "beta_os_cov")
results_summary <- mp@results$summary(vars)
results_summary <- as.CmdStanMCMC(mp)$summary(vars)

# calculate Z-scores
par_mean <- results_summary$mean
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-SurvivalWeibullPH.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ test_that("SurvivalWeibullPH can recover known values", {

# Variables to extract (order important)
vars <- c("sm_weibull_ph_lambda", "sm_weibull_ph_gamma", "beta_os_cov")
results_summary <- mp@results$summary(vars)
results_summary <- as.CmdStanMCMC(mp)$summary(vars)

# calculate Z-scores
par_mean <- results_summary$mean
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-model_multi_chain.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ test_that("Can recover known distribution parameters from random slope model whe
"link_dsld" = 0.1
)

results_df <- mp@results$summary(names(vars))
results_df <- as.CmdStanMCMC(mp)$summary(names(vars))

z_score <- abs(vars - results_df$mean) / results_df$sd
expect_true(all(z_score < qnorm(0.95)))
Expand Down
4 changes: 2 additions & 2 deletions vignettes/model_fitting.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ vars <- c(
"sm_weibull_ph_gamma",
"beta_os_cov"
)
mcmc_results@results$summary(vars)
as.CmdStanMCMC(mcmc_results)$summary(vars)
```

We can already see here from the `rhat` statistic whether the MCMC sampler
Expand Down Expand Up @@ -306,7 +306,7 @@ in a format that is understood by `bayesplot`. We can e.g. look
at some simple trace plots:

```{r fig.width=6}
vars_draws <- mcmc_results@results$draws(vars)
vars_draws <- as.CmdStanMCMC(mcmc_results)$draws(vars)
library(bayesplot)
mcmc_trace(vars_draws)
Expand Down

0 comments on commit 89b508a

Please sign in to comment.