This is an implementation of matrix one norm estimation in jax as specified by http://eprints.maths.manchester.ac.uk/321/1/35608.pdf
The implementation passes the scipy test suite with some minor relaxations, i.e. number of column resamples. Relaxed tests are documented in ./test_onenormest.py
Some basic benchmarks using a GPU on the Google Colab free tier see ~8x improvement from the scipy CPU implementation for 4096x4096 matrices.
There are existing implementations in scipy and octave
The algorithm as specified is imperative and control flow heavy. Additionally, a few variables have non-constant dimensions. This implementation has a few quirks to get jax to jit compile.
The main loop has many conditional early breaks. We handle this by manual continuation passing into a branch of jax.lax.cond
.
ind_hist
and ind
must have fixed dimensions.
In the scipy implementation and Higham, ind_hist
is a growable array that stores indices of the used unit vectors. In the octave implementation, ind_hist
is a fixed sized array that writes 1 into index j
when e_j
is used. We use the octave implementation to keep the array a fixed size.
ind
is shape (n,)
in Higham but only the first t
values are read out of it. The first t
values are read for writing to ind_hist
and it is read out of with column indices of Y
which is shape (n, t)
. Because we only test elementary vectors a single time, it is not guaranteed we'll have t
elementary vectors to test on each loop. We handle this by filling non used elements of ind
with a sentinel value n
. n
will be used to fill columns in X
with the zero vector instead of elementary vectors. These zero vectors will cause norm estimations of 0 which are always correct underestimations of the one norm. Note that because ind
can have the additional sentinel value of n
, ind_hist
must be extended to length n + 1
. Noting in ind_hist
that the sentinel value has been used has no effect.