Skip to content

Commit

Permalink
Feature: MJX convert functions (#8)
Browse files Browse the repository at this point in the history
* dump

* feat: tuned optax to solve logchol

* dump

* fix: remove WIP examples

* fix: bump version
  • Loading branch information
lvjonok authored Sep 15, 2024
1 parent d1000da commit 2c64beb
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 2 deletions.
133 changes: 133 additions & 0 deletions mujoco_sysid/mjx/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import jax.numpy as np


def theta2pseudo(theta: np.ndarray) -> np.ndarray:
m = theta[0]
h = theta[1:4]
I_xx, I_xy, I_yy, I_xz, I_yz, I_zz = theta[4:]

I_bar = np.array([[I_xx, I_xy, I_xz], [I_xy, I_yy, I_yz], [I_xz, I_yz, I_zz]])

Sigma = 0.5 * np.trace(I_bar) * np.eye(3) - I_bar

pseudo_inertia = np.zeros((4, 4))
pseudo_inertia = pseudo_inertia.at[:3, :3].set(Sigma)
pseudo_inertia = pseudo_inertia.at[:3, 3].set(h)
pseudo_inertia = pseudo_inertia.at[3, :3].set(h)
pseudo_inertia = pseudo_inertia.at[3, 3].set(m)

return pseudo_inertia


def pseudo2theta(pseudo_inertia: np.ndarray) -> np.ndarray:
m = pseudo_inertia[3, 3]
h = pseudo_inertia[:3, 3]
Sigma = pseudo_inertia[:3, :3]

I_bar = np.trace(Sigma) * np.eye(3) - Sigma

I_xx = I_bar[0, 0]
I_xy = I_bar[0, 1]
I_yy = I_bar[1, 1]
I_xz = I_bar[0, 2]
I_yz = I_bar[1, 2]
I_zz = I_bar[2, 2]

theta = np.array([m, h[0], h[1], h[2], I_xx, I_xy, I_yy, I_xz, I_yz, I_zz])

return theta


def logchol2chol(log_cholesky):
alpha, d1, d2, d3, s12, s23, s13, t1, t2, t3 = log_cholesky

exp_alpha = np.exp(alpha)
exp_d1 = np.exp(d1)
exp_d2 = np.exp(d2)
exp_d3 = np.exp(d3)

U = np.zeros((4, 4))
U = U.at[0, 0].set(exp_d1)
U = U.at[0, 1].set(s12)
U = U.at[0, 2].set(s13)
U = U.at[0, 3].set(t1)
U = U.at[1, 1].set(exp_d2)
U = U.at[1, 2].set(s23)
U = U.at[1, 3].set(t2)
U = U.at[2, 2].set(exp_d3)
U = U.at[2, 3].set(t3)
U = U.at[3, 3].set(1)

U *= exp_alpha

return U


def chol2logchol(U: np.ndarray) -> np.ndarray:
alpha = np.log(U[3, 3])
d1 = np.log(U[0, 0] / U[3, 3])
d2 = np.log(U[1, 1] / U[3, 3])
d3 = np.log(U[2, 2] / U[3, 3])
s12 = U[0, 1] / U[3, 3]
s23 = U[1, 2] / U[3, 3]
s13 = U[0, 2] / U[3, 3]
t1 = U[0, 3] / U[3, 3]
t2 = U[1, 3] / U[3, 3]
t3 = U[2, 3] / U[3, 3]
return np.array([alpha, d1, d2, d3, s12, s23, s13, t1, t2, t3])


def logchol2theta(log_cholesky: np.ndarray) -> np.ndarray:
alpha, d1, d2, d3, s12, s23, s13, t1, t2, t3 = log_cholesky

exp_d1 = np.exp(d1)
exp_d2 = np.exp(d2)
exp_d3 = np.exp(d3)

theta = np.array(
[
1,
t1,
t2,
t3,
s23**2 + t2**2 + t3**2 + exp_d2**2 + exp_d3**2,
-s12 * exp_d2 - s13 * s23 - t1 * t2,
s12**2 + s13**2 + t1**2 + t3**2 + exp_d1**2 + exp_d3**2,
-s13 * exp_d3 - t1 * t3,
-s23 * exp_d3 - t2 * t3,
s12**2 + s13**2 + s23**2 + t1**2 + t2**2 + exp_d1**2 + exp_d2**2,
]
)

exp_2_alpha = np.exp(2 * alpha)
theta *= exp_2_alpha

return theta


def pseudo2cholesky(pseudo_inertia: np.ndarray) -> np.ndarray:
n = pseudo_inertia.shape[0]
indices = np.arange(n - 1, -1, -1)

reversed_inertia = pseudo_inertia[indices][:, indices]

L_prime = np.linalg.cholesky(reversed_inertia)

U = L_prime[indices][:, indices]

return U


def cholesky2pseudo(U: np.ndarray) -> np.ndarray:
return U @ U.T


def pseudo2logchol(pseudo_inertia: np.ndarray) -> np.ndarray:
U = pseudo2cholesky(pseudo_inertia)
logchol = chol2logchol(U)
return logchol


def theta2logchol(theta: np.ndarray) -> np.ndarray:
pseudo_inertia = theta2pseudo(theta)
return pseudo2logchol(pseudo_inertia)
2 changes: 1 addition & 1 deletion mujoco_sysid/regressors.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def joint_body_regressor(mj_model, mj_data, body_id) -> npt.ArrayLike:
def get_jacobian(mjmodel, mjdata, bodyid):
R = mjdata.xmat[bodyid].reshape(3, 3)

jac_lin, jac_rot = np.zeros((3, 6)), np.zeros((3, 6))
jac_lin, jac_rot = np.zeros((3, mjmodel.nv)), np.zeros((3, mjmodel.nv))
mujoco.mj_jacBody(mjmodel, mjdata, jac_lin, jac_rot, bodyid)

return np.vstack([R.T @ jac_lin, R.T @ jac_rot])
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "mujoco_sysid"
description = "MuJoCo System Identification tools"
version = "0.2.0"
version = "0.2.1"
authors = [
{ name = "Lev Kozlov", email = "kozlov.l.a10@gmail.com" },
{ name = "Simeon Nedelchev", email = "simkaned@gmail.com" },
Expand Down
6 changes: 6 additions & 0 deletions tests/test_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np

from mujoco_sysid import regressors
from mujoco_sysid import parameters
from mujoco_sysid.utils import muj2pin

np.random.seed(0)
Expand Down Expand Up @@ -104,6 +105,8 @@ def test_joint_torque_regressor():

SAMPLES = 10000

theta = np.concatenate([parameters.get_dynamic_parameters(mjmodel, i) for i in mjmodel.jnt_bodyid])

for _ in range(SAMPLES):
q, v, dv = np.random.rand(pinmodel.nq), np.random.rand(pinmodel.nv), np.random.rand(pinmodel.nv)
pin.rnea(pinmodel, pindata, q, v, dv)
Expand All @@ -117,6 +120,9 @@ def test_joint_torque_regressor():
pinY = pin.computeJointTorqueRegressor(pinmodel, pindata, q, v, dv)
mjY = regressors.joint_torque_regressor(mjmodel, mjdata)

tau = pin.rnea(pinmodel, pindata, q, v, dv)

assert np.allclose(mjY @ theta, tau, atol=1e-6), f"Norm diff: {np.linalg.norm(mjY @ theta - tau)}"
assert np.allclose(mjY, pinY, atol=1e-6), f"Norm diff: {np.linalg.norm(mjY - pinY)}"


Expand Down

0 comments on commit 2c64beb

Please # to comment.