Skip to content

Commit

Permalink
allow tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
gjhuizing committed Jun 1, 2022
1 parent 9c3bf2e commit 349a341
Show file tree
Hide file tree
Showing 2 changed files with 226 additions and 76 deletions.
110 changes: 110 additions & 0 deletions tests/main_api_with_tuples_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from context import wsingular
from context import utils
from context import distance

import torch

# Define the dtype and device to work with.
dtype = torch.double
device = "cpu"

# Define the dimensions of our problem.
n_samples = 15
n_features = 20

# Initialize an empty dataset.
dataset = torch.zeros((n_samples, n_features), dtype=dtype)

# Iterate over the features and samples.
for i in range(n_samples):
for j in range(n_features):

# Fill the dataset with translated histograms.
dataset[i, j] = i / n_samples - j / n_features
dataset[i, j] = torch.abs(dataset[i, j] % 1)

# Take the distance to 0 on the torus.
dataset = torch.min(dataset, 1 - dataset)

# Make it a guassian.
dataset = torch.exp(-(dataset**2) / 0.1)

# Compute the normalizations.
A, B = wsingular.utils.normalize_dataset(dataset, dtype=dtype, device=device)

def test_wasserstein_singular_vectors():

# Compute the WSV.
C, D = wsingular.wasserstein_singular_vectors(
(A, B),
n_iter=10,
dtype=dtype,
device=device,
progress_bar=True,
tau=1e-3,
)

# Assert positivity of C.
assert torch.sum(C < 0) == 0

# Assert positivity of D.
assert torch.sum(D < 0) == 0


def test_sinkhorn_singular_vectors():

# Compute the SSV.
C, D = wsingular.sinkhorn_singular_vectors(
(A, B),
eps=5e-2,
dtype=dtype,
device=device,
n_iter=10,
progress_bar=True,
tau=1e-3,
)

# Assert positivity of C.
assert torch.sum(C < 0) == 0

# Assert positivity of D.
assert torch.sum(D < 0) == 0


def test_stochastic_wasserstein_singular_vectors():

# Compute the WSV.
C, D = wsingular.stochastic_wasserstein_singular_vectors(
(A, B),
n_iter=20,
dtype=dtype,
device=device,
progress_bar=True,
tau=1e-3,
)

# Assert positivity of C.
assert torch.sum(C < 0) == 0

# Assert positivity of D.
assert torch.sum(D < 0) == 0


def test_stochastic_sinkhorn_singular_vectors():

# Compute the SSV.
C, D = wsingular.stochastic_sinkhorn_singular_vectors(
(A, B),
eps=5e-2,
dtype=dtype,
device=device,
n_iter=20,
progress_bar=True,
tau=1e-3,
)

# Assert positivity of C.
assert torch.sum(C < 0) == 0

# Assert positivity of D.
assert torch.sum(D < 0) == 0
192 changes: 116 additions & 76 deletions wsingular/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def wasserstein_singular_vectors(
n_iter: int,
tau: float = 0,
p: int = 1,
writer = None,
writer=None,
small_value: float = 1e-6,
normalization_steps: int = 1,
C_ref: torch.Tensor = None,
Expand All @@ -24,7 +24,7 @@ def wasserstein_singular_vectors(
"""Performs power iterations and returns Wasserstein Singular Vectors. Early stopping is possible with Ctrl-C.
Args:
dataset (torch.Tensor): The input dataset, rows as samples
dataset (torch.Tensor): The input dataset, rows as samples. Alternatively, you can give a tuple of tensors (A, B).
dtype (str): The dtype
device (str): The device
n_iter (int): The number of power iterations.
Expand All @@ -43,22 +43,32 @@ def wasserstein_singular_vectors(
"""

# Perform some sanity checks.
assert len(dataset.shape) == 2 # correct shape
assert torch.sum(dataset < 0) == 0 # positivity
assert n_iter > 0 # at least one iteration
assert tau >= 0 # a positive regularization
assert p > 0 # a valid norm
assert small_value > 0 # a positive numerical offset
assert normalization_steps > 0 # normalizing at least once

# Make the transposed datasets A and B from the dataset.
A, B = utils.normalize_dataset(
dataset,
normalization_steps=normalization_steps,
small_value=small_value,
dtype=dtype,
device=device,
)
assert n_iter > 0 # at least one iteration
assert tau >= 0 # a positive regularization
assert p > 0 # a valid norm
assert small_value > 0 # a positive numerical offset
assert normalization_steps > 0 # normalizing at least once

if type(dataset) is tuple:
assert len(dataset) == 2 # correct shape

A, B = dataset # Recover A and B

assert torch.sum(A < 0) == 0 # positivity
assert torch.sum(B < 0) == 0 # positivity

else:
assert len(dataset.shape) == 2 # correct shape
assert torch.sum(dataset < 0) == 0 # positivity

# Make the transposed datasets A and B from the dataset.
A, B = utils.normalize_dataset(
dataset,
normalization_steps=normalization_steps,
small_value=small_value,
dtype=dtype,
device=device,
)

# Compute the regularization matrices.
R_A = utils.regularization_matrix(A, p=p, dtype=dtype, device=device)
Expand Down Expand Up @@ -145,7 +155,7 @@ def sinkhorn_singular_vectors(
tau: float = 0,
eps: float = 5e-2,
p: int = 1,
writer = None,
writer=None,
small_value: float = 1e-6,
normalization_steps: int = 1,
C_ref: torch.Tensor = None,
Expand All @@ -156,7 +166,7 @@ def sinkhorn_singular_vectors(
"""Performs power iterations and returns Sinkhorn Singular Vectors. Early stopping is possible with Ctrl-C.
Args:
dataset (torch.Tensor): The input dataset
dataset (torch.Tensor): The input dataset. Alternatively, you can give a tuple of tensors (A, B).
dtype (str): The dtype
device (str): The device
n_iter (int): The number of power iterations.
Expand All @@ -176,23 +186,33 @@ def sinkhorn_singular_vectors(
"""

# Perform some sanity checks.
assert len(dataset.shape) == 2 # correct shape
assert torch.sum(dataset < 0) == 0 # positivity
assert n_iter > 0 # at least one iteration
assert tau >= 0 # a positive regularization
assert eps >= 0 # a positive entropic regularization
assert p > 0 # a valid norm
assert small_value > 0 # a positive numerical offset
assert normalization_steps > 0 # normalizing at least once

# Make the transposed datasets A and B from the dataset U.
A, B = utils.normalize_dataset(
dataset,
normalization_steps=normalization_steps,
small_value=small_value,
dtype=dtype,
device=device,
)
assert n_iter > 0 # at least one iteration
assert tau >= 0 # a positive regularization
assert eps >= 0 # a positive entropic regularization
assert p > 0 # a valid norm
assert small_value > 0 # a positive numerical offset
assert normalization_steps > 0 # normalizing at least once

if type(dataset) is tuple:
assert len(dataset) == 2 # correct shape

A, B = dataset # Recover A and B

assert torch.sum(A < 0) == 0 # positivity
assert torch.sum(B < 0) == 0 # positivity

else:
assert len(dataset.shape) == 2 # correct shape
assert torch.sum(dataset < 0) == 0 # positivity

# Make the transposed datasets A and B from the dataset.
A, B = utils.normalize_dataset(
dataset,
normalization_steps=normalization_steps,
small_value=small_value,
dtype=dtype,
device=device,
)

# Compute the regularization matrices.
R_A = utils.regularization_matrix(A, p=p, dtype=dtype, device=device)
Expand Down Expand Up @@ -282,7 +302,7 @@ def stochastic_wasserstein_singular_vectors(
p: int = 1,
step_fn: Callable = lambda k: 1 / np.sqrt(k),
mult_update: bool = False,
writer = None,
writer=None,
small_value: float = 1e-6,
normalization_steps: int = 1,
C_ref: torch.Tensor = None,
Expand All @@ -292,7 +312,7 @@ def stochastic_wasserstein_singular_vectors(
"""Performs stochastic power iterations and returns Wasserstein Singular Vectors. Early stopping is possible with Ctrl-C.
Args:
dataset (torch.Tensor): The input dataset
dataset (torch.Tensor): The input dataset. Alternatively, you can give a tuple of tensors (A, B).
dtype (torch.dtype): The dtype
device (str): The device
n_iter (int): The number of power iterations.
Expand All @@ -313,23 +333,33 @@ def stochastic_wasserstein_singular_vectors(
"""

# Perform some sanity checks.
assert len(dataset.shape) == 2 # correct shape
assert torch.sum(dataset < 0) == 0 # positivity
assert n_iter > 0 # at least one iteration
assert tau >= 0 # a positive regularization
assert 0 < sample_prop <= 1 # a valid proportion
assert p > 0 # a valid norm
assert small_value > 0 # a positive numerical offset
assert normalization_steps > 0 # normalizing at least once

# Make the transposed datasets A and B from the dataset U.
A, B = utils.normalize_dataset(
dataset,
normalization_steps=normalization_steps,
small_value=small_value,
dtype=dtype,
device=device,
)
assert n_iter > 0 # at least one iteration
assert tau >= 0 # a positive regularization
assert 0 < sample_prop <= 1 # a valid proportion
assert p > 0 # a valid norm
assert small_value > 0 # a positive numerical offset
assert normalization_steps > 0 # normalizing at least once

if type(dataset) is tuple:
assert len(dataset) == 2 # correct shape

A, B = dataset # Recover A and B

assert torch.sum(A < 0) == 0 # positivity
assert torch.sum(B < 0) == 0 # positivity

else:
assert len(dataset.shape) == 2 # correct shape
assert torch.sum(dataset < 0) == 0 # positivity

# Make the transposed datasets A and B from the dataset.
A, B = utils.normalize_dataset(
dataset,
normalization_steps=normalization_steps,
small_value=small_value,
dtype=dtype,
device=device,
)

# Compute the regularization matrices.
R_A = utils.regularization_matrix(A, p=p, dtype=dtype, device=device)
Expand Down Expand Up @@ -484,7 +514,7 @@ def stochastic_sinkhorn_singular_vectors(
p: int = 1,
step_fn: Callable = lambda k: 1 / np.sqrt(k),
mult_update: bool = False,
writer = None,
writer=None,
small_value: float = 1e-6,
normalization_steps: int = 1,
C_ref: torch.Tensor = None,
Expand All @@ -494,7 +524,7 @@ def stochastic_sinkhorn_singular_vectors(
"""Performs stochastic power iterations and returns Sinkhorn Singular Vectors. Early stopping is possible with Ctrl-C.
Args:
dataset (torch.Tensor): The input dataset
dataset (torch.Tensor): The input dataset. Alternatively, you can give a tuple of tensors (A, B).
dtype (torch.dtype): The dtype
device (str): The device
n_iter (int): The number of power iterations.
Expand All @@ -516,24 +546,34 @@ def stochastic_sinkhorn_singular_vectors(
"""

# Perform some sanity checks.
assert len(dataset.shape) == 2 # correct shape
assert torch.sum(dataset < 0) == 0 # positivity
assert n_iter > 0 # at least one iteration
assert tau >= 0 # a positive regularization
assert 0 < sample_prop <= 1 # a valid proportion
assert eps >= 0 # a positive entropic regularization
assert p > 0 # a valid norm
assert small_value > 0 # a positive numerical offset
assert normalization_steps > 0 # normalizing at least once

# Make the transposed datasets A and B from the dataset U.
A, B = utils.normalize_dataset(
dataset,
normalization_steps=normalization_steps,
small_value=small_value,
dtype=dtype,
device=device,
)
assert n_iter > 0 # at least one iteration
assert tau >= 0 # a positive regularization
assert 0 < sample_prop <= 1 # a valid proportion
assert eps >= 0 # a positive entropic regularization
assert p > 0 # a valid norm
assert small_value > 0 # a positive numerical offset
assert normalization_steps > 0 # normalizing at least once

if type(dataset) is tuple:
assert len(dataset) == 2 # correct shape

A, B = dataset # Recover A and B

assert torch.sum(A < 0) == 0 # positivity
assert torch.sum(B < 0) == 0 # positivity

else:
assert len(dataset.shape) == 2 # correct shape
assert torch.sum(dataset < 0) == 0 # positivity

# Make the transposed datasets A and B from the dataset.
A, B = utils.normalize_dataset(
dataset,
normalization_steps=normalization_steps,
small_value=small_value,
dtype=dtype,
device=device,
)

# Compute the regularization matrices.
R_A = utils.regularization_matrix(A, p=p, dtype=dtype, device=device)
Expand Down

0 comments on commit 349a341

Please # to comment.