Skip to content

Commit

Permalink
fixes for tikz and expectation
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasahle committed Jan 22, 2025
1 parent c61985f commit e20d64b
Show file tree
Hide file tree
Showing 11 changed files with 108 additions and 73 deletions.
18 changes: 8 additions & 10 deletions examples/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,16 @@ def main():

layers = []

def conv_layer(i, in_channels, out_channels):
def conv_layer(channels: int):
# Declare heigth and weidth convolutions
c_in, c_out = symbols(f"c{i} c{i+1}")
h_in, h_out, w_in, w_out = symbols(f"h{i} h{i+1}, w{i} w{i+1}")
i = len(layers)
c_in, c_out, h_in, h_out, w_in, w_out = symbols(f"c{i} c{i+1} h{i} h{i+1}, w{i} w{i+1}")
h_conv = F.Convolution(h_in, h_out, hk=kernel_size)
w_conv = F.Convolution(w_in, w_out, wk=kernel_size)
# Declare kernel variable and add it to the layers list
kernel = Variable(f"kernel_{i}", c_in, c_out, hk=kernel_size, wk=kernel_size)
# Save the layer and shapes of the inner dimensions
layers.append(kernel)
# Save the shapes of the inner dimensions
shapes[c_in] = in_channels
shapes[c_out] = out_channels
shapes[c_out] = channels
shapes[h_out] = shapes[h_in] - shapes[kernel_size] + 1
shapes[w_out] = shapes[w_in] - shapes[kernel_size] + 1
# Apply the convolution
Expand All @@ -59,15 +57,15 @@ def conv_layer(i, in_channels, out_channels):
x = data

if False:
x = F.relu(x @ conv_layer(0, 1, 2)).simplify()
x = F.relu(x @ conv_layer(1, 2, 3)).simplify()
x = F.relu(x @ conv_layer(channels=2)).simplify()
x = F.relu(x @ conv_layer(channels=3)).simplify()
c2, h2, w2, c3 = symbols("c2 h2 w2 c3")
shapes[c3] = 3 * 24**2 # c2*w2*h2
layers.append(linear := Variable("lin", c2, h2, w2, out))
logits = x @ linear

elif True:
x = F.relu(x @ conv_layer(0, 1, 2)).simplify()
x = F.relu(x @ conv_layer(channels=2)).simplify()
c1, h1, w1, c2 = symbols("c1 h1 w1 c2")
shapes[c2] = 2 * 26**2 # c1*w1*h1
layers.append(linear := Variable("lin", c1, h1, w1, out))
Expand Down
13 changes: 12 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,5 +242,16 @@ def main14():
save_steps(expr)
print(to_tikz(expr.full_simplify()))

def main15():
d = symbols("d")
gs = [Variable(f"g{k}", i=d) for k in range(1)]
Ms = [Delta(d, "i", "j") - g @ g.rename(i='j') / Delta(d) for g in gs]
A = F.multi_dot(Ms * 2, dims=("i", "j"))
assert A.edges == {"i", "j"}, A.edges
B = A @ A
for g in gs:
B = Expectation(B, g).full_simplify()
save_steps(B)

if __name__ == "__main__":
notebook3()
main15()
48 changes: 24 additions & 24 deletions tensorgrad/extras/expectation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from sympy import Symbol
import torch
from tensorgrad.tensor import (
Constant,
Derivative,
Product,
Rename,
Expand Down Expand Up @@ -170,16 +169,13 @@ def simplify(self, args=None):

if isinstance(inner, Sum):
return Sum(
[
Expectation(t, self.wrt, self.mu, self.covar, self.covar_names).simplify(args)
for t in inner.tensors
],
[Expectation(t, self.wrt, self.mu, self.covar, self.covar_names) for t in inner.tensors],
inner.weights,
)

if isinstance(inner, Rename):
return Rename(
Expectation(inner.tensor, self.wrt, self.mu, self.covar, self.covar_names).simplify(args),
Expectation(inner.tensor, self.wrt, self.mu, self.covar, self.covar_names),
inner.mapping,
)

Expand Down Expand Up @@ -232,35 +228,39 @@ def simplify(self, args=None):
self.covar_names,
)
assert res.edges == self.edges, f"{res.edges=} != {self.edges=}"
return res.simplify(args=args)
return res

# Look for a power function with exponent >= 1 and pull out a factor
if (
fn := next(
(
t
for t in prod.tensors
if isinstance(t, Function)
and isinstance(t.signature, _PowerFunction)
and t.signature.k >= 1
),
None,
True
and (
fn := next(
(
t
for t in prod.tensors
if isinstance(t, Function)
and isinstance(t.signature, _PowerFunction)
and t.signature.k >= 1
),
None,
)
)
) is not None:
assert isinstance(fn, Function)
is not None
):
subs = prod.tensors[:]
subs.remove(fn)
(inner,) = fn.inputs
subs.append(inner * fn.weight) # We pull the weight out as well
subs.append(inner)
if fn.signature.k > 1:
subs.append(pow(inner, fn.signature.k - 1))
return Expectation(Product(subs), self.wrt, self.mu, self.covar, self.covar_names).simplify(
args=args
)
res = Expectation(Product(subs), self.wrt, self.mu, self.covar, self.covar_names)
return res

# Otherwise we look for constant factors to pull out
elif args.get("extract_constants_from_expectation") and any(
not t.depends_on(self.wrt) for t in prod.tensors
elif (
False
and args.get("extract_constants_from_expectation")
and any(not t.depends_on(self.wrt) for t in prod.tensors)
):
# Separate into constant and wrt-dependent factors
constant_terms, wrt_terms = [], []
Expand Down
34 changes: 30 additions & 4 deletions tensorgrad/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,10 +341,36 @@ def mean(tensor: Tensor, dim: DimType = None, keepdims: bool = False) -> Tensor:
return s / normalization


def dot(t1: Tensor, t2: Tensor, dim: DimType = None) -> Tensor:
"""Contract two tensors along the given dimensions, broadcasting over the remaining shared edges."""
dim = parse_dim(t1.edges & t2.edges, dim)
return sum(t1 * t2, dim)
def dot(t1: Tensor, t2: Tensor, dim: str | tuple[str, str]) -> Tensor:
"""Contract two tensors along the given dimensions, broadcasting over the remaining shared edges.
If the dimension is a tuple of two strings, the first string is the dimension of t1
and the second is the dimension of t2."""
if isinstance(dim, str):
dim = (dim, dim)
if len(dim) != 2 or not all(isinstance(d, str) for d in dim):
raise ValueError(f"Dot product requires one or two dimensions, got {dim=}")
if dim[0] not in t1.edges or dim[1] not in t2.edges:
raise ValueError(f"Edges {dim} must be in the tensors")
free_name = dim[0] + "_"
while free_name in t1.edges | t2.edges:
free_name += "_"
prod = t1.rename(**{dim[0]: free_name}) * t2.rename(**{dim[1]: free_name})
return sum(prod, free_name)


def multi_dot(ts: list[Tensor], dims: tuple[str, str]) -> Tensor:
"""
Compute the dot product of two or more tensors.
"""
if not ts:
return Ones()
prod = ts[0]
assert isinstance(prod, Tensor)
for t2 in ts[1:]:
# We use the right edge of t and the left edge of t2
prod = dot(prod, t2, dim=(dims[1], dims[0]))
assert isinstance(prod, Tensor)
return prod


class _ScaleFunction(FunctionSignature):
Expand Down
2 changes: 0 additions & 2 deletions tensorgrad/serializers/to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,6 @@ def forward_with_values(values: dict[Variable, torch.Tensor], shapes: dict[sympy
raise KeyError(f"No value provided for variable {var.name}")
# Ensure the torch tensors follow the order of the variable edges.
# The code will assume that this is always the case.
# print(f"{values[var].names=}, {var.edges=} {var.orig=}")
# local_ns[placeholder_name] = values[var].align_to(*(var.orig[e] for e in var.edges))
local_ns[placeholder_name] = values[var].align_to(*var.edges).rename(None)

# Now call `_generated_forward(batch, w0, out, _var_..., _var_...)`
Expand Down
31 changes: 17 additions & 14 deletions tensorgrad/serializers/to_tikz.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def add_node(self, name: str, node_type: str, label: str = None, degree: int = N
if node_type == "identity":
label = format_label(label)
label_str = f"${label}$" if label else ""
self.lines.append(f" {name}[identity, as={label_str},{nudge}];")
self.lines.append(f" {name}[identity, as={{}}, {nudge}, pin=45:{label_str}];")

elif node_type == "zero":
self.lines.append(f" {name}[zero, as=0,{nudge}];")
Expand Down Expand Up @@ -171,20 +171,20 @@ def add_node(self, name: str, node_type: str, label: str = None, degree: int = N
# Fallback
self.lines.append(f" {name}[as=${label}$,{nudge}];")

def add_edge(self, ref1: NodeRef, ref2: NodeRef, label: str, directed=False, multiplicity=1):
def add_edge(self, ref1: NodeRef, ref2: NodeRef, directed=False, multiplicity=1):
"""
Add an edge between two NodeRefs. We honor any edge_style stored in
the NodeRef. If both have styles, we prefer the second's or combine them?
"""
# Extract the internal node names:
id1 = ref1.name
id2 = ref2.name

labels = set()
for edge_label in [ref1.edge_label, label, ref2.edge_label]:
# We use a list because python sets don't keep insertion order like dicts
labels = []
for edge_label in [ref1.edge_label, ref2.edge_label]:
if edge_label:
formatted_label = format_label(edge_label)
labels.add(f"${formatted_label}$")
labels.append(f"${formatted_label}$")
if labels and labels[0] == labels[-1]:
labels = labels[:1]

# Combine or choose an edge style:
style = ref1.edge_style or ref2.edge_style or ""
Expand Down Expand Up @@ -212,7 +212,7 @@ def add_edge(self, ref1: NodeRef, ref2: NodeRef, label: str, directed=False, mul
label_str = f', "{labels[0]}" at start, "{labels[-1]}" at end'

self.lines.append(
f" ({id1}){edge_type}[{style_str}, bend left={angle}, {side} {label_str}] ({id2});"
f" ({ref1.name}){edge_type}[{style_str}, bend left={angle}, {side} {label_str}] ({ref2.name});"
)

def add_subgraph(self, subgraph: "TikzGraph", cluster_id: str, *, style: str = None, layout: str = None):
Expand Down Expand Up @@ -241,7 +241,8 @@ def handle_free_edges(self, free_edges: dict):
name = self.namer.fresh_name("free")
self.add_node(name, "invisible")
# Now connect from node_ref -> dummy with the label "e":
self.add_edge(node_ref, NodeRef(name, edge_label=e), label="")
print("Adding free edge", node_ref, NodeRef(name, edge_label=e))
self.add_edge(node_ref, NodeRef(name, edge_label=e))

###############################################################################
# Singledispatch for each tensor type
Expand All @@ -254,7 +255,8 @@ def _to_tikz(self, tensor):
def _(self, tensor: Delta):
# Make one node
name = self.namer.fresh_name("copy")
self.add_node(name, "identity")
label = tensor._size.name if tensor.order == 0 else None
self.add_node(name, "identity", label=label)
# Return that node for every edge
return {e: NodeRef(name) for e in tensor.edges}

Expand Down Expand Up @@ -314,7 +316,7 @@ def _(self, tensor: Function):
# connect these subedges to func_ref
for e in input_edges:
sub_ref = subedges.pop(e)
subgraph.add_edge(sub_ref, NodeRef(func_node), label=e, directed=True)
subgraph.add_edge(sub_ref, NodeRef(func_node, edge_label=e), directed=True)
# everything else remains free
free_edges |= subedges

Expand Down Expand Up @@ -356,7 +358,7 @@ def _(self, tensor: Expectation):
def _(self, tensor: Product):
# If empty product, return an identity node
if len(tensor.tensors) == 0:
self.add_node(self.namer.fresh_name("id"), "identity")
self.add_node(self.namer.fresh_name("id"), "identity", label="1")
return {}

# Gather sub-ids for each edge
Expand All @@ -378,7 +380,6 @@ def _(self, tensor: Product):
self.add_edge(
refs[0],
refs[1],
label="",
multiplicity=multiplicity,
)

Expand Down Expand Up @@ -590,5 +591,7 @@ def choose_layout(depth):
midway,
inner sep=1pt,
},
pin distance=.5ex,
every pin/.style={font=\small\itshape}
]
"""
7 changes: 3 additions & 4 deletions tensorgrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def is_isomorphic(
G1, _ = self.edge_structural_graph(match_edges=match_edges, edge_names=edge_names)
G2, _ = other.edge_structural_graph(match_edges=match_edges, edge_names=edge_names)
return nx.is_isomorphic(G1, G2, node_match=lambda n1, n2: n1.get("name") == n2.get("name"))
# return nx.vf2pp_is_isomorphic(G1, G2, node_label="name")

def isomorphisms(self, other: "Tensor") -> Generator[dict[str, str], None, None]:
"""Given self and other are isomorphic, this method returns a dictionary that renames self into other."""
Expand Down Expand Up @@ -305,8 +306,6 @@ def evaluate(
return res

res = self._inner_evaluate(values, dims)
if torch.isnan(res.rename(None)).any():
print(res)
assert not torch.isnan(res.rename(None)).any(), f"Got NaN in result in {self}"
# We guarantee that inner_evaluate returns the edges in the same order as self.edges
assert res.names == tuple(self.edges), f"Expected {self.edges=} but got {res.names=}"
Expand Down Expand Up @@ -1079,9 +1078,9 @@ def grad(self, x: Variable, new_names: Optional[dict[str, str]] = None) -> Tenso

# The two parts are then multiplied together on the connection names,
# while broadcasted on their remaining shared edges.
from .functions import dot # Import here to avoid circular import
import tensorgrad.functions as F # Import here to avoid circular import

part = dot(outside, inner, connection_names.values())
part = F.sum(outside * inner, connection_names.values())
parts.append(part)
res = Sum(parts)
assert res.edges == self.edges | new_edges, f"{res.edges} != {self.edges} | {new_edges}"
Expand Down
12 changes: 6 additions & 6 deletions tests/test_expectation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def simple_variables():

def test_names1():
x, mu, covar, covar_names = simple_variables()
res = Expectation(x, x, mu, covar, covar_names).simplify()
res = Expectation(x, x, mu, covar, covar_names).full_simplify()
assert res == mu


Expand All @@ -30,7 +30,7 @@ def test_names2():
# The expectation of a X transposed should be mu transposed
xt = x.rename(i="j", j="i")
mut = mu.rename(i="j", j="i")
res = Expectation(xt, x, mu, covar, covar_names).simplify()
res = Expectation(xt, x, mu, covar, covar_names).full_simplify()
assert res == mut


Expand All @@ -39,7 +39,7 @@ def test_names3():
zero = Zero(**x.shape)
# The expectation of the outer product x (x) x2 should be covar if mu = 0
x2 = x.rename(i="i2", j="j2")
res = Expectation(x @ x2, x, zero, covar, covar_names).simplify()
res = Expectation(x @ x2, x, zero, covar, covar_names).full_simplify()
assert res == covar


Expand All @@ -51,7 +51,7 @@ def test_names4():
x2t = xt.rename(i="i2", j="j2")
covart = covar.rename(i="j", j="i", i2="j2", j2="i2")
ex = Expectation(xt @ x2t, x, zero, covar, covar_names)
res = ex.simplify()
res = ex.full_simplify()
assert res == covart


Expand All @@ -68,7 +68,7 @@ def test_quadratic():
expr = X.rename(i="i0", j="j") @ A @ X.rename(j="j1", i="i")
assert expr.edges == {"i0", "i"}

res = Expectation(expr, X, mu, covar, {"i": "i_", "j": "j_"}).simplify().evaluate(ts)
res = Expectation(expr, X, mu, covar, {"i": "i_", "j": "j_"}).full_simplify().evaluate(ts)
expected = ts[A].rename(None).trace() * torch.eye(2).rename("i0", "i") # trace(A) * I
assert_close(res, expected)

Expand Down Expand Up @@ -179,7 +179,7 @@ def test_x():
# E[out]_p2 = E[out_p2] = E[S_{p2}^2] = 1
expr = Product([S, S1, Delta(p0, "p0, p1, p2")])
expr = Expectation(expr, S)
assert expr.simplify() == Delta(p0, "p2")
assert expr.full_simplify() == Delta(p0, "p2")


def test_triple_S():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def test_dot():
tB = torch.randn(2, 3, names=("i", "j"))

# Contract along 'j'
expr = F.dot(A, B, ["j"])
expr = F.dot(A, B, "j")
out = expr.evaluate({A: tA, B: tB})

# Behavior check: sum-of-products along 'j':
Expand Down
Loading

0 comments on commit e20d64b

Please # to comment.