From 558e3919c89fc46c33a304963dd3a6624e5f9680 Mon Sep 17 00:00:00 2001
From: Joris Vaillant <joris.vaillant@inria.fr>
Date: Mon, 5 Feb 2024 16:48:36 +0100
Subject: [PATCH] unique_ptr: Manage string and complex type

---
 CMakeLists.txt                          |  3 ++-
 include/eigenpy/std-unique-ptr.hpp      | 16 ++++++-------
 include/eigenpy/ufunc.hpp               |  7 ++----
 include/eigenpy/utils/python-compat.hpp | 23 ++++++++++++++++++
 include/eigenpy/utils/traits.hpp        | 32 +++++++++++++++++++++----
 include/eigenpy/variant.hpp             | 10 ++++----
 unittest/python/test_std_unique_ptr.py  | 25 ++++++++++++++++++-
 unittest/std_unique_ptr.cpp             | 25 ++++++++++++++++++-
 8 files changed, 115 insertions(+), 26 deletions(-)
 create mode 100644 include/eigenpy/utils/python-compat.hpp

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 7002b5af..e48c9105 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -104,7 +104,8 @@ search_for_boost_python(REQUIRED)
 # ----------------------------------------------------
 set(${PROJECT_NAME}_UTILS_HEADERS
     include/eigenpy/utils/scalar-name.hpp include/eigenpy/utils/is-approx.hpp
-    include/eigenpy/utils/is-aligned.hpp)
+    include/eigenpy/utils/is-aligned.hpp include/eigenpy/utils/traits.hpp
+    include/eigenpy/utils/python-compat.hpp)
 
 set(${PROJECT_NAME}_SOLVERS_HEADERS
     include/eigenpy/solvers/solvers.hpp
diff --git a/include/eigenpy/std-unique-ptr.hpp b/include/eigenpy/std-unique-ptr.hpp
index dcb95dc3..dfc9f5a1 100644
--- a/include/eigenpy/std-unique-ptr.hpp
+++ b/include/eigenpy/std-unique-ptr.hpp
@@ -7,6 +7,7 @@
 
 #include "eigenpy/fwd.hpp"
 #include "eigenpy/utils/traits.hpp"
+#include "eigenpy/utils/python-compat.hpp"
 
 #include <boost/python.hpp>
 
@@ -19,8 +20,7 @@ namespace details {
 
 /// Transfer std::unique_ptr ownership to an owning holder
 template <typename T>
-typename std::enable_if<is_class_or_union_remove_cvref<T>::value,
-                        PyObject*>::type
+typename std::enable_if<!is_python_primitive_type<T>::value, PyObject*>::type
 unique_ptr_to_python(std::unique_ptr<T>&& x) {
   typedef bp::objects::pointer_holder<std::unique_ptr<T>, T> holder_t;
   if (!x) {
@@ -32,8 +32,7 @@ unique_ptr_to_python(std::unique_ptr<T>&& x) {
 
 /// Convert and copy the primitive value to python
 template <typename T>
-typename std::enable_if<!is_class_or_union_remove_cvref<T>::value,
-                        PyObject*>::type
+typename std::enable_if<is_python_primitive_type<T>::value, PyObject*>::type
 unique_ptr_to_python(std::unique_ptr<T>&& x) {
   if (!x) {
     return bp::detail::none();
@@ -45,8 +44,7 @@ unique_ptr_to_python(std::unique_ptr<T>&& x) {
 /// std::unique_ptr keep the ownership but a reference to the std::unique_ptr
 /// value is created
 template <typename T>
-typename std::enable_if<is_class_or_union_remove_cvref<T>::value,
-                        PyObject*>::type
+typename std::enable_if<!is_python_primitive_type<T>::value, PyObject*>::type
 internal_unique_ptr_to_python(std::unique_ptr<T>& x) {
   if (!x) {
     return bp::detail::none();
@@ -57,8 +55,7 @@ internal_unique_ptr_to_python(std::unique_ptr<T>& x) {
 
 /// Convert and copy the primitive value to python
 template <typename T>
-typename std::enable_if<!is_class_or_union_remove_cvref<T>::value,
-                        PyObject*>::type
+typename std::enable_if<is_python_primitive_type<T>::value, PyObject*>::type
 internal_unique_ptr_to_python(std::unique_ptr<T>& x) {
   if (!x) {
     return bp::detail::none();
@@ -123,7 +120,8 @@ struct ReturnInternalStdUniquePtr : bp::return_internal_reference<> {
   template <class ArgumentPackage>
   static PyObject* postcall(ArgumentPackage const& args_, PyObject* result) {
     // Don't run return_internal_reference postcall on primitive type
-    if (PyLong_Check(result) || PyBool_Check(result) || PyFloat_Check(result)) {
+    if (PyInt_Check(result) || PyBool_Check(result) || PyFloat_Check(result) ||
+        PyStr_Check(result) || PyComplex_Check(result)) {
       return result;
     }
     return bp::return_internal_reference<>::postcall(args_, result);
diff --git a/include/eigenpy/ufunc.hpp b/include/eigenpy/ufunc.hpp
index cb6695ac..129438cf 100644
--- a/include/eigenpy/ufunc.hpp
+++ b/include/eigenpy/ufunc.hpp
@@ -9,6 +9,7 @@
 
 #include "eigenpy/register.hpp"
 #include "eigenpy/user-type.hpp"
+#include "eigenpy/utils/python-compat.hpp"
 
 namespace eigenpy {
 namespace internal {
@@ -207,11 +208,7 @@ void registerCommonUfunc() {
   const int type_code = Register::getTypeCode<Scalar>();
 
   PyObject *numpy_str;
-#if PY_MAJOR_VERSION >= 3
-  numpy_str = PyUnicode_FromString("numpy");
-#else
-  numpy_str = PyString_FromString("numpy");
-#endif
+  numpy_str = PyStr_FromString("numpy");
   PyObject *numpy;
   numpy = PyImport_Import(numpy_str);
   Py_DECREF(numpy_str);
diff --git a/include/eigenpy/utils/python-compat.hpp b/include/eigenpy/utils/python-compat.hpp
new file mode 100644
index 00000000..7ffbc9de
--- /dev/null
+++ b/include/eigenpy/utils/python-compat.hpp
@@ -0,0 +1,23 @@
+//
+// Copyright (c) 2024 INRIA
+//
+//
+
+#ifndef __eigenpy_utils_python_compat_hpp__
+#define __eigenpy_utils_python_compat_hpp__
+
+#if PY_MAJOR_VERSION >= 3
+
+#define PyInt_Check PyLong_Check
+
+#define PyStr_Check PyUnicode_Check
+#define PyStr_FromString PyUnicode_FromString
+
+#else
+
+#define PyStr_Check PyString_Check
+#define PyStr_FromString PyString_FromString
+
+#endif
+
+#endif  // ifndef __eigenpy_utils_python_compat_hpp__
diff --git a/include/eigenpy/utils/traits.hpp b/include/eigenpy/utils/traits.hpp
index 9b8e020c..b7525dea 100644
--- a/include/eigenpy/utils/traits.hpp
+++ b/include/eigenpy/utils/traits.hpp
@@ -7,25 +7,47 @@
 #define __eigenpy_utils_traits_hpp__
 
 #include <type_traits>
+#include <string>
+#include <complex>
 
 namespace eigenpy {
 
 namespace details {
 
+/// Trait to remove const&
+template <typename T>
+struct remove_cvref : std::remove_cv<typename std::remove_reference<T>::type> {
+};
+
 /// Trait to detect if T is a class or an union
 template <typename T>
 struct is_class_or_union
     : std::integral_constant<bool, std::is_class<T>::value ||
                                        std::is_union<T>::value> {};
 
+/// trait to detect if T is a std::complex managed by Boost Python
 template <typename T>
-struct remove_cvref : std::remove_cv<typename std::remove_reference<T>::type> {
-};
+struct is_python_complex : std::false_type {};
+
+/// From boost/python/converter/builtin_converters
+template <>
+struct is_python_complex<std::complex<float> > : std::true_type {};
+template <>
+struct is_python_complex<std::complex<double> > : std::true_type {};
+template <>
+struct is_python_complex<std::complex<long double> > : std::true_type {};
+
+template <typename T>
+struct is_python_primitive_type_helper
+    : std::integral_constant<bool, !is_class_or_union<T>::value ||
+                                       std::is_same<T, std::string>::value ||
+                                       std::is_same<T, std::wstring>::value ||
+                                       is_python_complex<T>::value> {};
 
-/// Trait to remove cvref and call is_class_or_union
+/// Trait to detect if T is a Python primitive type
 template <typename T>
-struct is_class_or_union_remove_cvref
-    : is_class_or_union<typename remove_cvref<T>::type> {};
+struct is_python_primitive_type
+    : is_python_primitive_type_helper<typename remove_cvref<T>::type> {};
 
 }  // namespace details
 
diff --git a/include/eigenpy/variant.hpp b/include/eigenpy/variant.hpp
index f989a0f9..bf787c3f 100644
--- a/include/eigenpy/variant.hpp
+++ b/include/eigenpy/variant.hpp
@@ -7,6 +7,7 @@
 
 #include "eigenpy/fwd.hpp"
 #include "eigenpy/utils/traits.hpp"
+#include "eigenpy/utils/python-compat.hpp"
 
 #include <boost/python.hpp>
 #include <boost/variant.hpp>
@@ -147,7 +148,7 @@ struct NumericConvertibleImpl<
                                std::is_integral<T>::value>::type> {
   static void* convertible(PyObject* obj) {
     // PyLong return true for bool type
-    return (PyLong_Check(obj) && !PyBool_Check(obj)) ? obj : nullptr;
+    return (PyInt_Check(obj) && !PyBool_Check(obj)) ? obj : nullptr;
   }
 
   static PyTypeObject const* expected_pytype() { return &PyLong_Type; }
@@ -220,14 +221,14 @@ struct VariantRefToObject : VariantVisitorType<PyObject*, Variant> {
   }
 
   template <typename T,
-            typename std::enable_if<!is_class_or_union_remove_cvref<T>::value,
+            typename std::enable_if<is_python_primitive_type<T>::value,
                                     bool>::type = true>
   result_type operator()(T t) const {
     return bp::incref(bp::object(t).ptr());
   }
 
   template <typename T,
-            typename std::enable_if<is_class_or_union_remove_cvref<T>::value,
+            typename std::enable_if<!is_python_primitive_type<T>::value,
                                     bool>::type = true>
   result_type operator()(T& t) const {
     return bp::detail::make_reference_holder::execute(&t);
@@ -301,7 +302,8 @@ struct ReturnInternalVariant : bp::return_internal_reference<> {
   template <class ArgumentPackage>
   static PyObject* postcall(ArgumentPackage const& args_, PyObject* result) {
     // Don't run return_internal_reference postcall on primitive type
-    if (PyLong_Check(result) || PyBool_Check(result) || PyFloat_Check(result)) {
+    if (PyInt_Check(result) || PyBool_Check(result) || PyFloat_Check(result) ||
+        PyStr_Check(result) || PyComplex_Check(result)) {
       return result;
     }
     return bp::return_internal_reference<>::postcall(args_, result);
diff --git a/unittest/python/test_std_unique_ptr.py b/unittest/python/test_std_unique_ptr.py
index 6feb408f..8b56460f 100644
--- a/unittest/python/test_std_unique_ptr.py
+++ b/unittest/python/test_std_unique_ptr.py
@@ -2,6 +2,8 @@ from std_unique_ptr import (
     make_unique_int,
     make_unique_v1,
     make_unique_null,
+    make_unique_str,
+    make_unique_complex,
     V1,
     UniquePtrHolder,
 )
@@ -17,6 +19,14 @@ assert v.v == 10
 v = make_unique_null()
 assert v is None
 
+v = make_unique_str()
+assert isinstance(v, str)
+assert v == "str"
+
+v = make_unique_complex()
+assert isinstance(v, complex)
+assert v == 1 + 0j
+
 unique_ptr_holder = UniquePtrHolder()
 
 v = unique_ptr_holder.int_ptr
@@ -33,6 +43,19 @@ assert v.v == 200
 v.v = 10
 assert unique_ptr_holder.v1_ptr.v == 10
 
-
 v = unique_ptr_holder.null_ptr
 assert v is None
+
+v = unique_ptr_holder.str_ptr
+assert isinstance(v, str)
+assert v == "str"
+# v is a copy, str_ptr will not be updated
+v = "str_updated"
+assert unique_ptr_holder.str_ptr == "str"
+
+v = unique_ptr_holder.complex_ptr
+assert isinstance(v, complex)
+assert v == 1 + 0j
+# v is a copy, complex_ptr will not be updated
+v = 1 + 2j
+assert unique_ptr_holder.complex_ptr == 1 + 0j
diff --git a/unittest/std_unique_ptr.cpp b/unittest/std_unique_ptr.cpp
index ab99a4da..a95a5d24 100644
--- a/unittest/std_unique_ptr.cpp
+++ b/unittest/std_unique_ptr.cpp
@@ -5,6 +5,8 @@
 #include <eigenpy/std-unique-ptr.hpp>
 
 #include <memory>
+#include <string>
+#include <complex>
 
 namespace bp = boost::python;
 
@@ -21,13 +23,26 @@ std::unique_ptr<V1> make_unique_v1() { return std::make_unique<V1>(10); }
 
 std::unique_ptr<V1> make_unique_null() { return nullptr; }
 
+std::unique_ptr<std::string> make_unique_str() {
+  return std::make_unique<std::string>("str");
+}
+
+std::unique_ptr<std::complex<double> > make_unique_complex() {
+  return std::make_unique<std::complex<double> >(1., 0.);
+}
+
 struct UniquePtrHolder {
   UniquePtrHolder()
-      : int_ptr(std::make_unique<int>(20)), v1_ptr(std::make_unique<V1>(200)) {}
+      : int_ptr(std::make_unique<int>(20)),
+        v1_ptr(std::make_unique<V1>(200)),
+        str_ptr(std::make_unique<std::string>("str")),
+        complex_ptr(std::make_unique<std::complex<double> >(1., 0.)) {}
 
   std::unique_ptr<int> int_ptr;
   std::unique_ptr<V1> v1_ptr;
   std::unique_ptr<V1> null_ptr;
+  std::unique_ptr<std::string> str_ptr;
+  std::unique_ptr<std::complex<double> > complex_ptr;
 };
 
 BOOST_PYTHON_MODULE(std_unique_ptr) {
@@ -39,6 +54,8 @@ BOOST_PYTHON_MODULE(std_unique_ptr) {
   bp::def("make_unique_v1", make_unique_v1);
   bp::def("make_unique_null", make_unique_null,
           eigenpy::StdUniquePtrCallPolicies());
+  bp::def("make_unique_str", make_unique_str);
+  bp::def("make_unique_complex", make_unique_complex);
 
   boost::python::class_<UniquePtrHolder, boost::noncopyable>("UniquePtrHolder",
                                                              bp::init<>())
@@ -50,5 +67,11 @@ BOOST_PYTHON_MODULE(std_unique_ptr) {
                                     eigenpy::ReturnInternalStdUniquePtr()))
       .add_property("null_ptr",
                     bp::make_getter(&UniquePtrHolder::null_ptr,
+                                    eigenpy::ReturnInternalStdUniquePtr()))
+      .add_property("str_ptr",
+                    bp::make_getter(&UniquePtrHolder::str_ptr,
+                                    eigenpy::ReturnInternalStdUniquePtr()))
+      .add_property("complex_ptr",
+                    bp::make_getter(&UniquePtrHolder::complex_ptr,
                                     eigenpy::ReturnInternalStdUniquePtr()));
 }
-- 
GitLab