You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The issue comes from the fact that there are a lot of intermediate matrices being constructed while computing the final matrix for diagonalization. In inference, some of those intermediate variables are dropped as the program progresses, but in gradient computation this is particularly bad, as all intermediate results are stored. Here are some thoughts of what can be improved:
In the computation of the matrix elements, there are matrices being constructed from vectors using broadcasting (or np.outer). These correspond to rank-1 updates that create rank-2 intermediate results. Some of those could be wrapped in a primitive that only depends on the input vectors, which should reduce the memory needed for the backward pass.
Also in that computation, if a layer is homogeneous, as is very often the case for the claddings, the corresponding matrix is only non-zero on the main diagonal (due to the element-wise multiplication by eps_inv_mat). However, the other matrices are still computed in their entirety.
Maybe something can also be improved in the way the matrix is constructed
It might be possible to significantly improve the memory usage especially when computing gradients, but how exactly requires some thought.
The text was updated successfully, but these errors were encountered: