-
-
Notifications
You must be signed in to change notification settings - Fork 32
Coordinate descent optimizer (regularized case) should find optimal weights for the given data
Ilya Gyrdymov edited this page Feb 16, 2019
·
1 revision
[10.0, 20.0, 30.0]
[20.0, 30.0, 40.0]
[70.0, 80.0, 90.0]
[2.0]
[3.0]
[2.0]
[10.0, 20.0, 30.0] [2.0]
[20.0, 30.0, 40.0] [3.0]
[70.0, 80.0, 90.0] [2.0]
[0.0, 0.0, 0.0]
xij * (yi - xi,-j * w-j)
- if w < -λ / 2 => w = w + λ / 2
- if -λ / 2 <= w <= λ / 2 => w = 0 (from subgradient)
- if w > λ / 2 => w = 2 - λ / 2
j = 0: j = 1: j = 2:
10.0 * (2.0 - (20.0 * 0.0 + 30.0 * 0.0)) 20.0 * (2.0 - (10.0 * 0.0 + 30.0 * 0.0)) 30.0 * (2.0 - (10.0 * 0.0 + 20.0 * 0.0))
20.0 * (3.0 - (30.0 * 0.0 + 40.0 * 0.0)) 30.0 * (3.0 - (20.0 * 0.0 + 40.0 * 0.0)) 40.0 * (3.0 - (20.0 * 0.0 + 30.0 * 0.0))
70.0 * (2.0 - (80.0 * 0.0 + 90.0 * 0.0)) 80.0 * (2.0 - (70.0 * 0.0 + 90.0 * 0.0)) 90.0 * (2.0 - (70.0 * 0.0 + 80.0 * 0.0))
sum up all above (column-wise):
j = 0: j = 1: j = 2:
20 40 60
60 90 120
140 160 180
-------------------
220 290 360
unregularized weights after the first iteration:
[220, 290, 360]
Let’s make them more regular:
w1 = 220 > 20 / 2 => 220 - 20 / 2 = 210.0
w2 = 290 > 20 / 2 => 290 - 20 / 2 = 280.0
w3 = 360 > 20 / 2 => 360 - 20 / 2 = 350.0
Regularized weights vector after the 1st iteration:
[210.0, 280.0, 350.0]
j = 0: j = 1: j = 2:
10.0 * (2.0 - (20.0 * 280.0 + 30.0 * 350.0)) 20.0 * (2.0 - (10.0 * 210.0 + 30.0 * 350.0)) 30.0 * (2.0 - (10.0 * 210.0 + 20.0 * 280.0))
20.0 * (3.0 - (30.0 * 280.0 + 40.0 * 350.0)) 30.0 * (3.0 - (20.0 * 210.0 + 40.0 * 350.0)) 40.0 * (3.0 - (20.0 * 210.0 + 30.0 * 280.0))
70.0 * (2.0 - (80.0 * 280.0 + 90.0 * 350.0)) 80.0 * (2.0 - (70.0 * 210.0 + 90.0 * 350.0)) 90.0 * (2.0 - (70.0 * 210.0 + 80.0 * 280.0))
sum up all above:
-160980 -251960 -230940
-447940 -545910 -503880
-3772860 -3695840 -3338820
----------------------------
-4381780 -4493710 -4073640
unregularized weights after the first iteration:
[-4381780, -4493710, -4073640]
Let’s make them more regular:
w1 = -4381780 < 20 / 2 => -4381780 + 20 / 2 = -4381770.0
w2 = -4493710 < 20 / 2 => -4493710 + 20 / 2 = -4493700.0
w3 = -4073640 < 20 / 2 => -4073640 + 20 / 2 = -4073630.0