From ee73d4bca9033b634a3d2b201815f63e34330aa1 Mon Sep 17 00:00:00 2001
From: Ale <alessandroassirell98@gmail.com>
Date: Fri, 19 Aug 2022 15:22:50 +0200
Subject: [PATCH] working on implementation

---
 .../quadruped_reactive_walking/Controller.py  | 39 +++++++++++--------
 .../WB_MPC_Wrapper.py                         |  8 ++--
 2 files changed, 27 insertions(+), 20 deletions(-)

diff --git a/python/quadruped_reactive_walking/Controller.py b/python/quadruped_reactive_walking/Controller.py
index 2d06dd9f..1533c2a4 100644
--- a/python/quadruped_reactive_walking/Controller.py
+++ b/python/quadruped_reactive_walking/Controller.py
@@ -48,7 +48,8 @@ class Interpolator:
             self.beta = self.v1
             self.gamma = self.q0
         elif self.type == 1:  # Quadratic fixed velocity
-            self.alpha = 2 * (self.q1 - self.q0 - self.v0 * self.dt) / self.dt**2
+            self.alpha = 2 * (self.q1 - self.q0 -
+                              self.v0 * self.dt) / self.dt**2
             self.beta = self.v0
             self.gamma = self.q0
         elif self.type == 2:  # Quadratic time variable
@@ -104,7 +105,8 @@ class Interpolator:
             plt.scatter(y=self.q0[i], x=0.0, color="violet", marker="+")
             plt.scatter(y=self.q1[i], x=self.dt, color="violet", marker="+")
             if self.type == 3 and self.q2 is not None:
-                plt.scatter(y=self.q2[i], x=2 * self.dt, color="violet", marker="+")
+                plt.scatter(y=self.q2[i], x=2 * self.dt,
+                            color="violet", marker="+")
 
             plt.subplot(3, 2, (i * 2) + 2)
             plt.title("Velocity interpolation")
@@ -112,10 +114,12 @@ class Interpolator:
             plt.scatter(y=self.v0[i], x=0.0, color="violet", marker="+")
             plt.scatter(y=self.v1[i], x=self.dt, color="violet", marker="+")
             if self.type == 3 and self.v2 is not None:
-                plt.scatter(y=self.v2[i], x=2 * self.dt, color="violet", marker="+")
+                plt.scatter(y=self.v2[i], x=2 * self.dt,
+                            color="violet", marker="+")
 
         plt.show()
 
+
 class DummyDevice:
     def __init__(self):
         self.imu = self.IMU()
@@ -171,9 +175,10 @@ class Controller:
         self.result.FF = self.params.Kff_main * np.ones(12)
 
         self.target = Target(params)
-        self.velocity_task =  self.target.velocity_task
+        self.velocity_task = self.target.velocity_task
 
-        self.mpc = WB_MPC_Wrapper.MPC_Wrapper(pd, params, self.velocity_task, self.gait)
+        self.mpc = WB_MPC_Wrapper.MPC_Wrapper(
+            pd, params, self.velocity_task, self.gait)
         self.mpc_solved = False
         self.k_result = 0
         self.k_solve = 0
@@ -183,8 +188,8 @@ class Controller:
             )
         try:
             file = np.load("/tmp/init_guess.npy", allow_pickle=True).item()
-            self.xs_init = list(file["xs"])
-            self.us_init = list(file["us"])
+            self.guess = {'xs': list(file['xs']), 'us': list(file['us']),
+                          'acs': file['acs'], 'fs': file['fs']}
             print("Initial guess loaded \n")
         except:
             self.xs_init = None
@@ -223,8 +228,7 @@ class Controller:
                     m["x_m"],
                     self.velocity_task.copy(),
                     self.gait,
-                    self.xs_init,
-                    self.us_init,
+                    self.guess
                 )
                 # if self.initialized:
                 #     self.mpc.solve(
@@ -279,13 +283,12 @@ class Controller:
                 q, v = self.interpolator.interpolate(t)
             else:
                 q, v = self.integrate_x(m)
-                q = q[7: ]
-                v = v[6 :]
+                q = q[7:]
+                v = v[6:]
 
             # q = xs[1][7: self.pd.nq]
             # v = xs[1][6 + self.pd.nq:]
 
-
             self.result.q_des = q[:]
             self.result.v_des = v[:]
 
@@ -299,8 +302,11 @@ class Controller:
                     )
                     for i in range(3)
                 ]
-            self.xs_init = self.mpc_result.xs[1:] + [self.mpc_result.xs[-1]]
-            self.us_init = self.mpc_result.us[1:] + [self.mpc_result.us[-1]]
+
+            self.guess["xs"] = self.mpc_result.xs[1:] + [self.mpc_result.xs[-1]]
+            self.guess["us"] = self.mpc_result.us[1:] + [self.mpc_result.us[-1]]
+            self.guess["acs"] = self.mpc_result.acs[1:] + [self.mpc_result.acs[-1]]
+            self.guess["us"] = self.mpc_result.fs[1:] + [self.mpc_result.fs[-1]]
 
         t_send = time.time()
         self.t_send = t_send - t_mpc
@@ -464,7 +470,8 @@ class Controller:
 
         self.q[:3] = self.estimator.get_q_estimate()[:3]
         self.q[6:] = self.estimator.get_q_estimate()[7:]
-        self.q[3:6] = quaternionToRPY(self.estimator.get_q_estimate()[3:7]).ravel()
+        self.q[3:6] = quaternionToRPY(
+            self.estimator.get_q_estimate()[3:7]).ravel()
         self.v = self.estimator.get_v_reference()
 
         return oRh, hRb, oTh
@@ -505,7 +512,7 @@ class Controller:
         feedforward torque
         """
         q0 = m["x_m"].copy()[: 19]
-        v0 = m["x_m"].copy()[19 :]
+        v0 = m["x_m"].copy()[19:]
         tau = np.concatenate([np.zeros(6), self.result.tau_ff.copy()])
 
         a = pin.aba(self.pd.model, self.pd.rdata, q0, v0, tau)
diff --git a/python/quadruped_reactive_walking/WB_MPC_Wrapper.py b/python/quadruped_reactive_walking/WB_MPC_Wrapper.py
index f86c33a7..99aaf0e4 100644
--- a/python/quadruped_reactive_walking/WB_MPC_Wrapper.py
+++ b/python/quadruped_reactive_walking/WB_MPC_Wrapper.py
@@ -55,7 +55,7 @@ class MPC_Wrapper:
         self.last_available_result = Result(pd)
         self.new_result = Value("b", False)
 
-    def solve(self, k, x0, tasks, gait, xs=None, us=None):
+    def solve(self, k, x0, tasks, gait, guess={}):
         """
         Call either the asynchronous MPC or the synchronous MPC depending on the value
         of multiprocessing during the creation of the wrapper
@@ -66,7 +66,7 @@ class MPC_Wrapper:
         if self.multiprocessing:
             self.run_MPC_asynchronous(k, x0, tasks, gait, xs, us)
         else:
-            self.run_MPC_synchronous(x0, tasks, gait, xs, us)
+            self.run_MPC_synchronous(x0, tasks, gait, guess)
 
     def get_latest_result(self):
         """
@@ -91,11 +91,11 @@ class MPC_Wrapper:
 
         return self.last_available_result
 
-    def run_MPC_synchronous(self, x0, tasks, gait, xs, us):
+    def run_MPC_synchronous(self, x0, tasks, gait, guess):
         """
         Run the MPC (synchronous version)
         """
-        self.ocp.solve(x0, tasks, gait, xs, us)
+        self.ocp.solve(x0, tasks, gait, guess)
         (
             self.last_available_result.xs,
             self.last_available_result.us,
-- 
GitLab