From c4e23b7f4918761b1a001b4105e16f2ef9ab6257 Mon Sep 17 00:00:00 2001
From: Justin Carpentier <justin.carpentier@inria.fr>
Date: Fri, 15 Nov 2019 21:05:48 +0100
Subject: [PATCH] core: add initEigenObject runner

---
 include/eigenpy/details.hpp | 38 +++++++++++++++++++++++++++++++++----
 1 file changed, 34 insertions(+), 4 deletions(-)

diff --git a/include/eigenpy/details.hpp b/include/eigenpy/details.hpp
index 1cbe269f..222b361f 100644
--- a/include/eigenpy/details.hpp
+++ b/include/eigenpy/details.hpp
@@ -172,6 +172,39 @@ namespace eigenpy
     bp::object NumpyArrayObject; PyTypeObject * NumpyArrayType;
     
   };
+
+  template<typename MatType, bool IsVectorAtCompileTime = MatType::IsVectorAtCompileTime>
+  struct initEigenObject
+  {
+    static MatType * run(PyArrayObject * pyArray, void * storage)
+    {
+      assert(PyArray_NDIM(pyArray) == 2);
+
+      const int rows = (int)PyArray_DIMS(pyArray)[0];
+      const int cols = (int)PyArray_DIMS(pyArray)[1];
+      
+      return new (storage) MatType(rows,cols);
+    }
+  };
+
+  template<typename MatType>
+  struct initEigenObject<MatType,true>
+  {
+    static MatType * run(PyArrayObject * pyArray, void * storage)
+    {
+      if(PyArray_NDIM(pyArray) == 1)
+      {
+        const int rows_or_cols = (int)PyArray_DIMS(pyArray)[0];
+        return new (storage) MatType(rows_or_cols);
+      }
+      else
+      {
+        const int rows = (int)PyArray_DIMS(pyArray)[0];
+        const int cols = (int)PyArray_DIMS(pyArray)[1];
+        return new (storage) MatType(rows,cols);
+      }
+    }
+  };
   
   template<typename MatType>
   struct EigenObjectAllocator
@@ -181,10 +214,7 @@ namespace eigenpy
     
     static void allocate(PyArrayObject * pyArray, void * storage)
     {
-      const int rows = (int)PyArray_DIMS(pyArray)[0];
-      const int cols = (int)PyArray_DIMS(pyArray)[1];
-      
-      Type * mat_ptr = new (storage) Type(rows,cols);
+      Type * mat_ptr = initEigenObject<Type>::run(pyArray,storage);
       
       if(NumpyEquivalentType<Scalar>::type_code == GET_PY_ARRAY_TYPE(pyArray))
       {
-- 
GitLab