Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Claret Bruno Model #350

Merged
merged 21 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/cron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ jobs:
CMDSTANR_NO_VER_CHECK=true
JMPOST_CACHE_DIR=${{ github.workspace }}/.cache
JMPOST_FULL_TEST=TRUE
JMPOST_GRAPH_SNAPSHOT=TRUE
additional-caches: |
${{ github.workspace }}/.cache

Expand Down
9 changes: 6 additions & 3 deletions .vscode/tasks.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
"devtools::test()"
],
"env": {
"JMPOST_CACHE_DIR": "${workspaceFolder}/local/test_cache"
"JMPOST_CACHE_DIR": "${workspaceFolder}/local/test_cache",
"JMPOST_GRAPH_SNAPSHOT" : "TRUE"
},
"problemMatcher": [
"$testthat"
Expand All @@ -36,7 +37,8 @@
],
"env": {
"JMPOST_CACHE_DIR": "${workspaceFolder}/local/test_cache",
"JMPOST_FULL_TEST": "TRUE"
"JMPOST_FULL_TEST": "TRUE",
"JMPOST_GRAPH_SNAPSHOT" : "TRUE"
},
"problemMatcher": [
"$testthat"
Expand Down Expand Up @@ -65,7 +67,8 @@
],
"env": {
"JMPOST_CACHE_DIR": "${workspaceFolder}/local/test_cache",
"NOT_CRAN": "TRUE"
"NOT_CRAN": "TRUE",
"JMPOST_GRAPH_SNAPSHOT" : "TRUE"
},
"problemMatcher": [
"$testthat"
Expand Down
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ Collate:
'Quantities.R'
'SurvivalQuantities.R'
'JointModelSamples.R'
'LongitudinalClaretBruno.R'
'LongitudinalGSF.R'
'LongitudinalQuantities.R'
'LongitudinalRandomSlope.R'
Expand All @@ -103,6 +104,7 @@ Collate:
'SimGroup.R'
'SimJointData.R'
'SimLongitudinal.R'
'SimLongitudinalClaretBruno.R'
'SimLongitudinalGSF.R'
'SimLongitudinalRandomSlope.R'
'SimLongitudinalSteinFojo.R'
Expand Down
12 changes: 12 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ S3method(enableGQ,LongitudinalGSF)
S3method(enableGQ,LongitudinalRandomSlope)
S3method(enableGQ,LongitudinalSteinFojo)
S3method(enableGQ,default)
S3method(enableLink,LongitudinalClaretBruno)
S3method(enableLink,LongitudinalGSF)
S3method(enableLink,LongitudinalRandomSlope)
S3method(enableLink,LongitudinalSteinFojo)
Expand All @@ -102,6 +103,7 @@ S3method(getParameters,Link)
S3method(getParameters,LinkComponent)
S3method(getParameters,StanModel)
S3method(getParameters,default)
S3method(getPredictionNames,LongitudinalClaretBruno)
S3method(getPredictionNames,LongitudinalGSF)
S3method(getPredictionNames,LongitudinalRandomSlope)
S3method(getPredictionNames,LongitudinalSteinFojo)
Expand All @@ -114,21 +116,25 @@ S3method(initialValues,Prior)
S3method(initialValues,StanModel)
S3method(length,Link)
S3method(length,QuantityCollapser)
S3method(linkDSLD,LongitudinalClaretBruno)
S3method(linkDSLD,LongitudinalGSF)
S3method(linkDSLD,LongitudinalRandomSlope)
S3method(linkDSLD,LongitudinalSteinFojo)
S3method(linkDSLD,PromiseLongitudinalModel)
S3method(linkDSLD,default)
S3method(linkGrowth,LongitudinalClaretBruno)
S3method(linkGrowth,LongitudinalGSF)
S3method(linkGrowth,LongitudinalRandomSlope)
S3method(linkGrowth,LongitudinalSteinFojo)
S3method(linkGrowth,PromiseLongitudinalModel)
S3method(linkGrowth,default)
S3method(linkIdentity,LongitudinalClaretBruno)
S3method(linkIdentity,LongitudinalGSF)
S3method(linkIdentity,LongitudinalRandomSlope)
S3method(linkIdentity,LongitudinalSteinFojo)
S3method(linkIdentity,PromiseLongitudinalModel)
S3method(linkIdentity,default)
S3method(linkTTG,LongitudinalClaretBruno)
S3method(linkTTG,LongitudinalGSF)
S3method(linkTTG,LongitudinalSteinFojo)
S3method(linkTTG,PromiseLongitudinalModel)
Expand All @@ -139,11 +145,13 @@ S3method(names,ParameterList)
S3method(resolvePromise,Link)
S3method(resolvePromise,PromiseLinkComponent)
S3method(resolvePromise,default)
S3method(sampleObservations,SimLongitudinalClaretBruno)
S3method(sampleObservations,SimLongitudinalGSF)
S3method(sampleObservations,SimLongitudinalRandomSlope)
S3method(sampleObservations,SimLongitudinalSteinFojo)
S3method(sampleObservations,SimSurvival)
S3method(sampleStanModel,JointModel)
S3method(sampleSubjects,SimLongitudinalClaretBruno)
S3method(sampleSubjects,SimLongitudinalGSF)
S3method(sampleSubjects,SimLongitudinalRandomSlope)
S3method(sampleSubjects,SimLongitudinalSteinFojo)
Expand Down Expand Up @@ -171,6 +179,7 @@ export(GridPrediction)
export(JointModel)
export(Link)
export(LinkComponent)
export(LongitudinalClaretBruno)
export(LongitudinalGSF)
export(LongitudinalModel)
export(LongitudinalQuantities)
Expand All @@ -185,6 +194,7 @@ export(STAN_BLOCKS)
export(SimGroup)
export(SimJointData)
export(SimLongitudinal)
export(SimLongitudinalClaretBruno)
export(SimLongitudinalGSF)
export(SimLongitudinalRandomSlope)
export(SimLongitudinalSteinFojo)
Expand Down Expand Up @@ -245,6 +255,7 @@ exportClasses(DataSurvival)
exportClasses(JointModel)
exportClasses(JointModelSamples)
exportClasses(Link)
exportClasses(LongitudinalClaretBruno)
exportClasses(LongitudinalGSF)
exportClasses(LongitudinalModel)
exportClasses(LongitudinalRandomSlope)
Expand All @@ -258,6 +269,7 @@ exportClasses(PromiseLongitudinalModel)
exportClasses(SimGroup)
exportClasses(SimJointData)
exportClasses(SimLongitudinal)
exportClasses(SimLongitudinalClaretBruno)
exportClasses(SimLongitudinalGSF)
exportClasses(SimLongitudinalRandomSlope)
exportClasses(SimLongitudinalSteinFojo)
Expand Down
176 changes: 176 additions & 0 deletions R/LongitudinalClaretBruno.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
#' @include LongitudinalModel.R
#' @include StanModule.R
#' @include generics.R
#' @include ParameterList.R
#' @include Parameter.R
#' @include Link.R
NULL


#' `LongitudinalClaretBruno`
#'
#' This class extends the general [`LongitudinalModel`] class for using the
#' Claret-Bruno model for the longitudinal outcome.
#'
#' @section Available Links:
#' - [`linkDSLD()`]
#' - [`linkTTG()`]
#' - [`linkIdentity()`]
#' - [`linkGrowth()`]
#' @exportClass LongitudinalClaretBruno
.LongitudinalClaretBruno <- setClass(
Class = "LongitudinalClaretBruno",
contains = "LongitudinalModel"
)


#' @rdname LongitudinalClaretBruno-class
#'
#' @param mu_b (`Prior`)\cr for the mean population baseline sld value.
#' @param mu_g (`Prior`)\cr for the mean population growth rate.
#' @param mu_c (`Prior`)\cr for the mean population resistance rate.
#' @param mu_p (`Prior`)\cr for the mean population growth inhibition
#'
#' @param omega_b (`Prior`)\cr for the population standard deviation for the baseline sld value.
#' @param omega_g (`Prior`)\cr for the population standard deviation for the growth rate.
#' @param omega_c (`Prior`)\cr for the population standard deviation for the resistance rate.
#' @param omega_p (`Prior`)\cr for the population standard deviation for the growth inhibition.
#'
#' @param sigma (`Prior`)\cr for the variance of the longitudinal values.
#'
#' @param centred (`logical`)\cr whether to use the centred parameterization.
#'
#' @export
LongitudinalClaretBruno <- function(

mu_b = prior_normal(log(60), 0.5),
mu_g = prior_normal(log(1), 0.5),
mu_c = prior_normal(log(0.4), 0.5),
mu_p = prior_normal(log(2), 0.5),

omega_b = prior_lognormal(log(0.2), 0.5),
omega_g = prior_lognormal(log(0.2), 0.5),
omega_c = prior_lognormal(log(0.2), 0.5),
omega_p = prior_lognormal(log(0.2), 0.5),

sigma = prior_lognormal(log(0.1), 0.5),

centred = FALSE
) {

sf_model <- StanModule(decorated_render(
.x = read_stan("lm-claret-bruno/model.stan"),
centred = centred
))

parameters <- list(
Parameter(name = "lm_clbr_mu_b", prior = mu_b, size = "n_studies"),
Parameter(name = "lm_clbr_mu_g", prior = mu_g, size = "n_arms"),
Parameter(name = "lm_clbr_mu_c", prior = mu_c, size = "n_arms"),
Parameter(name = "lm_clbr_mu_p", prior = mu_p, size = "n_arms"),

Parameter(name = "lm_clbr_omega_b", prior = omega_b, size = 1),
Parameter(name = "lm_clbr_omega_g", prior = omega_g, size = 1),
Parameter(name = "lm_clbr_omega_c", prior = omega_c, size = 1),
Parameter(name = "lm_clbr_omega_p", prior = omega_p, size = 1),

Parameter(name = "lm_clbr_sigma", prior = sigma, size = 1)
)

assert_flag(centred)
parameters_extra <- if (centred) {
list(
Parameter(
name = "lm_clbr_ind_b",
prior = prior_init_only(prior_lognormal(mu_b@init, omega_b@init)),
size = "n_subjects"
),
Parameter(
name = "lm_clbr_ind_g",
prior = prior_init_only(prior_lognormal(mu_g@init, omega_g@init)),
size = "n_subjects"
),
Parameter(
name = "lm_clbr_ind_c",
prior = prior_init_only(prior_lognormal(mu_c@init, omega_c@init)),
size = "n_subjects"
),
Parameter(
name = "lm_clbr_ind_p",
prior = prior_init_only(prior_lognormal(mu_p@init, omega_p@init)),
size = "n_subjects"
)
)
} else {
list(
Parameter(name = "lm_clbr_eta_b", prior = prior_std_normal(), size = "n_subjects"),
Parameter(name = "lm_clbr_eta_g", prior = prior_std_normal(), size = "n_subjects"),
Parameter(name = "lm_clbr_eta_c", prior = prior_std_normal(), size = "n_subjects"),
Parameter(name = "lm_clbr_eta_p", prior = prior_std_normal(), size = "n_subjects")
)
}
parameters <- append(parameters, parameters_extra)

x <- LongitudinalModel(
name = "Claret-Bruno",
stan = merge(
sf_model,
StanModule("lm-claret-bruno/functions.stan")
),
parameters = do.call(ParameterList, parameters)
)
.LongitudinalClaretBruno(x)
}



#' @export
enableLink.LongitudinalClaretBruno <- function(object, ...) {
object@stan <- merge(
object@stan,
StanModule("lm-claret-bruno/link.stan")
)
object
}

#' @export
linkDSLD.LongitudinalClaretBruno <- function(prior = prior_normal(0, 2), model, ...) {
LinkComponent(
key = "link_dsld",
stan = StanModule("lm-claret-bruno/link_dsld.stan"),
prior = prior
)
}

#' @export
linkTTG.LongitudinalClaretBruno <- function(prior = prior_normal(0, 2), model, ...) {
LinkComponent(
key = "link_ttg",
stan = StanModule("lm-claret-bruno/link_ttg.stan"),
prior = prior
)
}

#' @export
linkIdentity.LongitudinalClaretBruno <- function(prior = prior_normal(0, 2), model, ...) {
LinkComponent(
key = "link_identity",
stan = StanModule("lm-claret-bruno/link_identity.stan"),
prior = prior
)
}

#' @export
linkGrowth.LongitudinalClaretBruno <- function(prior = prior_normal(0, 2), model, ...) {
LinkComponent(
key = "link_growth",
stan = StanModule("lm-claret-bruno/link_growth.stan"),
prior = prior
)
}

#' @rdname getPredictionNames
#' @export
getPredictionNames.LongitudinalClaretBruno <- function(object, ...) {
c("b", "g", "c", "p")
}
1 change: 1 addition & 0 deletions R/LongitudinalGSF.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ NULL
#' - [`linkDSLD()`]
#' - [`linkTTG()`]
#' - [`linkIdentity()`]
#' - [`linkGrowth()`]
#' @exportClass LongitudinalGSF
.LongitudinalGSF <- setClass(
Class = "LongitudinalGSF",
Expand Down
1 change: 1 addition & 0 deletions R/LongitudinalSteinFojo.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ NULL
#' - [`linkDSLD()`]
#' - [`linkTTG()`]
#' - [`linkIdentity()`]
#' - [`linkGrowth()`]
#' @exportClass LongitudinalSteinFojo
.LongitudinalSteinFojo <- setClass(
Class = "LongitudinalSteinFojo",
Expand Down
Loading
Loading