From 5fa71e4c18e77dda278684b82ce27652ee1d32c6 Mon Sep 17 00:00:00 2001 From: Justin Carpentier <justin.carpentier@inria.fr> Date: Tue, 21 Feb 2023 17:21:59 +0100 Subject: [PATCH] core: fix check_swap --- include/eigenpy/eigen-allocator.hpp | 44 ++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/include/eigenpy/eigen-allocator.hpp b/include/eigenpy/eigen-allocator.hpp index cdcec2c0..91edf0cd 100644 --- a/include/eigenpy/eigen-allocator.hpp +++ b/include/eigenpy/eigen-allocator.hpp @@ -91,20 +91,44 @@ struct init_tensor { #endif template <typename MatType> -bool check_swap(PyArrayObject *pyArray, const Eigen::MatrixBase<MatType> &mat) { - if (PyArray_NDIM(pyArray) == 0) return false; - if (mat.rows() == PyArray_DIMS(pyArray)[0]) - return false; - else - return true; +struct check_swap_impl_matrix; + +template <typename EigenType, + typename BaseType = typename get_eigen_base_type<EigenType>::type> +struct check_swap_impl; + +template <typename MatType> +struct check_swap_impl<MatType, Eigen::MatrixBase<MatType> > + : check_swap_impl_matrix<MatType> {}; + +template <typename MatType> +struct check_swap_impl_matrix { + static bool run(PyArrayObject *pyArray, + const Eigen::MatrixBase<MatType> &mat) { + if (PyArray_NDIM(pyArray) == 0) return false; + if (mat.rows() == PyArray_DIMS(pyArray)[0]) + return false; + else + return true; + } +}; + +template <typename EigenType> +bool check_swap(PyArrayObject *pyArray, const EigenType &mat) { + return check_swap_impl<EigenType>::run(pyArray, mat); } #ifdef EIGENPY_WITH_TENSOR_SUPPORT template <typename TensorType> -bool check_swap(PyArrayObject * /*pyArray*/, - const Eigen::TensorBase<TensorType> & /*tensor*/) { - return false; -} +struct check_swap_impl_tensor { + static bool run(PyArrayObject * /*pyArray*/, const TensorType & /*tensor*/) { + return false; + } +}; + +template <typename TensorType> +struct check_swap_impl<TensorType, Eigen::TensorBase<TensorType> > + : check_swap_impl_tensor<TensorType> {}; #endif template <typename Scalar, typename NewScalar, -- GitLab