From 038302c0a8488d3fe6747f4c5b5659a7a96a3a69 Mon Sep 17 00:00:00 2001
From: Fanny Risbourg <frisbourg@laas.fr>
Date: Wed, 3 Aug 2022 17:09:42 +0200
Subject: [PATCH] clean times

---
 config/walk_parameters.yaml                   |   7 +-
 include/qrw/Params.hpp                        |   3 +-
 python/Params.cpp                             |   1 +
 .../quadruped_reactive_walking/Controller.py  | 210 ++++++++----------
 .../WB_MPC/CrocoddylOCP.py                    |   6 +-
 .../WB_MPC/ProblemData.py                     | 160 +++++++++----
 .../WB_MPC_Wrapper.py                         |  25 +--
 .../tools/LoggerControl.py                    |  44 ++--
 src/Params.cpp                                |   4 +
 9 files changed, 253 insertions(+), 207 deletions(-)

diff --git a/config/walk_parameters.yaml b/config/walk_parameters.yaml
index 53bd57e3..5a67c1fd 100644
--- a/config/walk_parameters.yaml
+++ b/config/walk_parameters.yaml
@@ -7,13 +7,13 @@ robot:
     PLOTTING: true  # Enable/disable automatic plotting at the end of the experiment
     DEMONSTRATION: false  # Enable/disable demonstration functionalities
     SIMULATION: true  # Enable/disable PyBullet simulation or running on real robot
-    enable_pyb_GUI: true  # Enable/disable PyBullet GUI
+    enable_pyb_GUI: false  # Enable/disable PyBullet GUI
     envID: 0  # Identifier of the environment to choose in which one the simulation will happen
     use_flat_plane: true  # If True the ground is flat, otherwise it has bumps
     predefined_vel: true  # If we are using a predefined reference velocity (True) or a joystick (False)
     N_SIMULATION: 5000  # Number of simulated wbc time steps
     enable_corba_viewer: false  # Enable/disable Corba Viewer
-    enable_multiprocessing: true  # Enable/disable running the MPC in another process in parallel of the main loop
+    enable_multiprocessing: false  # Enable/disable running the MPC in another process in parallel of the main loop
     perfect_estimator: true  # Enable/disable perfect estimator by using data directly from PyBullet
 
     # General control parameters
@@ -22,8 +22,9 @@ robot:
     # q_init: [0.0, 0.764, -1.407, 0.0, 0.76407, -1.4, 0.0, 0.76407, -1.407, 0.0, 0.764, -1.407]  # h_com = 0.218
     q_init: [0.0, 0.7, -1.4, 0.0, 0.7, -1.4, 0.0, -0.7, 1.4, 0.0, -0.7, 1.4]  # Initial articular positions
     dt_wbc: 0.001  # Time step of the whole body control
-    dt_mpc: 0.01  # Time step of the model predictive control
+    dt_mpc: 0.001  # Time step of the model predictive control
     type_MPC: 3  # Which MPC solver you want to use: 0 for OSQP MPC, 1, 2, 3 for Crocoddyl MPCs
+    save_guess: false # true to interpolate the impedance quantities between nodes of the MPC
     interpolate_mpc: true # true to interpolate the impedance quantities between nodes of the MPC
     interpolation_type: 0 # 0,1,2,3 decide which kind of interpolation is used
 #     Kp_main: [0.0, 0.0, 0.0]  # Proportional gains for the PD+
diff --git a/include/qrw/Params.hpp b/include/qrw/Params.hpp
index f594cfcc..f0366016 100644
--- a/include/qrw/Params.hpp
+++ b/include/qrw/Params.hpp
@@ -92,8 +92,9 @@ class Params {
   double dt_mpc;                // Time step of the model predictive control
   int N_periods;                // Number of gait periods in the MPC prediction horizon
   int type_MPC;                 // Which MPC solver you want to use: 0 for OSQP MPC, 1, 2, 3 for Crocoddyl MPCs
+  bool save_guess;              // true to save the initial result of the mpc
   bool interpolate_mpc;         // true to interpolate the impedance quantities, otherwise integrate
-  int interpolation_type;         // true to interpolate the impedance quantities, otherwise integrate
+  int interpolation_type;       // type of interpolation used
   bool kf_enabled;              // Use complementary filter (False) or kalman filter (True) for the estimator
   std::vector<double> Kp_main;  // Proportional gains for the PD+
   std::vector<double> Kd_main;  // Derivative gains for the PD+
diff --git a/python/Params.cpp b/python/Params.cpp
index 33b0327a..0a7b5cd8 100644
--- a/python/Params.cpp
+++ b/python/Params.cpp
@@ -26,6 +26,7 @@ struct ParamsVisitor : public bp::def_visitor<ParamsVisitor<Params>> {
         .def_readwrite("type_MPC", &Params::type_MPC)
         .def_readwrite("use_flat_plane", &Params::use_flat_plane)
         .def_readwrite("predefined_vel", &Params::predefined_vel)
+        .def_readwrite("save_guess", &Params::save_guess)
         .def_readwrite("interpolate_mpc", &Params::interpolate_mpc)
         .def_readwrite("interpolation_type", &Params::interpolation_type)
         .def_readwrite("kf_enabled", &Params::kf_enabled)
diff --git a/python/quadruped_reactive_walking/Controller.py b/python/quadruped_reactive_walking/Controller.py
index fe705107..6fe861d9 100644
--- a/python/quadruped_reactive_walking/Controller.py
+++ b/python/quadruped_reactive_walking/Controller.py
@@ -23,81 +23,75 @@ class Result:
         self.tau_ff = np.zeros(12)
 
 
-class Interpolation:
+class Interpolator:
     def __init__(self, params):
-        self.params = params
-
-    def load_data(self, q, v):
-        self.v0 = v[0, :]
-        self.q0 = q[0, :]
-        self.v1 = v[1, :]
-        self.q1 = q[1, :]
+        self.dt = params.dt_mpc
+        self.type = params.interpolation_type
+
+        self.q0 = np.zeros(3)
+        self.q1 = np.zeros(3)
+        self.v0 = np.zeros(3)
+        self.v1 = np.zeros(3)
+        self.alpha = np.zeros(3)
+        self.beta = np.zeros(3)
+        self.gamma = np.zeros(3)
+        self.delta = 1.0
+
+    def update(self, x0, x1):
+        self.q0 = x0[:3]
+        self.q1 = x1[:3]
+        self.v0 = x0[3:]
+        self.v1 = x1[3:]
+        if self.type == 0:  # Linear
+            self.alpha = 0.0
+            self.beta = self.v1
+            self.gamma = self.q0
+        elif self.type == 1:  # Quadratic fixed velocity
+            self.alpha = 2 * (self.q1 - self.q0 - self.v0 * self.dt) / self.dt**2
+            self.beta = self.v0
+            self.gamma = self.q0
+        elif self.type == 2:  # Quadratic time variable
+            for i in range(3):
+                q0 = self.q0[i]
+                v0 = self.v0[i]
+                q1 = self.q1[i]
+                v1 = self.v1[i]
+                if (q1 == q0) or (v1 == -v0):
+                    self.alpha[i] = 0.0
+                    self.beta[i] = 0.0
+                    self.gamma[i] = q1
+                    self.delta = 1.0
+                else:
+                    self.alpha[i] = (v1**2 - v0**2) / (2 * (q1 - q0))
+                    self.beta[i] = v0
+                    self.gamma[i] = q0
+                    self.delta = 2 * (q1 - q0) / (v1 + v0) / self.dt
 
     def interpolate(self, t):
-        # Linear
-        if self.params.interpolation_type == 0:
-            beta = self.v1
-            gamma = self.q0
-
-            v_t = beta
-            q_t = gamma + beta * t
-
-        # Linear Wrong
-        if self.params.interpolation_type == 1:
-            beta = self.v1
-            gamma = self.q0
-            alpha = 1/2 * self.v1 * (self.v1 - self.v0) / (self.q1 - self.q0)
-
-            v_t = beta
-            q_t = gamma + beta * t
-
-        # Perfect match, but wrong
-        if self.params.interpolation_type == 2:
-            if (self.q1 - self.q0 == 0).any():
-                alpha = np.zeros(len(self.q0))
-            else:
-                alpha = (self.v1**2 - self.v0**2) / (self.q1 - self.q0)
-
-            beta = self.v0
-            gamma = self.q0
-
-            v_t = beta + alpha * t
-            q_t = gamma + beta * t + alpha * t**2
-
-        # Quadratic
-        if self.params.interpolation_type == 3:
-            if (self.q1 - self.q0 == 0).any():
-                alpha = np.zeros(len(self.q0))
-            else:
-                alpha = self.v1 * (self.v1 - self.v0) / (self.q1 - self.q0)
-
-            beta = self.v0
-            gamma = self.q0
+        if self.type == 2:
+            t *= self.delta
+        q = 1 / 2 * self.alpha * t**2 + self.beta * t + self.gamma
+        v = self.v1 if self.type == 2 else self.alpha * t + self.beta
 
-            v_t = beta + alpha * t
-            q_t = gamma + beta * t + 1 / 2 * alpha * t**2
-
-        return q_t, v_t
+        return q, v
 
-    def plot_interpolation(self, n, dt):
+    def plot(self, n, dt):
         import matplotlib.pyplot as plt
 
+        t = np.linspace(0.0, 2 * self.dt, 2 * n + 1)
         plt.style.use("seaborn")
-        t = np.linspace(0, n * dt, n + 1)
-        q_t = np.array([self.interpolate((i) * dt)[0] for i in range(n + 1)])
-        v_t = np.array([self.interpolate((i) * dt)[1] for i in range(n + 1)])
         for i in range(3):
             plt.subplot(3, 2, (i * 2) + 1)
             plt.title("Position interpolation")
-            plt.plot(t, q_t[:, i])
-            plt.scatter(y=self.q0[i], x=t[0], color="violet", marker="+")
-            plt.scatter(y=self.q1[i], x=t[-1], color="violet", marker="+")
+            plt.plot(t, [self.compute_q(t * dt / n)[i] for t in range(n + 1)])
+            plt.scatter(y=self.q0[i], x=0.0, color="violet", marker="+")
+            plt.scatter(y=self.q1[i], x=self.dt, color="violet", marker="+")
 
             plt.subplot(3, 2, (i * 2) + 2)
             plt.title("Velocity interpolation")
-            plt.plot(t, v_t[:, i])
-            plt.scatter(y=self.v0[i], x=t[0], color="violet", marker="+")
-            plt.scatter(y=self.v1[i], x=t[-1], color="violet", marker="+")
+            plt.plot(t, [self.compute_v(t * dt / n)[i] for t in range(n + 1)])
+            plt.scatter(y=self.v0[i], x=0.0, color="violet", marker="+")
+            plt.scatter(y=self.v1[i], x=self.dt, color="violet", marker="+")
 
         plt.show()
 
@@ -127,8 +121,8 @@ class DummyDevice:
 
 class Controller:
     def __init__(self, pd, target, params, q_init, t):
-        """Function that runs a simulation scenario based on a reference velocity profile, an environment and
-        various parameters to define the gait
+        """
+        Function that computes the reference control (tau, q_des, v_des and gains)
 
         Args:
             params (Params object): store parameters
@@ -137,23 +131,25 @@ class Controller:
         """
         self.q_security = np.array([1.2, 2.1, 3.14] * 4)
 
-        self.mpc = WB_MPC_Wrapper.MPC_Wrapper(pd, target, params)
         self.pd = pd
         self.target = target
-        self.point_target = []
         self.params = params
         self.q_init = pd.q0
 
         self.k = 0
-        self.cnt_mpc = 0
-        self.cnt_wbc = 0
         self.error = False
         self.initialized = False
-        self.first_step = False
-        self.interpolator = Interpolation(params)
+
         self.result = Result(params)
         self.result.q_des = self.pd.q0[7:].copy()
         self.result.v_des = self.pd.v0[6:].copy()
+
+        self.mpc = WB_MPC_Wrapper.MPC_Wrapper(pd, target, params)
+        self.mpc_solved = False
+        self.k_result = 0
+        self.k_solve = 0
+        if self.params.interpolate_mpc:
+            self.interpolator = Interpolator(params)
         try:
             file = np.load("/tmp/init_guess.npy", allow_pickle=True).item()
             self.xs_init = list(file["xs"])
@@ -178,21 +174,18 @@ class Controller:
         t_start = time.time()
 
         m = self.read_state(device)
-
         t_measures = time.time()
         self.t_measures = t_measures - t_start
 
-        self.point_target = self.target.evaluate_in_t(1)[self.pd.rfFootId]
-        if self.k % int(self.params.dt_mpc / self.params.dt_wbc) == 0:
-            try:
-                self.target.update(self.cnt_mpc)
-                self.target.shift_gait()
-                if not self.params.enable_multiprocessing:
-                    self.cnt_wbc = 0
+        if self.k % self.pd.mpc_wbc_ratio == 0:
+            if self.mpc_solved:
+                self.k_solve = self.k
+                self.mpc_solved = False
 
+            self.target.update(self.k // self.pd.mpc_wbc_ratio)
+            self.target.shift_gait()
+            try:
                 self.mpc.solve(self.k, m["x_m"], self.xs_init, self.us_init)
-
-                self.cnt_mpc += 1
             except ValueError:
                 self.error = True
                 print("MPC Problem")
@@ -202,41 +195,24 @@ class Controller:
 
         if not self.error:
             self.mpc_result = self.mpc.get_latest_result()
-            if self.params.enable_multiprocessing:
-                if self.mpc_result.new_result:
-                    print("new result! at iter: ", str(self.cnt_wbc))
-                    self.cnt_wbc = 0
-
-            print(
-                "MPC iter: ",
-                self.cnt_mpc,
-                " / Counter value: ",
-                self.cnt_wbc,
-                " / k value: ",
-                self.k,
-            )
-            # ## ONLY IF YOU WANT TO STORE THE FIRST SOLUTION TO WARM-START THE INITIAL Problem ###
-            # if not self.initialized:
-            #   np.save(open('/tmp/init_guess.npy', "wb"), {"xs": self.mpc_result.xs, "us": self.mpc_result.us} )
-            #   print("Initial guess saved")
+            if self.mpc_result.new_result:
+                self.mpc_solved = True
+                self.k_new = self.k
+                print(f"MPC solved in {self.k - self.k_solve} iterations")
 
-            # Keep only the actuated joints and set the other to default values
-            self.result.FF = self.params.Kff_main * np.ones(12)
+            if not self.initialized and self.params.save_guess:
+                self.save_guess()
 
-            actuated_tau_ff = self.compute_torque(m)
-            self.result.tau_ff = np.array([0] * 3 + list(actuated_tau_ff) + [0] * 6)
+            self.result.FF = self.params.Kff_main * np.ones(12)
+            self.result.tau_ff[3:6] = self.compute_torque(m)[:]
 
             if self.params.interpolate_mpc:
-                # load the data to be interpolated only once per mpc solution
-                if self.cnt_wbc == 0:
-                    x = np.array(self.mpc_result.xs)
-                    self.interpolator.load_data(x[:, : self.pd.nq], x[:, self.pd.nq :])
-
-                q, v = self.interpolator.interpolate(
-                    (self.k % int(self.params.dt_mpc / self.params.dt_wbc) + self.cnt_wbc + 1) * self.pd.dt_wbc
-                )
+                if self.mpc_result.new_result:
+                    self.interpolator.update(self.mpc_result.xs[0], self.mpc_result.xs[1])
+                    # self.interpolator.plot(self.pd.mpc_wbc_ratio, self.pd.dt_wbc)
 
-                # self.interpolator.plot_interpolation(self.pd.r1, self.pd.dt_wbc)
+                t = (self.k - self.k_solve + 1) * self.pd.dt_wbc
+                q, v = self.interpolator.interpolate(t)
             else:
                 q, v = self.integrate_x(m)
 
@@ -259,7 +235,6 @@ class Controller:
 
         self.t_loop = time.time() - t_start
         self.k += 1
-        self.cnt_wbc += 1
         self.initialized = True
 
         return self.error
@@ -362,11 +337,20 @@ class Controller:
         self.result.v_des[:] = np.zeros(12)
         self.result.tau_ff[:] = np.zeros(12)
 
+    def save_guess(self):
+        """
+        Save the result of the MPC in a file called /tmp/init_guess.npy
+        """
+        np.save(
+            open("/tmp/init_guess.npy", "wb"),
+            {"xs": self.mpc_result.xs, "us": self.mpc_result.us},
+        )
+        print("Initial guess saved")
+
     def read_state(self, device):
         qj_m = device.joints.positions
         vj_m = device.joints.velocities
         x_m = np.concatenate([qj_m[3:6], vj_m[3:6]])
-
         return {"qj_m": qj_m, "vj_m": vj_m, "x_m": x_m}
 
     def compute_torque(self, m):
@@ -402,7 +386,3 @@ class Controller:
         q = q0 + v * self.params.dt_wbc
 
         return q, v
-
-    def tuple_to_array(self, tup):
-        a = np.array([element for tupl in tup for element in tupl])
-        return a
diff --git a/python/quadruped_reactive_walking/WB_MPC/CrocoddylOCP.py b/python/quadruped_reactive_walking/WB_MPC/CrocoddylOCP.py
index 07eccab1..b8c054ca 100644
--- a/python/quadruped_reactive_walking/WB_MPC/CrocoddylOCP.py
+++ b/python/quadruped_reactive_walking/WB_MPC/CrocoddylOCP.py
@@ -12,7 +12,7 @@ class OCP:
     def __init__(self, pd: ProblemData, target: Target):
         self.pd = pd
         self.target = target
-        self.max_iter = 10
+        self.max_iter = 1
 
         self.state = crocoddyl.StateMultibody(self.pd.model)
         self.initialized = False
@@ -107,9 +107,7 @@ class OCP:
 
         t_warm_start = time()
         self.t_warm_start = t_warm_start - t_update
-        print("CROCODDYL START")
-        print("CROCODDYL INITIAL POINT: ", x0)
-        self.ddp.setCallbacks([crocoddyl.CallbackVerbose()])
+        # self.ddp.setCallbacks([crocoddyl.CallbackVerbose()])
         self.ddp.solve(xs, us, self.max_iter, False)
 
         t_ddp = time()
diff --git a/python/quadruped_reactive_walking/WB_MPC/ProblemData.py b/python/quadruped_reactive_walking/WB_MPC/ProblemData.py
index 758e8b6b..1ec400f9 100644
--- a/python/quadruped_reactive_walking/WB_MPC/ProblemData.py
+++ b/python/quadruped_reactive_walking/WB_MPC/ProblemData.py
@@ -2,14 +2,15 @@ import numpy as np
 import example_robot_data as erd
 import pinocchio as pin
 
+
 class problemDataAbstract:
-    def __init__(self, param, frozen_names = []):
-        self.dt = param.dt_mpc # OCP dt
+    def __init__(self, param, frozen_names=[]):
+        self.dt = param.dt_mpc  # OCP dt
         self.dt_wbc = param.dt_wbc
-        self.r1 = int(self.dt/self.dt_wbc)
+        self.mpc_wbc_ratio = int(self.dt / self.dt_wbc)
         self.init_steps = 0
-        self.target_steps =  150
-        self.T = self.init_steps + self.target_steps -1
+        self.target_steps = 150
+        self.T = self.init_steps + self.target_steps - 1
 
         self.robot = erd.load("solo12")
         self.q0 = self.robot.q0
@@ -28,11 +29,13 @@ class problemDataAbstract:
         self.nq = self.model.nq
         self.nv = self.model.nv
         self.nx = self.nq + self.nv
-        self.ndx = 2*self.nv
-        self.nu = 12 - len(frozen_names)  + 1 if len(frozen_names) != 0 else 12 # -1 to take into account the freeflyer
+        self.ndx = 2 * self.nv
+        self.nu = (
+            12 - len(frozen_names) + 1 if len(frozen_names) != 0 else 12
+        )  # -1 to take into account the freeflyer
         self.ntau = self.nv
 
-        self.effort_limit = np.ones(self.nu) *3   
+        self.effort_limit = np.ones(self.nu) * 3
 
         self.v0 = np.zeros(18)
         self.x0 = np.concatenate([self.q0, self.v0])
@@ -40,10 +43,15 @@ class problemDataAbstract:
 
         self.xref = self.x0
         self.uref = self.u0
-                 
-        self.lfFoot, self.rfFoot, self.lhFoot, self.rhFoot = 'FL_FOOT', 'FR_FOOT', 'HL_FOOT', 'HR_FOOT'
+
+        self.lfFoot, self.rfFoot, self.lhFoot, self.rhFoot = (
+            "FL_FOOT",
+            "FR_FOOT",
+            "HL_FOOT",
+            "HR_FOOT",
+        )
         self.cnames = [self.lfFoot, self.rfFoot, self.lhFoot, self.rhFoot]
-        self.allContactIds = [ self.model.getFrameId(f) for f in self.cnames]
+        self.allContactIds = [self.model.getFrameId(f) for f in self.cnames]
         self.lfFootId = self.model.getFrameId(self.lfFoot)
         self.rfFootId = self.model.getFrameId(self.rfFoot)
         self.lhFootId = self.model.getFrameId(self.lhFoot)
@@ -54,18 +62,20 @@ class problemDataAbstract:
     def freeze(self):
         geom_models = [self.visual_model, self.collision_model]
         self.model, geometric_models_reduced = pin.buildReducedModel(
-                                                self.model,
-                                                list_of_geom_models=geom_models,
-                                                list_of_joints_to_lock=self.frozen_idxs,
-                                                reference_configuration=self.q0) 
+            self.model,
+            list_of_geom_models=geom_models,
+            list_of_joints_to_lock=self.frozen_idxs,
+            reference_configuration=self.q0,
+        )
         self.rdata = self.model.createData()
         self.visual_model = geometric_models_reduced[0]
         self.collision_model = geometric_models_reduced[1]
 
+
 class ProblemData(problemDataAbstract):
     def __init__(self, param):
         super().__init__(param)
-        
+
         self.useFixedBase = 0
         # Cost function weights
         self.mu = 0.7
@@ -73,58 +83,114 @@ class ProblemData(problemDataAbstract):
         self.friction_cone_w = 1e3
         self.control_bound_w = 1e3
         self.control_reg_w = 1e0
-        self.state_reg_w = np.array([0] * 3 \
-                            + [1e1] * 3 \
-                            + [1e0] * 3 \
-                            + [1e-3] * 3\
-                            + [1e0] * 6
-                            + [0] * 6 \
-                            + [1e1] * 3 \
-                            + [3*1e-1] * 3\
-                            + [1e1] * 6 ) 
-        self.terminal_velocity_w = np.array([0] * 18 + [1e3] * 18 )
+        self.state_reg_w = np.array(
+            [0] * 3
+            + [1e1] * 3
+            + [1e0] * 3
+            + [1e-3] * 3
+            + [1e0] * 6
+            + [0] * 6
+            + [1e1] * 3
+            + [3 * 1e-1] * 3
+            + [1e1] * 6
+        )
+        self.terminal_velocity_w = np.array([0] * 18 + [1e3] * 18)
         self.control_bound_w = 1e3
 
-        self.x0 = np.array([ 0.0, 0.0, 0.2607495, 0, 0, 0, 1,
-                             0,  0.7, -1.4,  
-                             0. ,  0.7, -1.4,  
-                             0. , -0.7,  1.4,  
-                             0. , -0.7, 1.4,
-                             0, 0, 0,  0,  0,  0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) # x0 got from PyBullet
-                            
-        self.u0 = np.array([-0.02615051, -0.25848605,  0.51696646,  
-                            0.0285894 , -0.25720605, 0.51441775, 
-                            -0.02614404, 0.25848271, -0.51697107,  
-                            0.02859587, 0.25720939, -0.51441314]) # quasi static control
+        self.x0 = np.array(
+            [
+                0.0,
+                0.0,
+                0.2607495,
+                0,
+                0,
+                0,
+                1,
+                0,
+                0.7,
+                -1.4,
+                0.0,
+                0.7,
+                -1.4,
+                0.0,
+                -0.7,
+                1.4,
+                0.0,
+                -0.7,
+                1.4,
+                0,
+                0,
+                0,
+                0,
+                0,
+                0,
+                0,
+                0,
+                0,
+                0,
+                0,
+                0,
+                0,
+                0,
+                0,
+                0,
+                0,
+                0,
+            ]
+        )  # x0 got from PyBullet
+
+        self.u0 = np.array(
+            [
+                -0.02615051,
+                -0.25848605,
+                0.51696646,
+                0.0285894,
+                -0.25720605,
+                0.51441775,
+                -0.02614404,
+                0.25848271,
+                -0.51697107,
+                0.02859587,
+                0.25720939,
+                -0.51441314,
+            ]
+        )  # quasi static control
         self.xref = self.x0
         self.uref = self.u0
 
 
 class ProblemDataFull(problemDataAbstract):
     def __init__(self, param):
-        frozen_names = ["root_joint", "FL_HAA", "FL_HFE", "FL_KFE",
-                        "HL_HAA", "HL_HFE", "HL_KFE",
-                        "HR_HAA", "HR_HFE", "HR_KFE" ]
-
+        frozen_names = [
+            "root_joint",
+            "FL_HAA",
+            "FL_HFE",
+            "FL_KFE",
+            "HL_HAA",
+            "HL_HFE",
+            "HL_KFE",
+            "HR_HAA",
+            "HR_HFE",
+            "HR_KFE",
+        ]
 
         super().__init__(param, frozen_names)
-        
+
         self.useFixedBase = 1
 
         # Cost function weights
         # Cost function weights
         self.mu = 0.7
         self.foot_tracking_w = 1e3
-        #self.friction_cone_w = 1e3 * 0
+        # self.friction_cone_w = 1e3 * 0
         self.control_bound_w = 1e3
         self.control_reg_w = 1e0
-        self.state_reg_w = np.array([1e-5] * 3 + [1e0]*3)
-        self.terminal_velocity_w = np.array([0] * 3 + [1e3] * 3 )
+        self.state_reg_w = np.array([1e-5] * 3 + [1e0] * 3)
+        self.terminal_velocity_w = np.array([0] * 3 + [1e3] * 3)
 
-        self.q0_reduced = self.q0[10 : 13]
+        self.q0_reduced = self.q0[10:13]
         self.v0_reduced = np.zeros(self.nq)
         self.x0_reduced = np.concatenate([self.q0_reduced, self.v0_reduced])
 
         self.xref = self.x0_reduced
         self.uref = self.u0
-    
\ No newline at end of file
diff --git a/python/quadruped_reactive_walking/WB_MPC_Wrapper.py b/python/quadruped_reactive_walking/WB_MPC_Wrapper.py
index 874e55c0..3fd35a7f 100644
--- a/python/quadruped_reactive_walking/WB_MPC_Wrapper.py
+++ b/python/quadruped_reactive_walking/WB_MPC_Wrapper.py
@@ -18,7 +18,6 @@ class Result:
         self.K = list(np.zeros([pd.T, pd.nu, pd.nx]))
         self.solving_duration = 0.0
         self.new_result = False
-        
 
 
 class MPC_Wrapper:
@@ -28,7 +27,6 @@ class MPC_Wrapper:
     """
 
     def __init__(self, pd, target, params):
-        self.initialized = False
         self.params = params
         self.pd = pd
         self.target = target
@@ -37,7 +35,6 @@ class MPC_Wrapper:
 
         if self.multiprocessing:
             self.new_data = Value("b", False)
-            self.new_result = Value("b", False)
             self.running = Value("b", True)
             self.in_k = Value("i", 0)
             self.in_x0 = Array("d", [0] * pd.nx)
@@ -52,6 +49,7 @@ class MPC_Wrapper:
             self.ocp = OCP(pd, target)
 
         self.last_available_result = Result(pd)
+        self.new_result = Value("b", False)
 
     def solve(self, k, x0, xs=None, us=None):
         """
@@ -73,23 +71,20 @@ class MPC_Wrapper:
         If a new result is available, return the new result.
         Otherwise return the old result again.
         """
-        if self.initialized:
-            if self.multiprocessing and self.new_result.value:
-                self.new_result.value = False
+        if self.new_result.value:
+            if self.multiprocessing:
                 (
                     self.last_available_result.xs,
                     self.last_available_result.us,
                     self.last_available_result.K,
                     self.last_available_result.solving_duration,
                 ) = self.decompress_dataOut()
-                self.last_available_result.new_result = True
-                
 
-            elif self.multiprocessing and not self.new_result.value:
-                self.last_available_result.new_result = False
-                
+            self.last_available_result.new_result = True
+            self.new_result.value = False
         else:
-            self.initialized = True
+            self.last_available_result.new_result = False
+
         return self.last_available_result
 
     def run_MPC_synchronous(self, x0, xs, us):
@@ -101,8 +96,9 @@ class MPC_Wrapper:
             self.last_available_result.xs,
             self.last_available_result.us,
             self.last_available_result.K,
-            self.last_available_result.solving_duration
+            self.last_available_result.solving_duration,
         ) = self.ocp.get_results()
+        self.new_result.value = True
 
     def run_MPC_asynchronous(self, k, x0, xs, us):
         """
@@ -110,7 +106,7 @@ class MPC_Wrapper:
         """
         print("Call to solve")
         if k == 0:
-            self.last_available_result.xs = [x0 for _ in range (self.pd.T + 1)]
+            self.last_available_result.xs = [x0 for _ in range(self.pd.T + 1)]
             p = Process(target=self.MPC_asynchronous)
             p.start()
         self.add_new_data(k, x0, xs, us)
@@ -212,7 +208,6 @@ class MPC_Wrapper:
             )[:, :, :] = np.array(K)
         self.out_solving_time.value = solving_time
 
-
     def decompress_dataOut(self):
         """
         Return the result of the asynchronous MPC (desired contact forces) that is
diff --git a/python/quadruped_reactive_walking/tools/LoggerControl.py b/python/quadruped_reactive_walking/tools/LoggerControl.py
index e0106f5b..8a97e17f 100644
--- a/python/quadruped_reactive_walking/tools/LoggerControl.py
+++ b/python/quadruped_reactive_walking/tools/LoggerControl.py
@@ -105,7 +105,7 @@ class LoggerControl:
             self.mocapOrientationQuat[self.i] = device.baseState[1]
 
         # Controller timings: MPC time, ...
-        self.target[self.i] = controller.point_target
+        self.target[self.i] = controller.target.evaluate_in_t(1)[self.pd.rfFootId]
         self.t_mpc[self.i] = controller.t_mpc
         self.t_send[self.i] = controller.t_send
         self.t_loop[self.i] = controller.t_loop
@@ -163,7 +163,7 @@ class LoggerControl:
 
         plt.show()
 
-    def plot_states(self, save=False, fileName='/tmp'):
+    def plot_states(self, save=False, fileName="/tmp"):
         import matplotlib.pyplot as plt
 
         legend = ["Hip", "Shoulder", "Knee"]
@@ -199,7 +199,7 @@ class LoggerControl:
         if save:
             plt.savefig(fileName + "/joint_velocities")
 
-    def plot_torques(self, save=False, fileName='/tmp'):
+    def plot_torques(self, save=False, fileName="/tmp"):
         import matplotlib.pyplot as plt
 
         legend = ["Hip", "Shoulder", "Knee"]
@@ -209,8 +209,7 @@ class LoggerControl:
             plt.subplot(2, 2, i + 1)
             plt.title("Joint torques of " + str(i))
             [
-                plt.plot(np.array(self.torquesFromCurrentMeasurment)
-                         [:, (3 * i + jj)])
+                plt.plot(np.array(self.torquesFromCurrentMeasurment)[:, (3 * i + jj)])
                 for jj in range(3)
             ]
             plt.ylabel("Torque [Nm]")
@@ -220,14 +219,15 @@ class LoggerControl:
         if save:
             plt.savefig(fileName + "/joint_torques")
 
-    def plot_target(self, save=False, fileName='/tmp'):
+    def plot_target(self, save=False, fileName="/tmp"):
         import matplotlib.pyplot as plt
 
-        x_mes = np.concatenate(
-            [self.q_mes[:, 3:6], self.v_mes[:, 3:6]], axis=1)
+        x_mes = np.concatenate([self.q_mes[:, 3:6], self.v_mes[:, 3:6]], axis=1)
 
-        horizon = int(self.ocp_xs.shape[0] / self.pd.r1)
-        t_scale = np.linspace(0, (horizon)*self.pd.dt, (horizon)*self.pd.r1)
+        horizon = int(self.ocp_xs.shape[0] / self.pd.mpc_wbc_ratio)
+        t_scale = np.linspace(
+            0, (horizon) * self.pd.dt, (horizon) * self.pd.mpc_wbc_ratio
+        )
 
         x_mpc = [self.ocp_xs[0][0, :]]
         [x_mpc.append(x[1, :]) for x in self.ocp_xs[:-1]]
@@ -235,8 +235,12 @@ class LoggerControl:
 
         # Feet positions calcuilated by every ocp
         all_ocp_feet_p_log = {
-            idx: [get_translation_array(self.pd, self.ocp_xs[i * self.pd.r1], idx)[0]
-                  for i in range(horizon)]
+            idx: [
+                get_translation_array(
+                    self.pd, self.ocp_xs[i * self.pd.mpc_wbc_ratio], idx
+                )[0]
+                for i in range(horizon)
+            ]
             for idx in self.pd.allContactIds
         }
         for foot in all_ocp_feet_p_log:
@@ -279,17 +283,16 @@ class LoggerControl:
         #         y = all_ocp_feet_p_log[self.pd.rfFootId][i][:,p]
         #         for j in range(len(y) - 1):
         #             plt.plot(t[j:j+2], y[j:j+2], color='royalblue', linewidth = 3, marker='o' ,alpha=max([1 - j/len(y), 0]))
-            
 
-    def plot_riccati_gains(self, n, save=False, fileName='/tmp'):
+    def plot_riccati_gains(self, n, save=False, fileName="/tmp"):
         import matplotlib.pyplot as plt
 
         # Equivalent Stiffness Damping plots
         legend = ["Hip", "Shoulder", "Knee"]
         plt.figure(figsize=(12, 18), dpi=90)
         for p in range(3):
-            plt.subplot(3, 1, p+1)
-            plt.title('Joint:  ' + legend[p])
+            plt.subplot(3, 1, p + 1)
+            plt.title("Joint:  " + legend[p])
             plt.plot(self.MPC_equivalent_Kp[:, p])
             plt.plot(self.MPC_equivalent_Kd[:, p])
             plt.legend(["Stiffness", "Damping"])
@@ -309,8 +312,7 @@ class LoggerControl:
     def plot_controller_times(self):
         import matplotlib.pyplot as plt
 
-        t_range = np.array(
-            [k * self.pd.dt for k in range(self.tstamps.shape[0])])
+        t_range = np.array([k * self.pd.dt for k in range(self.tstamps.shape[0])])
 
         plt.figure()
         plt.plot(t_range, self.t_measures, "r+")
@@ -327,8 +329,7 @@ class LoggerControl:
     def plot_OCP_times(self):
         import matplotlib.pyplot as plt
 
-        t_range = np.array(
-            [k * self.pd.dt for k in range(self.tstamps.shape[0])])
+        t_range = np.array([k * self.pd.dt for k in range(self.tstamps.shape[0])])
 
         plt.figure()
         plt.plot(t_range, self.t_ocp_update, "r+")
@@ -344,8 +345,7 @@ class LoggerControl:
     def plot_OCP_update_times(self):
         import matplotlib.pyplot as plt
 
-        t_range = np.array(
-            [k * self.pd.dt for k in range(self.tstamps.shape[0])])
+        t_range = np.array([k * self.pd.dt for k in range(self.tstamps.shape[0])])
 
         plt.figure()
         plt.plot(t_range, self.t_ocp_update_FK, "r+")
diff --git a/src/Params.cpp b/src/Params.cpp
index 84a1d06a..fdea1b25 100644
--- a/src/Params.cpp
+++ b/src/Params.cpp
@@ -23,6 +23,7 @@ Params::Params()
       dt_mpc(0.0),
       N_periods(0),
       type_MPC(0),
+      save_guess(false),
       interpolate_mpc(true),
       interpolation_type(0),
       kf_enabled(false),
@@ -140,6 +141,9 @@ void Params::initialize(const std::string& file_path) {
   assert_yaml_parsing(robot_node, "robot", "perfect_estimator");
   perfect_estimator = robot_node["perfect_estimator"].as<bool>();
 
+  assert_yaml_parsing(robot_node, "robot", "save_guess");
+  save_guess = robot_node["save_guess"].as<bool>();
+
   assert_yaml_parsing(robot_node, "robot", "interpolate_mpc");
   interpolate_mpc = robot_node["interpolate_mpc"].as<bool>();
 
-- 
GitLab