From f9e0fed9d975aa4ec2d5e7b5ed6ee47efdbf2f27 Mon Sep 17 00:00:00 2001
From: Ale <alessandroassirell98@gmail.com>
Date: Mon, 11 Jul 2022 16:34:21 +0200
Subject: [PATCH] working on gait shift

---
 .../WB_MPC/CrocoddylOCP.py                    | 97 ++++++++++++-------
 1 file changed, 61 insertions(+), 36 deletions(-)

diff --git a/python/quadruped_reactive_walking/WB_MPC/CrocoddylOCP.py b/python/quadruped_reactive_walking/WB_MPC/CrocoddylOCP.py
index f4c192b4..68d70b60 100644
--- a/python/quadruped_reactive_walking/WB_MPC/CrocoddylOCP.py
+++ b/python/quadruped_reactive_walking/WB_MPC/CrocoddylOCP.py
@@ -14,15 +14,15 @@ class OCP:
         
         self.results = OcpResult()
         self.state = crocoddyl.StateMultibody(self.pd.model)
+        self.initialized = False
         self.initialize_models()
     
     def initialize_models(self):
         self.models = []
         for _ in range(self.pd.T):
             self.models.append(Model(self.pd, self.state)) # RunningModels
-        self.models.append(Model(self.pd, self.state, isTerminal=True)) #TerminalModel
+        self.models.append(Model(self.pd, self.state, isTerminal=True)) # TerminalModel
         
-
     def make_ocp(self, x0):
         """ Create a shooting problem for a simple walking gait.
 
@@ -34,6 +34,9 @@ class OCP:
         pin.forwardKinematics(self.pd.model, self.pd.rdata, q0)
         pin.updateFramePlacements(self.pd.model, self.pd.rdata)
 
+        if self.initialized:
+            self.models = self.models[1:] + [self.models[-2]]
+
         for t in range(self.pd.T):
             target = self.target.evaluate_in_t(t)
             freeIds = [idf for idf in self.pd.allContactIds if idf not in self.target.contactSequence[t]]
@@ -41,15 +44,16 @@ class OCP:
 
         freeIds = [idf for idf in self.pd.allContactIds if idf not in self.target.contactSequence[self.pd.T]]
         #contactIds = self.target.contactSequence[self.pd.T]
-        self.appendTargetToModel(self.models[self.pd.T], self.target.evaluate_in_t(self.pd.T), freeIds)
+        self.appendTargetToModel(self.models[self.pd.T], self.target.evaluate_in_t(self.pd.T), freeIds, True)
 
         problem = crocoddyl.ShootingProblem(x0, 
                                             [m.model for m in self.models[:-1]], 
                                             self.models[-1].model)
+        self.initialized = True
 
         return problem
 
-    def appendTargetToModel(self, model, target, swingFootIds):
+    def appendTargetToModel(self, model, target, swingFootIds, isTerminal=False):
         """ Action models for a footstep phase.
         :param numKnots: number of knots for the footstep phase
         :param supportFootIds: Ids of the supporting feet
@@ -65,7 +69,7 @@ class OCP:
             except:
                 pass
         
-        model.tracking_cost(swingFootTask)
+        model.update_model(swingFootTask, isTerminal=isTerminal)
 
 
 # Solve
@@ -143,6 +147,10 @@ class Model:
         self.nu = self.actuation.nu
 
         self.createStandardModel()
+        if isTerminal:
+            self.make_terminal_model()
+        else:
+            self.make_running_model()
 
     def createStandardModel(self):
         """ Action model for a swing foot phase.
@@ -153,10 +161,6 @@ class Model:
         :param swingFootTask: swinging foot task
         :return action model for a swing foot phase
         """
-        # Creating a 3D multi-contact model, and then including the supporting
-        # foot
-
-        
         self.contactModel = crocoddyl.ContactModelMultiple(self.state, self.nu)
         for i in self.supportFootIds:
             supportContactModel = crocoddyl.ContactModel3D(self.state, i, np.array([0., 0., 0.]), self.nu,
@@ -166,44 +170,59 @@ class Model:
         # Creating the cost model for a contact phase
         costModel = crocoddyl.CostModelSum(self.state, self.nu)
 
-        if not self.isTerminal:
-
-            for i in self.supportFootIds:
-                cone = crocoddyl.FrictionCone(self.pd.Rsurf, self.pd.mu, 4, False)
-                coneResidual = crocoddyl.ResidualModelContactFrictionCone(self.state, i, cone, self.nu)
-                coneActivation = crocoddyl.ActivationModelQuadraticBarrier(crocoddyl.ActivationBounds(cone.lb, cone.ub))
-                frictionCone = crocoddyl.CostModelResidual(self.state, coneActivation, coneResidual)
-                costModel.addCost(self.pd.model.frames[i].name + "_frictionCone", frictionCone, self.pd.friction_cone_w)
-
-            ctrlResidual = crocoddyl.ResidualModelControl(self.state, self.pd.uref)
-            ctrlReg = crocoddyl.CostModelResidual(self.state, ctrlResidual)
-            costModel.addCost("ctrlReg", ctrlReg, self.pd.control_reg_w)
-
-            ctrl_bound_residual = crocoddyl.ResidualModelControl(self.state, self.nu)
-            ctrl_bound_activation = crocoddyl.ActivationModelQuadraticBarrier(crocoddyl.ActivationBounds(-self.pd.effort_limit, self.pd.effort_limit))
-            ctrl_bound = crocoddyl.CostModelResidual(self.state, ctrl_bound_activation, ctrl_bound_residual)
-            costModel.addCost("ctrlBound", ctrl_bound, self.pd.control_bound_w)
-
         stateResidual = crocoddyl.ResidualModelState(self.state, self.pd.xref, self.nu)
         stateActivation = crocoddyl.ActivationModelWeightedQuad(self.pd.state_reg_w**2)
         stateReg = crocoddyl.CostModelResidual(self.state, stateActivation, stateResidual)
         costModel.addCost("stateReg", stateReg, 1)
 
-        if self.isTerminal:
-            stateResidual = crocoddyl.ResidualModelState(self.state, self.pd.xref, self.nu)
-            stateActivation = crocoddyl.ActivationModelWeightedQuad(self.pd.terminal_velocity_w**2)
-            stateReg = crocoddyl.CostModelResidual(self.state, stateActivation, stateResidual)
-            costModel.addCost("terminalVelocity", stateReg, 1)
-
         self.costModel = costModel
 
         self.dmodel = crocoddyl.DifferentialActionModelContactFwdDynamics(self.state, self.actuation, self.contactModel,
                                                                     self.costModel, 0., True)
         self.model = crocoddyl.IntegratedActionModelEuler(self.dmodel, self.control, self.pd.dt)
+    
+    def make_terminal_model(self):
+        self.remove_running_costs()  
+        self.isTerminal=True
+        stateResidual = crocoddyl.ResidualModelState(self.state, self.pd.xref, self.nu)
+        stateActivation = crocoddyl.ActivationModelWeightedQuad(self.pd.terminal_velocity_w**2)
+        stateReg = crocoddyl.CostModelResidual(self.state, stateActivation, stateResidual)
+        self.costModel.addCost("terminalVelocity", stateReg, 1)
 
+    def make_running_model(self):
+        self.remove_terminal_cost()
+
+        self.isTerminal = False
+        for i in self.supportFootIds:
+            cone = crocoddyl.FrictionCone(self.pd.Rsurf, self.pd.mu, 4, False)
+            coneResidual = crocoddyl.ResidualModelContactFrictionCone(self.state, i, cone, self.nu)
+            coneActivation = crocoddyl.ActivationModelQuadraticBarrier(crocoddyl.ActivationBounds(cone.lb, cone.ub))
+            frictionCone = crocoddyl.CostModelResidual(self.state, coneActivation, coneResidual)
+            self.costModel.addCost(self.pd.model.frames[i].name + "_frictionCone", frictionCone, self.pd.friction_cone_w)
+
+        ctrlResidual = crocoddyl.ResidualModelControl(self.state, self.pd.uref)
+        ctrlReg = crocoddyl.CostModelResidual(self.state, ctrlResidual)
+        self.costModel.addCost("ctrlReg", ctrlReg, self.pd.control_reg_w)
+
+        ctrl_bound_residual = crocoddyl.ResidualModelControl(self.state, self.nu)
+        ctrl_bound_activation = crocoddyl.ActivationModelQuadraticBarrier(crocoddyl.ActivationBounds(-self.pd.effort_limit, self.pd.effort_limit))
+        ctrl_bound = crocoddyl.CostModelResidual(self.state, ctrl_bound_activation, ctrl_bound_residual)
+        self.costModel.addCost("ctrlBound", ctrl_bound, self.pd.control_bound_w)
+    
+    def remove_running_costs(self):
+        runningCosts = self.dmodel.costs.active.tolist()
+        idx = runningCosts.index("stateReg")
+        runningCosts.pop(idx)
+        for cost in runningCosts:
+            if cost in self.dmodel.costs.active.tolist():
+                self.dmodel.costs.removeCost(cost)
+
+    def remove_terminal_cost(self):
+        if "terminalVelocity" in self.dmodel.costs.active.tolist():
+                self.dmodel.costs.removeCost("terminalVelocity")
 
     def tracking_cost(self, swingFootTask):
-        if swingFootTask is not None and not self.isTerminal:
+        if swingFootTask is not None:
             for i in swingFootTask:
                 frameTranslationResidual = crocoddyl.ResidualModelFrameTranslation(self.state, i[0], i[1].translation,self.nu)
                 footTrack = crocoddyl.CostModelResidual(self.state, frameTranslationResidual)
@@ -211,5 +230,11 @@ class Model:
                     self.dmodel.costs.removeCost(self.pd.model.frames[i[0]].name + "_footTrack")
                 self.costModel.addCost(self.pd.model.frames[i[0]].name + "_footTrack", footTrack, self.pd.foot_tracking_w)
 
-
-
+    def update_model(self, swingFootTask=[], isTerminal = False):
+        if isTerminal:
+            self.make_terminal_model()
+        elif self.isTerminal:
+            self.make_running_model()
+            self.tracking_cost(swingFootTask)
+        else:
+            self.tracking_cost(swingFootTask)
-- 
GitLab