Skip to content

Commit 683e023

Browse files
authored
make it easier to find out what ⊢is and how to type it (#216)
* make it easier to find out what ⊢is and how to type it * more edits * try * try again * apparently referring to unicode operators is really difficult * a bit more editing of docstrings
1 parent e3ff6b4 commit 683e023

File tree

2 files changed

+30
-26
lines changed

2 files changed

+30
-26
lines changed

docs/src/index.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,14 @@ test_frule(foo ⊢ Tangent{Foo}(;a=rand()), rand())
134134

135135
## Specifying Tangents
136136
[`test_frule`](@ref) and [`test_rrule`](@ref) allow you to specify the tangents used for testing.
137-
This is done by passing in `x ⊢ Δx`, where `x` is the primal and `Δx` is the tangent, in the place of the primal inputs.
138-
If this is not done the tangent will be automatically generated via `FiniteDifferences.rand_tangent`.
137+
By default, tangents will be automatically generated via `FiniteDifferences.rand_tangent`.
138+
To explicitly specify a tangent, pass in `x ⊢ Δx`, where `x` is the primal and `Δx` is the tangent, in the place of the primal inputs.
139+
(You can enter [``](@ref) via `\vdash` + tab in the Julia REPL and supporting editors.)
139140
A special case of this is that if you specify it as `x ⊢ NoTangent()` then finite differencing will not be used on that input.
140141
Similarly, by setting the `output_tangent` keyword argument, you can specify the tangent for the primal output.
141142

142143
This can be useful when the default provided `FiniteDifferences.rand_tangent` doesn't produce the desired tangent for your type.
143-
For example the default tangent for an `Int` is `NoTangent()`.
144-
Which is correct e.g. when the `Int` represents a discrete integer like in indexing.
144+
For example, the default tangent for an `Int` is `NoTangent()`, which is correct e.g. when the `Int` represents a discrete integer like in indexing.
145145
But if you are testing something where the `Int` is actually a special case of a real number, then you would want to specify the tangent as a `Float64`.
146146

147147
Care must be taken when manually specifying tangents.
@@ -273,4 +273,4 @@ Test.DefaultTestSet("test_rrule: abs on Float64", Any[], 5, false, false)
273273
```
274274

275275
This behavior can also be overridden globally by setting the environment variable `CHAINRULES_TEST_INFERRED` before ChainRulesTestUtils is loaded or by changing `ChainRulesTestUtils.TEST_INFERRED[]` from inside Julia.
276-
ChainRulesTestUtils can detect whether a test is run as part of [PkgEval](https://github.com/JuliaCI/PkgEval.jl)and in this case disables inference tests automatically. Packages can use [`@maybe_inferred`](@ref) to get the same behavior for other inference tests.
276+
ChainRulesTestUtils can detect whether a test is run as part of [PkgEval](https://github.com/JuliaCI/PkgEval.jl) and in this case disables inference tests automatically. Packages can use [`@maybe_inferred`](@ref) to get the same behavior for other inference tests.

src/testers.jl

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@ Given a function `f` with scalar input and scalar output, perform finite differe
55
at input point `z` to confirm that there are correct `frule` and `rrule`s provided.
66
77
# Arguments
8-
- `f`: Function for which the `frule` and `rrule` should be tested.
9-
- `z`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
8+
- `f`: function for which the `frule` and `rrule` should be tested.
9+
- `z`: input at which to evaluate `f` (should generally be set to an arbitrary point in the domain).
1010
11-
`fkwargs` are passed to `f` as keyword arguments.
12-
If `check_inferred=true`, then the type-stability of the `frule` and `rrule` are checked.
13-
All remaining keyword arguments are passed to `isapprox`.
11+
# Keyword Arguments
12+
- `fdm`: the finite differencing method to use.
13+
- `fkwargs` are passed to `f` as keyword arguments.
14+
- If `check_inferred=true`, then the inferrability (type-stability) of the `frule` and `rrule` are checked.
15+
- All remaining keyword arguments are passed to `isapprox`.
1416
"""
1517
function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), check_inferred=true, kwargs...)
1618
# To simplify some of the calls we make later lets group the kwargs for reuse
@@ -71,19 +73,20 @@ end
7173
7274
# Arguments
7375
- `config`: defaults to `ChainRulesTestUtils.ADviaRuleConfig`.
74-
- `f`: Function for which the `frule` should be tested. Can also provide `f ⊢ ḟ`.
75-
- `args` either the primal args `x`, or primals and their tangents: `x ⊢ ẋ`
76-
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
77-
- `ẋ`: differential w.r.t. `x`, will be generated automatically if not provided
76+
- `f`: function for which the `frule` should be tested. Its tangent can be provided using `f ⊢ ḟ`.
77+
(You can enter `⊢` via `\\vdash` + tab in the Julia REPL and supporting editors.)
78+
- `args...`: either the primal args `x`, or primals and their tangents: `x ⊢ ẋ`
79+
- `x`: input at which to evaluate `f` (should generally be set to an arbitrary point in the domain).
80+
- `ẋ`: differential w.r.t. `x`; will be generated automatically if not provided.
7881
Non-differentiable arguments, such as indices, should have `ẋ` set as `NoTangent()`.
7982
8083
# Keyword Arguments
81-
- `output_tangent` tangent to test accumulation of derivatives against
82-
should be a differential for the output of `f`. Is set automatically if not provided.
84+
- `output_tangent`: tangent against which to test accumulation of derivatives.
85+
Should be a differential for the output of `f`. Is set automatically if not provided.
8386
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
84-
- `frule_f=frule`: Function with an `frule`-like API that is tested (defaults to
87+
- `frule_f=frule`: function with an `frule`-like API that is tested (defaults to
8588
`frule`). Used for testing gradients from AD systems.
86-
- If `check_inferred=true`, then the inferrability of the `frule` is checked,
89+
- If `check_inferred=true`, then the inferrability (type-stability) of the `frule` is checked,
8790
as long as `f` is itself inferrable.
8891
- `fkwargs` are passed to `f` as keyword arguments.
8992
- All remaining keyword arguments are passed to `isapprox`.
@@ -144,21 +147,22 @@ end
144147
145148
# Arguments
146149
- `config`: defaults to `ChainRulesTestUtils.ADviaRuleConfig`.
147-
- `f`: Function to which rule should be applied. Can also provide `f ⊢ f̄`.
148-
- `args` either the primal args `x`, or primals and their tangents: `x ⊢ x̄`
149-
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
150-
- `x̄`: currently accumulated cotangent, will be generated automatically if not provided
150+
- `f`: function for which the `rrule` should be tested. Its tangent can be provided using `f ⊢ f̄`.
151+
(You can enter `⊢` via `\\vdash` + tab in the Julia REPL and supporting editors.)
152+
- `args...`: either the primal args `x`, or primals and their tangents: `x ⊢ x̄`
153+
- `x`: input at which to evaluate `f` (should generally be set to an arbitrary point in the domain).
154+
- `x̄`: currently accumulated cotangent; will be generated automatically if not provided.
151155
Non-differentiable arguments, such as indices, should have `x̄` set as `NoTangent()`.
152156
153157
# Keyword Arguments
154-
- `output_tangent` the seed to propagate backward for testing (technically a cotangent).
158+
- `output_tangent`: the seed to propagate backward for testing (technically a cotangent).
155159
should be a differential for the output of `f`. Is set automatically if not provided.
156-
- `check_thunked_output_tangent=true`: also checks that passing a thunked version of the
160+
- `check_thunked_output_tangent=true`: also checks that passing a thunked version of the
157161
output tangent to the pullback returns the same result.
158162
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
159-
- `rrule_f=rrule`: Function with an `rrule`-like API that is tested (defaults to `rrule`).
163+
- `rrule_f=rrule`: function with an `rrule`-like API that is tested (defaults to `rrule`).
160164
Used for testing gradients from AD systems.
161-
- If `check_inferred=true`, then the inferrability of the `rrule` is checked
165+
- If `check_inferred=true`, then the inferrability (type-stability) of the `rrule` is checked
162166
— if `f` is itself inferrable — along with the inferrability of the pullback it returns.
163167
- `fkwargs` are passed to `f` as keyword arguments.
164168
- All remaining keyword arguments are passed to `isapprox`.

0 commit comments

Comments
 (0)