-
Notifications
You must be signed in to change notification settings - Fork 8
/
preconditioned_stochastic_gradient_descent.py
2259 lines (1995 loc) · 116 KB
/
preconditioned_stochastic_gradient_descent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Created in May, 2018
Pytorch functions for preconditioned SGD
@author: XILIN LI, lixilinx@gmail.com
Updated in Dec, 2020:
Wrapped Kronecker product preconditioner for easy use: the code will select the proper Kronecker product
preconditioner based on the formats of input left and right preconditioners.
Added torch.jit.script decorator by default.
Updates in 2022:
Added low-rank approximation (LRA) and XMat preconditioners.
Wrapped LRA, XMat and Newton preconditioners as classes for easy use.
Updates in 2023:
Added gradient whitening preconditioner.
Replaced matrix norm lower bound max(abs(A)) with sqrt(max(max_i sum_j a_ij^2, max_j sum_i a_ij^2)).
Initialize Q to ((v^T*v)/(h^T*h))^(1/4)*I if the initial scale of Q is set to None.
Wrapped affine family as a class.
Updates in 2024 Jan:
Further tightened lower bound of a matrix spectral norm (see norm_lower_bound).
Updates in 2024 Mar:
By default, the 2nd (previously 1st) order derivative info is used to normalized the step size for preconditioner update.
For class Newton optimizer, also providing a choice for keeping inv(Q) via matrix inverse rank-2 update.
Update rule for a triangular Q is modified to approximately match that on GL(n, R).
Functional usage of PSGD is to be deprecated, and not updated.
Updates in 2024 Aug:
Reverting triu01 back to triu.
QR approximation via triu01, i.e.,
[I + A]_R = I + triu(A) + triu(A, 1)
is fairly accurate when ||A|| < 0.25, but causes regressions for large lr_preconditioner.
Impacted classes: Affine and Newton.
Updates in 2024 Sept:
Reverting update_precond_affine_dropv_math_ back to update_precond_affine_math_ for the PSGD affine whitening preconditioner.
Integrating out v requires more accurate Lipschitz constant estimate of the preconditioner estimation criterion (nontrivial!).
Add class Kron for Kronecker product preconditioner applicable to tensors with any dims.
Updates in 2024 Dec:
Init Q on the fly to (torch.mean(v*v))**(1/4) * (torch.mean(h**4))**(-1/8) * I (less likely to overshoot).
For the gradient whitening preconditioner, use damped pair (v, g + 2**(-13)*v) to address machine roundoff errors.
"""
import opt_einsum
import torch
def damped_pair_vg(g, damp=2**(-13)):
"""
Instead of return (v, g), it returns pair
(v, g + sqrt(eps)*mean(abs(g))*v)
such that the covariance matrix of the modified g is lower bound by
eps * (mean(abs(g)))**2 * I
This should damp the preconditioner to encourage numerical stability.
The default amount of damping is 2**(-13), slightly smaller than sqrt(eps('single')).
If v is integrated out, let's just use the modified g;
If hvp is used, recommend to use L2 regularization to lower bound the Hessian, although this method also works.
Please check example
https://github.com/lixilinx/psgd_torch/blob/master/misc/psgd_with_finite_precision_arithmetic.py
for the rationale to set default damping level to 2**(-13).
"""
v = torch.randn_like(g)
return (v, g + damp*torch.mean(torch.abs(g))*v)
def norm_lower_bound(A):
"""
Returns a cheap lower bound for the spectral norm of A.
Numerical results on random matrices with a wide range of distributions and sizes suggest,
norm(A) <= sqrt(2) * norm_lower_bound(A)
Looks to be a very tight lower bound.
"""
max_abs = torch.max(torch.abs(A)) # used to normalize A to avoid numerically under- or over-flow
if max_abs > 0:
A = A/max_abs
aa = torch.real(A * A.conj())
value0, i = torch.max(torch.sum(aa, dim=0), 0)
value1, j = torch.max(torch.sum(aa, dim=1), 0)
if value0 > value1:
x = A[:, i].conj() @ A
# We must have norm(x) > 0 since norm(x) >= value0 > value1 >= 0
# Also, avoid expression norm(x*A^H)/norm(x) as x*A^H could under/over flow
return max_abs * torch.linalg.vector_norm((x / torch.linalg.vector_norm(x)) @ A.H)
else:
x = A @ A[j].conj()
# normx = torch.linalg.vector_norm(x)
# if normx > 0:
# # Again, avoid expression norm(A^H*x)/norm(x) as A^H*x could under/over flow
# return max_abs * torch.linalg.vector_norm(A.H @ (x / normx))
# else: # A = 0
# return normx
return max_abs * torch.linalg.vector_norm(A.H @ (x / torch.linalg.vector_norm(x)))
else: # must have A=0
return max_abs
def woodbury_identity_(invA, U, V):
# implements the Woodbury identity,
#
# inv(A + U*V) = inv(A) - inv(A)*U*inv(I + V*inv(A)*U)*V*inv(A)
#
# with inplace update of invA.
#
# Note that using the Woodbury identity multiple times could accumulate numerical erros.
invAU = invA @ U
VinvAU = V @ invAU
I = torch.eye(VinvAU.shape[0], dtype=VinvAU.dtype, device=VinvAU.device)
invA.sub_( invAU @ torch.linalg.solve(I + VinvAU, V@invA) )
def triu01(A):
# it is useful as for a small A, the R of QR decomposition qr(I + A) is about I + triu(A, 0) + triu(A, 1)
return torch.triu(A, diagonal=0) + torch.triu(A, diagonal=1)
###############################################################################
@torch.jit.script
def update_precond_dense(Q, dxs, dgs, step=0.01, _tiny=1.2e-38):
# type: (Tensor, List[Tensor], List[Tensor], float, float) -> Tensor
"""
update dense preconditioner P = Q^T*Q
Q: Cholesky factor of preconditioner with positive diagonal entries
dxs: list of perturbations of parameters
dgs: list of perturbations of gradients
step: update step size normalized to range [0, 1]
_tiny: an offset to avoid division by zero
"""
dx = torch.cat([torch.reshape(x, [-1, 1]) for x in dxs])
dg = torch.cat([torch.reshape(g, [-1, 1]) for g in dgs])
a = Q.mm(dg)
#b = torch.triangular_solve(dx, Q, upper=True, transpose=True)[0]
b = torch.linalg.solve_triangular(Q.t(), dx, upper=False)
grad = torch.triu(a.mm(a.t()) - b.mm(b.t()))
# step0 = step/(grad.abs().max() + _tiny)
step0 = step/(norm_lower_bound(grad) + _tiny)
return Q - step0*grad.mm(Q)
@torch.jit.script
def precond_grad_dense(Q, grads):
# type: (Tensor, List[Tensor]) -> List[Tensor]
"""
return preconditioned gradient using dense preconditioner
Q: Cholesky factor of preconditioner
grads: list of gradients
"""
grad = [torch.reshape(g, [-1, 1]) for g in grads]
lens = [g.shape[0] for g in grad]
grad = torch.cat(grad)
grad = Q.t().mm(Q.mm(grad))
pre_grads = []
idx = 0
for i in range(len(grads)):
pre_grads.append(torch.reshape(grad[idx : idx + lens[i]], grads[i].shape))
idx = idx + lens[i]
return pre_grads
###############################################################################
def update_precond_kron(Ql, Qr, dX, dG, step=0.01, _tiny=1.2e-38):
"""
Update Kronecker product preconditioner P = kron_prod(Qr^T*Qr, Ql^T*Ql)
Either Ql or Qr can be sparse, and the code can choose the right update rule.
dX: perturbation of (matrix) parameter
dG: perturbation of (matrix) gradient
step: update step size
_tiny: an offset to avoid division by zero
"""
m, n = Ql.shape
p, q = Qr.shape
if m==n: # left is dense
if p==q: #(dense, dense) format
return _update_precond_dense_dense(Ql, Qr, dX, dG, step, _tiny)
elif p==2: # (dense, normalization) format
return _update_precond_norm_dense(Qr, Ql, dX.t(), dG.t(), step, _tiny)[::-1]
elif p==1: # (dense, scaling) format
return _update_precond_dense_scale(Ql, Qr, dX, dG, step, _tiny)
else:
raise Exception('Unknown Kronecker product preconditioner')
elif m==2: # left is normalization
if p==q: # (normalization, dense) format
return _update_precond_norm_dense(Ql, Qr, dX, dG, step, _tiny)
elif p==1: # (normalization, scaling) format
return _update_precond_norm_scale(Ql, Qr, dX, dG, step, _tiny)
else:
raise Exception('Unknown Kronecker product preconditioner')
elif m==1: # left is scaling
if p==q: # (scaling, dense) format
return _update_precond_dense_scale(Qr, Ql, dX.t(), dG.t(), step, _tiny)[::-1]
elif p==2: # (scaling, normalization) format
return _update_precond_norm_scale(Qr, Ql, dX.t(), dG.t(), step, _tiny)[::-1]
else:
raise Exception('Unknown Kronecker product preconditioner')
else:
raise Exception('Unknown Kronecker product preconditioner')
def precond_grad_kron(Ql, Qr, Grad):
"""
return preconditioned gradient using Kronecker product preconditioner P = kron_prod(Qr^T*Qr, Ql^T*Ql)
Either Ql or Qr can be sparse, and the code can choose the right way to precondition the gradient
Grad: (matrix) gradient
"""
m, n = Ql.shape
p, q = Qr.shape
if m==n: # left is dense
if p==q: #(dense, dense) format
return _precond_grad_dense_dense(Ql, Qr, Grad)
elif p==2: # (dense, normalization) format
return _precond_grad_norm_dense(Qr, Ql, Grad.t()).t()
elif p==1: # (dense, scaling) format
return _precond_grad_dense_scale(Ql, Qr, Grad)
else:
raise Exception('Unknown Kronecker product preconditioner')
elif m==2: # left is normalization
if p==q: # (normalization, dense) format
return _precond_grad_norm_dense(Ql, Qr, Grad)
elif p==1: # (normalization, scaling) format
return _precond_grad_norm_scale(Ql, Qr, Grad)
else:
raise Exception('Unknown Kronecker product preconditioner')
elif m==1: # left is scaling
if p==q: # (scaling, dense) format
return _precond_grad_dense_scale(Qr, Ql, Grad.t()).t()
elif p==2: # (scaling, normalization) format
return _precond_grad_norm_scale(Qr, Ql, Grad.t()).t()
else:
raise Exception('Unknown Kronecker product preconditioner')
else:
raise Exception('Unknown Kronecker product preconditioner')
###############################################################################
@torch.jit.script
def _update_precond_dense_dense(Ql, Qr, dX, dG, step=0.01, _tiny=1.2e-38):
# type: (Tensor, Tensor, Tensor, Tensor, float, float) -> Tuple[Tensor, Tensor]
"""
update Kronecker product preconditioner P = kron_prod(Qr^T*Qr, Ql^T*Ql)
Ql: (left side) Cholesky factor of preconditioner with positive diagonal entries
Qr: (right side) Cholesky factor of preconditioner with positive diagonal entries
dX: perturbation of (matrix) parameter
dG: perturbation of (matrix) gradient
step: update step size normalized to range [0, 1]
_tiny: an offset to avoid division by zero
"""
max_l = torch.max(torch.diag(Ql))
max_r = torch.max(torch.diag(Qr))
rho = torch.sqrt(max_l/max_r)
Ql /= rho
Qr *= rho
#A = Ql.mm( dG.mm( Qr.t() ) )
#Bt = torch.triangular_solve((torch.triangular_solve(dX.t(), Qr, upper=True, transpose=True))[0].t(),
# Ql, upper=True, transpose=True)[0]
A = torch.linalg.multi_dot([Ql, dG, Qr.t()])
Bt = torch.linalg.solve_triangular(Ql.t(), torch.linalg.solve_triangular(Qr, dX, upper=True, left=False), upper=False)
grad1 = torch.triu(A.mm(A.t()) - Bt.mm(Bt.t()))
grad2 = torch.triu(A.t().mm(A) - Bt.t().mm(Bt))
# step1 = step/(torch.max(torch.abs(grad1)) + _tiny)
# step2 = step/(torch.max(torch.abs(grad2)) + _tiny)
step1 = step/(norm_lower_bound(grad1) + _tiny)
step2 = step/(norm_lower_bound(grad2) + _tiny)
return Ql - step1*grad1.mm(Ql), Qr - step2*grad2.mm(Qr)
@torch.jit.script
def _precond_grad_dense_dense(Ql, Qr, Grad):
# type: (Tensor, Tensor, Tensor) -> Tensor
"""
return preconditioned gradient using Kronecker product preconditioner
Ql: (left side) Cholesky factor of preconditioner
Qr: (right side) Cholesky factor of preconditioner
Grad: (matrix) gradient
"""
#return torch.chain_matmul(Ql.t(), Ql, Grad, Qr.t(), Qr)
return torch.linalg.multi_dot([Ql.t(), Ql, Grad, Qr.t(), Qr])
###############################################################################
# (normalization, dense) format Kronecker product preconditioner
@torch.jit.script
def _update_precond_norm_dense(ql, Qr, dX, dG, step=0.01, _tiny=1.2e-38):
# type: (Tensor, Tensor, Tensor, Tensor, float, float) -> Tuple[Tensor, Tensor]
"""
update (normalization, dense) Kronecker product preconditioner P = kron_prod(Qr^T*Qr, Ql^T*Ql), where
dX and dG have shape (M, N)
ql has shape (2, M)
Qr has shape (N, N)
ql[0] is the diagonal part of Ql
ql[1,0:-1] is the last column of Ql, excluding the last entry
dX is perturbation of (matrix) parameter
dG is perturbation of (matrix) gradient
step: update step size normalized to range [0, 1]
_tiny: an offset to avoid division by zero
"""
# make sure that Ql and Qr have similar dynamic range
max_l = torch.max(ql[0])
max_r = torch.max(torch.diag(Qr))
rho = torch.sqrt(max_l/max_r)
ql /= rho
Qr *= rho
# refer to https://arxiv.org/abs/1512.04202 for details
A = ql[0:1].t()*dG + ql[1:].t().mm( dG[-1:] ) # Ql*dG
A = A.mm(Qr.t())
Bt = dX/ql[0:1].t()
Bt[-1:] -= (ql[1:]/(ql[0:1]*ql[0,-1])).mm(dX)
#Bt = torch.triangular_solve(Bt.t(), Qr, upper=True, transpose=True)[0].t()
Bt = torch.linalg.solve_triangular(Qr, Bt, upper=True, left=False)
grad1_diag = torch.sum(A*A, dim=1) - torch.sum(Bt*Bt, dim=1)
grad1_bias = A[:-1].mm(A[-1:].t()) - Bt[:-1].mm(Bt[-1:].t())
grad1_bias = torch.cat([torch.squeeze(grad1_bias), grad1_bias.new_zeros(1)])
step1 = step/(torch.max(torch.max(torch.abs(grad1_diag)),
torch.max(torch.abs(grad1_bias))) + _tiny)
new_ql0 = ql[0] - step1*grad1_diag*ql[0]
new_ql1 = ql[1] - step1*(grad1_diag*ql[1] + ql[0,-1]*grad1_bias)
grad2 = torch.triu(A.t().mm(A) - Bt.t().mm(Bt))
# step2 = step/(torch.max(torch.abs(grad2)) + _tiny)
step2 = step/(norm_lower_bound(grad2) + _tiny)
return torch.stack((new_ql0, new_ql1)), Qr - step2*grad2.mm(Qr)
@torch.jit.script
def _precond_grad_norm_dense(ql, Qr, Grad):
# type: (Tensor, Tensor, Tensor) -> Tensor
"""
return preconditioned gradient using (normalization, dense) Kronecker product preconditioner
Suppose Grad has shape (M, N)
ql[0] is the diagonal part of Ql
ql[1, 0:-1] is the last column of Ql, excluding the last entry
Qr: shape (N, N), Cholesky factor of right preconditioner
Grad: (matrix) gradient
"""
preG = ql[0:1].t()*Grad + ql[1:].t().mm(Grad[-1:]) # Ql*Grad
#preG = torch.chain_matmul(preG, Qr.t(), Qr)
preG = torch.linalg.multi_dot([preG, Qr.t(), Qr])
add_last_row = ql[1:].mm(preG) # use it to modify the last row
preG *= ql[0:1].t()
preG[-1:] += add_last_row
return preG
###############################################################################
# (normalization, scaling) Kronecker product preconditioner
# the left one is a normalization preconditioner; the right one is a scaling preconditioner
@torch.jit.script
def _update_precond_norm_scale(ql, qr, dX, dG, step=0.01, _tiny=1.2e-38):
# type: (Tensor, Tensor, Tensor, Tensor, float, float) -> Tuple[Tensor, Tensor]
"""
update (normalization, scaling) preconditioner P = kron_prod(Qr^T*Qr, Ql^T*Ql), where
dX and dG have shape (M, N)
ql has shape (2, M)
qr has shape (1, N)
ql[0] is the diagonal part of Ql
ql[1, 0:-1] is the last column of Ql, excluding the last entry
qr is the diagonal part of Qr
dX is perturbation of (matrix) parameter
dG is perturbation of (matrix) gradient
step: update step size
_tiny: an offset to avoid division by zero
"""
# make sure that Ql and Qr have similar dynamic range
max_l = torch.max(ql[0])
max_r = torch.max(qr) # qr always is positive
rho = torch.sqrt(max_l/max_r)
ql /= rho
qr *= rho
# refer to https://arxiv.org/abs/1512.04202 for details
A = ql[0:1].t()*dG + ql[1:].t().mm( dG[-1:] ) # Ql*dG
A *= qr # Ql*dG*Qr
Bt = dX/ql[0:1].t()
Bt[-1:] -= (ql[1:]/(ql[0:1]*ql[0,-1])).mm(dX)
Bt /= qr # Ql^(-T)*dX*Qr^(-1)
grad1_diag = torch.sum(A*A, dim=1) - torch.sum(Bt*Bt, dim=1)
grad1_bias = A[:-1].mm(A[-1:].t()) - Bt[:-1].mm(Bt[-1:].t())
grad1_bias = torch.cat([torch.squeeze(grad1_bias), grad1_bias.new_zeros(1)])
step1 = step/(torch.max(torch.max(torch.abs(grad1_diag)),
torch.max(torch.abs(grad1_bias))) + _tiny)
new_ql0 = ql[0] - step1*grad1_diag*ql[0]
new_ql1 = ql[1] - step1*(grad1_diag*ql[1] + ql[0,-1]*grad1_bias)
grad2 = torch.sum(A*A, dim=0, keepdim=True) - torch.sum(Bt*Bt, dim=0, keepdim=True)
step2 = step/(torch.max(torch.abs(grad2)) + _tiny)
return torch.stack((new_ql0, new_ql1)), qr - step2*grad2*qr
@torch.jit.script
def _precond_grad_norm_scale(ql, qr, Grad):
# type: (Tensor, Tensor, Tensor) -> Tensor
"""
return preconditioned gradient using (normalization, scaling) Kronecker product preconditioner
Suppose Grad has shape (M, N)
ql has shape (2, M)
qr has shape (1, N)
ql[0] is the diagonal part of Ql
ql[1, 0:-1] is the last column of Ql, excluding the last entry
qr is the diagonal part of Qr
Grad: (matrix) gradient
"""
preG = ql[0:1].t()*Grad + ql[1:].t().mm(Grad[-1:]) # Ql*Grad
preG *= (qr*qr) # Ql*Grad*Qr^T*Qr
add_last_row = ql[1:].mm(preG) # use it to modify the last row
preG *= ql[0:1].t()
preG[-1:] += add_last_row
return preG
###############################################################################
@torch.jit.script
def _update_precond_dense_scale(Ql, qr, dX, dG, step=0.01, _tiny=1.2e-38):
# type: (Tensor, Tensor, Tensor, Tensor, float, float) -> Tuple[Tensor, Tensor]
"""
update (dense, scaling) preconditioner P = kron_prod(Qr^T*Qr, Ql^T*Ql), where
dX and dG have shape (M, N)
Ql has shape (M, M)
qr has shape (1, N)
qr is the diagonal part of Qr
dX is perturbation of (matrix) parameter
dG is perturbation of (matrix) gradient
step: update step size
_tiny: an offset to avoid division by zero
"""
max_l = torch.max(torch.diag(Ql))
max_r = torch.max(qr)
rho = torch.sqrt(max_l/max_r)
Ql /= rho
qr *= rho
A = Ql.mm( dG*qr )
#Bt = torch.triangular_solve(dX/qr, Ql, upper=True, transpose=True)[0]
Bt = torch.linalg.solve_triangular(Ql.t(), dX/qr, upper=False)
grad1 = torch.triu(A.mm(A.t()) - Bt.mm(Bt.t()))
grad2 = torch.sum(A*A, dim=0, keepdim=True) - torch.sum(Bt*Bt, dim=0, keepdim=True)
# step1 = step/(torch.max(torch.abs(grad1)) + _tiny)
step1 = step/(norm_lower_bound(grad1) + _tiny)
step2 = step/(torch.max(torch.abs(grad2)) + _tiny)
return Ql - step1*grad1.mm(Ql), qr - step2*grad2*qr
@torch.jit.script
def _precond_grad_dense_scale(Ql, qr, Grad):
# type: (Tensor, Tensor, Tensor) -> Tensor
"""
return preconditioned gradient using (dense, scaling) Kronecker product preconditioner
Suppose Grad has shape (M, N)
Ql: shape (M, M), (left side) Cholesky factor of preconditioner
qr: shape (1, N), defines a diagonal matrix for output feature scaling
Grad: (matrix) gradient
"""
#return torch.chain_matmul(Ql.t(), Ql, Grad*(qr*qr))
return torch.linalg.multi_dot([Ql.t(), Ql, Grad*(qr*qr)])
###############################################################################
@torch.jit.script
def update_precond_splu(L12, l3, U12, u3, dxs, dgs, step=0.01, _tiny=1.2e-38):
# type: (Tensor,Tensor,Tensor,Tensor, List[Tensor],List[Tensor], float,float) -> Tuple[Tensor,Tensor,Tensor,Tensor]
"""
update sparse LU preconditioner P = Q^T*Q, where
Q = L*U,
L12 = [L1; L2]
U12 = [U1, U2]
L = [L1, 0; L2, diag(l3)]
U = [U1, U2; 0, diag(u3)]
l3 and u3 are column vectors
dxs: a list of random perturbation on parameters
dgs: a list of resultant perturbation on gradients
step: update step size normalized to range [0, 1]
_tiny: an offset to avoid division by zero
"""
# make sure that L and U have similar dynamic range
max_l = torch.max(torch.max(torch.diag(L12)), torch.max(l3))
max_u = torch.max(torch.max(torch.diag(U12)), torch.max(u3))
rho = torch.sqrt(max_l/max_u)
L12 /= rho
l3 /= rho
U12 *= rho
u3 *= rho
# extract the blocks
r = U12.shape[0]
L1 = L12[:r]
L2 = L12[r:]
U1 = U12[:, :r]
U2 = U12[:, r:]
dx = torch.cat([torch.reshape(x, [-1, 1]) for x in dxs]) # a tall column vector
dg = torch.cat([torch.reshape(g, [-1, 1]) for g in dgs]) # a tall column vector
# U*dg
Ug1 = U1.mm(dg[:r]) + U2.mm(dg[r:])
Ug2 = u3*dg[r:]
# Q*dg
Qg1 = L1.mm(Ug1)
Qg2 = L2.mm(Ug1) + l3*Ug2
# inv(U^T)*dx
#iUtx1 = torch.triangular_solve(dx[:r], U1, upper=True, transpose=True)[0]
iUtx1 = torch.linalg.solve_triangular(U1.t(), dx[:r], upper=False)
iUtx2 = (dx[r:] - U2.t().mm(iUtx1))/u3
# inv(Q^T)*dx
iQtx2 = iUtx2/l3
#iQtx1 = torch.triangular_solve(iUtx1 - L2.t().mm(iQtx2), L1, upper=False, transpose=True)[0]
iQtx1 = torch.linalg.solve_triangular(L1.t(), iUtx1 - L2.t().mm(iQtx2), upper=True)
# L^T*Q*dg
LtQg1 = L1.t().mm(Qg1) + L2.t().mm(Qg2)
LtQg2 = l3*Qg2
# P*dg
Pg1 = U1.t().mm(LtQg1)
Pg2 = U2.t().mm(LtQg1) + u3*LtQg2
# inv(L)*inv(Q^T)*dx
#iLiQtx1 = torch.triangular_solve(iQtx1, L1, upper=False)[0]
iLiQtx1 = torch.linalg.solve_triangular(L1, iQtx1, upper=False)
iLiQtx2 = (iQtx2 - L2.mm(iLiQtx1))/l3
# inv(P)*dx
iPx2 = iLiQtx2/u3
#iPx1 = torch.triangular_solve(iLiQtx1 - U2.mm(iPx2), U1, upper=True)[0]
iPx1 = torch.linalg.solve_triangular(U1, iLiQtx1 - U2.mm(iPx2), upper=True)
# update L
grad1 = Qg1.mm(Qg1.t()) - iQtx1.mm(iQtx1.t())
grad1 = torch.tril(grad1)
grad2 = Qg2.mm(Qg1.t()) - iQtx2.mm(iQtx1.t())
grad3 = Qg2*Qg2 - iQtx2*iQtx2
# max_abs_grad = torch.max(torch.abs(grad1))
# max_abs_grad = torch.max(max_abs_grad, torch.max(torch.abs(grad2)))
# max_abs_grad = torch.max(max_abs_grad, torch.max(torch.abs(grad3)))
# step0 = step/(max_abs_grad + _tiny)
step0 = step/(torch.maximum(norm_lower_bound(torch.cat([grad1, grad2], 0)), torch.max(torch.abs(grad3))) + _tiny)
newL1 = L1 - step0*grad1.mm(L1)
newL2 = L2 - step0*grad2.mm(L1) - step0*grad3*L2
newl3 = l3 - step0*grad3*l3
# update U
grad1 = Pg1.mm(dg[:r].t()) - dx[:r].mm(iPx1.t())
grad1 = torch.triu(grad1)
grad2 = Pg1.mm(dg[r:].t()) - dx[:r].mm(iPx2.t())
grad3 = Pg2*dg[r:] - dx[r:]*iPx2
# max_abs_grad = torch.max(torch.abs(grad1))
# max_abs_grad = torch.max(max_abs_grad, torch.max(torch.abs(grad2)))
# max_abs_grad = torch.max(max_abs_grad, torch.max(torch.abs(grad3)))
# step0 = step/(max_abs_grad + _tiny)
step0 = step/(torch.maximum(norm_lower_bound(torch.cat([grad1, grad2], 1)), torch.max(torch.abs(grad3))) + _tiny)
newU1 = U1 - U1.mm(step0*grad1)
newU2 = U2 - U1.mm(step0*grad2) - step0*grad3.t()*U2
newu3 = u3 - step0*grad3*u3
return torch.cat([newL1, newL2], dim=0), newl3, torch.cat([newU1, newU2], dim=1), newu3
@torch.jit.script
def precond_grad_splu(L12, l3, U12, u3, grads):
# type: (Tensor,Tensor,Tensor,Tensor, List[Tensor]) -> List[Tensor]
"""
return preconditioned gradient with sparse LU preconditioner
where P = Q^T*Q,
Q = L*U,
L12 = [L1; L2]
U12 = [U1, U2]
L = [L1, 0; L2, diag(l3)]
U = [U1, U2; 0, diag(u3)]
l3 and u3 are column vectors
grads: a list of gradients to be preconditioned
"""
grad = [torch.reshape(g, [-1, 1]) for g in grads] # a list of column vector
lens = [g.shape[0] for g in grad] # length of each column vector
grad = torch.cat(grad) # a tall column vector
r = U12.shape[0]
L1 = L12[:r]
L2 = L12[r:]
U1 = U12[:, :r]
U2 = U12[:, r:]
# U*g
Ug1 = U1.mm(grad[:r]) + U2.mm(grad[r:])
Ug2 = u3*grad[r:]
# Q*g
Qg1 = L1.mm(Ug1)
Qg2 = L2.mm(Ug1) + l3*Ug2
# L^T*Q*g
LtQg1 = L1.t().mm(Qg1) + L2.t().mm(Qg2)
LtQg2 = l3*Qg2
# P*g
pre_grad = torch.cat([U1.t().mm(LtQg1),
U2.t().mm(LtQg1) + u3*LtQg2])
pre_grads = [] # restore pre_grad to its original shapes
idx = 0
for i in range(len(grads)):
pre_grads.append(torch.reshape(pre_grad[idx : idx + lens[i]], grads[i].shape))
idx = idx + lens[i]
return pre_grads
##############################################################################
#
# The low-rank approximation (LRA) preconditioner is defined as
#
# Q = (I + U*V')*diag(d)
#
# which, after reparameterization, is equivalent to form
#
# diag(d) + U*V'
#
# UVd as an alias of LRA due to the form of this preconditioner.
#
#@torch.jit.script
def IpUVtmatvec(U, V, x):
# type: (Tensor, Tensor, Tensor) -> Tensor
"""
Returns (I + U*V')*x. All variables are either matrices or column vectors.
"""
return x + U.mm(V.t().mm(x))
# def IpUVtsolve(U, V, x):
# """
# Returns inv(I + U*V')*x. All variables are either matrices or column vectors.
# """
# VtU = V.t().mm(U)
# I = torch.eye(VtU.size(dim=0), dtype=VtU.dtype, device=VtU.device)
# return x - U.mm(torch.linalg.solve(I + VtU, V.t().mm(x))) # torch.solve is slow
# def norm_UVt(U, V):
# """
# Returns ||U*V'||_fro = sqrt(tr(U'*U*V'*V)) = sqrt(sum((U'*U)*(V'*V)))
# """
# return torch.sqrt(torch.abs(torch.sum( (U.t().mm(U))*(V.t().mm(V)) )))
#@torch.jit.script
def update_precond_UVd_math_(U, V, d, v, h, step, step_normalizer, tiny):
# type: (Tensor, Tensor, Tensor, Tensor, Tensor, float, str, float) -> None
"""
Update preconditioner Q = (I + U*V')*diag(d) with (vector, Hessian-vector product) = (v, h).
State variables U, V and d are updated inplace.
U, V, d, v, and h are either matrices or column vectors.
"""
# balance the numerical dynamic ranges of U and V; optional
if torch.rand([]) < 0.01:
normU = torch.linalg.vector_norm(U)
normV = torch.linalg.vector_norm(V)
rho = torch.sqrt(normU/normV)
U.div_(rho)
V.mul_(rho)
Qh = IpUVtmatvec(U, V, d*h)
Ph = d*IpUVtmatvec(V, U, Qh)
# invQtv = IpUVtsolve(V, U, v/d)
# invPv = IpUVtsolve(U, V, invQtv)/d
VtU = V.t().mm(U)
I = torch.eye(VtU.size(dim=0), dtype=VtU.dtype, device=VtU.device)
IpVtU = I + VtU
invQtv = v/d
# torch's linalg.solve is slow for small matrix
# invQtv = invQtv - V.mm(torch.linalg.solve(IpVtU.t(), U.t().mm(invQtv)))
# invPv = invQtv - U.mm(torch.linalg.solve(IpVtU, V.t().mm(invQtv)))
LU, pivots = torch.linalg.lu_factor(IpVtU)
invQtv = invQtv - V.mm(torch.linalg.lu_solve(LU, pivots, U.t().mm(invQtv), adjoint=True))
invPv = invQtv - U.mm(torch.linalg.lu_solve(LU, pivots, V.t().mm(invQtv)))
invPv = invPv/d
nablaD = Ph*h - v*invPv
if step_normalizer == '2nd':
mu = step*torch.min(torch.rsqrt(Ph*Ph + v*v + tiny)*torch.rsqrt(h*h + invPv*invPv + tiny)) # two seperate rsqrt's to avoid underflow
else:
mu = step/(torch.max(torch.abs(nablaD)) + tiny)
# d = d - mu*d*nablaD
d.sub_(mu*d*nablaD)
# update either U or V, not both at the same time
a, b = Qh, invQtv
if torch.rand([]) < 0.5:
# nablaU = Qh.mm(Qh.t().mm(V)) - invQtv.mm(invQtv.t().mm(V))
# mu = step/(norm_UVt(nablaU, V) + _tiny)
# U = U - mu*(nablaU + nablaU.mm(V.t().mm(U)))
atV = a.t().mm(V)
btV = b.t().mm(V)
atVVt = atV.mm(V.t())
btVVt = btV.mm(V.t())
if step_normalizer == '2nd':
mu = step/( torch.linalg.vector_norm(a)*torch.linalg.vector_norm(atVVt)
+torch.linalg.vector_norm(b)*torch.linalg.vector_norm(btVVt) + tiny)
else: # '1st'
norm = torch.sqrt(torch.abs( (a.t().mm(a))*(atVVt.mm(atVVt.t())) # abs to avoid sqrt(-0.0)
+(b.t().mm(b))*(btVVt.mm(btVVt.t()))
-2*(a.t().mm(b))*(atVVt.mm(btVVt.t())) ))
mu = step/(norm + tiny)
# U = U - mu*( a.mm(atV.mm(IpVtU))
# -b.mm(btV.mm(IpVtU)) )
U.sub_(mu*( a.mm(atV.mm(IpVtU))
-b.mm(btV.mm(IpVtU)) ))
else:
# nablaV = Qh.mm(Qh.t().mm(U)) - invQtv.mm(invQtv.t().mm(U))
# mu = step/(norm_UVt(U, nablaV) + _tiny)
# V = V - mu*(nablaV + V.mm(U.t().mm(nablaV)))
atU = a.t().mm(U)
btU = b.t().mm(U)
UUta = U.mm(atU.t())
UUtb = U.mm(btU.t())
if step_normalizer == '2nd':
mu = step/( torch.linalg.vector_norm(a)*torch.linalg.vector_norm(UUta)
+torch.linalg.vector_norm(b)*torch.linalg.vector_norm(UUtb) + tiny)
else: # '1st'
norm = torch.sqrt(torch.abs( (UUta.t().mm(UUta))*(a.t().mm(a)) # abs to avoid sqrt(-0.0)
+(UUtb.t().mm(UUtb))*(b.t().mm(b))
-2*(UUta.t().mm(UUtb))*(a.t().mm(b)) ))
mu = step/(norm + tiny)
# V = V - mu*( (a + V.mm(atU.t())).mm(atU)
# -(b + V.mm(btU.t())).mm(btU) )
V.sub_(mu*( (a + V.mm(atU.t())).mm(atU)
-(b + V.mm(btU.t())).mm(btU) ))
# return [U, V, d]
#@torch.jit.script
def precond_grad_UVd_math(U, V, d, g):
# type: (Tensor, Tensor, Tensor, Tensor) -> Tensor
"""
Preconditioning gradient g with Q = (I + U*V')*diag(d).
All variables here are either matrices or column vectors.
"""
g = IpUVtmatvec(U, V, d*g)
g = d*IpUVtmatvec(V, U, g)
return g
class LRA:
"""
Implements the low-rank approximation (LRA, UVd as an alias) preconditioner, Q = (I + U*V')*diag(d), as a class.
Args for initialization:
params_with_grad: a list of parameters or variables requiring gradients;
rank_of_approximation: rank of approximation, i.e., max rank of U or V, with 0 for diagonal preconditioner;
preconditioner_init_scale: initial scale of Q, or roughly, Q = preconditioner_init_scale*eye(), with None for automatical setting;
lr_params: normalized learning rate for parameters in range [0, 1];
lr_preconditioner: normalized learning rate for preconditioner in range [0, 1];
momentum: momentum factor in range [0,1);
grad_clip_max_norm: maximum allowable gradient norm after clipping, None for no clipping;
preconditioner_update_probability: probability on updating Q, 1 for updating at every step, and 0 for never, i.e., SGD when Q=I;
step_normalizer: '1st' for normalizing lr_preconditioner with 1st order derivative info, and '2nd' for normalizing with 2nd derivative info;
exact_hessian_vector_product: True for exact Hessian-vector product via 2nd derivative,
and False for approximate one via the finite difference method;
preconditioner_type: "Newton" or "whitening", see https://arxiv.org/abs/1809.10232 for the Newton and (empirical) Fisher types.
Notes:
Note 1: The Hessian-vector product can be approximated using the finite difference method by setting
exact_hessian_vector_product = False when the 2nd derivatives is not available.
In this case, make sure that the closure produces the same outputs given the same inputs,
except for numerical errors due to non-deterministic behaviors.
Random numbers, if any, used inside the closure should be generated starting from the same state, where the rng state can be
read and set by, e.g., `torch.cuda.get_rng_state' and `torch.cuda.set_rng_state', respectively.
Note 2: Momentum here is the moving average of gradient so that its setting is decoupled from the learning rate.
This is necessary as the learning rate in PSGD is normalized.
Note 3: `torch.linalg.solve' is called twice in function `update_precond_UVd_math_'.
Certain solver could be orders of magnitude faster than others, especially for small matrices
(see https://drive.google.com/file/d/1CTNx1q67_py87jn-0OI-vSLcsM1K7VsM/view, Table 2).
Considering replace it with faster ones if the default solver is too slow.
Note 4: Currently, no support of sparse and mixed-precision gradients.
Half precision (bfloat16) works well except that torch.linalg.solve (v2.2) requires casting bfloat16 to float32.
Note 5: lr_params, lr_preconditioner, momentum, grad_clip_max_norm, preconditioner_update_probability, step_normalizer,
and exact_hessian_vector_product (bool) all can be reset on the fly.
"""
def __init__(self, params_with_grad, rank_of_approximation:int=10, preconditioner_init_scale=None,
lr_params=0.01, lr_preconditioner=None, momentum=0.0,
grad_clip_max_norm=None, preconditioner_update_probability=1.0,
step_normalizer='2nd',
exact_hessian_vector_product:bool=True, preconditioner_type="Newton"):
# mutable members
self.lr_params = lr_params
if lr_preconditioner is None:
if step_normalizer == '2nd':
self.lr_preconditioner = 0.1
else:
self.lr_preconditioner = 0.01
else:
self.lr_preconditioner = lr_preconditioner
self.momentum = momentum if (0<momentum<1) else 0.0
self.grad_clip_max_norm = grad_clip_max_norm
self.preconditioner_update_probability = preconditioner_update_probability
self.exact_hessian_vector_product = exact_hessian_vector_product
self.step_normalizer = step_normalizer
# protected members
params_with_grad = [params_with_grad,] if isinstance(params_with_grad, torch.Tensor) else params_with_grad
self._params_with_grad = [param for param in params_with_grad if param.requires_grad] # double check requires_grad flag
dtype, device = self._params_with_grad[0].dtype, self._params_with_grad[0].device
self._tiny = torch.finfo(dtype).tiny
self._delta_param_scale = torch.finfo(dtype).eps**0.5
self._param_sizes = [torch.numel(param) for param in self._params_with_grad]
self._param_cumsizes = torch.cumsum(torch.tensor(self._param_sizes), 0)
num_params = self._param_cumsizes[-1]
# check rank_of_approximation
if rank_of_approximation <= 0:
print("Hint: the Xmat preconditioner may be more efficinet in this case.")
if 2*rank_of_approximation + 1 >= num_params:
print("Hint: the Newton preconditioner may be more efficient in this case.")
# +10 to 1) avoid /0; 2) make sure that norm(U*V') << 1 even when rank_of_approximation=1
self._U = torch.randn(num_params, rank_of_approximation, dtype=dtype, device=device) / (num_params*(rank_of_approximation+10))**0.5
self._V = torch.randn(num_params, rank_of_approximation, dtype=dtype, device=device) / (num_params*(rank_of_approximation+10))**0.5
if preconditioner_init_scale is None:
self._d = None # set it on the fly
else:
self._d = torch.ones(num_params, 1, dtype=dtype, device=device) * preconditioner_init_scale
self._m = None # momentum buffer
self._preconditioner_type = preconditioner_type
@torch.no_grad()
def step(self, closure):
"""
Performs a single step of PSGD with the low-rank approximation (LRA, or UVd) preconditioner, i.e.,
updating the trainable parameters once, and returning what closure returns.
Args:
closure (callable): a (stateless) closure that evaluates the function of self._params_with_grad,
and returns the loss, or an iterable with the first one being loss.
Random numbers, if any, used inside the closure should be generated starting
from the same rng state if exact_hessian_vector_product=False and preconditioner_type="Newton".
"""
if (self._preconditioner_type=="Newton") and ((torch.rand([]) < self.preconditioner_update_probability) or (self._d is None)):
# evaluates gradients, Hessian-vector product, and updates the preconditioner
if self.exact_hessian_vector_product:
# exact Hessian-vector product
with torch.enable_grad():
closure_returns = closure()
loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]
grads = torch.autograd.grad(loss, self._params_with_grad, create_graph=True)
vs = [torch.randn_like(param) for param in self._params_with_grad]
Hvs = torch.autograd.grad(grads, self._params_with_grad, vs)
else:
# approximate Hessian-vector product via finite-difference formulae. Use it with cautions.
with torch.enable_grad():
closure_returns = closure()
loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]
grads = torch.autograd.grad(loss, self._params_with_grad)
vs = [self._delta_param_scale * torch.randn_like(param) for param in self._params_with_grad]
[param.add_(v) for (param, v) in zip(self._params_with_grad, vs)]
with torch.enable_grad():
perturbed_returns = closure()
perturbed_loss = perturbed_returns if isinstance(perturbed_returns, torch.Tensor) else perturbed_returns[0]
perturbed_grads = torch.autograd.grad(perturbed_loss, self._params_with_grad)
Hvs = [perturbed_g - g for (perturbed_g, g) in zip(perturbed_grads, grads)]
# update preconditioner
v = torch.cat([torch.reshape(v, [-1, 1]) for v in vs]) # column vector
h = torch.cat([torch.reshape(h, [-1, 1]) for h in Hvs]) # column vector
# set self._d if it is None
if self._d is None:
# self._d = (torch.sum(v*v)/torch.sum(h*h))**0.25 * torch.ones_like(v)
self._d = (torch.mean(v*v))**(1/4) * (torch.mean(h**4))**(-1/8) * torch.ones_like(v)
# update self._U, _V and _d
update_precond_UVd_math_(self._U, self._V, self._d, v, h, self.lr_preconditioner, self.step_normalizer, self._tiny)
# if self.exact_hessian_vector_product:
# update_precond_UVd_math_(self._U, self._V, self._d,
# v[:,None], h[:,None], step=self.lr_preconditioner, tiny=self._tiny)
# else: # compensate the levels of v and h; helpful to reduce numerical errors in half-precision training
# update_precond_UVd_math_(self._U, self._V, self._d,
# v[:,None]/self._delta_param_scale, h[:,None]/self._delta_param_scale, step=self.lr_preconditioner, tiny=self._tiny)
else:
# only evaluates the gradients
with torch.enable_grad():
closure_returns = closure()
loss = closure_returns if isinstance(closure_returns, torch.Tensor) else closure_returns[0]
grads = torch.autograd.grad(loss, self._params_with_grad)
vs = None # no vs and Hvs
# cat grads
grad = torch.cat([torch.reshape(g, [-1, 1]) for g in grads]) # column vector
# update preconditioner here if it is the whitening type
if (self._preconditioner_type!="Newton") and ((torch.rand([]) < self.preconditioner_update_probability) or (self._d is None)):
if self._d is None:
# self._d = (len(grad)/torch.sum(grad*grad))**0.25 * torch.ones_like(grad)
self._d = (torch.mean(grad**4))**(-1/8) * torch.ones_like(grad)
# update the preconditioner whitening the gradients
# v = torch.randn_like(grad)
# update_precond_UVd_math_(self._U, self._V, self._d, v, grad, self.lr_preconditioner, self.step_normalizer, self._tiny)
update_precond_UVd_math_(self._U, self._V, self._d, *damped_pair_vg(grad), self.lr_preconditioner, self.step_normalizer, self._tiny)
# preconditioned gradients; momentum is optional
if self.momentum > 0:
if self._m is None:
self._m = (1 - self.momentum)*grad
else:
self._m.mul_(self.momentum).add_(grad, alpha=1 - self.momentum)
pre_grad = precond_grad_UVd_math(self._U, self._V, self._d, self._m)
else:
self._m = None # clean the buffer when momentum is set to zero
pre_grad = precond_grad_UVd_math(self._U, self._V, self._d, grad)
# gradient clipping is optional
if self.grad_clip_max_norm is None:
lr = self.lr_params
else:
grad_norm = torch.linalg.vector_norm(pre_grad) + self._tiny
lr = self.lr_params * min(self.grad_clip_max_norm/grad_norm, 1.0)
# update the parameters
if self.exact_hessian_vector_product or (vs is None) or (self._preconditioner_type!="Newton"):
delta = lr * pre_grad
else: # in this case, do not forget to remove the perturbation on parameters
delta = lr * pre_grad + v
# -delta
[param.subtract_(delta[j - i:j].view_as(param))
for (param, i, j) in zip(self._params_with_grad, self._param_sizes, self._param_cumsizes)]
# return whatever closure returns
return closure_returns
# UVd as an alias
UVd = LRA
################## end of LRA/UVd preconditioner #################################
##############################################################################
# An Xmat (X-matrix) preconditioner is defined as
#
# Q = diag(a) + adiag(b)
#
# where adiag means anti-diagonal.
# It's slightly more complicated than a diagonal preconditioner (LRA with rank=0), but may perform better.
#
#@torch.jit.script
def update_precond_Xmat_math_(a, b, v, h, step, step_normalizer, tiny):
# type: (Tensor, Tensor, Tensor, Tensor, float, str, float) -> None
"""
Update preconditioner Q = diag(a) + adiag(b) with (vector, Hessian-vector product) = (v, h).
State variables a and b are updated inplace.
"""
Qh = a*h + b*torch.flip(h, [0])
aflip, bflip = torch.flip(a, [0]), torch.flip(b, [0])
invQtv = (aflip*v - bflip*torch.flip(v, [0]))/(a*aflip - b*bflip)
u, v = Qh*Qh, invQtv*invQtv
# nablaA = Qh*Qh - invQtv*invQtv
nablaA = u - v
nablaB = Qh*torch.flip(Qh, [0]) - invQtv*torch.flip(invQtv, [0])
q, r = divmod(len(nablaB), 2)
if r == 1:
nablaB[q] = 0
if step_normalizer == '2nd':
mu = step/(torch.max(u + v) + tiny)
else:
mu = step/(torch.maximum(torch.max(torch.abs(nablaA)), torch.max(torch.abs(nablaB))) + tiny)
a.sub_(mu*(nablaA*a + nablaB*bflip))
b.sub_(mu*(nablaA*b + nablaB*aflip))
#@torch.jit.script
def precond_grad_Xmat_math(a, b, g):
# type: (Tensor, Tensor, Tensor) -> Tensor
"""
Preconditioning gradient g with Q = diag(a) + adiag(b).
"""
ab = a * b
return (a*a + torch.flip(b*b, [0]))*g + (ab + torch.flip(ab, [0]))*torch.flip(g, [0])
class XMat:
"""
Implements the Xmat preconditioner, Q = diag(a) + adiag(b), as a class.
Args for initialization:
params_with_grad: a list of parameters or variables requiring gradients;
preconditioner_init_scale: initial scale of Q, i.e., Q = preconditioner_init_scale*eye(), with None for automatical setting;
lr_params: normalized learning rate for parameters in range [0, 1];
lr_preconditioner: normalized learning rate for preconditioner in range [0, 1];