From f2053d74dce2a59da76f25d0971ebf92f573ef7e Mon Sep 17 00:00:00 2001
From: Justin Carpentier <justin.carpentier@inria.fr>
Date: Sun, 19 Feb 2023 19:52:18 +0100
Subject: [PATCH] core: start reworking EigenFromPyConverter in preparation for
 Eigen::Tensor

---
 include/eigenpy/eigen-from-python.hpp | 76 +++++++++++++++++++++------
 1 file changed, 59 insertions(+), 17 deletions(-)

diff --git a/include/eigenpy/eigen-from-python.hpp b/include/eigenpy/eigen-from-python.hpp
index f03ae26..27f066c 100644
--- a/include/eigenpy/eigen-from-python.hpp
+++ b/include/eigenpy/eigen-from-python.hpp
@@ -12,18 +12,27 @@
 
 namespace eigenpy {
 
-template <typename C>
+template <typename EigenType,
+          typename BaseType = typename get_eigen_base_type<EigenType>::type>
 struct expected_pytype_for_arg {};
 
-template <typename Scalar, int Rows, int Cols, int Options, int MaxRows,
-          int MaxCols>
-struct expected_pytype_for_arg<
-    Eigen::Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> > {
+template <typename MatType>
+struct expected_pytype_for_arg<MatType, Eigen::MatrixBase<MatType> > {
+  static PyTypeObject const *get_pytype() {
+    PyTypeObject const *py_type = eigenpy::getPyArrayType();
+    return py_type;
+  }
+};
+
+#ifdef EIGENPY_WITH_TENSOR_SUPPORT
+template <typename TensorType>
+struct expected_pytype_for_arg<TensorType, Eigen::TensorBase<TensorType> > {
   static PyTypeObject const *get_pytype() {
     PyTypeObject const *py_type = eigenpy::getPyArrayType();
     return py_type;
   }
 };
+#endif
 
 }  // namespace eigenpy
 
@@ -38,6 +47,13 @@ struct expected_pytype_for_arg<
     : eigenpy::expected_pytype_for_arg<
           Eigen::Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> > {};
 
+#ifdef EIGENPY_WITH_TENSOR_SUPPORT
+template <typename Scalar, int Rank, int Options, typename IndexType>
+struct expected_pytype_for_arg<Eigen::Tensor<Scalar, Rank, Options, IndexType> >
+    : eigenpy::expected_pytype_for_arg<
+          Eigen::Tensor<Scalar, Rank, Options, IndexType> > {};
+#endif
+
 }  // namespace converter
 }  // namespace python
 }  // namespace boost
@@ -269,8 +285,23 @@ void eigen_from_py_construct(
   memory->convertible = storage->storage.bytes;
 }
 
-template <typename MatType, typename _Scalar>
-struct EigenFromPy {
+template <typename EigenType,
+          typename BaseType = typename get_eigen_base_type<EigenType>::type>
+struct eigen_from_py_impl {
+  typedef typename EigenType::Scalar Scalar;
+
+  /// \brief Determine if pyObj can be converted into a MatType object
+  static void *convertible(PyObject *pyObj);
+
+  /// \brief Allocate memory and copy pyObj in the new storage
+  static void construct(PyObject *pyObj,
+                        bp::converter::rvalue_from_python_stage1_data *memory);
+
+  static void registration();
+};
+
+template <typename MatType>
+struct eigen_from_py_impl<MatType, Eigen::MatrixBase<MatType> > {
   typedef typename MatType::Scalar Scalar;
 
   /// \brief Determine if pyObj can be converted into a MatType object
@@ -283,8 +314,12 @@ struct EigenFromPy {
   static void registration();
 };
 
-template <typename MatType, typename _Scalar>
-void *EigenFromPy<MatType, _Scalar>::convertible(PyObject *pyObj) {
+template <typename EigenType, typename _Scalar>
+struct EigenFromPy : eigen_from_py_impl<EigenType> {};
+
+template <typename MatType>
+void *eigen_from_py_impl<MatType, Eigen::MatrixBase<MatType> >::convertible(
+    PyObject *pyObj) {
   if (!call_PyArray_Check(reinterpret_cast<PyObject *>(pyObj))) return 0;
 
   PyArrayObject *pyArray = reinterpret_cast<PyArrayObject *>(pyObj);
@@ -384,26 +419,33 @@ void *EigenFromPy<MatType, _Scalar>::convertible(PyObject *pyObj) {
   return pyArray;
 }
 
-template <typename MatType, typename _Scalar>
-void EigenFromPy<MatType, _Scalar>::construct(
+template <typename MatType>
+void eigen_from_py_impl<MatType, Eigen::MatrixBase<MatType> >::construct(
     PyObject *pyObj, bp::converter::rvalue_from_python_stage1_data *memory) {
   eigen_from_py_construct<MatType>(pyObj, memory);
 }
 
-template <typename MatType, typename _Scalar>
-void EigenFromPy<MatType, _Scalar>::registration() {
+template <typename MatType>
+void eigen_from_py_impl<MatType, Eigen::MatrixBase<MatType> >::registration() {
   bp::converter::registry::push_back(
-      reinterpret_cast<void *(*)(_object *)>(&EigenFromPy::convertible),
-      &EigenFromPy::construct, bp::type_id<MatType>()
+      reinterpret_cast<void *(*)(_object *)>(&eigen_from_py_impl::convertible),
+      &eigen_from_py_impl::construct, bp::type_id<MatType>()
 #ifndef BOOST_PYTHON_NO_PY_SIGNATURES
-                                   ,
+                                          ,
       &eigenpy::expected_pytype_for_arg<MatType>::get_pytype
 #endif
   );
 }
 
+template <typename EigenType,
+          typename BaseType = typename get_eigen_base_type<EigenType>::type>
+struct eigen_from_py_converter_impl;
+
+template <typename EigenType>
+struct EigenFromPyConverter : eigen_from_py_converter_impl<EigenType> {};
+
 template <typename MatType>
-struct EigenFromPyConverter {
+struct eigen_from_py_converter_impl<MatType, Eigen::MatrixBase<MatType> > {
   static void registration() {
     EigenFromPy<MatType>::registration();
 
-- 
GitLab