test_shooting.py 3.08 KB
Newer Older
Guilhem Saurel's avatar
Guilhem Saurel committed
1
2
import sys
import unittest
3
from random import randint
Guilhem Saurel's avatar
Guilhem Saurel committed
4

5
import numpy as np
Guilhem Saurel's avatar
Guilhem Saurel committed
6
7
8

import crocoddyl
from crocoddyl.utils import UnicycleDerived
9
10
11
12
13
14
15
16


class ShootingProblemTestCase(unittest.TestCase):
    MODEL = None
    MODEL_DER = None

    def setUp(self):
        self.T = randint(1, 101)
17
        state = self.MODEL.state
18
19
20
21
22
23
        self.xs = []
        self.us = []
        self.xs.append(state.rand())
        for i in range(self.T):
            self.xs.append(state.rand())
            self.us.append(np.matrix(np.random.rand(self.MODEL.nu)).T)
24
25
        self.PROBLEM = crocoddyl.ShootingProblem(self.xs[0], [self.MODEL] * self.T, self.MODEL)
        self.PROBLEM_DER = crocoddyl.ShootingProblem(self.xs[0], [self.MODEL_DER] * self.T, self.MODEL_DER)
26
27
28

    def test_number_of_nodes(self):
        self.assertEqual(self.T, self.PROBLEM.T, "Wrong number of nodes")
29

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    def test_calc(self):
        # Running calc functions
        cost = self.PROBLEM.calc(self.xs, self.us)
        costDer = self.PROBLEM_DER.calc(self.xs, self.us)
        self.assertAlmostEqual(cost, costDer, 10, "Wrong cost value")
        for d1, d2 in zip(self.PROBLEM.runningDatas, self.PROBLEM_DER.runningDatas):
            self.assertTrue(np.allclose(d1.xnext, d2.xnext, atol=1e-9), "Next state doesn't match.")

    def test_calcDiff(self):
        # Running calc functions
        cost = self.PROBLEM.calcDiff(self.xs, self.us)
        costDer = self.PROBLEM_DER.calcDiff(self.xs, self.us)
        self.assertAlmostEqual(cost, costDer, 10, "Wrong cost value")
        for d1, d2 in zip(self.PROBLEM.runningDatas, self.PROBLEM_DER.runningDatas):
            self.assertTrue(np.allclose(d1.xnext, d2.xnext, atol=1e-9), "Next state doesn't match.")
            self.assertTrue(np.allclose(d1.Lx, d2.Lx, atol=1e-9), "Lx doesn't match.")
            self.assertTrue(np.allclose(d1.Lu, d2.Lu, atol=1e-9), "Lu doesn't match.")
            self.assertTrue(np.allclose(d1.Lxx, d2.Lxx, atol=1e-9), "Lxx doesn't match.")
            self.assertTrue(np.allclose(d1.Lxu, d2.Lxu, atol=1e-9), "Lxu doesn't match.")
            self.assertTrue(np.allclose(d1.Luu, d2.Luu, atol=1e-9), "Luu doesn't match.")
            self.assertTrue(np.allclose(d1.Fx, d2.Fx, atol=1e-9), "Fx doesn't match.")
            self.assertTrue(np.allclose(d1.Fu, d2.Fu, atol=1e-9), "Fu doesn't match.")

    def test_rollout(self):
        xs = self.PROBLEM.rollout(self.us)
        xsDer = self.PROBLEM_DER.rollout(self.us)
        for x1, x2 in zip(xs, xsDer):
            self.assertTrue(np.allclose(x1, x2, atol=1e-9), "The rollout state doesn't match.")


class UnicycleShootingTest(ShootingProblemTestCase):
61
62
    MODEL = crocoddyl.ActionModelUnicycle()
    MODEL_DER = UnicycleDerived()
63

64

65
if __name__ == '__main__':
66
67
68
69
70
71
72
73
74
    test_classes_to_run = [UnicycleShootingTest]
    loader = unittest.TestLoader()
    suites_list = []
    for test_class in test_classes_to_run:
        suite = loader.loadTestsFromTestCase(test_class)
        suites_list.append(suite)
    big_suite = unittest.TestSuite(suites_list)
    runner = unittest.TextTestRunner()
    results = runner.run(big_suite)
75
    sys.exit(not results.wasSuccessful())