Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

LeRobotDataset v2.1 #711

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions examples/port_datasets/pusht_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

import numpy as np
import torch
from huggingface_hub import HfApi

from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw

PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface."
Expand Down Expand Up @@ -134,8 +136,8 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
if mode not in ["video", "image", "keypoints"]:
raise ValueError(mode)

if (LEROBOT_HOME / repo_id).exists():
shutil.rmtree(LEROBOT_HOME / repo_id)
if (HF_LEROBOT_HOME / repo_id).exists():
shutil.rmtree(HF_LEROBOT_HOME / repo_id)

if not raw_dir.exists():
download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw")
Expand All @@ -148,6 +150,10 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
action = zarr_data["action"][:]
image = zarr_data["img"] # (b, h, w, c)

if image.dtype == np.float32 and image.max() == np.float32(255):
# HACK: images are loaded as float32 but they actually encode uint8 data
image = image.astype(np.uint8)

episode_data_index = {
"from": np.concatenate(([0], zarr_data.meta["episode_ends"][:-1])),
"to": zarr_data.meta["episode_ends"],
Expand Down Expand Up @@ -180,6 +186,7 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
# Shift reward and success by +1 until the last item of the episode
"next.reward": reward[i + (frame_idx < num_frames - 1)],
"next.success": success[i + (frame_idx < num_frames - 1)],
"task": PUSHT_TASK,
}

frame["observation.state"] = torch.from_numpy(agent_pos[i])
Expand All @@ -191,12 +198,12 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T

dataset.add_frame(frame)

dataset.save_episode(task=PUSHT_TASK)

dataset.consolidate()
dataset.save_episode()

if push_to_hub:
dataset.push_to_hub()
hub_api = HfApi()
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")


if __name__ == "__main__":
Expand All @@ -218,5 +225,5 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
main(raw_dir, repo_id=repo_id, mode=mode)

# Uncomment if you want to load the local dataset and explore it
# dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True)
# dataset = LeRobotDataset(repo_id=repo_id)
# breakpoint()
15 changes: 15 additions & 0 deletions lerobot/common/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
# keys
import os
from pathlib import Path

from huggingface_hub.constants import HF_HOME

OBS_ENV = "observation.environment_state"
OBS_ROBOT = "observation.state"
OBS_IMAGE = "observation.image"
Expand All @@ -15,3 +20,13 @@
OPTIMIZER_STATE = "optimizer_state.safetensors"
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
SCHEDULER_STATE = "scheduler_state.json"

# cache dir
default_cache_path = Path(HF_HOME) / "lerobot"
HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()

if "LEROBOT_HOME" in os.environ:
raise ValueError(
f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n"
"'LEROBOT_HOME' is deprecated, please use 'HF_LEROBOT_HOME' instead."
)
40 changes: 40 additions & 0 deletions lerobot/common/datasets/backward_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import packaging.version

V2_MESSAGE = """
The dataset you requested ({repo_id}) is in {version} format.
We introduced a new format since v2.0 which is not backward compatible with v1.x.
Please, use our conversion script. Modify the following command with your own task description:
```
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
--repo-id {repo_id} \\
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
```
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.", "Insert the
peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.", "Open the top
cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped
target.", "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the
sweatshirt.", ...
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
"""

V21_MESSAGE = """
The dataset you requested ({repo_id}) is in {version} format.
While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
stats instead of per-episode stats. Update your dataset stats to the new format using this command:
```
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py --repo-id={repo_id}
```
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
"""


class BackwardCompatibilityError(Exception):
def __init__(self, repo_id: str, version: packaging.version.Version):
message = V2_MESSAGE.format(repo_id=repo_id, version=version)
super().__init__(message)
Loading
Loading