-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathlasso.py
54 lines (40 loc) · 1.3 KB
/
lasso.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import math
import time
import torch
import surecr
import linops as lo
d = 250
seed = d + 7
r2 = 0.8
p = 2 * d
variance = 2.
def run_experiment():
X, y, lambda_val = generate_data()
def prox(v, t):
return torch.relu(v - lambda_val * t) - torch.relu(-v - lambda_val * t)
A = lo.aslinearoperator(X.cuda()) # Construct linear operator from 2D tensor.
y_cuda = y.cuda()
torch.manual_seed(d + 76)
t0 = time.monotonic()
solver = surecr.FISTASolver(
A, prox, torch.zeros(p).cuda(), device=y_cuda.device)
sure = surecr.SURE(variance, solver)
sure_val = sure.compute(y_cuda)
tf = time.monotonic()
print(f"SURE: {sure_val} time: {tf - t0}")
def generate_data():
# Generate data
torch.manual_seed(seed)
beta_p = torch.zeros((p))
beta_p[:d//20] = 1.
X = torch.randn((d, p))
mu_p_norm = torch.linalg.vector_norm(X @ beta_p)
beta_p_scale = torch.sqrt((1 - r2) * mu_p_norm**2 / d / variance)
beta = beta_p / beta_p_scale
y = X @ beta + math.sqrt(variance) * torch.randn(d)
y_for_lambda = X @ beta + math.sqrt(variance) * torch.randn(d)
lambda_max = torch.linalg.vector_norm(X.T @ y_for_lambda, torch.inf)
lambda_val = (0.1 * lambda_max).cuda()
return X, y, lambda_val
if __name__ == '__main__':
run_experiment()