Skip to content

Commit

Permalink
faster linear gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
lehner committed Jan 16, 2025
1 parent 0caf060 commit 4e24cf8
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 9 deletions.
50 changes: 43 additions & 7 deletions lib/gpt/ml/layer/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,38 @@
from gpt.ml.layer import base


def projector_color_trace(x):
return g.color_trace(x)

cache = {}
def projector_color_trace(a, b):
cache_tag = f"{a.otype.__name__}_{a.grid}_{b.otype.__name__}_{b.grid}"

if cache_tag not in cache:
ti = g.stencil.tensor_instructions
Ns = a.otype.spin_ndim
Nc = a.otype.color_ndim
code = []
for spin1 in range(Ns):
for spin2 in range(Ns):
for color in range(Nc):
aa = spin1 * Nc + color
bb = spin2 * Nc + color
dst = spin1 * Ns + spin2
code.append((0, dst, ti.mov_cc if color == 0 else ti.inc_cc, 1.0, [(2, 0, bb), (1, 0, aa)]))

res = g(g.color_trace(a * g.adj(b)))
res2 = g.lattice(res)
segments = [(len(code) // (Ns * Ns), Ns * Ns)]
ein = g.stencil.tensor(res2, [(0, 0, 0, 0)], code, segments)
ein(res2, a, b)

eps2 = g.norm2(res - res2) / g.norm2(res)
assert eps2 < 1e-10

cache[cache_tag] = (res2, ein)

res2, ein = cache[cache_tag]
res3 = g.lattice(res2)
ein(res3, a, b)
return res3

class linear(base):
def __init__(
Expand Down Expand Up @@ -80,17 +109,24 @@ def _contract(self, w, f):
return ret

def __call__(self, weights, layer_input):
t = g.timer("linear")
t("weights")
layer_input = g.util.to_list(layer_input)
w = self._get_weight_list(weights)
return self._contract(w, layer_input)

t("contract")
x = self._contract(w, layer_input)
t()
if g.default.is_verbose("linear_performance"):
g.message(t)
return x

def projected_gradient_adj(self, weights, layer_input, left):
layer_input = g.util.to_list(layer_input)
left = g.util.to_list(left)

assert len(weights) == 1

t = g.timer("projected_gradient_adj")
t = g.timer("linear.projected_gradient_adj")
t("weight list")
w = self._get_weight_list(weights)
t("field list")
Expand All @@ -104,7 +140,7 @@ def projected_gradient_adj(self, weights, layer_input, left):
for i in range(len(left)):
for j in range(n):
t("sums")
ip_left_f = g.sum(self.projector(left[i] * g.adj(layer_input[j])))
ip_left_f = g.sum(self.projector(left[i], layer_input[j]))
pos = n * i + j
if pos not in self.access_cache:
self.access_cache[pos] = {}
Expand Down
13 changes: 11 additions & 2 deletions lib/gpt/ml/layer/parallel_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,13 @@ def _get_field_list(self, layer_input, ttr):
return ret_f

def __call__(self, weights, layer_input):
return self._get_field_list(layer_input, self.transport)
t = g.timer("parallel_transport")
t("get fields")
x = self._get_field_list(layer_input, self.transport)
t()
if g.default.is_verbose("parallel_transport_performance"):
g.message(t)
return x

def projected_gradient_adj(self, weights, layer_input, left):
left = g.util.to_list(left)
Expand All @@ -73,7 +79,7 @@ def projected_gradient_adj(self, weights, layer_input, left):
assert len(left) == len(layer_input)
assert len(left) == 1 + len(self.paths)

t = g.timer("projected_gradient_adj")
t = g.timer("parallel_transport.projected_gradient_adj")
t("field list")
if self.itransport is None:
self.itransport = [
Expand All @@ -86,4 +92,7 @@ def projected_gradient_adj(self, weights, layer_input, left):

t()

if g.default.is_verbose("parallel_transport_performance"):
g.message(t)

return [ileft]

0 comments on commit 4e24cf8

Please # to comment.