From ca8582ca5a214741454ff41147556fe4573653cc Mon Sep 17 00:00:00 2001
From: Ale <alessandroassirell98@gmail.com>
Date: Mon, 1 Aug 2022 15:26:58 +0200
Subject: [PATCH] interpolation finished: 1 quadratic with mistmatch of
 derivatives, 2 linear 3 good one

---
 .../quadruped_reactive_walking/Controller.py  | 132 ++++++++++++------
 .../WB_MPC/CrocoddylOCP.py                    |  31 ++--
 .../WB_MPC/Target.py                          |   8 +-
 .../main_solo12_control.py                    |   2 +-
 .../tools/LoggerControl.py                    |  41 +++---
 5 files changed, 139 insertions(+), 75 deletions(-)

diff --git a/python/quadruped_reactive_walking/Controller.py b/python/quadruped_reactive_walking/Controller.py
index a2316bca..f32b3039 100644
--- a/python/quadruped_reactive_walking/Controller.py
+++ b/python/quadruped_reactive_walking/Controller.py
@@ -23,6 +23,72 @@ class Result:
         self.tau_ff = np.zeros(12)
 
 
+class Interpolation:
+    def __init__(self):
+        pass
+
+    def load_data(self, q, v):
+        self.v0 = v[0, :]
+        self.q0 = q[0, :]
+        self.v1 = v[1, :]
+        self.q1 = q[1, :]
+
+    def interpolate(self, t):
+        # Perfect match, but wrong
+        # if (self.q1-self.q0 == 0).any():
+        # alpha = np.zeros(len(self.q0))
+        # else:
+        # alpha = 2 * 1/2* (self.v1**2 - self.v0**2)/(self.q1 - self.q0)
+
+        # beta = self.v0
+        # gamma = self.q0
+
+        # v_t = beta + alpha * t
+        # q_t = gamma + beta * t + alpha * t**2
+
+        # Linear
+        # beta = self.v1
+        # gamma = self.q0
+
+        # v_t = beta
+        # q_t = gamma + beta * t
+
+        # Quadratic
+        if (self.q1-self.q0 == 0).any():
+            alpha = np.zeros(len(self.q0))
+        else:
+            alpha = self.v1*(self.v1 - self.v0)/(self.q1 - self.q0)
+
+        beta = self.v0
+        gamma = self.q0
+
+        v_t = beta + alpha * t
+        q_t = gamma + beta * t + 1/2 * alpha * t**2
+
+        return q_t, v_t
+
+    def plot_interpolation(self, n, dt):
+        import matplotlib.pyplot as plt
+        plt.style.use("seaborn")
+        t = np.linspace(0, n*dt, n + 1)
+        q_t = np.array([self.interpolate((i) * dt)[0] for i in range(n+1)])
+        v_t = np.array([self.interpolate((i) * dt)[1] for i in range(n+1)])
+        for i in range(3):
+            plt.subplot(3, 2, (i*2) + 1)
+            plt.title("Position interpolation")
+            plt.plot(t, q_t[:, i])
+            plt.scatter(y=self.q0[i], x=t[0], color="violet", marker="+")
+            plt.scatter(y=self.q1[i], x=t[-1], color="violet", marker="+")
+
+            plt.subplot(3, 2, (i*2) + 2)
+            plt.title("Velocity interpolation")
+            plt.plot(t, v_t[:, i])
+            plt.scatter(y=self.v0[i], x=t[0], color="violet", marker="+")
+            plt.scatter(y=self.v1[i], x=t[-1], color="violet", marker="+")
+
+        plt.show()
+
+
 class DummyDevice:
     def __init__(self):
         self.imu = self.IMU()
@@ -44,7 +110,7 @@ class DummyDevice:
         def __init__(self):
             self.positions = np.zeros(12)
             self.velocities = np.zeros(12)
-90
+
 
 class Controller:
     def __init__(self, pd, target, params, q_init, t):
@@ -70,7 +136,7 @@ class Controller:
         self.cnt_wbc = 0
         self.error = False
         self.initialized = False
-        self.interpolated = False
+        self.interpolator = Interpolation()
         self.result = Result(params)
         self.q = self.pd.q0[7:].copy()
         self.v = self.pd.v0[6:].copy()
@@ -117,13 +183,13 @@ class Controller:
 
                 # Trajectory tracking
                 # if self.initialized:
-                    # self.mpc.solve(
-                        # self.k, self.mpc_result.xs[1], self.xs_init, self.us_init)
+                # self.mpc.solve(
+                # self.k, self.mpc_result.xs[1], self.xs_init, self.us_init)
                 # else:
-                    # self.mpc.solve(self.k, m["x_m"],
-                                #    self.xs_init, self.us_init)
+                # self.mpc.solve(self.k, m["x_m"],
+                #    self.xs_init, self.us_init)
 
-                self.cnt_mpc += 1        
+                self.cnt_mpc += 1
             except ValueError:
                 self.error = True
                 print("MPC Problem")
@@ -133,27 +199,35 @@ class Controller:
 
         if not self.error:
             self.mpc_result = self.mpc.get_latest_result()
+            if self.cnt_wbc == 0:
+                x = np.array(self.mpc_result.xs)
+                self.interpolator.load_data(
+                    x[:, : self.pd.nq], x[:, self.pd.nq:])
+            #self.interpolator.plot_interpolation(15, 0.001)
 
             # ## ONLY IF YOU WANT TO STORE THE FIRST SOLUTION TO WARM-START THE INITIAL Problem ###
-            #if not self.initialized:
+            # if not self.initialized:
             #   np.save(open('/tmp/init_guess.npy', "wb"), {"xs": self.mpc_result.xs, "us": self.mpc_result.us} )
             #   print("Initial guess saved")
 
             # Keep only the actuated joints and set the other to default values
-            q_interpolated, v_interpolated = self.interpolate_x(self.cnt_wbc * self.pd.dt_wbc)
+            q_interpolated, v_interpolated = self.interpolator.interpolate(
+                (self.cnt_wbc + 1) * self.pd.dt_wbc)  # +1 because the first one matches the actual state
             self.q[3:6] = q_interpolated
             self.v[3:6] = v_interpolated
 
-            # self.result.P = np.array(self.params.Kp_main.tolist() * 4)
-            # self.result.D = np.array(self.params.Kd_main.tolist() * 4)
+            self.result.P = np.array(self.params.Kp_main.tolist() * 4)
+            self.result.D = np.array(self.params.Kd_main.tolist() * 4)
             self.result.FF = self.params.Kff_main * np.ones(12)
+
             self.result.q_des = self.q
             self.result.v_des = self.v
-            actuated_tau_ff = self.mpc_result.us[0] + np.dot(self.mpc_result.K[0], 
-                                                         np.concatenate([pin.difference(self.pd.model, m["x_m"][: self.pd.nq],
-                                                                                        self.mpc_result.xs[0][: self.pd.nq]), 
-                                                                        m["x_m"][self.pd.nq] -  self.mpc_result.xs[0][self.pd.nq:] ]) )
-            self.result.tau_ff = np.array([0] * 3 + list(actuated_tau_ff) + [0] * 6)
+            actuated_tau_ff = self.mpc_result.us[0] + np.dot(self.mpc_result.K[0],
+                                                             np.concatenate([pin.difference(self.pd.model, m["x_m"][: self.pd.nq],
+                                                                                            self.mpc_result.xs[0][: self.pd.nq]),
+                                                                             m["x_m"][self.pd.nq] - self.mpc_result.xs[0][self.pd.nq:]]))
+            self.result.tau_ff = np.array(
+                [0] * 3 + list(actuated_tau_ff) + [0] * 6)
 
             self.xs_init = self.mpc_result.xs[1:] + [self.mpc_result.xs[-1]]
             self.us_init = self.mpc_result.us[1:] + [self.mpc_result.us[-1]]
@@ -161,8 +235,8 @@ class Controller:
         t_send = time.time()
         self.t_send = t_send - t_mpc
 
-        #self.clamp_result(device)
-        #self.security_check(m)
+        self.clamp_result(device)
+        self.security_check(m)
 
         if self.error:
             self.set_null_control()
@@ -287,28 +361,6 @@ class Controller:
 
         return {"qj_m": qj_m, "vj_m": vj_m, "x_m": x_m}
 
-    def interpolate_x(self, t):
-        q = np.array(self.mpc_result.xs)[:, : self.pd.nq]
-        v = np.array(self.mpc_result.xs)[:, self.pd.nq :]
-        v0 = v[0, :]
-        q0 = q[0, :]
-        v1 = v[1, :]
-        q1 = q[1, :]
-    
-        if (q1-q0 == 0).any():
-            alpha = np.zeros(len(q0))
-        else:
-            alpha = (v1**2 - v0**2)/(q1 - q0)
-
-        beta = v0
-        gamma = q0
-
-        v_t = beta + alpha * t
-        q_t = gamma + beta *t + 1/2 * alpha * t**2
-
-        return q_t, v_t
-        
-
     def tuple_to_array(self, tup):
         a = np.array([element for tupl in tup for element in tupl])
         return a
diff --git a/python/quadruped_reactive_walking/WB_MPC/CrocoddylOCP.py b/python/quadruped_reactive_walking/WB_MPC/CrocoddylOCP.py
index 5bdecdd1..d2ce4ed0 100644
--- a/python/quadruped_reactive_walking/WB_MPC/CrocoddylOCP.py
+++ b/python/quadruped_reactive_walking/WB_MPC/CrocoddylOCP.py
@@ -28,7 +28,6 @@ class OCP:
         )
         self.ddp = crocoddyl.SolverFDDP(self.problem)
 
-
     def initialize_models(self):
         self.nodes = []
         for t in range(self.pd.T):
@@ -66,7 +65,8 @@ class OCP:
                 self.target.evaluate_in_t(self.pd.T - 1),
                 self.target.contactSequence[self.pd.T - 1],
             )  # model without contact for this task
-            self.nodes[0].update_model(self.target.contactSequence[self.pd.T - 1], task)
+            self.nodes[0].update_model(
+                self.target.contactSequence[self.pd.T - 1], task)
 
             t_update_last_model = time()
             self.t_update_last_model = t_update_last_model - t_FK
@@ -164,7 +164,8 @@ class OCP:
 
     def get_croco_acc(self):
         acc = []
-        [acc.append(m.differential.xout) for m in self.ddp.problem.runningDatas]
+        [acc.append(m.differential.xout)
+         for m in self.ddp.problem.runningDatas]
         return acc
 
 
@@ -180,7 +181,8 @@ class Node:
             self.actuation = crocoddyl.ActuationModelFloatingBase(self.state)
         else:
             self.actuation = crocoddyl.ActuationModelFull(self.state)
-        self.control = crocoddyl.ControlParametrizationModelPolyZero(self.actuation.nu)
+        self.control = crocoddyl.ControlParametrizationModelPolyZero(
+            self.actuation.nu)
         self.nu = self.actuation.nu
 
         self.createStandardModel(supportFootIds)
@@ -202,7 +204,8 @@ class Node:
         self.contactModel = crocoddyl.ContactModelMultiple(self.state, self.nu)
         for i in supportFootIds:
             supportContactModel = crocoddyl.ContactModel3D(
-                self.state, i, np.array([0.0, 0.0, 0.0]), self.nu, np.array([0.0, 0.0])
+                self.state, i, np.array(
+                    [0.0, 0.0, 0.0]), self.nu, np.array([0.0, 0.0])
             )
             self.contactModel.addContact(
                 self.pd.model.frames[i].name + "_contact", supportContactModel
@@ -211,7 +214,8 @@ class Node:
         # Creating the cost model for a contact phase
         costModel = crocoddyl.CostModelSum(self.state, self.nu)
 
-        stateResidual = crocoddyl.ResidualModelState(self.state, self.pd.xref, self.nu)
+        stateResidual = crocoddyl.ResidualModelState(
+            self.state, self.pd.xref, self.nu)
         stateActivation = crocoddyl.ActivationModelWeightedQuad(
             self.pd.state_reg_w**2
         )
@@ -234,7 +238,8 @@ class Node:
         self.contactModel = crocoddyl.ContactModelMultiple(self.state, self.nu)
         for i in supportFootIds:
             supportContactModel = crocoddyl.ContactModel3D(
-                self.state, i, np.array([0.0, 0.0, 0.0]), self.nu, np.array([0.0, 0.0])
+                self.state, i, np.array(
+                    [0.0, 0.0, 0.0]), self.nu, np.array([0.0, 0.0])
             )
             self.dmodel.contacts.addContact(
                 self.pd.model.frames[i].name + "_contact", supportContactModel
@@ -242,7 +247,8 @@ class Node:
 
     def make_terminal_model(self):
         self.isTerminal = True
-        stateResidual = crocoddyl.ResidualModelState(self.state, self.pd.xref, self.nu)
+        stateResidual = crocoddyl.ResidualModelState(
+            self.state, self.pd.xref, self.nu)
         stateActivation = crocoddyl.ActivationModelWeightedQuad(
             self.pd.terminal_velocity_w**2
         )
@@ -273,14 +279,17 @@ class Node:
         ctrlReg = crocoddyl.CostModelResidual(self.state, ctrlResidual)
         self.costModel.addCost("ctrlReg", ctrlReg, self.pd.control_reg_w)
 
-        ctrl_bound_residual = crocoddyl.ResidualModelControl(self.state, self.nu)
+        ctrl_bound_residual = crocoddyl.ResidualModelControl(
+            self.state, self.nu)
         ctrl_bound_activation = crocoddyl.ActivationModelQuadraticBarrier(
-            crocoddyl.ActivationBounds(-self.pd.effort_limit, self.pd.effort_limit)
+            crocoddyl.ActivationBounds(-self.pd.effort_limit,
+                                       self.pd.effort_limit)
         )
         ctrl_bound = crocoddyl.CostModelResidual(
             self.state, ctrl_bound_activation, ctrl_bound_residual
         )
-        self.costModel.addCost("ctrlBound", ctrl_bound, self.pd.control_bound_w)
+        self.costModel.addCost("ctrlBound", ctrl_bound,
+                               self.pd.control_bound_w)
 
         self.tracking_cost(swingFootTask)
 
diff --git a/python/quadruped_reactive_walking/WB_MPC/Target.py b/python/quadruped_reactive_walking/WB_MPC/Target.py
index 8ca4d2c5..b924f952 100644
--- a/python/quadruped_reactive_walking/WB_MPC/Target.py
+++ b/python/quadruped_reactive_walking/WB_MPC/Target.py
@@ -8,9 +8,9 @@ class Target:
         self.pd = pd
         self.dt = pd.dt
 
-        self.gait = ([] + \
-                    [[0, 0, 0, 0]] * pd.init_steps + \
-                    [[0, 0, 0, 0]] * pd.target_steps )
+        self.gait = ([] +
+                     [[0, 0, 0, 0]] * pd.init_steps +
+                     [[0, 0, 0, 0]] * pd.target_steps)
 
         self.T = pd.T
         self.contactSequence = [self.patternToId(p) for p in self.gait]
@@ -23,7 +23,7 @@ class Target:
         self.FR_foot0 = pd.rdata.oMf[pd.rfFootId].translation.copy()
         self.A = np.array([0, 0.03, 0.03])
         self.offset = np.array([0.05, -0.02, 0.06])
-        self.freq = np.array([0, 0.5*0 , 0.5*0 ])
+        self.freq = np.array([0, 0.5 * 0, 0.5 * 0])
         self.phase = np.array([0, np.pi / 2, 0])
 
     def patternToId(self, gait):
diff --git a/python/quadruped_reactive_walking/main_solo12_control.py b/python/quadruped_reactive_walking/main_solo12_control.py
index 4773623a..609b997e 100644
--- a/python/quadruped_reactive_walking/main_solo12_control.py
+++ b/python/quadruped_reactive_walking/main_solo12_control.py
@@ -85,7 +85,7 @@ def check_position_error(device, controller):
         device (robot wrapper): a wrapper to communicate with the robot
         controller (array): the controller storing the desired position
     """
-    if np.max(np.abs(controller.result.q_des - device.joints.positions)) > 15:
+    if np.max(np.abs(controller.result.q_des - device.joints.positions)) > 0.15:
         print("DIFFERENCE: ", controller.result.q_des - device.joints.positions)
         print("q_des: ", controller.result.q_des)
         print("q_mes: ", device.joints.positions)
diff --git a/python/quadruped_reactive_walking/tools/LoggerControl.py b/python/quadruped_reactive_walking/tools/LoggerControl.py
index 7a9ba978..fbfd1ccf 100644
--- a/python/quadruped_reactive_walking/tools/LoggerControl.py
+++ b/python/quadruped_reactive_walking/tools/LoggerControl.py
@@ -122,7 +122,7 @@ class LoggerControl:
         self.t_mpc[self.i] = controller.t_mpc
         self.t_send[self.i] = controller.t_send
         self.t_loop[self.i] = controller.t_loop
-        
+
         self.t_ocp_ddp[self.i] = controller.mpc_result.solving_duration
         if not self.params.enable_multiprocessing:
             self.t_ocp_update[self.i] = controller.mpc.ocp.t_update
@@ -163,7 +163,7 @@ class LoggerControl:
 
         plt.show()
 
-    def plot_states(self, save = False, fileName = '/tmp'):
+    def plot_states(self, save=False, fileName='/tmp'):
         import matplotlib.pyplot as plt
 
         legend = ["Hip", "Shoulder", "Knee"]
@@ -199,7 +199,7 @@ class LoggerControl:
         if save:
             plt.savefig(fileName + "/joint_velocities")
 
-    def plot_torques(self, save=False, fileName = '/tmp'):
+    def plot_torques(self, save=False, fileName='/tmp'):
         import matplotlib.pyplot as plt
 
         legend = ["Hip", "Shoulder", "Knee"]
@@ -209,7 +209,8 @@ class LoggerControl:
             plt.subplot(2, 2, i + 1)
             plt.title("Joint torques of " + str(i))
             [
-                plt.plot(np.array(self.torquesFromCurrentMeasurment)[:, (3 * i + jj)])
+                plt.plot(np.array(self.torquesFromCurrentMeasurment)
+                         [:, (3 * i + jj)])
                 for jj in range(3)
             ]
             plt.ylabel("Torque [Nm]")
@@ -222,7 +223,8 @@ class LoggerControl:
     def plot_target(self, save=False, fileName='/tmp'):
         import matplotlib.pyplot as plt
 
-        x_mes = np.concatenate([self.q_mes[:, 3:6], self.v_mes[:, 3:6]], axis=1)
+        x_mes = np.concatenate(
+            [self.q_mes[:, 3:6], self.v_mes[:, 3:6]], axis=1)
 
         horizon = self.ocp_xs.shape[0]
 
@@ -232,7 +234,8 @@ class LoggerControl:
 
         # Feet positions calcuilated by every ocp
         all_ocp_feet_p_log = {
-            idx: [get_translation_array(self.pd, x, idx)[0] for x in self.ocp_xs]
+            idx: [get_translation_array(self.pd, x, idx)[0]
+                  for x in self.ocp_xs]
             for idx in self.pd.allContactIds
         }
         for foot in all_ocp_feet_p_log:
@@ -263,7 +266,6 @@ class LoggerControl:
         if save:
             plt.savefig(fileName + "/target")
 
-
         """ legend = ['x', 'y', 'z']
         plt.figure(figsize=(12, 18), dpi = 90)
         for p in range(3):
@@ -276,15 +278,14 @@ class LoggerControl:
                     plt.plot(t[j:j+2], y[j:j+2], color='royalblue', linewidth = 3, marker='o' ,alpha=max([1 - j/len(y), 0]))
             plt.plot(self.target[:, p]) """
 
-
     def plot_riccati_gains(self, n, save=False, fileName='/tmp'):
         import matplotlib.pyplot as plt
 
         # Equivalent Stiffness Damping plots
         legend = ["Hip", "Shoulder", "Knee"]
-        plt.figure(figsize=(12, 18), dpi = 90)
+        plt.figure(figsize=(12, 18), dpi=90)
         for p in range(3):
-            plt.subplot(3,1, p+1)
+            plt.subplot(3, 1, p+1)
             plt.title('Joint:  ' + legend[p])
             plt.plot(self.MPC_equivalent_Kp[:, p])
             plt.plot(self.MPC_equivalent_Kd[:, p])
@@ -294,9 +295,8 @@ class LoggerControl:
         if save:
             plt.savefig(fileName + "/diagonal_Riccati_gains")
 
-
         # Riccati gains
-        plt.figure(figsize=(12, 18), dpi = 90)
+        plt.figure(figsize=(12, 18), dpi=90)
         plt.title("Riccati gains at step: " + str(n))
         plt.imshow(self.ocp_K[n])
         plt.colorbar()
@@ -306,7 +306,8 @@ class LoggerControl:
     def plot_controller_times(self):
         import matplotlib.pyplot as plt
 
-        t_range = np.array([k * self.pd.dt for k in range(self.tstamps.shape[0])])
+        t_range = np.array(
+            [k * self.pd.dt for k in range(self.tstamps.shape[0])])
 
         plt.figure()
         plt.plot(t_range, self.t_measures, "r+")
@@ -319,11 +320,12 @@ class LoggerControl:
         plt.legend(lgd)
         plt.xlabel("Time [s]")
         plt.ylabel("Time [s]")
-        
+
     def plot_OCP_times(self):
         import matplotlib.pyplot as plt
 
-        t_range = np.array([k * self.pd.dt for k in range(self.tstamps.shape[0])])
+        t_range = np.array(
+            [k * self.pd.dt for k in range(self.tstamps.shape[0])])
 
         plt.figure()
         plt.plot(t_range, self.t_ocp_update, "r+")
@@ -339,7 +341,8 @@ class LoggerControl:
     def plot_OCP_update_times(self):
         import matplotlib.pyplot as plt
 
-        t_range = np.array([k * self.pd.dt for k in range(self.tstamps.shape[0])])
+        t_range = np.array(
+            [k * self.pd.dt for k in range(self.tstamps.shape[0])])
 
         plt.figure()
         plt.plot(t_range, self.t_ocp_update_FK, "r+")
@@ -364,9 +367,9 @@ class LoggerControl:
             name,
             ocp_xs=self.ocp_xs,
             ocp_us=self.ocp_us,
-            ocp_K = self.ocp_K,
-            MPC_equivalent_Kp = self.MPC_equivalent_Kp,
-            MPC_equivalent_Kd = self.MPC_equivalent_Kd,
+            ocp_K=self.ocp_K,
+            MPC_equivalent_Kp=self.MPC_equivalent_Kp,
+            MPC_equivalent_Kd=self.MPC_equivalent_Kd,
             t_measures=self.t_measures,
             t_mpc=self.t_mpc,
             t_send=self.t_send,
-- 
GitLab