Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Support for differentiation through arrays #133

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open

Support for differentiation through arrays #133

wants to merge 9 commits into from

Conversation

richfitz
Copy link
Member

@richfitz richfitz commented Dec 17, 2024

Some duplication of code from odin here, I've moved that within maths so that we can reuse it easily in odin. Tests have been copied over too.

This PR includes support for differentiating array expressions. These are quite hard to think about.

When we talk about differentiating with respect to x where x is an array, we seek to create an object with the same dimensions as x. So if x is a vector then d/dx(expr) is a vector with the derivatives of (x[1], x[2], ..., x[n]). Similarly if x is a matrix then d/dx(expr) is a matrix where element (i, j) is the derivative of expr with respect to x[i, j] and so on.

Array indexing is therefore pretty easy. The derivative with respect to x[...] is 1 if we have a derivative of x at exactly the same index and 0 otherwise. In many expressions we'll have f(x[i]) on the rhs in which case we end up with f'(x[i]) on the way out.

Things are harder if we have other indices within the expression:

> differentiate(quote(x[2]), "x")
if (2 == i) 1 else 0

this makes sense when we see the overall loop:

for (size_t i = 1 i <= n; ++i) {
  ... x[2] ...
}

so here the derivative of this expression is 1 where i is 2 and 0 otherwise.

For sums, the same logic applies: if we have the sum over the whole array x the derivative of this with respect to x[i, j, ...] is an array of 1 everywhere:

differentiate(quote(sum(x)), "x")

(it's useful to think of this within some loop driven by the indices on the lhs of odin, perhaps?). This is because each element x[...] appears exactly once in the sum.

Partial sums are much nastier to think about but the solution here follows from the logic above: check that we include x[...] in the sum and return 1 if we do, 0 otherwise.

Do to this we check:

  • are we summing over x
  • can we see that our x is excluded (that is the same position in x is never consumed)
  • generate a series of expressions for all remaining indices and && them together, return 1 if they are all TRUE

Some examples:

> differentiate(quote(sum(x[i, ])), "x")
[1] 1
> differentiate(quote(sum(x[j, ])), "x")
if (j == i) 1 else 0
> differentiate(quote(sum(x[j, , a:b])), "x")
if (j == i && k >= a && k <= b) 1 else 0

Copy link

codecov bot commented Dec 18, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.78%. Comparing base (bbe410f) to head (7e68bb6).

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #133   +/-   ##
=======================================
  Coverage   99.78%   99.78%           
=======================================
  Files          66       66           
  Lines        5232     5232           
=======================================
  Hits         5221     5221           
  Misses         11       11           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@richfitz richfitz marked this pull request as ready for review December 19, 2024 14:57
@richfitz richfitz requested a review from weshinsley December 19, 2024 15:06
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant