Skip to content

Commit

Permalink
Change API to make more flexible in practice (#33)
Browse files Browse the repository at this point in the history
* total change

* format

* format

* format

* oh well then

* more

* pre-commit

* format

* add more examples

Co-authored-by: Nathan Simpson <phinate@protonmail.com>
  • Loading branch information
Nathan Simpson and Nathan Simpson authored Jun 24, 2022
1 parent 17b85c0 commit 5b8157c
Show file tree
Hide file tree
Showing 26 changed files with 3,605 additions and 4,524 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/psf/black
rev: 22.1.0
rev: 22.3.0
hooks:
- id: black-jupyter

Expand Down
1,655 changes: 1,272 additions & 383 deletions demo.ipynb

Large diffs are not rendered by default.

Binary file added examples/ap000.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1,081 changes: 1,081 additions & 0 deletions examples/binning.ipynb

Large diffs are not rendered by default.

457 changes: 457 additions & 0 deletions examples/cuts.ipynb

Large diffs are not rendered by default.

380 changes: 380 additions & 0 deletions examples/diffable_histograms.ipynb

Large diffs are not rendered by default.

Binary file added examples/float.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions examples/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
celluloid
git+http://github.com/scikit-hep/pyhf.git@make_difffable_model_ctor
plothelp
234 changes: 234 additions & 0 deletions examples/simple-analysis-optimisation.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import matplotlib.pyplot as plt\n",
"import optax\n",
"from jaxopt import OptaxSolver\n",
"import relaxed\n",
"from celluloid import Camera\n",
"from functools import partial\n",
"import matplotlib.lines as mlines\n",
"\n",
"# matplotlib settings\n",
"plt.rc(\"figure\", figsize=(6, 3), dpi=220, facecolor=\"w\")\n",
"plt.rc(\"legend\", fontsize=6)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Optimising a simple one-bin analysis with `relaxed`\n",
"\n",
"Let's define an analysis with a predicted number of signal and background events, with some uncertainty on the background estimate. We'll abstract the analysis configuration into a single parameter $\\phi$ like so:\n",
"\n",
"$$s = 15 + \\phi $$\n",
"$$b = 45 - 2 \\phi $$\n",
"$$\\sigma_b = 0.5 + 0.1*\\phi^2 $$\n",
"\n",
"Note that $s \\propto \\phi$ and $\\propto -2\\phi$, so increasing $\\phi$ corresponds to increasing the signal/backround ratio. However, our uncertainty scales like $\\phi^2$, so we're also going to compromise in our certainty of the background count as we do that. This kind of tradeoff between $s/b$ ratio and uncertainty is important for the discovery of a new signal, so we can't get away with optimising $s/b$ alone.\n",
"\n",
"To illustrate this, we'll plot the discovery significance for this model with and without uncertainty."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# model definition\n",
"def yields(phi, uncertainty=True):\n",
" s = 15 + phi\n",
" b = 45 - 2 * phi\n",
" db = (\n",
" 0.5 + 0.1 * phi**2 if uncertainty else jnp.zeros_like(phi) + 0.001\n",
" ) # small enough to be negligible\n",
" return jnp.asarray([s]), jnp.asarray([b]), jnp.asarray([db])\n",
"\n",
"\n",
"# our analysis pipeline, from phi to p-value\n",
"def pipeline(phi, return_yields=False, uncertainty=True):\n",
" y = yields(phi, uncertainty=uncertainty)\n",
" # use a dummy version of pyhf for simplicity + compatibility with jax\n",
" model = relaxed.dummy_pyhf.uncorrelated_background(*y)\n",
" nominal_pars = jnp.array([1.0, 1.0])\n",
" data = model.expected_data(nominal_pars) # we expect the nominal model\n",
" # do the hypothesis test (and fit model pars with gradient descent)\n",
" pvalue = relaxed.infer.hypotest(\n",
" 0.0, # value of mu for the alternative hypothesis\n",
" data,\n",
" model,\n",
" test_stat=\"q0\", # discovery significance test\n",
" lr=1e-3,\n",
" expected_pars=nominal_pars, # optionally providing MLE pars in advance\n",
" )\n",
" if return_yields:\n",
" return pvalue, y\n",
" else:\n",
" return pvalue\n",
"\n",
"\n",
"# calculate p-values for a range of phi values\n",
"phis = jnp.linspace(0, 10, 100)\n",
"\n",
"# with uncertainty\n",
"pipe = partial(pipeline, return_yields=True, uncertainty=True)\n",
"pvals, ys = jax.vmap(pipe)(phis) # map over phi grid\n",
"# without uncertainty\n",
"pipe_no_uncertainty = partial(pipeline, uncertainty=False)\n",
"pvals_no_uncertainty = jax.vmap(pipe_no_uncertainty)(phis)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, axs = plt.subplots(2, 1, sharex=True)\n",
"axs[0].plot(phis, pvals, label=\"with uncertainty\", color=\"C2\")\n",
"axs[0].plot(phis, pvals_no_uncertainty, label=\"no uncertainty\", color=\"C4\")\n",
"axs[0].set_ylabel(\"$p$-value\")\n",
"# plot vertical dotted line at minimum of p-values + s/b\n",
"best_phi = phis[jnp.argmin(pvals)]\n",
"axs[0].axvline(x=best_phi, linestyle=\"dotted\", color=\"C2\", label=\"optimal p-value\")\n",
"axs[0].axvline(\n",
" x=phis[jnp.argmin(pvals_no_uncertainty)],\n",
" linestyle=\"dotted\",\n",
" color=\"C4\",\n",
" label=r\"optimal $s/b$\",\n",
")\n",
"axs[0].legend(loc=\"upper left\", ncol=2)\n",
"s, b, db = ys\n",
"s, b, db = s.ravel(), b.ravel(), db.ravel() # everything is [[x]] for pyhf\n",
"axs[1].fill_between(phis, s + b, b, color=\"C9\", label=\"signal\")\n",
"axs[1].fill_between(phis, b, color=\"C1\", label=\"background\")\n",
"axs[1].fill_between(phis, b - db, b + db, facecolor=\"k\", alpha=0.2, label=r\"$\\sigma_b$\")\n",
"axs[1].set_xlabel(\"$\\phi$\")\n",
"axs[1].set_ylabel(\"yield\")\n",
"axs[1].legend(loc=\"lower left\")\n",
"plt.suptitle(\"Discovery p-values, with and without uncertainty\")\n",
"plt.tight_layout()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Using gradient descent, we can optimise this analysis in an uncertainty-aware way by directly optimising $\\phi$ for the lowest discovery p-value. Here's how you do that:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# The fast way!\n",
"# use the OptaxSolver wrapper from jaxopt to perform the minimisation\n",
"# set a couple of tolerance kwargs to make sure we don't get stuck\n",
"solver = OptaxSolver(pipeline, opt=optax.adam(1e-3), tol=1e-8, maxiter=10000)\n",
"pars = 9.0 # random init\n",
"result = solver.run(pars).params\n",
"print(\n",
" f\"our solution: phi={result:.5f}\\ntrue optimum: phi={phis[jnp.argmin(pvals)]:.5f}\\nbest s/b: phi=10\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# The longer way (but with plots)!\n",
"pipe = partial(pipeline, return_yields=True, uncertainty=True)\n",
"solver = OptaxSolver(pipe, opt=optax.adam(1e-1), has_aux=True)\n",
"pars = 9.0\n",
"state = solver.init_state(pars) # we're doing init, update steps instead of .run()\n",
"\n",
"plt.rc(\"figure\", figsize=(6, 3), dpi=220, facecolor=\"w\")\n",
"plt.rc(\"legend\", fontsize=8)\n",
"fig, axs = plt.subplots(1, 2)\n",
"cam = Camera(fig)\n",
"steps = 5 # increase me for better results! (100ish works well)\n",
"for i in range(steps):\n",
" pars, state = solver.update(pars, state)\n",
" s, b, db = state.aux\n",
" val = state.value\n",
" ax = axs[0]\n",
" cv = ax.plot(phis, pvals, c=\"C0\")\n",
" cvs = ax.plot(phis, pvals_no_uncertainty, c=\"green\")\n",
" current = ax.scatter(pars, val, c=\"C0\")\n",
" ax.set_xlabel(r\"analysis config $\\phi$\")\n",
" ax.set_ylabel(\"p-value\")\n",
" ax.legend(\n",
" [\n",
" mlines.Line2D([], [], color=\"C0\"),\n",
" mlines.Line2D([], [], color=\"green\"),\n",
" current,\n",
" ],\n",
" [\"p-value (with uncert)\", \"p-value (without uncert)\", \"current value\"],\n",
" frameon=False,\n",
" )\n",
" ax.text(0.3, 0.61, f\"step {i}\", transform=ax.transAxes)\n",
" ax = axs[1]\n",
" ax.set_ylim((0, 80))\n",
" b1 = ax.bar(0.5, b, facecolor=\"C1\", label=\"b\")\n",
" b2 = ax.bar(0.5, s, bottom=b, facecolor=\"C9\", label=\"s\")\n",
" b3 = ax.bar(\n",
" 0.5, db, bottom=b - db / 2, facecolor=\"k\", alpha=0.5, label=r\"$\\sigma_b$\"\n",
" )\n",
" ax.set_ylabel(\"yield\")\n",
" ax.set_xticks([])\n",
" ax.legend([b1, b2, b3], [\"b\", \"s\", r\"$\\sigma_b$\"], frameon=False)\n",
" plt.tight_layout()\n",
" cam.snap()\n",
"\n",
"ani = cam.animate()\n",
"# uncomment this to save and view the animation!\n",
"# ani.save(\"ap00.gif\", fps=9)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"interpreter": {
"hash": "22d6333b89854cd01c2018f3ca2f5a59a2cde2765fbca789ff36cfad48ca629b"
},
"kernelspec": {
"display_name": "Python 3.9.12 ('venv': venv)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Binary file added examples/withbinfloat.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/withnobinfloat.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 5b8157c

Please # to comment.