Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add huggingface support for inferring audio/image columns. #307

Merged
merged 2 commits into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 107 additions & 39 deletions demo/audio-embed.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@
{
"data": {
"text/plain": [
"(APIInfo(api=<fastapi.applications.FastAPI object at 0x164c221f0>, port=5000, server=<meerkat.interactive.server.Server object at 0x164f5cdc0>, name='127.0.0.1', shared=False, process=None, _url=None),\n",
" FrontendInfo(package_manager='npm', port=8000, name='localhost', shared=False, process=<subprocess.Popen object at 0x164f8ad00>, _url=None))"
"(APIInfo(api=<fastapi.applications.FastAPI object at 0x16b817d90>, port=5000, server=<meerkat.interactive.server.Server object at 0x12077ef10>, name='127.0.0.1', shared=False, process=None, _url=None),\n",
" FrontendInfo(package_manager='npm', port=8001, name='localhost', shared=False, process=<Popen: returncode: None args: ['python', '-m', 'http.server', '8001']>, _url=None))"
]
},
"execution_count": 2,
Expand Down Expand Up @@ -82,29 +82,56 @@
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\">[03/12/23 11:54:17] </span><span style=\"color: #800000; text-decoration-color: #800000\">WARNING </span> <span style=\"font-weight: bold\">[</span><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">download_and_prepare</span><span style=\"font-weight: bold\">()]</span> <span style=\"font-weight: bold\">[</span>datasets.builder: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">798</span><span style=\"font-weight: bold\">]</span> :: Found cached <a href=\"file:///Users/arjundd/miniconda3/envs/meerkat_prod/lib/python3.8/site-packages/datasets/builder.py\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">builder.py</span></a><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">:</span><a href=\"file:///Users/arjundd/miniconda3/envs/meerkat_prod/lib/python3.8/site-packages/datasets/builder.py#798\" target=\"_blank\"><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">798</span></a>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> dataset parquet <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"font-weight: bold\">(</span><span style=\"color: #800080; text-decoration-color: #800080\">/Users/arjundd/.cache/huggingface/datasets/lewtun___parquet/lewtun--mu</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #800080; text-decoration-color: #800080\">sic_genres_small-2686d03f87ff3ace/0.0.0/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">2a3b91fbd88a2c90d1dbbb32b460cf6</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"<span style=\"color: #7fbfbf; text-decoration-color: #7fbfbf\"> </span> <span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">21d31bd5b05b934492fdef7d8d6f236ec</span><span style=\"font-weight: bold\">)</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span>\n",
"</pre>\n"
],
"application/json": {
"ascii": false,
"bar_format": null,
"colour": null,
"elapsed": 0.018494844436645508,
"initial": 0,
"n": 0,
"ncols": null,
"nrows": null,
"postfix": null,
"prefix": "Downloading readme",
"rate": null,
"total": 487,
"unit": "B",
"unit_divisor": 1000,
"unit_scale": true
},
"application/vnd.jupyter.widget-view+json": {
"model_id": "5e1f344e17434976887db524d2c1ad5a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"\u001b[2;36m[03/12/23 11:54:17]\u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m \u001b[1m[\u001b[0m\u001b[1;35mdownload_and_prepare\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m \u001b[1m[\u001b[0mdatasets.builder: \u001b[1;36m798\u001b[0m\u001b[1m]\u001b[0m :: Found cached \u001b]8;id=244572;file:///Users/arjundd/miniconda3/envs/meerkat_prod/lib/python3.8/site-packages/datasets/builder.py\u001b\\\u001b[2mbuilder.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=761482;file:///Users/arjundd/miniconda3/envs/meerkat_prod/lib/python3.8/site-packages/datasets/builder.py#798\u001b\\\u001b[2m798\u001b[0m\u001b]8;;\u001b\\\n",
"\u001b[2;36m \u001b[0m dataset parquet \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m \u001b[1m(\u001b[0m\u001b[35m/Users/arjundd/.cache/huggingface/datasets/lewtun___parquet/lewtun--mu\u001b[0m \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m \u001b[35msic_genres_small-2686d03f87ff3ace/0.0.0/\u001b[0m\u001b[95m2a3b91fbd88a2c90d1dbbb32b460cf6\u001b[0m \u001b[2m \u001b[0m\n",
"\u001b[2;36m \u001b[0m \u001b[95m21d31bd5b05b934492fdef7d8d6f236ec\u001b[0m\u001b[1m)\u001b[0m \u001b[2m \u001b[0m\n"
"Downloading readme: 0%| | 0.00/487 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/json": {
"ascii": false,
"bar_format": null,
"colour": null,
"elapsed": 0.008691787719726562,
"initial": 0,
"n": 0,
"ncols": null,
"nrows": null,
"postfix": null,
"prefix": "",
"rate": null,
"total": 1,
"unit": "it",
"unit_divisor": 1000,
"unit_scale": false
},
"application/vnd.jupyter.widget-view+json": {
"model_id": "8dd26908aa7742b68faf7ba8d90189b2",
"model_id": "094b5d2c21de460fadcbb7bcfbdb80fc",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -119,7 +146,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/arjundd/miniconda3/envs/meerkat_prod/lib/python3.8/site-packages/meerkat/columns/scalar/arrow.py:205: UserWarning: Unable to check if column is a valid primary key: Function 'unique' has no kernel matching input types (struct<bytes: binary, path: string>)\n",
"/Users/sabrieyuboglu/code/meerkat/meerkat/columns/scalar/arrow.py:205: UserWarning: Unable to check if column is a valid primary key: Function 'unique' has no kernel matching input types (struct<bytes: binary, path: string>)\n",
" warnings.warn(f\"Unable to check if column is a valid primary key: {e}\")\n"
]
}
Expand Down Expand Up @@ -196,13 +223,30 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"application/json": {
"ascii": false,
"bar_format": null,
"colour": null,
"elapsed": 0.007481098175048828,
"initial": 0,
"n": 0,
"ncols": null,
"nrows": null,
"postfix": null,
"prefix": "Downloading",
"rate": null,
"total": 2875055,
"unit": "B",
"unit_divisor": 1000,
"unit_scale": true
},
"application/vnd.jupyter.widget-view+json": {
"model_id": "e8898afe833e4e1a8dba42136988b07e",
"model_id": "558f8ebcbff240cab96eff2057184780",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -301,6 +345,15 @@
"Then, we will use UMAP to decompose the embeddings."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"plot_df = df.merge(df_embed, on=\"song_id\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
Expand All @@ -310,20 +363,10 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/arjundd/miniconda3/envs/meerkat_prod/lib/python3.8/site-packages/meerkat/ops/merge.py:151: FutureWarning: iteritems is deprecated and will be removed in a future version. Use .items instead.\n",
" for name, column in merged_df.iteritems():\n"
"OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.\n"
]
}
],
"source": [
"plot_df = df.merge(df_embed, on=\"song_id\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Compute umap of embeddings. This may take a few seconds.\n",
"from umap import UMAP\n",
Expand All @@ -340,22 +383,47 @@
"metadata": {},
"outputs": [],
"source": [
"plot_df = plot_df.mark()\n",
"plot_df = plot_df.mark()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <iframe\n",
" width=\"100%\"\n",
" height=\"600px\"\n",
" src=\"http://localhost:8001/?id=flexd05c2cab-9b27-4b77-9035-2ec8a2e30193\"\n",
" frameborder=\"0\"\n",
" allowfullscreen\n",
" \n",
" ></iframe>\n",
" "
],
"text/plain": [
"<IPython.lib.display.IFrame at 0x29764b400>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot = mk.gui.plotly.ScatterPlot(df=plot_df, x=\"umap_1\", y=\"umap_2\",)\n",
"\n",
"# Because we're using the reactive decorator, the filter function will re-run whenever\n",
"# plot.selected changes. This will update the gallery to only show the selected points.\n",
"@mk.gui.reactive\n",
"def filter(selected: list, df: mk.DataFrame):\n",
" return df[df.primary_key.isin(selected)]\n",
"\n",
"filtered_df = filter(plot.selected, plot_df)\n",
"table = mk.gui.Table(filtered_df)\n",
"table = mk.gui.Table(filtered_df, classes=\"h-full\")\n",
"\n",
"mk.gui.html.flexcol(\n",
" [plot, table],\n",
" classes=\"h-[1200px]\",\n",
")"
"mk.gui.html.flex([plot, table], classes=\"h-[600px]\") "
]
},
{
Expand All @@ -382,7 +450,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
"version": "3.9.12"
},
"orig_nbformat": 4,
"vscode": {
Expand Down
1 change: 1 addition & 0 deletions meerkat/columns/deferred/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ def load_audio(path: str) -> Audio:
"loader": load_audio,
"formatters": DeferredAudioFormatterGroup,
"exts": [".wav", ".mp3"],
"defer": False,
},
}

Expand Down
61 changes: 59 additions & 2 deletions meerkat/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,18 +677,75 @@ def from_huggingface(cls, *args, **kwargs):
"""
import datasets

datasets.logging.set_verbosity_error()
import pyarrow.compute as pc

# Load the dataset
dataset = datasets.load_dataset(*args, **kwargs)

def _convert_columns(dataset: datasets.Dataset):
df = cls.from_arrow(dataset._data.table)
for name, feature in dataset.features.items():
if isinstance(feature, datasets.Audio):

column = df[name]
bytes = ArrowScalarColumn(pc.struct_field(column._data, "bytes"))
path = ArrowScalarColumn(pc.struct_field(column._data, "path"))
if (~bytes.isnull()).all():
from meerkat.interactive.formatter import AudioFormatterGroup

df[name] = (
df[name]
.defer(lambda x: x["bytes"])
.format(AudioFormatterGroup())
)
elif (~path.isnull()).all():
from meerkat.columns.deferred.file import FileColumn

df[name] = FileColumn(path, type="audio")
else:
raise ValueError(
"Huggingface column must either provide bytes or path for "
"every row."
)
elif isinstance(feature, datasets.Image):
column = df[name]
bytes = ArrowScalarColumn(pc.struct_field(column._data, "bytes"))
path = ArrowScalarColumn(pc.struct_field(column._data, "path"))
if (~ArrowScalarColumn(bytes).isnull()).all():
import io

from PIL import Image

from meerkat.interactive.formatter import ImageFormatterGroup

df[name] = bytes.defer(
lambda x: Image.open(io.BytesIO(x))
).format(ImageFormatterGroup().defer())
elif (~path.isnull()).all():
from meerkat.columns.deferred.file import FileColumn

df[name] = FileColumn(path, type="image")
else:
raise ValueError(
"Huggingface column must either provide bytes or path for "
"every row."
)

return df

if isinstance(dataset, dict):
return dict(
map(
lambda t: (t[0], cls.from_arrow(t[1]._data.table)),
lambda t: (
t[0],
_convert_columns(t[1]),
),
dataset.items(),
)
)
else:
return cls.from_arrow(dataset._data)
return _convert_columns(dataset)

@classmethod
# @capture_provenance(capture_args=["filepath"])
Expand Down
9 changes: 8 additions & 1 deletion meerkat/interactive/formatter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
from .base import Formatter, deferred_formatter_group
from .boolean import BooleanFormatter, BooleanFormatterGroup
from .code import CodeFormatter, CodeFormatterGroup
from .image import ImageFormatter, ImageFormatterGroup
from .image import (
DeferredImageFormatter,
DeferredImageFormatterGroup,
ImageFormatter,
ImageFormatterGroup,
)
from .number import NumberFormatter, NumberFormatterGroup
from .pdf import PDFFormatter, PDFFormatterGroup
from .raw_html import HTMLFormatter, HTMLFormatterGroup
Expand All @@ -14,6 +19,8 @@
"deferred_formatter_group",
"ImageFormatter",
"ImageFormatterGroup",
"DeferredImageFormatter",
"DeferredImageFormatterGroup",
"TextFormatter",
"TextFormatterGroup",
"NumberFormatter",
Expand Down