-
Notifications
You must be signed in to change notification settings - Fork 98
/
Copy pathHumanML3D.py
127 lines (107 loc) · 5.35 KB
/
HumanML3D.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import numpy as np
import torch
import os
from os.path import join as pjoin
from .humanml.utils.word_vectorizer import WordVectorizer
from .humanml.scripts.motion_process import (process_file, recover_from_ric)
from . import BASEDataModule
from .humanml import Text2MotionDatasetEval, Text2MotionDataset, Text2MotionDatasetCB, MotionDataset, MotionDatasetVQ, Text2MotionDatasetToken, Text2MotionDatasetM2T
from .utils import humanml3d_collate
class HumanML3DDataModule(BASEDataModule):
def __init__(self, cfg, **kwargs):
super().__init__(collate_fn=humanml3d_collate)
self.cfg = cfg
self.save_hyperparameters(logger=False)
# Basic info of the dataset
cfg.DATASET.JOINT_TYPE = 'humanml3d'
self.name = "humanml3d"
self.njoints = 22
# Path to the dataset
data_root = cfg.DATASET.HUMANML3D.ROOT
self.hparams.data_root = data_root
self.hparams.text_dir = pjoin(data_root, "texts")
self.hparams.motion_dir = pjoin(data_root, 'new_joint_vecs')
# Mean and std of the dataset
dis_data_root = pjoin(cfg.DATASET.HUMANML3D.MEAN_STD_PATH, 't2m', "VQVAEV3_CB1024_CMT_H1024_NRES3", "meta")
self.hparams.mean = np.load(pjoin(dis_data_root, "mean.npy"))
self.hparams.std = np.load(pjoin(dis_data_root, "std.npy"))
# Mean and std for fair evaluation
dis_data_root_eval = pjoin(cfg.DATASET.HUMANML3D.MEAN_STD_PATH, 't2m', "Comp_v6_KLD01", "meta")
self.hparams.mean_eval = np.load(pjoin(dis_data_root_eval, "mean.npy"))
self.hparams.std_eval = np.load(pjoin(dis_data_root_eval, "std.npy"))
# Length of the dataset
self.hparams.max_motion_length = cfg.DATASET.HUMANML3D.MAX_MOTION_LEN
self.hparams.min_motion_length = cfg.DATASET.HUMANML3D.MIN_MOTION_LEN
self.hparams.max_text_len = cfg.DATASET.HUMANML3D.MAX_TEXT_LEN
self.hparams.unit_length = cfg.DATASET.HUMANML3D.UNIT_LEN
# Additional parameters
self.hparams.debug = cfg.DEBUG
self.hparams.stage = cfg.TRAIN.STAGE
self.hparams.w_vectorizer = WordVectorizer(
cfg.DATASET.WORD_VERTILIZER_PATH, "our_vab")
# Dataset switch
self.DatasetEval = Text2MotionDatasetEval
if cfg.TRAIN.STAGE == "vae":
if cfg.model.params.motion_vae.target.split('.')[-1].lower() == "vqvae":
self.hparams.win_size = 64
self.Dataset = MotionDatasetVQ
else:
self.Dataset = MotionDataset
elif 'lm' in cfg.TRAIN.STAGE:
self.hparams.code_path = cfg.DATASET.CODE_PATH
self.hparams.task_path = cfg.DATASET.TASK_PATH
self.hparams.std_text = cfg.DATASET.HUMANML3D.STD_TEXT
self.Dataset = Text2MotionDatasetCB
elif cfg.TRAIN.STAGE == "token":
self.Dataset = Text2MotionDatasetToken
self.DatasetEval = Text2MotionDatasetToken
elif cfg.TRAIN.STAGE == "m2t":
self.Dataset = Text2MotionDatasetM2T
self.DatasetEval = Text2MotionDatasetM2T
else:
self.Dataset = Text2MotionDataset
# Get additional info of the dataset
self._sample_set = self.get_sample_set(overrides={"split": "test", "tiny": True})
self.nfeats = self._sample_set.nfeats
cfg.DATASET.NFEATS = self.nfeats
def feats2joints(self, features):
mean = torch.tensor(self.hparams.mean).to(features)
std = torch.tensor(self.hparams.std).to(features)
features = features * std + mean
return recover_from_ric(features, self.njoints)
def joints2feats(self, features):
example_data = np.load(os.path.join(self.hparams.data_root, 'joints', '000021.npy'))
example_data = example_data.reshape(len(example_data), -1, 3)
example_data = torch.from_numpy(example_data)
features = process_file(features, self.njoints, example_data, 't2m')[0]
return features
def normalize(self, features):
mean = torch.tensor(self.hparams.mean).to(features)
std = torch.tensor(self.hparams.std).to(features)
features = (features - mean) / std
return features
def denormalize(self, features):
mean = torch.tensor(self.hparams.mean).to(features)
std = torch.tensor(self.hparams.std).to(features)
features = features * std + mean
return features
def renorm4t2m(self, features):
# renorm to t2m norms for using t2m evaluators
ori_mean = torch.tensor(self.hparams.mean).to(features)
ori_std = torch.tensor(self.hparams.std).to(features)
eval_mean = torch.tensor(self.hparams.mean_eval).to(features)
eval_std = torch.tensor(self.hparams.std_eval).to(features)
features = features * ori_std + ori_mean
features = (features - eval_mean) / eval_std
return features
def mm_mode(self, mm_on=True):
if mm_on:
self.is_mm = True
self.name_list = self.test_dataset.name_list
self.mm_list = np.random.choice(self.name_list,
self.cfg.METRIC.MM_NUM_SAMPLES,
replace=False)
self.test_dataset.name_list = self.mm_list
else:
self.is_mm = False
self.test_dataset.name_list = self.name_list