Support for differentiation through arrays #133
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Some duplication of code from odin here, I've moved that within
maths
so that we can reuse it easily in odin. Tests have been copied over too.This PR includes support for differentiating array expressions. These are quite hard to think about.
When we talk about differentiating with respect to
x
wherex
is an array, we seek to create an object with the same dimensions asx
. So ifx
is a vector thend/dx(expr)
is a vector with the derivatives of(x[1], x[2], ..., x[n])
. Similarly ifx
is a matrix thend/dx(expr)
is a matrix where element(i, j)
is the derivative ofexpr
with respect tox[i, j]
and so on.Array indexing is therefore pretty easy. The derivative with respect to
x[...]
is 1 if we have a derivative ofx
at exactly the same index and 0 otherwise. In many expressions we'll havef(x[i])
on the rhs in which case we end up withf'(x[i])
on the way out.Things are harder if we have other indices within the expression:
this makes sense when we see the overall loop:
so here the derivative of this expression is 1 where
i
is 2 and 0 otherwise.For sums, the same logic applies: if we have the sum over the whole array
x
the derivative of this with respect tox[i, j, ...]
is an array of 1 everywhere:(it's useful to think of this within some loop driven by the indices on the lhs of odin, perhaps?). This is because each element
x[...]
appears exactly once in the sum.Partial sums are much nastier to think about but the solution here follows from the logic above: check that we include
x[...]
in the sum and return 1 if we do, 0 otherwise.Do to this we check:
x
x
is excluded (that is the same position inx
is never consumed)&&
them together, return1
if they are allTRUE
Some examples: