From 0bbd6364e35061e87e2ccabb6a951e5ed752e2e0 Mon Sep 17 00:00:00 2001
From: cbusato <busatoclement@gmail.com>
Date: Mon, 27 Feb 2023 12:06:46 +0100
Subject: [PATCH] Switch beack from xs optimization to dxs optimization

---
 .../WB_MPC/CasadiOCP.py                       | 24 +++++++++++--------
 1 file changed, 14 insertions(+), 10 deletions(-)

diff --git a/python/quadruped_reactive_walking/WB_MPC/CasadiOCP.py b/python/quadruped_reactive_walking/WB_MPC/CasadiOCP.py
index cb1ac42b..2f196596 100644
--- a/python/quadruped_reactive_walking/WB_MPC/CasadiOCP.py
+++ b/python/quadruped_reactive_walking/WB_MPC/CasadiOCP.py
@@ -163,15 +163,17 @@ class OCP:
         print("-- End of solve() --")
         """from IPython import embed
         embed()"""
-    
+
     def my_debug(self, i):
         print("---" + str(i) + "---")
         for j in range(4):
             print(self.opti.debug.value(self.datas[0].f[j]))
 
         for cost in self.datas[0].costs:
-            self.log_costs[cost].append([self.opti.value(data.costs[cost]) for data in self.datas])
-    
+            self.log_costs[cost].append(
+                [self.opti.value(data.costs[cost]) for data in self.datas]
+            )
+
         """for data in self.datas:
             for cost in data.costs:
                 self.log_costs[cost].append(self.opti.value(data.costs[cost]))"""
@@ -210,7 +212,10 @@ class OCP:
         ]
         self.acs = [opti.variable(self.pd.nv) for _ in self.runningModels]
         self.us = [opti.variable(self.pd.nu) for _ in self.runningModels]
-        self.xs = [opti.variable(self.pd.nx) for _ in (self.runningModels + [self.terminalModel])]
+        self.xs = [
+            m.integrate(x0, dx)
+            for m, dx in zip(self.runningModels + [self.terminalModel], self.dxs)
+        ]
         self.fs = []
         for m in self.runningModels:
             f_tmp = [opti.variable(3) for _ in range(len(m.contactIds))]
@@ -234,10 +239,11 @@ class OCP:
         eq = []
         # First running node is in initial state of the robot
         eq.append(self.runningModels[0].difference(self.xs[0], x0))
+        # Could also be eq.append(self.dxs[0]) to force first dxs to 0?
 
         # Set targets in xref of ProblemData
-        self.pd.xref[self.pd.nq:(self.pd.nq + 3)] = targets[0]
-        self.pd.xref[(self.pd.nq + 3):(self.pd.nq + 6)] = targets[1]
+        self.pd.xref[self.pd.nq : (self.pd.nq + 3)] = targets[0]
+        self.pd.xref[(self.pd.nq + 3) : (self.pd.nq + 6)] = targets[1]
 
         # Gather costs and equality constraints for each running node
         for t in range(self.pd.T):
@@ -310,8 +316,8 @@ class OCP:
                     ]
                 )
 
-            for x, xg in zip(self.xs, xs_g):
-                self.opti.set_initial(x, xg)
+            for x, xg in zip(self.dxs, xs_g):
+                self.opti.set_initial(x, xdiff(x0, xg))
             for a, ag in zip(self.acs, acs_g):
                 self.opti.set_initial(a, ag)
             for u, ug in zip(self.us, us_g):
@@ -324,8 +330,6 @@ class OCP:
             print("Got warm start")
         except:
             print("Can't load warm start")
-            for x in self.xs:
-                self.opti.set_initial(x, x0)
 
     def get_results(self):
         xs_sol = [self.opti.value(x) for x in self.xs]
-- 
GitLab