Skip to content

Commit

Permalink
Merge pull request #16 from paxcema/dim_check
Browse files Browse the repository at this point in the history
Support either no categorical or no continuous input
  • Loading branch information
lucidrains authored Feb 21, 2023
2 parents 3a38072 + 544c518 commit 25707af
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 31 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'tab-transformer-pytorch',
packages = find_packages(),
version = '0.2.0',
version = '0.2.1',
license='MIT',
description = 'Tab Transformer - Pytorch',
author = 'Phil Wang',
Expand Down
38 changes: 24 additions & 14 deletions tab_transformer_pytorch/ft_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __init__(
):
super().__init__()
assert all(map(lambda n: n > 0, categories)), 'number of each category must be positive'
assert len(categories) + num_continuous > 0, 'input shape must not be null'

# categories related calculations

Expand All @@ -130,25 +131,29 @@ def __init__(

# for automatically offsetting unique category ids to the correct position in the categories embedding table

categories_offset = F.pad(torch.tensor(list(categories)), (1, 0), value = num_special_tokens)
categories_offset = categories_offset.cumsum(dim = -1)[:-1]
self.register_buffer('categories_offset', categories_offset)
if self.num_unique_categories > 0:
categories_offset = F.pad(torch.tensor(list(categories)), (1, 0), value = num_special_tokens)
categories_offset = categories_offset.cumsum(dim = -1)[:-1]
self.register_buffer('categories_offset', categories_offset)

# categorical embedding
# categorical embedding

self.categorical_embeds = nn.Embedding(total_tokens, dim)
self.categorical_embeds = nn.Embedding(total_tokens, dim)

# continuous

self.numerical_embedder = NumericalEmbedder(dim, num_continuous)
self.num_continuous = num_continuous

if self.num_continuous > 0:
self.numerical_embedder = NumericalEmbedder(dim, self.num_continuous)

# cls token

self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

# transformer

self.transformer = Transformer(
self.transformer = Transformer(
dim = dim,
depth = depth,
heads = heads,
Expand All @@ -166,23 +171,28 @@ def __init__(
)

def forward(self, x_categ, x_numer):
b = x_categ.shape[0]

assert x_categ.shape[-1] == self.num_categories, f'you must pass in {self.num_categories} values for your categories input'
x_categ += self.categories_offset

x_categ = self.categorical_embeds(x_categ)
xs = []
if self.num_unique_categories > 0:
x_categ += self.categories_offset

x_categ = self.categorical_embeds(x_categ)

xs.append(x_categ)

# add numerically embedded tokens
if self.num_continuous > 0:
x_numer = self.numerical_embedder(x_numer)

x_numer = self.numerical_embedder(x_numer)
xs.append(x_numer)

# concat categorical and numerical

x = torch.cat((x_categ, x_numer), dim = 1)
x = torch.cat(xs, dim = 1)

# append cls tokens

b = x.shape[0]
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)

Expand Down
43 changes: 27 additions & 16 deletions tab_transformer_pytorch/tab_transformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def __init__(
):
super().__init__()
assert all(map(lambda n: n > 0, categories)), 'number of each category must be positive'
assert len(categories) + num_continuous > 0, 'input shape must not be null'

# categories related calculations

Expand All @@ -162,18 +163,21 @@ def __init__(

# for automatically offsetting unique category ids to the correct position in the categories embedding table

categories_offset = F.pad(torch.tensor(list(categories)), (1, 0), value = num_special_tokens)
categories_offset = categories_offset.cumsum(dim = -1)[:-1]
self.register_buffer('categories_offset', categories_offset)
if self.num_unique_categories > 0:
categories_offset = F.pad(torch.tensor(list(categories)), (1, 0), value = num_special_tokens)
categories_offset = categories_offset.cumsum(dim = -1)[:-1]
self.register_buffer('categories_offset', categories_offset)

# continuous
self.num_continuous = num_continuous

if exists(continuous_mean_std):
assert continuous_mean_std.shape == (num_continuous, 2), f'continuous_mean_std must have a shape of ({num_continuous}, 2) where the last dimension contains the mean and variance respectively'
self.register_buffer('continuous_mean_std', continuous_mean_std)
if self.num_continuous > 0:
if exists(continuous_mean_std):
assert continuous_mean_std.shape == (num_continuous, 2), f'continuous_mean_std must have a shape of ({num_continuous}, 2) where the last dimension contains the mean and variance respectively'
self.register_buffer('continuous_mean_std', continuous_mean_std)

self.norm = nn.LayerNorm(num_continuous)

self.norm = nn.LayerNorm(num_continuous)
self.num_continuous = num_continuous

# transformer

Expand All @@ -198,20 +202,27 @@ def __init__(
self.mlp = MLP(all_dimensions, act = mlp_act)

def forward(self, x_categ, x_cont):
xs = []

assert x_categ.shape[-1] == self.num_categories, f'you must pass in {self.num_categories} values for your categories input'
x_categ += self.categories_offset

x = self.transformer(x_categ)
if self.num_unique_categories > 0:
x_categ += self.categories_offset

x = self.transformer(x_categ)

flat_categ = x.flatten(1)
flat_categ = x.flatten(1)
xs.append(flat_categ)

assert x_cont.shape[1] == self.num_continuous, f'you must pass in {self.num_continuous} values for your continuous input'

if exists(self.continuous_mean_std):
mean, std = self.continuous_mean_std.unbind(dim = -1)
x_cont = (x_cont - mean) / std
if self.num_continuous > 0:
if exists(self.continuous_mean_std):
mean, std = self.continuous_mean_std.unbind(dim = -1)
x_cont = (x_cont - mean) / std

normed_cont = self.norm(x_cont)
normed_cont = self.norm(x_cont)
xs.append(normed_cont)

x = torch.cat((flat_categ, normed_cont), dim = -1)
x = torch.cat(xs, dim = -1)
return self.mlp(x)

0 comments on commit 25707af

Please # to comment.