Skip to content

Commit

Permalink
Use Dictionary lookup for supplied IDs to Embedding Operator (#148)
Browse files Browse the repository at this point in the history
* Use lookup dict for Embedding operator with ids to speed up transform

* Add test for embedding operator with unknown value

* Update typehint for unknown_value

* Remove embedding tag from input schema to embeddings tests

The op now adds the embedding tag automatically, and the input schema
cannot have both a CATEGORICAL and EMBEDDING tag

* Flatten array passed  as ids to EmbeddingOperator

* Add assertion for ids shape

* Set default value of `embedding_index_mapping`

* Correct casing of message in test
  • Loading branch information
oliverholworthy authored May 12, 2023
1 parent dbe0ade commit ec9bedf
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 15 deletions.
78 changes: 68 additions & 10 deletions merlin/dataloader/ops/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from typing import Optional, Union

import numpy as np
Expand All @@ -25,9 +26,9 @@


class EmbeddingOperator(BaseOperator):
"""Create an operator that will apply a torch embedding table to supplied indices.
This operator allows the user to supply an id lookup table if the indices supplied
via the id_lookup_table.
"""Create an operator that will apply an embedding table to supplied indices.
An id lookup table for the embeddings can be supplied with the argument `id_lookup_table`.
Parameters
----------
Expand All @@ -39,6 +40,12 @@ class EmbeddingOperator(BaseOperator):
name of new column of embeddings, added to output, by default "embeddings"
id_lookup_table : np.array, optional
numpy array of values that represent embedding indices, by default None
mmap : bool, default False
When loading embeddings from a file, specify whether we should memory map the file
This is useful for accessing a large file without reading the entire file into memory.
unknown_value : Union[float, int, np.ndarray]
If an embedding index is not found.
Specifies the value should we return for the corresponding embedding.
"""

def __init__(
Expand All @@ -47,24 +54,75 @@ def __init__(
lookup_key: str = "id",
embedding_name: str = "embeddings",
id_lookup_table: Optional[Union[np.ndarray, str]] = None,
mmap=False,
mmap: bool = False,
unknown_value: Union[float, int, np.ndarray] = 0,
):
if mmap:
embeddings = np.load(embeddings, mmap_mode="r")
id_lookup_table = np.load(id_lookup_table) if id_lookup_table else None
if isinstance(embeddings, (str, os.PathLike)):
mmap_mode = "r" if mmap else None
embeddings = np.load(embeddings, mmap_mode=mmap_mode)
elif isinstance(embeddings, np.ndarray):
pass
else:
raise ValueError(
f"Unsupported type '{type(embeddings)}' passed to argument `embeddings` "
f"of '{type(self).__name__}'. "
"Expected either a numpy.ndarray "
"or a (string or pathlike object corresponding to a numpy file) "
"containing embeddings. "
)
self.embeddings = embeddings

embedding_index_mapping = None
if isinstance(id_lookup_table, (str, os.PathLike)):
_ids = np.load(id_lookup_table)
embedding_index_mapping = self._get_embedding_index_mapping(_ids)
elif isinstance(id_lookup_table, np.ndarray):
_ids = id_lookup_table
embedding_index_mapping = self._get_embedding_index_mapping(_ids)
elif id_lookup_table is None:
pass
else:
raise ValueError(
f"Unsupported type '{type(id_lookup_table)}' passed to argument `id_lookup_table` "
f"of '{type(self).__name__}'. "
"Expected either a numpy.ndarray "
"or a (string or pathlike object corresponding to a numpy file) "
"containing the IDs that correspond to the embeddings. "
)
self.embedding_index_mapping = embedding_index_mapping

self.lookup_key = lookup_key
self.embedding_name = embedding_name
self.id_lookup_table = id_lookup_table
self.unknown_value = unknown_value

def _get_embedding_index_mapping(self, ids):
expected_ids_shape = (self.embeddings.shape[0],)
assert ids.shape == expected_ids_shape, (
"IDs provided must match the number of embeddings. "
f"Expected IDs with shape {expected_ids_shape} "
f"Received IDs with shape: {ids.shape} "
f"Embeddings shape: {self.embeddings.shape} "
)
id_to_index_mapping = dict(zip(ids, range(len(ids))))
return id_to_index_mapping

def transform(
self, col_selector: ColumnSelector, transformable: Transformable
) -> Transformable:
keys = transformable[self.lookup_key]
indices = keys.cpu().values
if self.id_lookup_table is not None:
indices = np.in1d(self.id_lookup_table, indices)

if self.embedding_index_mapping is not None:
indices = np.array([self.embedding_index_mapping.get(_id, -1) for _id in indices])

embeddings = self.embeddings[indices]

# set unknown embedding to zero
for idx in np.ndindex(indices.shape):
embedding_index = indices[idx]
if embedding_index == -1:
embeddings[idx] = self.unknown_value

embeddings_col = TensorColumn(embeddings, offsets=keys.cpu().offsets)
transformable[self.embedding_name] = (
embeddings_col.gpu() if keys.device == Device.GPU else embeddings_col
Expand Down
66 changes: 61 additions & 5 deletions tests/unit/dataloader/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pandas as pd
import pytest

from merlin.core.compat import cupy
from merlin.core.dispatch import HAS_GPU
from merlin.dataloader.loader_base import LoaderBase as Loader # noqa
from merlin.dataloader.ops.embeddings import EmbeddingOperator
Expand All @@ -29,6 +30,59 @@
from merlin.table import TensorColumn, TensorTable


def test_embeddings_invalid_ids():
ids = np.array(["a", "b"])
embeddings = np.random.rand(3, 10)
with pytest.raises(AssertionError) as exc_info:
EmbeddingOperator(
embeddings,
lookup_key="id",
embedding_name="id_embedding",
id_lookup_table=ids,
)
assert "IDs provided must match the number of embeddings" in str(exc_info.value)
assert "Expected IDs with shape (3,)" in str(exc_info.value)


@pytest.mark.parametrize("unknown_value", [0, 1, np.random.uniform(size=10)])
def test_embedding_lookup_with_unknown_value(unknown_value):
ids = np.array(["a", "b", "c"])
embeddings = np.random.rand(3, 10)
df = pd.DataFrame(
{
"id": ["a", "unknown"],
"feature": [1, 2],
}
)

dataset = Dataset(df, cpu=True)

data_loader = Loader(
dataset,
batch_size=3,
transforms=[
EmbeddingOperator(
embeddings,
lookup_key="id",
embedding_name="id_embedding",
id_lookup_table=ids,
unknown_value=unknown_value,
),
],
shuffle=False,
)
x, y = data_loader.peek()

assert x["id"].values.shape == (2,)
embedding_values = x["id_embedding"].values
if cupy and isinstance(embedding_values, cupy.ndarray):
embedding_values = embedding_values.get()
assert embedding_values.shape == (2, 10)
np.testing.assert_equal(embedding_values[0], embeddings[0])
np.testing.assert_equal(embedding_values[1], unknown_value)
assert data_loader.output_schema.column_names == ["id", "feature", "id_embedding"]


def test_embedding_with_target():
id_embeddings = np.random.rand(1000, 10)
df = pd.DataFrame(
Expand Down Expand Up @@ -116,7 +170,7 @@ def test_embedding_np_mmap_dl_with_lookup(tmpdir, rev_embedding_ids, np_embeddin
dataset = dataset.repartition(10)
schema = dataset.schema
for col_name in cat_names:
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING])
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL])
dataset.schema = schema

data_loader = Loader(
Expand Down Expand Up @@ -148,7 +202,7 @@ def test_embedding_np_dl_no_lookup(tmpdir, embedding_ids, embeddings_from_datafr
dataset = dataset.repartition(10)
schema = dataset.schema
for col_name in cat_names:
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING])
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL])
dataset.schema = schema
paths = sorted(glob.glob(f"{embeddings_from_dataframe}/*"))
embeddings_ds = Dataset(paths)
Expand Down Expand Up @@ -183,15 +237,17 @@ def test_embedding_np_dl_with_lookup(tmpdir, rev_embedding_ids, embeddings_from_
dataset = dataset.repartition(10)
schema = dataset.schema
for col_name in cat_names:
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING])
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL])
dataset.schema = schema
paths = sorted(glob.glob(f"{embeddings_from_dataframe}/*"))
embeddings_ds = Dataset(paths)
embeddings_np = embeddings_ds.to_ddf().compute().to_numpy()[:, 1:]
data_loader = Loader(
dataset,
batch_size=batch_size,
transforms=[EmbeddingOperator(embeddings_np, id_lookup_table=embedding_ids.to_numpy())],
transforms=[
EmbeddingOperator(embeddings_np, id_lookup_table=embedding_ids.to_numpy().ravel())
],
shuffle=False,
device=cpu,
)
Expand Down Expand Up @@ -222,7 +278,7 @@ def test_embedding_np_dl_with_lookup_ragged(
dataset = dataset.repartition(10)
schema = dataset.schema
for col_name in cat_names:
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING])
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL])
dataset.schema = schema
paths = sorted(glob.glob(f"{embeddings_from_dataframe}/*"))
embeddings_ds = Dataset(paths)
Expand Down

0 comments on commit ec9bedf

Please # to comment.