Skip to content

Commit eda49aa

Browse files
committed
Linear regression with Var GD optimizer
1 parent 3645211 commit eda49aa

File tree

11 files changed

+344
-78
lines changed

11 files changed

+344
-78
lines changed
2.87 KB
Loading

examples/ffnn-iris/report/report.html

+51-51
Large diffs are not rendered by default.

examples/linr-wine/main.py

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import os
2+
import sys
3+
sys.path.insert(0, os.path.abspath('../..'))
4+
5+
from nnlearn.linear import LinearRegression as LR
6+
from nnlearn.metrics import mean_squared_error
7+
from nnlearn.util import ScriptInformation
8+
9+
# TODO: replace this with own implentation
10+
from sklearn import preprocessing
11+
from sklearn.model_selection import train_test_split
12+
from sklearn.datasets import load_wine
13+
14+
def test_linr_regressor():
15+
16+
logger = ScriptInformation()
17+
logger.section_start(":grapes: Linear Regression - Wine data")
18+
logger.script_time()
19+
logger.author("Ludek", "Cizinsky")
20+
logger.section_start(":construction: Prepare input for the model")
21+
22+
logger.working_on("Load and split data")
23+
X, y = load_wine(return_X_y=True)
24+
25+
logger.working_on("Train test split")
26+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
27+
m = X_train.shape[1]
28+
29+
logger.working_on("Process the data")
30+
scaler = preprocessing.StandardScaler().fit(X_train)
31+
X_train = scaler.transform(X_train)
32+
X_test = scaler.transform(X_test)
33+
34+
logger.section_start(":robot: Train the model")
35+
figpath = "report/figures/"
36+
clf = LR(optimizer='gd_backp',
37+
epochs = 25,
38+
loss_func='mse',
39+
batch_size=.25,
40+
lr=.15,
41+
shuffle=True,
42+
bias=True,
43+
figpath=figpath)
44+
clf.fit(X_train, y_train)
45+
logger.c.print(clf.report)
46+
47+
logger.section_start(":crystal_ball: Validate the model")
48+
y_hat = clf.predict(X_test)
49+
mse = mean_squared_error(y_test, y_hat, var=False)
50+
logger.important_metric('MSE', mse)
51+
52+
logger.save("report/report.html")
53+
54+
if __name__ == '__main__':
55+
test_linr_regressor()
56+
Loading

examples/linr-wine/report/report.html

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
<!DOCTYPE html>
2+
<head>
3+
<meta charset="UTF-8">
4+
<style>
5+
.r1 {color: #31748f; text-decoration-color: #31748f; font-weight: bold}
6+
.r2 {color: #908caa; text-decoration-color: #908caa}
7+
.r3 {color: #008080; text-decoration-color: #008080; font-weight: bold}
8+
.r4 {color: #00ff00; text-decoration-color: #00ff00; font-weight: bold}
9+
.r5 {color: #9ccfd8; text-decoration-color: #9ccfd8}
10+
.r6 {color: #908caa; text-decoration-color: #908caa; font-weight: bold}
11+
.r7 {font-weight: bold}
12+
.r8 {font-weight: bold; text-decoration: underline}
13+
.r9 {color: #ebbcba; text-decoration-color: #ebbcba}
14+
.r10 {color: #eb6f92; text-decoration-color: #eb6f92}
15+
.r11 {color: #808000; text-decoration-color: #808000; font-weight: bold}
16+
.r12 {color: #0000ff; text-decoration-color: #0000ff}
17+
.r13 {color: #f6c177; text-decoration-color: #f6c177}
18+
.r14 {color: #eb6f92; text-decoration-color: #eb6f92; font-weight: bold}
19+
body {
20+
color: #000000;
21+
background-color: #ffffff;
22+
}
23+
</style>
24+
</head>
25+
<html>
26+
<body>
27+
<code>
28+
<pre style="font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span class="r1">─────────────────────────────────────────────────────────────────────────────────────────────────── </span>🍇 Linear Regression - Wine data<span class="r1"> ────────────────────────────────────────────────────────────────────────────────────────────────────</span>
29+
<span class="r2">📆 November </span><span class="r3">02</span><span class="r2"> </span><span class="r3">2022</span><span class="r2"> </span><span class="r4">02:04:18</span>
30+
🐼 <span class="r2">Created by</span> <span class="r5">Ludek Cizinsky</span>
31+
<span class="r1">──────────────────────────────────────────────────────────────────────────────────────────────────── </span>🚧 Prepare input for the model<span class="r1"> ─────────────────────────────────────────────────────────────────────────────────────────────────────</span>
32+
<span class="r6">🐍 Load and split data</span>
33+
<span class="r6">🐍 Train test split</span>
34+
<span class="r6">🐍 Process the data</span>
35+
<span class="r1">────────────────────────────────────────────────────────────────────────────────────────────────────────── </span>🤖 Train the model<span class="r1"> ───────────────────────────────────────────────────────────────────────────────────────────────────────────</span>
36+
╭──── table of training ────╮ ╭───── training information ──────╮
37+
│ ┏━━━━━━━┳━━━━━━━━┳━━━━━━┓ │ │ │
38+
│ ┃<span class="r7"> Epoch </span><span class="r7"> Loss </span><span class="r7"> MAEr </span>┃ │ │ <span class="r8">Hyper-parameters</span>
39+
│ ┡━━━━━━━╇━━━━━━━━╇━━━━━━┩ │ │ │
40+
│ │<span class="r2"> 00001 </span><span class="r9"> 02.670 </span><span class="r10"> 0.82 </span>│ │ │ Following hyper-parameters have │
41+
│ │<span class="r2"> 00002 </span><span class="r9"> 00.727 </span><span class="r10"> 0.62 </span>│ │ │ been used: │
42+
│ │<span class="r2"> 00003 </span><span class="r9"> 00.304 </span><span class="r10"> 0.43 </span>│ │ │ │
43+
│ │<span class="r2"> 00004 </span><span class="r9"> 00.207 </span><span class="r10"> 0.49 </span>│ │ │ <span class="r11"></span>Epochs: 25 │
44+
│ │<span class="r2"> 00005 </span><span class="r9"> 00.195 </span><span class="r10"> 0.33 </span>│ │ │ <span class="r11"></span>Loss func: mse │
45+
│ │<span class="r2"> 00006 </span><span class="r9"> 00.142 </span><span class="r10"> 0.43 </span>│ │ │ <span class="r11"></span>Batch size: 29 │
46+
│ │<span class="r2"> 00007 </span><span class="r9"> 00.137 </span><span class="r10"> 0.46 </span>│ │ │ <span class="r11"></span>LR: 0.15 │
47+
│ │<span class="r2"> 00008 </span><span class="r9"> 00.193 </span><span class="r10"> 0.28 </span>│ │ │ │
48+
│ │<span class="r2"> 00009 </span><span class="r9"> 00.108 </span><span class="r10"> 0.28 </span>│ │ │ │
49+
│ │<span class="r2"> 00010 </span><span class="r9"> 00.087 </span><span class="r10"> 0.23 </span>│ │ │ <span class="r8">Training plot</span>
50+
│ │<span class="r2"> 00011 </span><span class="r9"> 00.093 </span><span class="r10"> 0.29 </span>│ │ │ │
51+
│ │<span class="r2"> 00012 </span><span class="r9"> 00.093 </span><span class="r10"> 0.26 </span>│ │ │ 📈 See training plot <a class="r12" href="figures/training.png">here</a>
52+
│ │<span class="r2"> 00013 </span><span class="r9"> 00.085 </span><span class="r10"> 0.26 </span>│ │ ╰─────────────────────────────────╯
53+
│ │<span class="r2"> 00014 </span><span class="r9"> 00.080 </span><span class="r10"> 0.24 </span>│ │
54+
│ │<span class="r2"> 00015 </span><span class="r9"> 00.118 </span><span class="r10"> 0.45 </span>│ │
55+
│ │<span class="r2"> 00016 </span><span class="r9"> 00.148 </span><span class="r10"> 0.24 </span>│ │
56+
│ │<span class="r2"> 00017 </span><span class="r9"> 00.088 </span><span class="r10"> 0.40 </span>│ │
57+
│ │<span class="r2"> 00018 </span><span class="r9"> 00.105 </span><span class="r10"> 0.43 </span>│ │
58+
│ │<span class="r2"> 00019 </span><span class="r9"> 00.126 </span><span class="r10"> 0.26 </span>│ │
59+
│ │<span class="r2"> 00020 </span><span class="r9"> 00.078 </span><span class="r10"> 0.24 </span>│ │
60+
│ │<span class="r2"> 00021 </span><span class="r9"> 00.076 </span><span class="r10"> 0.27 </span>│ │
61+
│ │<span class="r2"> 00022 </span><span class="r9"> 00.070 </span><span class="r10"> 0.22 </span>│ │
62+
│ │<span class="r2"> 00023 </span><span class="r9"> 00.066 </span><span class="r10"> 0.31 </span>│ │
63+
│ │<span class="r2"> 00024 </span><span class="r9"> 00.120 </span><span class="r10"> 0.52 </span>│ │
64+
│ │<span class="r2"> 00025 </span><span class="r9"> 00.118 </span><span class="r10"> 0.23 </span>│ │
65+
│ └───────┴────────┴──────┘ │
66+
╰───────────────────────────╯
67+
<span class="r1">───────────────────────────────────────────────────────────────────────────────────────────────────────── </span>🔮 Validate the model<span class="r1"> ─────────────────────────────────────────────────────────────────────────────────────────────────────────</span>
68+
🚥 <span class="r13">MSE</span>: <span class="r14">3.6739563478051167</span>
69+
</pre>
70+
</code>
71+
</body>
72+
</html>

nnlearn/linear/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from ._linr import LinearRegression
2+
3+
__all__ = [
4+
"LinearRegression"
5+
]

nnlearn/linear/_linr.py

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import numpy as np
2+
from rich.progress import track
3+
4+
from nnlearn.reporting import GdReport
5+
from nnlearn.nanograd import Var
6+
from nnlearn.base import GdBase
7+
8+
class LinearRegression(GdBase, GdReport):
9+
10+
def __init__(self,
11+
optimizer='gd_backp',
12+
loss_func='mse',
13+
epochs=50,
14+
batch_size=1.0,
15+
shuffle=False,
16+
lr=.01,
17+
bias=True,
18+
figpath=""):
19+
20+
# Common attributes to models optimized via GD
21+
GdBase.__init__(self,
22+
batch_size,
23+
shuffle,
24+
loss_func,
25+
epochs,
26+
lr)
27+
28+
# Reporting
29+
GdReport.__init__(self, figpath, 'reg')
30+
31+
# LR specific
32+
self.optimizer = optimizer
33+
self.bias = bias
34+
self._theta = None
35+
36+
def _zero_grads(self):
37+
for w in self._theta[:, 0]:
38+
w.grad = 0
39+
40+
def _update_weights(self):
41+
for w in self._theta[:, 0]:
42+
w.v -= self.lr * w.grad
43+
44+
def _forward(self, X):
45+
return (X @ self._theta)[:, 0]
46+
47+
def _train(self):
48+
for epoch in track(range(1, self.epochs + 1), "Training..."):
49+
self._reshuffle()
50+
X_batches, y_batches = self._get_batches()
51+
52+
batch = 1
53+
losses = []
54+
for X, y in zip(X_batches, y_batches):
55+
56+
# Predict
57+
yhat = self._forward(X)
58+
59+
# Compute loss based on the prediction
60+
loss = self.loss_func(y, yhat)
61+
losses.append(loss.v)
62+
63+
# reset gradients of variables to zero
64+
self._zero_grads()
65+
66+
# backward propagate
67+
loss.backward()
68+
69+
# update weights
70+
self._update_weights()
71+
72+
# Increase batch number
73+
batch += 1
74+
75+
# Epoch evaluation
76+
yhat_train = self._arr_to_val(self._forward(self.Xv))
77+
y_train = self._arr_to_val(self.yv)
78+
self.eval_epoch(epoch, losses, y_train, yhat_train)
79+
80+
self.create_report(self.loss_func_name, self.batch_size, self.lr)
81+
82+
def _initialize_parameters(self, m):
83+
m = m + 1 if self.bias else m
84+
self._theta = np.random.normal(0, 1, m).reshape(-1, 1)
85+
if self.optimizer == 'gd_backp':
86+
self._theta = self._arr_to_var(self._theta)
87+
88+
def _add_constant_column(self, X):
89+
if self.bias:
90+
X = np.hstack((X, np.ones((X.shape[0], 1), dtype=X.dtype)))
91+
return X
92+
93+
def fit(self, X, y):
94+
self._initialize_parameters(X.shape[1])
95+
X = self._add_constant_column(X)
96+
self._preprocessing(X, y)
97+
self._train()
98+
99+
def predict(self, X):
100+
X = self._add_constant_column(X)
101+
Xv = self._arr_to_var(X)
102+
yhat = self._arr_to_val(self._forward(Xv))
103+
return yhat

nnlearn/metrics/_regression.py

+23-7
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
"""
1010

1111
import numpy as np
12+
from nnlearn.nanograd import Var
1213

13-
def squared_error(y, p):
14+
def squared_error(y, p, var):
1415
"""Squared error.
1516
1617
Sqaured error can be defined as follows:
@@ -38,26 +39,41 @@ def squared_error(y, p):
3839
-----
3940
Usually used for regression problems.
4041
"""
41-
return np.sum((y - p)**2)
42+
if var:
43+
return (y - p).sqr()
44+
else:
45+
return np.sum((y - p)**2)
4246

43-
def mean_squared_error(y, p):
47+
def mean_squared_error(Y, P, var=True):
4448
"""Mean of squared error
4549
4650
Parameters
4751
----------
48-
y : :class:`ndarray`
52+
Y : :class:`ndarray`
4953
One dimensional array with ground truth values.
5054
51-
p : :class:`ndarray`
55+
P : :class:`ndarray`
5256
One dimensional array with predicted values.
5357
5458
Returns
5559
-------
5660
float
5761
Mean squared error.
5862
"""
59-
n = y.shape[0]
60-
return squared_error(y, p)/n
63+
64+
if var:
65+
n = Var(Y.shape[0])
66+
total = Var(0)
67+
for i in range(n.v):
68+
y = Y[i] # true class
69+
yhat = P[i]
70+
total += squared_error(y, yhat, var)
71+
72+
return total/n
73+
74+
else:
75+
n = Y.shape[0]
76+
return squared_error(Y, P, var)
6177

6278
def absolute_error(y, p):
6379

nnlearn/nanograd/_nanograd.py

+3
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ def log2(self) -> 'Var':
5959
def exp(self) -> 'Var':
6060
return Var(np.exp(self.v), [(self, np.exp(self.v))])
6161

62+
def sqr(self) -> 'Var':
63+
return Var(self.v**2, [(self, 2*self.v)])
64+
6265
def __repr__(self):
6366
return "Var(v=%.4f, grad=%.4f)" % (self.v, self.grad)
6467

nnlearn/network/_ffnn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(self,
6060
lr)
6161

6262
# Reporting
63-
GdReport.__init__(self, figpath)
63+
GdReport.__init__(self, figpath, 'clf')
6464

6565
# FFNN specific
6666
self.layers = layers

0 commit comments

Comments
 (0)