From 1bdc905260ccc7aa2b7f6232ffa10887bffccad0 Mon Sep 17 00:00:00 2001
From: Pierre Fernbach <pierre.fernbach@laas.fr>
Date: Tue, 17 Dec 2019 10:57:45 +0100
Subject: [PATCH] [Python][Tests] remove all unecessary reshape(-1,1) in python
 Solve #19

---
 python/test/test.py | 34 +++++++++++++++++-----------------
 1 file changed, 17 insertions(+), 17 deletions(-)

diff --git a/python/test/test.py b/python/test/test.py
index 10783d8..a7dc241 100644
--- a/python/test/test.py
+++ b/python/test/test.py
@@ -260,20 +260,20 @@ class TestCurves(unittest.TestCase):
         min = 1.
         max = 2.5
         # a reshape is required as the inputs must be of shape (n,1) and not (n,)
-        # p0.reshape(-1,1) is equivalent to p0.reshape(len(p0),1)
-        polC0 = polynomial(p0.reshape(-1,1), p1.reshape(-1,1), min, max)
+        # p0 is equivalent to p0.reshape(len(p0),1)
+        polC0 = polynomial(p0, p1, min, max)
         self.assertEqual(polC0.min(), min)
         self.assertEqual(polC0.max(), max)
         # TODO: Why are thoso `.T[0]` needed ?
         self.assertTrue(array_equal(polC0((min + max) / 2.), 0.5 * p0 + 0.5 * p1))
-        polC1 = polynomial(p0.reshape(-1,1), dp0.reshape(-1,1), p1.reshape(-1,1), dp1.reshape(-1,1), min, max)
+        polC1 = polynomial(p0, dp0, p1, dp1, min, max)
         self.assertEqual(polC1.min(), min)
         self.assertEqual(polC1.max(), max)
         self.assertTrue(isclose(polC1(min), p0).all())
         self.assertTrue(isclose(polC1(max), p1).all())
         self.assertTrue(isclose(polC1.derivate(min, 1), dp0).all())
         self.assertTrue(isclose(polC1.derivate(max, 1), dp1).all())
-        polC2 = polynomial(p0.reshape(-1,1), dp0.reshape(-1,1), ddp0.reshape(-1,1), p1.reshape(-1,1), dp1.reshape(-1,1), ddp1.reshape(-1,1), min, max)
+        polC2 = polynomial(p0, dp0, ddp0, p1, dp1, ddp1, min, max)
         self.assertEqual(polC2.min(), min)
         self.assertEqual(polC2.max(), max)
         self.assertTrue(isclose(polC2(min), p0).all())
@@ -284,13 +284,13 @@ class TestCurves(unittest.TestCase):
         self.assertTrue(isclose(polC2.derivate(max, 2), ddp1).all())
         # check that the exception are correctly raised :
         with self.assertRaises(ValueError):
-            polC0 = polynomial(p0.reshape(-1,1), p1.reshape(-1,1), max, min)
+            polC0 = polynomial(p0, p1, max, min)
 
         with self.assertRaises(ValueError):
-            polC1 = polynomial(p0.reshape(-1,1), dp0.reshape(-1,1), p1.reshape(-1,1), dp1.reshape(-1,1), max, min)
+            polC1 = polynomial(p0, dp0, p1, dp1, max, min)
 
         with self.assertRaises(ValueError):
-            polC2 = polynomial(p0.reshape(-1,1), dp0.reshape(-1,1), ddp0.reshape(-1,1), p1.reshape(-1,1), dp1.reshape(-1,1), ddp1.reshape(-1,1), max, min)
+            polC2 = polynomial(p0, dp0, ddp0, p1, dp1, ddp1, max, min)
 
     def test_cubic_hermite_spline(self):
         print("test_cubic_hermite_spline")
@@ -366,17 +366,17 @@ class TestCurves(unittest.TestCase):
         end_point1 = array([1.,3.,5.,6.5,-2.])
         max1 = 2.5
         with self.assertRaises(RuntimeError): # cannot add final point in an empty curve
-          pc.append(end_point1.reshape(-1,1),max1)
+          pc.append(end_point1,max1)
         with self.assertRaises(ValueError):# a and end_point1 doesn't have the same dimension
           pc.append(a)
-          pc.append(end_point1.reshape(-1,1),max1)
+          pc.append(end_point1,max1)
 
         pc = piecewise_polynomial_curve()
         d = polynomial(waypoints3, 0., 1.2)
         self.assertEqual(pc.num_curves(),0)
         pc.append(d)
         self.assertEqual(pc.num_curves(),1)
-        pc.append(end_point1.reshape(-1,1),max1)
+        pc.append(end_point1,max1)
         self.assertEqual(pc.num_curves(),2)
         self.assertEqual(pc.min(),0.)
         self.assertEqual(pc.max(),max1)
@@ -595,7 +595,7 @@ class TestCurves(unittest.TestCase):
       # add another curve :
       end_pos2 = array([-2,0.2,1.6])
       max2 = 2.7
-      se3_2 = SE3Curve(translation(max).reshape(-1,1),end_pos2.reshape(-1,1),end_rot,end_rot,max,max2)
+      se3_2 = SE3Curve(translation(max),end_pos2,end_rot,end_rot,max,max2)
       pc.append(se3_2)
       self.assertEqual(pc.num_curves(),2)
       pmin2 = pc(max)
@@ -683,7 +683,7 @@ class TestCurves(unittest.TestCase):
         end_translation = array([-17., 3.7, 1.])
         end_pose = SE3.Identity()
         end_pose.rotation = end_rot
-        end_pose.translation = end_translation.reshape(-1,1)
+        end_pose.translation = end_translation
         max3 = 6.5
         pc.append(end_pose,max3)
         self.assertEqual(pc.num_curves(),3)
@@ -714,8 +714,8 @@ class TestCurves(unittest.TestCase):
             end_pose = SE3.Identity()
             init_pose.rotation = init_rot
             end_pose.rotation = end_rot
-            init_pose.translation = init_translation.reshape(-1,1)
-            end_pose.translation = end_translation.reshape(-1,1)
+            init_pose.translation = init_translation
+            end_pose.translation = end_translation
             min = 0.7
             max = 12.
             se3 = SE3Curve(init_pose, end_pose, min, max)
@@ -735,7 +735,7 @@ class TestCurves(unittest.TestCase):
             end_translation2 = array([-2., 1.6, -14.])
             end_pose2 = SE3.Identity()
             end_pose2.rotation = end_rot
-            end_pose2.translation = end_translation2.reshape(-1,1)
+            end_pose2.translation = end_translation2
             max2 = 23.9
             se3_2 = SE3Curve(end_pose, end_pose2, max, max2)
             pc.append(se3_2)
@@ -1088,8 +1088,8 @@ class TestCurves(unittest.TestCase):
             end_pose = SE3.Identity()
             init_pose.rotation = init_rot
             end_pose.rotation = end_rot
-            init_pose.translation = init_translation.reshape(-1,1)
-            end_pose.translation = end_translation.reshape(-1,1)
+            init_pose.translation = init_translation
+            end_pose.translation = end_translation
             min = 0.7
             max = 12.
             se3 = SE3Curve(init_pose, end_pose, min, max)
-- 
GitLab