Skip to content

Commit

Permalink
v1.0.1 OakInkShape support cache
Browse files Browse the repository at this point in the history
  • Loading branch information
lixiny committed Feb 15, 2023
1 parent a57f99a commit 3a0ec46
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 23 deletions.
1 change: 1 addition & 0 deletions oikit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "1.0.1"
71 changes: 49 additions & 22 deletions oikit/oi_shape/oi_shape.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import hashlib
import os
import re

import json
import numpy as np
import torch
import trimesh
import pickle
from manotorch.manolayer import ManoLayer, MANOOutput
from oikit import __version__ as oikit_version
from oikit.common import suppress_trimesh_logging
from oikit.oi_shape.utils import (
ALL_CAT,
Expand All @@ -22,13 +24,12 @@

class OakInkShape:

def __init__(
self,
data_split=ALL_SPLIT,
intent_mode=list(ALL_INTENT),
category=ALL_CAT,
mano_assets_root="assets/mano_v1_2",
):
def __init__(self,
data_split=ALL_SPLIT,
intent_mode=list(ALL_INTENT),
category=ALL_CAT,
mano_assets_root="assets/mano_v1_2",
use_cache=True):
self.name = "OakInkShape"

assert 'OAKINK_DIR' in os.environ, "environment variable 'OAKINK_DIR' is not set"
Expand All @@ -45,11 +46,34 @@ def __init__(
self.intent_idx = [ALL_INTENT[i] for i in self.intent_mode]

self.mano_layer = ManoLayer(center_idx=0, mano_assets_root=mano_assets_root)

if use_cache is True:
cache_identifier_dict = {
"version": oikit_version,
"data_split": self.data_split,
"categories": self.categories,
"intent_mode": self.intent_mode
}
cache_identifier_raw = json.dumps(cache_identifier_dict, sort_keys=True)
cache_identifier = hashlib.md5(cache_identifier_raw.encode("ascii")).hexdigest()
cache_path = os.path.join(os.path.expanduser('~'), ".cache", self.name, oikit_version,
f"{cache_identifier}.pkl")
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
if os.path.exists(cache_path):
print(f"{self.name} loading cache from {cache_path}")
with open(cache_path, "rb") as p_f:
cache = pickle.load(p_f)
self.grasp_list = cache["grasp_list"]
self.obj_warehouse = cache["obj_warehouse"]
self.data_dir = data_dir
self.meta_dir = meta_dir
return

# * >>>> filter with regex
grasp_list = []
category_begin_idx = []
seq_cat_matcher = re.compile(r"(.+)/(.{6})_(.{4})_([_0-9]+)/([\-0-9]+)")
for cat in tqdm(self.categories):
for cat in tqdm(self.categories, desc="Process categories"):
real_matcher = re.compile(rf"({cat}/(.{{6}})/.{{10}})/hand_param\.pkl$")
virtual_matcher = re.compile(rf"({cat}/(.{{6}})/.{{10}})/(.{{6}})/hand_param\.pkl$")
path = os.path.join(oi_shape_dir, cat)
Expand Down Expand Up @@ -120,7 +144,7 @@ def __init__(
batch_hand_pose = []
batch_hand_shape = []
batch_hand_tsl = []
for g in tqdm(grasp_list):
for _, g in enumerate(grasp_list):
batch_hand_pose.append(g["hand_pose"])
batch_hand_shape.append(g["hand_shape"])
batch_hand_tsl.append(g["hand_tsl"])
Expand All @@ -139,7 +163,7 @@ def __init__(

# * >>>> handle handover
if "handover" in self.intent_mode:
for i, g in tqdm(enumerate(grasp_list)):
for i, g in tqdm(enumerate(grasp_list), total=len(grasp_list), desc="Process handover grasp"):
if g["subject_alt_id"] is None:
continue
for bidx in category_begin_idx:
Expand All @@ -166,30 +190,33 @@ def __init__(
suppress_trimesh_logging()
self.obj_warehouse = {}
obj_id_set = {g["obj_id"] for g in grasp_list}
for oid in tqdm(obj_id_set):
obj_trimesh = trimesh.load(get_obj_path(oid, data_dir, meta_dir),
for oid in tqdm(obj_id_set, desc="Load obj model"):
obj_trimesh = trimesh.load(get_obj_path(oid, data_dir, meta_dir, use_downsample=True),
process=False,
force="mesh",
skip_materials=True)
bbox_center = (obj_trimesh.vertices.min(0) + obj_trimesh.vertices.max(0)) / 2
obj_trimesh.vertices = obj_trimesh.vertices - bbox_center

obj_holder = {
"verts": np.asfarray(obj_trimesh.vertices, dtype=np.float32), # V, in object canonical space
"faces": obj_trimesh.faces.astype(np.int32), # F, paired with V
}
self.obj_warehouse[oid] = obj_holder
self.obj_warehouse[oid] = obj_trimesh

self.grasp_list = grasp_list
self.data_dir = data_dir
self.meta_dir = meta_dir

if use_cache is True:
cache = {"grasp_list": self.grasp_list, "obj_warehouse": self.obj_warehouse}
with open(cache_path, "wb") as f:
pickle.dump(cache, f)
print(f"{self.name} cache saved to {cache_path}")

def __len__(self):
return len(self.grasp_list)

def __getitem__(self, idx):
grasp = self.grasp_list[idx]
obj_holder = self.obj_warehouse[grasp["obj_id"]]
grasp["obj_verts"] = obj_holder["verts"]
grasp["obj_faces"] = obj_holder["faces"]
obj_mesh = self.obj_warehouse[grasp["obj_id"]]
grasp["obj_verts"] = obj_mesh.vertices.astype(np.float32)
grasp["obj_faces"] = obj_mesh.faces.astype(np.int32)
if grasp["action_id"] == "0004":
joints, verts, pose, shape, tsl = self.get_hand_over(idx)
grasp["alt_joints"] = joints
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get_dep():

setup(
name="oikit",
version="1.0.0",
version="1.0.1",
author="Lixin Yang",
author_email="siriusyang@sjtu.edu.cn",
description="OakInk tooKIT",
Expand Down

0 comments on commit 3a0ec46

Please # to comment.