Skip to content

Commit

Permalink
Partial copy of Doggettx's minimal memory requirement improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
swfsql committed Sep 5, 2022
1 parent 90679eb commit 9c9c1ea
Showing 1 changed file with 40 additions and 12 deletions.
52 changes: 40 additions & 12 deletions ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,23 +174,51 @@ def forward(self, x, context=None, mask=None):
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
del context, x

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)

if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
#if exists(mask):
# mask = rearrange(mask, 'b ... -> b (...)')

# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
stats = torch.cuda.memory_stats(q.device)
mem_total = torch.cuda.get_device_properties(0).total_memory
mem_active = stats['active_bytes.all.current']
mem_free = mem_total - mem_active

mem_required = q.shape[0] * q.shape[1] * k.shape[1] * 4 * 2.5
steps = 1

if mem_required > mem_free:
steps = 2**(math.ceil(math.log(mem_required / mem_free, 2)))

slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
s1 *= self.scale

#if exists(mask):
# max_neg_value = -torch.finfo(s1.dtype).max
# mask2 = repeat(mask, 'b j -> (b h) () j', h=h)
# s1.masked_fill_(~mask2, max_neg_value)
# del mask2

# attention, what we cannot get enough of
s2 = s1.softmax(dim=-1)
del s1

r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2


r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1

return self.to_out(r2)

out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)


class BasicTransformerBlock(nn.Module):
Expand Down Expand Up @@ -258,4 +286,4 @@ def forward(self, x, context=None):
x = block(x, context=context)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
x = self.proj_out(x)
return x + x_in
return x + x_in

0 comments on commit 9c9c1ea

Please # to comment.