Skip to content
Snippets Groups Projects
Verified Commit ce3ad460 authored by Justin Carpentier's avatar Justin Carpentier
Browse files

core: fix issue with Eigen::TenosrBase on older Eigen versions

parent 5fa71e4c
No related branches found
No related tags found
No related merge requests found
...@@ -131,7 +131,29 @@ struct check_swap_impl<TensorType, Eigen::TensorBase<TensorType> > ...@@ -131,7 +131,29 @@ struct check_swap_impl<TensorType, Eigen::TensorBase<TensorType> >
: check_swap_impl_tensor<TensorType> {}; : check_swap_impl_tensor<TensorType> {};
#endif #endif
// template <typename MatType>
// struct cast_impl_matrix;
//
// template <typename EigenType,
// typename BaseType = typename get_eigen_base_type<EigenType>::type>
// struct cast_impl;
//
// template <typename MatType>
// struct cast_impl<MatType, Eigen::MatrixBase<MatType> >
// : cast_impl_matrix<MatType> {};
//
// template <typename MatType>
// struct cast_impl_matrix
//{
// template <typename NewScalar, typename MatrixIn, typename MatrixOut>
// static void run(const Eigen::MatrixBase<MatrixIn> &input,
// const Eigen::MatrixBase<MatrixOut> &dest) {
// dest.const_cast_derived() = input.template cast<NewScalar>();
// }
// };
template <typename Scalar, typename NewScalar, template <typename Scalar, typename NewScalar,
template <typename D> class EigenBase = Eigen::MatrixBase,
bool cast_is_valid = FromTypeToType<Scalar, NewScalar>::value> bool cast_is_valid = FromTypeToType<Scalar, NewScalar>::value>
struct cast { struct cast {
template <typename MatrixIn, typename MatrixOut> template <typename MatrixIn, typename MatrixOut>
...@@ -139,34 +161,26 @@ struct cast { ...@@ -139,34 +161,26 @@ struct cast {
const Eigen::MatrixBase<MatrixOut> &dest) { const Eigen::MatrixBase<MatrixOut> &dest) {
dest.const_cast_derived() = input.template cast<NewScalar>(); dest.const_cast_derived() = input.template cast<NewScalar>();
} }
};
#ifdef EIGENPY_WITH_TENSOR_SUPPORT #ifdef EIGENPY_WITH_TENSOR_SUPPORT
template <typename Scalar, typename NewScalar>
struct cast<Scalar, NewScalar, Eigen::TensorRef, true> {
template <typename TensorIn, typename TensorOut> template <typename TensorIn, typename TensorOut>
static void run(const Eigen::TensorBase<TensorIn> &input, static void run(const TensorIn &input, TensorOut &dest) {
const Eigen::TensorBase<TensorOut> &dest) { dest = input.template cast<NewScalar>();
const_cast<TensorOut &>(static_cast<const TensorOut &>(dest)) =
input.template cast<NewScalar>();
} }
#endif
}; };
#endif
template <typename Scalar, typename NewScalar> template <typename Scalar, typename NewScalar,
struct cast<Scalar, NewScalar, false> { template <typename D> class EigenBase>
struct cast<Scalar, NewScalar, EigenBase, false> {
template <typename MatrixIn, typename MatrixOut> template <typename MatrixIn, typename MatrixOut>
static void run(const Eigen::MatrixBase<MatrixIn> & /*input*/, static void run(const MatrixIn /*input*/, const MatrixOut /*dest*/) {
const Eigen::MatrixBase<MatrixOut> & /*dest*/) {
// do nothing // do nothing
assert(false && "Must never happened"); assert(false && "Must never happened");
} }
#ifdef EIGENPY_WITH_TENSOR_SUPPORT
template <typename TensorIn, typename TensorOut>
static void run(const Eigen::TensorBase<TensorIn> & /*input*/,
const Eigen::TensorBase<TensorOut> & /*dest*/) {
// do nothing
assert(false && "Must never happened");
}
#endif
}; };
} // namespace details } // namespace details
...@@ -358,12 +372,15 @@ struct eigen_allocator_impl_tensor { ...@@ -358,12 +372,15 @@ struct eigen_allocator_impl_tensor {
copy(pyArray, tensor); copy(pyArray, tensor);
} }
#define EIGENPY_CAST_FROM_PYARRAY_TO_EIGEN_TENSOR(TensorType, Scalar, \ #define EIGENPY_CAST_FROM_PYARRAY_TO_EIGEN_TENSOR(TensorType, Scalar, \
NewScalar, pyArray, tensor) \ NewScalar, pyArray, tensor) \
details::cast<Scalar, NewScalar>::run( \ { \
NumpyMap<TensorType, Scalar>::map(pyArray, \ typename NumpyMap<TensorType, Scalar>::EigenMap pyArray_map = \
details::check_swap(pyArray, tensor)), \ NumpyMap<TensorType, Scalar>::map( \
tensor) pyArray, details::check_swap(pyArray, tensor)); \
details::cast<Scalar, NewScalar, Eigen::TensorRef>::run(pyArray_map, \
tensor); \
}
/// \brief Copy Python array into the input matrix mat. /// \brief Copy Python array into the input matrix mat.
template <typename TensorDerived> template <typename TensorDerived>
...@@ -417,9 +434,13 @@ struct eigen_allocator_impl_tensor { ...@@ -417,9 +434,13 @@ struct eigen_allocator_impl_tensor {
#define EIGENPY_CAST_FROM_EIGEN_TENSOR_TO_PYARRAY(TensorType, Scalar, \ #define EIGENPY_CAST_FROM_EIGEN_TENSOR_TO_PYARRAY(TensorType, Scalar, \
NewScalar, tensor, pyArray) \ NewScalar, tensor, pyArray) \
details::cast<Scalar, NewScalar>::run( \ { \
tensor, NumpyMap<TensorType, NewScalar>::map( \ typename NumpyMap<TensorType, NewScalar>::EigenMap pyArray_map = \
pyArray, details::check_swap(pyArray, tensor))) NumpyMap<TensorType, NewScalar>::map( \
pyArray, details::check_swap(pyArray, tensor)); \
details::cast<Scalar, NewScalar, Eigen::TensorRef>::run(tensor, \
pyArray_map); \
}
/// \brief Copy mat into the Python array using Eigen::Map /// \brief Copy mat into the Python array using Eigen::Map
static void copy(const TensorType &tensor, PyArrayObject *pyArray) { static void copy(const TensorType &tensor, PyArrayObject *pyArray) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment