diff --git a/merlin/dataloader/ops/embeddings.py b/merlin/dataloader/ops/embeddings.py index e364ef88..f1470841 100644 --- a/merlin/dataloader/ops/embeddings.py +++ b/merlin/dataloader/ops/embeddings.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os from typing import Optional, Union import numpy as np @@ -25,9 +26,9 @@ class EmbeddingOperator(BaseOperator): - """Create an operator that will apply a torch embedding table to supplied indices. - This operator allows the user to supply an id lookup table if the indices supplied - via the id_lookup_table. + """Create an operator that will apply an embedding table to supplied indices. + + An id lookup table for the embeddings can be supplied with the argument `id_lookup_table`. Parameters ---------- @@ -39,6 +40,12 @@ class EmbeddingOperator(BaseOperator): name of new column of embeddings, added to output, by default "embeddings" id_lookup_table : np.array, optional numpy array of values that represent embedding indices, by default None + mmap : bool, default False + When loading embeddings from a file, specify whether we should memory map the file + This is useful for accessing a large file without reading the entire file into memory. + unknown_value : Union[float, int, np.ndarray] + If an embedding index is not found. + Specifies the value should we return for the corresponding embedding. """ def __init__( @@ -47,24 +54,75 @@ def __init__( lookup_key: str = "id", embedding_name: str = "embeddings", id_lookup_table: Optional[Union[np.ndarray, str]] = None, - mmap=False, + mmap: bool = False, + unknown_value: Union[float, int, np.ndarray] = 0, ): - if mmap: - embeddings = np.load(embeddings, mmap_mode="r") - id_lookup_table = np.load(id_lookup_table) if id_lookup_table else None + if isinstance(embeddings, (str, os.PathLike)): + mmap_mode = "r" if mmap else None + embeddings = np.load(embeddings, mmap_mode=mmap_mode) + elif isinstance(embeddings, np.ndarray): + pass + else: + raise ValueError( + f"Unsupported type '{type(embeddings)}' passed to argument `embeddings` " + f"of '{type(self).__name__}'. " + "Expected either a numpy.ndarray " + "or a (string or pathlike object corresponding to a numpy file) " + "containing embeddings. " + ) self.embeddings = embeddings + + embedding_index_mapping = None + if isinstance(id_lookup_table, (str, os.PathLike)): + _ids = np.load(id_lookup_table) + embedding_index_mapping = self._get_embedding_index_mapping(_ids) + elif isinstance(id_lookup_table, np.ndarray): + _ids = id_lookup_table + embedding_index_mapping = self._get_embedding_index_mapping(_ids) + elif id_lookup_table is None: + pass + else: + raise ValueError( + f"Unsupported type '{type(id_lookup_table)}' passed to argument `id_lookup_table` " + f"of '{type(self).__name__}'. " + "Expected either a numpy.ndarray " + "or a (string or pathlike object corresponding to a numpy file) " + "containing the IDs that correspond to the embeddings. " + ) + self.embedding_index_mapping = embedding_index_mapping + self.lookup_key = lookup_key self.embedding_name = embedding_name - self.id_lookup_table = id_lookup_table + self.unknown_value = unknown_value + + def _get_embedding_index_mapping(self, ids): + expected_ids_shape = (self.embeddings.shape[0],) + assert ids.shape == expected_ids_shape, ( + "IDs provided must match the number of embeddings. " + f"Expected IDs with shape {expected_ids_shape} " + f"Received IDs with shape: {ids.shape} " + f"Embeddings shape: {self.embeddings.shape} " + ) + id_to_index_mapping = dict(zip(ids, range(len(ids)))) + return id_to_index_mapping def transform( self, col_selector: ColumnSelector, transformable: Transformable ) -> Transformable: keys = transformable[self.lookup_key] indices = keys.cpu().values - if self.id_lookup_table is not None: - indices = np.in1d(self.id_lookup_table, indices) + + if self.embedding_index_mapping is not None: + indices = np.array([self.embedding_index_mapping.get(_id, -1) for _id in indices]) + embeddings = self.embeddings[indices] + + # set unknown embedding to zero + for idx in np.ndindex(indices.shape): + embedding_index = indices[idx] + if embedding_index == -1: + embeddings[idx] = self.unknown_value + embeddings_col = TensorColumn(embeddings, offsets=keys.cpu().offsets) transformable[self.embedding_name] = ( embeddings_col.gpu() if keys.device == Device.GPU else embeddings_col diff --git a/tests/unit/dataloader/test_embeddings.py b/tests/unit/dataloader/test_embeddings.py index f3dbb339..b55ff2ae 100644 --- a/tests/unit/dataloader/test_embeddings.py +++ b/tests/unit/dataloader/test_embeddings.py @@ -19,6 +19,7 @@ import pandas as pd import pytest +from merlin.core.compat import cupy from merlin.core.dispatch import HAS_GPU from merlin.dataloader.loader_base import LoaderBase as Loader # noqa from merlin.dataloader.ops.embeddings import EmbeddingOperator @@ -29,6 +30,59 @@ from merlin.table import TensorColumn, TensorTable +def test_embeddings_invalid_ids(): + ids = np.array(["a", "b"]) + embeddings = np.random.rand(3, 10) + with pytest.raises(AssertionError) as exc_info: + EmbeddingOperator( + embeddings, + lookup_key="id", + embedding_name="id_embedding", + id_lookup_table=ids, + ) + assert "IDs provided must match the number of embeddings" in str(exc_info.value) + assert "Expected IDs with shape (3,)" in str(exc_info.value) + + +@pytest.mark.parametrize("unknown_value", [0, 1, np.random.uniform(size=10)]) +def test_embedding_lookup_with_unknown_value(unknown_value): + ids = np.array(["a", "b", "c"]) + embeddings = np.random.rand(3, 10) + df = pd.DataFrame( + { + "id": ["a", "unknown"], + "feature": [1, 2], + } + ) + + dataset = Dataset(df, cpu=True) + + data_loader = Loader( + dataset, + batch_size=3, + transforms=[ + EmbeddingOperator( + embeddings, + lookup_key="id", + embedding_name="id_embedding", + id_lookup_table=ids, + unknown_value=unknown_value, + ), + ], + shuffle=False, + ) + x, y = data_loader.peek() + + assert x["id"].values.shape == (2,) + embedding_values = x["id_embedding"].values + if cupy and isinstance(embedding_values, cupy.ndarray): + embedding_values = embedding_values.get() + assert embedding_values.shape == (2, 10) + np.testing.assert_equal(embedding_values[0], embeddings[0]) + np.testing.assert_equal(embedding_values[1], unknown_value) + assert data_loader.output_schema.column_names == ["id", "feature", "id_embedding"] + + def test_embedding_with_target(): id_embeddings = np.random.rand(1000, 10) df = pd.DataFrame( @@ -116,7 +170,7 @@ def test_embedding_np_mmap_dl_with_lookup(tmpdir, rev_embedding_ids, np_embeddin dataset = dataset.repartition(10) schema = dataset.schema for col_name in cat_names: - schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING]) + schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL]) dataset.schema = schema data_loader = Loader( @@ -148,7 +202,7 @@ def test_embedding_np_dl_no_lookup(tmpdir, embedding_ids, embeddings_from_datafr dataset = dataset.repartition(10) schema = dataset.schema for col_name in cat_names: - schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING]) + schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL]) dataset.schema = schema paths = sorted(glob.glob(f"{embeddings_from_dataframe}/*")) embeddings_ds = Dataset(paths) @@ -183,7 +237,7 @@ def test_embedding_np_dl_with_lookup(tmpdir, rev_embedding_ids, embeddings_from_ dataset = dataset.repartition(10) schema = dataset.schema for col_name in cat_names: - schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING]) + schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL]) dataset.schema = schema paths = sorted(glob.glob(f"{embeddings_from_dataframe}/*")) embeddings_ds = Dataset(paths) @@ -191,7 +245,9 @@ def test_embedding_np_dl_with_lookup(tmpdir, rev_embedding_ids, embeddings_from_ data_loader = Loader( dataset, batch_size=batch_size, - transforms=[EmbeddingOperator(embeddings_np, id_lookup_table=embedding_ids.to_numpy())], + transforms=[ + EmbeddingOperator(embeddings_np, id_lookup_table=embedding_ids.to_numpy().ravel()) + ], shuffle=False, device=cpu, ) @@ -222,7 +278,7 @@ def test_embedding_np_dl_with_lookup_ragged( dataset = dataset.repartition(10) schema = dataset.schema for col_name in cat_names: - schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING]) + schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL]) dataset.schema = schema paths = sorted(glob.glob(f"{embeddings_from_dataframe}/*")) embeddings_ds = Dataset(paths)