From 0c228288dfe258475b6613a8991a5836704d0b57 Mon Sep 17 00:00:00 2001
From: Ale <alessandroassirell98@gmail.com>
Date: Mon, 8 Aug 2022 15:36:17 +0200
Subject: [PATCH] editing for whole body

---
 .../quadruped_reactive_walking/Controller.py  | 18 +++++------
 .../WB_MPC/ProblemData.py                     | 31 +++++++++----------
 .../tools/LoggerControl.py                    |  2 +-
 3 files changed, 24 insertions(+), 27 deletions(-)

diff --git a/python/quadruped_reactive_walking/Controller.py b/python/quadruped_reactive_walking/Controller.py
index 83c259b8..1b1b1315 100644
--- a/python/quadruped_reactive_walking/Controller.py
+++ b/python/quadruped_reactive_walking/Controller.py
@@ -248,7 +248,7 @@ class Controller:
                 self.save_guess()
 
             self.result.FF = self.params.Kff_main * np.ones(12)
-            self.result.tau_ff[3:6] = self.compute_torque(m)[:]
+            self.result.tau_ff = self.compute_torque(m)[:]
 
             # if self.params.interpolate_mpc:
             #     if self.mpc_result.new_result:
@@ -260,11 +260,11 @@ class Controller:
             #     q, v = self.interpolator.interpolate(t)
             # else:
             #     q, v = self.integrate_x(m)
-            q = xs[1][:3]
-            v = xs[1][3:]
+            q = xs[1][: self.pd.nq]
+            v = xs[1][self.pd.nq :]
 
-            self.result.q_des[3:6] = q[:]
-            self.result.v_des[3:6] = v[:]
+            self.result.q_des = q[:]
+            self.result.v_des = v[:]
 
             if self.axs is not None:
                 [
@@ -407,7 +407,7 @@ class Controller:
     def read_state(self, device):
         qj_m = device.joints.positions
         vj_m = device.joints.velocities
-        x_m = np.concatenate([qj_m[3:6], vj_m[3:6]])
+        x_m = np.concatenate([qj_m, vj_m])
         return {"qj_m": qj_m, "vj_m": vj_m, "x_m": x_m}
 
     def compute_torque(self, m):
@@ -424,9 +424,9 @@ class Controller:
         #         m["x_m"][self.pd.nq :] - self.mpc_result.xs[0][self.pd.nq :],
         #     ]
         # )
-        # x_diff = self.mpc_result.xs[0] - m["x_m"]
-        # tau = self.mpc_result.us[0] + np.dot(self.mpc_result.K[0], x_diff)
-        tau = self.mpc_result.us[0]
+        x_diff = self.mpc_result.xs[0] - m["x_m"]
+        tau = self.mpc_result.us[0] + np.dot(self.mpc_result.K[0], x_diff)
+        # tau = self.mpc_result.us[0]
         return tau
 
     def integrate_x(self, m):
diff --git a/python/quadruped_reactive_walking/WB_MPC/ProblemData.py b/python/quadruped_reactive_walking/WB_MPC/ProblemData.py
index 1ec400f9..60c2979b 100644
--- a/python/quadruped_reactive_walking/WB_MPC/ProblemData.py
+++ b/python/quadruped_reactive_walking/WB_MPC/ProblemData.py
@@ -23,7 +23,8 @@ class problemDataAbstract:
 
         self.frozen_names = frozen_names
         if frozen_names:
-            self.frozen_idxs = [self.model.getJointId(id) for id in frozen_names]
+            self.frozen_idxs = [self.model.getJointId(
+                id) for id in frozen_names]
             self.freeze()
 
         self.nq = self.model.nq
@@ -162,17 +163,7 @@ class ProblemData(problemDataAbstract):
 class ProblemDataFull(problemDataAbstract):
     def __init__(self, param):
         frozen_names = [
-            "root_joint",
-            "FL_HAA",
-            "FL_HFE",
-            "FL_KFE",
-            "HL_HAA",
-            "HL_HFE",
-            "HL_KFE",
-            "HR_HAA",
-            "HR_HFE",
-            "HR_KFE",
-        ]
+            "root_joint"]
 
         super().__init__(param, frozen_names)
 
@@ -185,11 +176,17 @@ class ProblemDataFull(problemDataAbstract):
         # self.friction_cone_w = 1e3 * 0
         self.control_bound_w = 1e3
         self.control_reg_w = 1e0
-        self.state_reg_w = np.array([1e-5] * 3 + [1e0] * 3)
-        self.terminal_velocity_w = np.array([0] * 3 + [1e3] * 3)
-
-        self.q0_reduced = self.q0[10:13]
-        self.v0_reduced = np.zeros(self.nq)
+        self.state_reg_w = np.array([1e0] * 3
+                                    + [1e-5] * 3
+                                    + [1e0] * 6
+                                    + [1e1] * 3
+                                    + [1e0] * 3
+                                    + [1e1] * 6
+                                    )
+        self.terminal_velocity_w = np.array([0] * 12 + [1e3] * 12)
+
+        self.q0_reduced = self.q0[7:]
+        self.v0_reduced = np.zeros(self.nv)
         self.x0_reduced = np.concatenate([self.q0_reduced, self.v0_reduced])
 
         self.xref = self.x0_reduced
diff --git a/python/quadruped_reactive_walking/tools/LoggerControl.py b/python/quadruped_reactive_walking/tools/LoggerControl.py
index ce900e6c..06a6cc21 100644
--- a/python/quadruped_reactive_walking/tools/LoggerControl.py
+++ b/python/quadruped_reactive_walking/tools/LoggerControl.py
@@ -209,7 +209,7 @@ class LoggerControl:
     def plot_target(self, save=False, fileName="/tmp"):
         import matplotlib.pyplot as plt
 
-        x_mes = np.concatenate([self.q_mes[:, 3:6], self.v_mes[:, 3:6]], axis=1)
+        x_mes = np.concatenate([self.q_mes, self.v_mes], axis=1)
 
         horizon = int(self.ocp_xs.shape[0] / self.pd.mpc_wbc_ratio)
         t_scale = np.linspace(
-- 
GitLab