Skip to content

Commit

Permalink
once more
Browse files Browse the repository at this point in the history
  • Loading branch information
dweindl committed Dec 10, 2024
1 parent e25785f commit 6420d67
Show file tree
Hide file tree
Showing 7 changed files with 436 additions and 365 deletions.
59 changes: 31 additions & 28 deletions doc/example/distributions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,34 +32,34 @@
"import seaborn as sns\n",
"\n",
"from petab.v1.C import *\n",
"from petab.v1.distributions import *\n",
"from petab.v1.priors import Prior\n",
"\n",
"sns.set_style(None)\n",
"\n",
"\n",
"def plot(distr: Distribution, ax=None):\n",
"def plot(prior: Prior, ax=None):\n",
" \"\"\"Visualize a distribution.\"\"\"\n",
" if ax is None:\n",
" fig, ax = plt.subplots()\n",
"\n",
" sample = distr.sample(10000)\n",
" sample = prior.sample(10000)\n",
"\n",
" # pdf\n",
" xmin = min(sample.min(), distr.lb_scaled if distr.bounds is not None else sample.min())\n",
" xmax = max(sample.max(), distr.ub_scaled if distr.bounds is not None else sample.max())\n",
" xmin = min(sample.min(), prior.lb_scaled if prior.bounds is not None else sample.min())\n",
" xmax = max(sample.max(), prior.ub_scaled if prior.bounds is not None else sample.max())\n",
" x = np.linspace(xmin, xmax, 500)\n",
" y = distr.pdf(x)\n",
" y = prior.pdf(x)\n",
" ax.plot(x, y, color='red', label='pdf')\n",
"\n",
" sns.histplot(sample, stat='density', ax=ax, label=\"sample\")\n",
"\n",
" # bounds\n",
" if distr.bounds is not None:\n",
" for bound in (distr.lb_scaled, distr.ub_scaled):\n",
" if prior.bounds is not None:\n",
" for bound in (prior.lb_scaled, prior.ub_scaled):\n",
" if bound is not None and np.isfinite(bound):\n",
" ax.axvline(bound, color='black', linestyle='--', label='bound')\n",
"\n",
" ax.set_title(str(distr))\n",
" ax.set_title(str(prior))\n",
" ax.set_xlabel('Parameter value on the parameter scale')\n",
" ax.grid(False)\n",
" handles, labels = ax.get_legend_handles_labels()\n",
Expand All @@ -81,11 +81,11 @@
"metadata": {},
"cell_type": "code",
"source": [
"plot(Uniform(0, 1))\n",
"plot(Normal(0, 1))\n",
"plot(Laplace(0, 1))\n",
"plot(LogNormal(0, 1))\n",
"plot(LogLaplace(1, 0.5))"
"plot(Prior(UNIFORM, (0, 1)))\n",
"plot(Prior(NORMAL, (0, 1)))\n",
"plot(Prior(LAPLACE, (0, 1)))\n",
"plot(Prior(LOG_NORMAL, (0, 1)))\n",
"plot(Prior(LOG_LAPLACE, (1, 0.5)))"
],
"id": "4f09e50a3db06d9f",
"outputs": [],
Expand All @@ -101,10 +101,11 @@
"metadata": {},
"cell_type": "code",
"source": [
"plot(Normal(10, 2, transformation=LIN))\n",
"plot(Normal(10, 2, transformation=LOG))\n",
"plot(Prior(NORMAL, (10, 2), transformation=LIN))\n",
"plot(Prior(NORMAL, (10, 2), transformation=LOG))\n",
"\n",
"# Note that the log-normal distribution is different from a log-transformed normal distribution:\n",
"plot(LogNormal(10, 2, transformation=LIN))"
"plot(Prior(LOG_NORMAL, (10, 2), transformation=LIN))"
],
"id": "f6192c226f179ef9",
"outputs": [],
Expand All @@ -120,8 +121,8 @@
"metadata": {},
"cell_type": "code",
"source": [
"plot(LogNormal(10, 2, transformation=LOG))\n",
"plot(ParameterScaleNormal(10, 2))"
"plot(Prior(LOG_NORMAL, (10, 2), transformation=LOG))\n",
"plot(Prior(PARAMETER_SCALE_NORMAL, (10, 2)))"
],
"id": "34c95268e8921070",
"outputs": [],
Expand All @@ -137,11 +138,11 @@
"metadata": {},
"cell_type": "code",
"source": [
"plot(Uniform(0, 1, transformation=LOG10))\n",
"plot(ParameterScaleUniform(0, 1, transformation=LOG10))\n",
"plot(Prior(UNIFORM, (0.01, 2), transformation=LOG10))\n",
"plot(Prior(PARAMETER_SCALE_UNIFORM, (0.01, 2), transformation=LOG10))\n",
"\n",
"plot(Uniform(0, 1, transformation=LIN))\n",
"plot(ParameterScaleUniform(0, 1, transformation=LIN))\n"
"plot(Prior(UNIFORM, (0.01, 2), transformation=LIN))\n",
"plot(Prior(PARAMETER_SCALE_UNIFORM, (0.01, 2), transformation=LIN))\n"
],
"id": "5ca940bc24312fc6",
"outputs": [],
Expand All @@ -157,8 +158,8 @@
"metadata": {},
"cell_type": "code",
"source": [
"plot(Normal(0, 1, bounds=(-4, 4))) # negligible clipping-bias at 4 sigma\n",
"plot(Uniform(0, 1, bounds=(0.1, 0.9))) # significant clipping-bias"
"plot(Prior(NORMAL, (0, 1), bounds=(-4, 4))) # negligible clipping-bias at 4 sigma\n",
"plot(Prior(UNIFORM, (0, 1), bounds=(0.1, 0.9))) # significant clipping-bias"
],
"id": "4ac42b1eed759bdd",
"outputs": [],
Expand All @@ -174,8 +175,10 @@
"metadata": {},
"cell_type": "code",
"source": [
"plot(Normal(10, 1, bounds=(6, 14), transformation=\"log10\"))\n",
"plot(ParameterScaleNormal(10, 1, bounds=(10**6, 10**14), transformation=\"log10\"))\n"
"plot(Prior(NORMAL, (10, 1), bounds=(6, 14), transformation=\"log10\"))\n",
"plot(Prior(PARAMETER_SCALE_NORMAL, (10, 1), bounds=(10**6, 10**14), transformation=\"log10\"))\n",
"plot(Prior(LAPLACE, (10, 2), bounds=(6, 14)))\n",
"\n"
],
"id": "581e1ac431860419",
"outputs": [],
Expand All @@ -185,7 +188,7 @@
"metadata": {},
"cell_type": "code",
"source": "",
"id": "802a64be56a6c94f",
"id": "633733651bbc3ef0",
"outputs": [],
"execution_count": null
}
Expand Down
3 changes: 2 additions & 1 deletion petab/v1/C.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@
LOG10 = "log10"
#: Supported observable transformations
OBSERVABLE_TRANSFORMATIONS = [LIN, LOG, LOG10]

#: Supported parameter transformations
PARAMETER_SCALES = [LIN, LOG, LOG10]

# NOISE MODELS

Expand Down
Loading

0 comments on commit 6420d67

Please sign in to comment.