From a4ba7e8fc8d454b2e34191215e2a2364ce0c4ea9 Mon Sep 17 00:00:00 2001
From: Justin Carpentier <justin.carpentier@inria.fr>
Date: Tue, 5 May 2020 15:42:24 +0200
Subject: [PATCH] test: add test for user types

---
 unittest/CMakeLists.txt           |   2 +
 unittest/python/test_user_type.py |  31 ++++++++
 unittest/user_type.cpp            | 115 ++++++++++++++++++++++++++++++
 3 files changed, 148 insertions(+)
 create mode 100644 unittest/python/test_user_type.py
 create mode 100644 unittest/user_type.cpp

diff --git a/unittest/CMakeLists.txt b/unittest/CMakeLists.txt
index d6518c3..c1a30b0 100644
--- a/unittest/CMakeLists.txt
+++ b/unittest/CMakeLists.txt
@@ -36,12 +36,14 @@ ADD_LIB_UNIT_TEST(return_by_ref "eigen3")
 IF(NOT ${EIGEN3_VERSION} VERSION_LESS "3.2.0")
   ADD_LIB_UNIT_TEST(eigen_ref "eigen3")
 ENDIF()
+ADD_LIB_UNIT_TEST(user_type "eigen3")
 
 ADD_PYTHON_UNIT_TEST("py-matrix" "unittest/python/test_matrix.py" "unittest")
 ADD_PYTHON_UNIT_TEST("py-geometry" "unittest/python/test_geometry.py" "unittest")
 ADD_PYTHON_UNIT_TEST("py-complex" "unittest/python/test_complex.py" "unittest")
 ADD_PYTHON_UNIT_TEST("py-return-by-ref" "unittest/python/test_return_by_ref.py" "unittest")
 ADD_PYTHON_UNIT_TEST("py-eigen-ref" "unittest/python/test_eigen_ref.py" "unittest")
+ADD_PYTHON_UNIT_TEST("py-user-type" "unittest/python/test_user_type.py" "unittest")
 
 ADD_PYTHON_UNIT_TEST("py-switch" "unittest/python/test_switch.py" "python/eigenpy")
 SET_TESTS_PROPERTIES("py-switch" PROPERTIES DEPENDS ${PYWRAP})
diff --git a/unittest/python/test_user_type.py b/unittest/python/test_user_type.py
new file mode 100644
index 0000000..057cfe9
--- /dev/null
+++ b/unittest/python/test_user_type.py
@@ -0,0 +1,31 @@
+import user_type
+
+rows = 10
+cols = 20
+
+def test(mat):
+  mat.fill(mat.dtype.type(10.))
+  mat_copy = mat.copy()
+  assert (mat == mat_copy).all()
+  assert not (mat != mat_copy).all()
+
+  mat_op = mat + mat
+  mat_op = mat.copy(order='F') + mat.copy(order='C')
+  
+  mat_op = mat - mat
+  mat_op = mat * mat
+  mat_op = mat.dot(mat.T)
+  mat_op = mat / mat
+
+  mat_op = -mat;
+
+  assert (mat >= mat).all()
+  assert (mat <= mat).all()
+  assert not (mat > mat).all()
+  assert not (mat < mat).all()
+
+mat = user_type.create_double(rows,cols)
+test(mat)
+
+mat = user_type.create_float(rows,cols)
+test(mat)
diff --git a/unittest/user_type.cpp b/unittest/user_type.cpp
new file mode 100644
index 0000000..b128648
--- /dev/null
+++ b/unittest/user_type.cpp
@@ -0,0 +1,115 @@
+/*
+ * Copyright 2020 INRIA
+ */
+
+#include "eigenpy/eigenpy.hpp"
+#include "eigenpy/user-type.hpp"
+#include "eigenpy/ufunc.hpp"
+
+#include <iostream>
+#include <sstream>
+
+template<typename Scalar>
+struct CustomType
+{
+  CustomType() {}
+  
+  explicit CustomType(const Scalar & value)
+  : m_value(value)
+  {}
+  
+  CustomType operator*(const CustomType & other) const { return CustomType(m_value * other.m_value); }
+  CustomType operator+(const CustomType & other) const { return CustomType(m_value + other.m_value); }
+  CustomType operator-(const CustomType & other) const { return CustomType(m_value - other.m_value); }
+  CustomType operator/(const CustomType & other) const { return CustomType(m_value / other.m_value); }
+  
+  void operator+=(const CustomType & other) { m_value += other.m_value; }
+  void operator-=(const CustomType & other) { m_value -= other.m_value; }
+  void operator*=(const CustomType & other) { m_value *= other.m_value; }
+  void operator/=(const CustomType & other) { m_value /= other.m_value; }
+  
+  void operator=(const Scalar & value) { m_value = value; }
+  
+  bool operator==(const CustomType & other) const { return m_value == other.m_value; }
+  bool operator!=(const CustomType & other) const { return m_value != other.m_value; }
+  
+  bool operator<=(const CustomType & other) const { return m_value <= other.m_value; }
+  bool operator<(const CustomType & other) const { return m_value < other.m_value; }
+  bool operator>=(const CustomType & other) const { return m_value >= other.m_value; }
+  bool operator>(const CustomType & other) const { return m_value > other.m_value; }
+  
+  CustomType operator-() const { return CustomType(-m_value); }
+  
+  std::string print() const
+  {
+    std::stringstream ss;
+    ss << "value: " << m_value << std::endl;
+    return ss.str();
+  }
+ 
+protected:
+  
+  Scalar m_value;
+};
+
+template<typename Scalar>
+Eigen::Matrix<CustomType<Scalar>,Eigen::Dynamic,Eigen::Dynamic> create(int rows, int cols)
+{
+  typedef Eigen::Matrix<CustomType<Scalar>,Eigen::Dynamic,Eigen::Dynamic> Matrix;
+  return Matrix(rows,cols);
+}
+
+template<typename Scalar>
+Eigen::Matrix<Scalar,Eigen::Dynamic,Eigen::Dynamic> build_matrix(int rows, int cols)
+{
+  typedef Eigen::Matrix<Scalar,Eigen::Dynamic,Eigen::Dynamic> Matrix;
+  return Matrix(rows,cols);
+}
+
+template<typename Scalar>
+void expose_custom_type(const std::string & name)
+{
+  using namespace Eigen;
+  namespace bp = boost::python;
+  
+  typedef CustomType<Scalar> Type;
+  
+  bp::class_<Type>(name.c_str(),bp::init<Scalar>(bp::arg("value")))
+  
+  .def(bp::self + bp::self)
+  .def(bp::self - bp::self)
+  .def(bp::self * bp::self)
+  .def(bp::self / bp::self)
+  
+  .def(bp::self += bp::self)
+  .def(bp::self -= bp::self)
+  .def(bp::self *= bp::self)
+  .def(bp::self /= bp::self)
+  
+  .def("__repr__",&Type::print)
+  ;
+  
+  eigenpy::registerNewType<Type>();
+  eigenpy::registerCommonUfunc<Type>();
+}
+
+BOOST_PYTHON_MODULE(user_type)
+{
+  using namespace Eigen;
+  namespace bp = boost::python;
+  eigenpy::enableEigenPy();
+  
+  expose_custom_type<double>("CustomDouble");
+  typedef CustomType<double> DoubleType;
+  typedef Eigen::Matrix<DoubleType,Eigen::Dynamic,Eigen::Dynamic> DoubleMatrix;
+  eigenpy::EigenToPyConverter<DoubleMatrix>::registration();
+  bp::def("create_double",create<double>);
+  
+  expose_custom_type<float>("CustomFloat");
+  typedef CustomType<float> FloatType;
+  typedef Eigen::Matrix<FloatType,Eigen::Dynamic,Eigen::Dynamic> FloatMatrix;
+  eigenpy::EigenToPyConverter<FloatMatrix>::registration();
+  bp::def("create_float",create<float>);
+  
+  bp::def("build_matrix",build_matrix<double>);
+}
-- 
GitLab