Skip to content

Normalization #281

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 4 commits into from
Nov 29, 2021
Merged

Normalization #281

merged 4 commits into from
Nov 29, 2021

Conversation

caglayantuna
Copy link
Member

This pull request is related to issue #264. We suggest some updates in _nn_cp and _cp.py files in order to have consistent normalization options for non_negative_parafac, non_negative_parafac_hals and parafac functions. In these 3 functions, we added weights to mttkrp and pseudo_inverse (accum for non_negative_parafac) computation;

mttkrp = unfolding_dot_khatri_rao(tensor, (weights, factors), mode)
pseudo_inverse = tl.reshape(weights, (-1, 1)) * pseudo_inverse * tl.reshape(weights, (1, -1))

Since we have used weights to compute mttkrp, we removed the weights from the iprod computation;

iprod = tl.sum(tl.sum(mttkrp * factors[-1], axis=0))

Finally, we suggest to use cp_normalize function after the error computation for all functions.

According to our experiments, we don't observe any error as in issue #264 with these modifications.

@codecov
Copy link

codecov bot commented Jun 17, 2021

Codecov Report

Merging #281 (cff18e7) into main (c41092d) will increase coverage by 0.03%.
The diff coverage is 97.56%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #281      +/-   ##
==========================================
+ Coverage   88.09%   88.12%   +0.03%     
==========================================
  Files         103      103              
  Lines        5963     5988      +25     
==========================================
+ Hits         5253     5277      +24     
- Misses        710      711       +1     
Impacted Files Coverage Δ
tensorly/decomposition/_nn_cp.py 85.83% <94.44%> (+1.39%) ⬆️
tensorly/decomposition/_cp.py 87.22% <100.00%> (-0.32%) ⬇️
tensorly/decomposition/tests/test_cp.py 100.00% <100.00%> (ø)
tensorly/tenalg/proximal.py 66.98% <0.00%> (-0.48%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update c41092d...cff18e7. Read the comment docs.

@JeanKossaifi
Copy link
Member

I'm not sure this is necessarily better - it will be slightly heavier computationally. It's also not necessary as the weights would just be naturally absorbed in the last factor and we discussed, I don't have a strong intuition on whether one is better than the other. In Kolda's seminal paper(s) (and in the tensor-toolbox), I believe they do something similar to what we already have.

I think a discussion is needed to simplify and make both code and API more uniform, I'll post that in the issue.

It would be helpful to benchmark both approaches and see how it affects convergence, performance and numerical stability (e.g. are the factors in a better range, etc)?

@cohenjer
Copy link
Contributor

@JeanKossaifi Sure we should try even further to study normalization, but for now the nonnegative Parafac normalization is bugged (#264) and with the PR, we add a class method for normalization, which we can always tinker with later on. I think it would be better to merge the current PR, and maybe open an issue for further discussions?

Copy link
Member

@JeanKossaifi JeanKossaifi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean an attribute for normalization? I don't see a class method in the PR.
I left a few comments in the code.

The first priority is of course correctness. Once this is achieved, however, I'm weary about adding additional complexity in the basic methods / algorithms that are used often as this can quickly result in much slower algorithms. For instance if the weights don't influence convergence or numerical stability then why incorporate them in the mttkrp / update calculations when they otherwise would automatically get absorbed in the last term.


mttkrp = unfolding_dot_khatri_rao(tensor, (None, factors), mode)
pseudo_inverse = tl.reshape(weights, (-1, 1)) * pseudo_inverse * tl.reshape(weights, (1, -1))
mttkrp = unfolding_dot_khatri_rao(tensor, (weights, factors), mode)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need all that since the weights are automatically absorbed in the last factor. If there is no advantage it's just additional computation (slower algo) for no strong reason.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can compare the two version on examples of various sizes to check if this change is making the algorithm effectively slower?
If indeed we observe that the algorithm runs slower, we will revert to ignoring weights when normalization is not asked by the user. However if the user asks for normalization, then we must somehow store the weights somewhere. The whole idea of normalization, I think, was to avoid factors exploding in norm or being conversely extremely small, so we should not pull the norm in them if the user asks normalization.
But this would mean having the no-weights updates when normalization is off, and weighted updates with normalization is on. I would argue that if the proposed change does not lead to any sensible computation time difference, we should keep it so that the code is easier to understand and maintain?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with all these points :)

factors[mode] = factor
if normalize_factors and mode != modes_list[-1]:
weights, factors = cp_normalize((weights, factors))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is much better - thanks for uniformizing it

@@ -243,29 +245,28 @@ def non_negative_parafac(tensor, rank, n_iter_max=100, init='svd', svd='numpy_sv
accum *= tl.dot(tl.transpose(factors[e]), factors[e])
else:
accum = tl.dot(tl.transpose(factors[e]), factors[e])

accum = tl.reshape(weights, (-1, 1)) * accum * tl.reshape(weights, (1, -1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct? The weights are for the full tensor, this is scaling the factors up everytime.
e.g. the full tensor is \sum_r weights_r factor[0, r] \outer ... \outer factor[-1, r]
in other words, each element of weights is used only once

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is correct I think:

  • If there is no normalization, weights are always 1 and this line does nothing.
  • If there is normalization, in the current version all the factors have unit norm columns, and the weights must be accounted for in the gradient computation. Just like the non-updated factors appear twice each in accum, the weights also appear twice.

@JeanKossaifi
Copy link
Member

Thanks @cohenjer, I agree with all your points.

@JeanKossaifi
Copy link
Member

Thanks for the great work @caglayantuna and @cohenjer.
Is everyone happy with merging this?

@caglayantuna
Copy link
Member Author

Thanks @JeanKossaifi. From my side, it is ok.

@cohenjer
Copy link
Contributor

Let's go !

@JeanKossaifi
Copy link
Member

Awesome, merging!

@JeanKossaifi JeanKossaifi merged commit c2fa4ec into tensorly:main Nov 29, 2021
@caglayantuna caglayantuna deleted the normalization branch November 30, 2021 08:51
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants