From b97d1d2a23cc2adf5b391650883cc726a9b48808 Mon Sep 17 00:00:00 2001
From: Joris Vaillant <joris.vaillant@inria.fr>
Date: Wed, 31 Jan 2024 14:41:12 +0100
Subject: [PATCH] core: Avoid ambiguity with numeric type conversion

---
 include/eigenpy/variant.hpp        | 120 ++++++++++++++++++++++-------
 unittest/python/test_variant.py.in |  80 ++++++++-----------
 unittest/variant.cpp.in            |  64 +++------------
 3 files changed, 139 insertions(+), 125 deletions(-)

diff --git a/include/eigenpy/variant.hpp b/include/eigenpy/variant.hpp
index d2e8b10a..4028f856 100644
--- a/include/eigenpy/variant.hpp
+++ b/include/eigenpy/variant.hpp
@@ -36,32 +36,6 @@ struct empty_variant {};
 template <typename T>
 struct is_empty_variant : std::false_type {};
 
-/// Convert None to a {boost,std}::variant with boost::blank or std::monostate
-/// value
-template <typename Variant>
-struct EmptyConvertible {
-  static void registration() {
-    bp::converter::registry::push_back(convertible, construct,
-                                       bp::type_id<Variant>());
-  }
-
-  // convertible only for None
-  static void* convertible(PyObject* obj) {
-    return (obj == Py_None) ? obj : nullptr;
-  };
-
-  // construct in place
-  static void construct(PyObject*,
-                        bp::converter::rvalue_from_python_stage1_data* data) {
-    void* storage =
-        reinterpret_cast<bp::converter::rvalue_from_python_storage<Variant>*>(
-            data)
-            ->storage.bytes;
-    new (storage) Variant(typename empty_variant<Variant>::type());
-    data->convertible = storage;
-  };
-};
-
 #ifdef EIGENPY_WITH_CXX17_SUPPORT
 
 /// std::variant implementation
@@ -126,6 +100,90 @@ struct empty_variant<boost::variant<Alternatives...> > {
 template <>
 struct is_empty_variant<boost::blank> : std::true_type {};
 
+/// Convert None to a {boost,std}::variant with boost::blank or std::monostate
+/// value
+template <typename Variant>
+struct EmptyConvertible {
+  static void registration() {
+    bp::converter::registry::push_back(convertible, construct,
+                                       bp::type_id<Variant>());
+  }
+
+  // convertible only for None
+  static void* convertible(PyObject* obj) {
+    return (obj == Py_None) ? obj : nullptr;
+  };
+
+  // construct in place
+  static void construct(PyObject*,
+                        bp::converter::rvalue_from_python_stage1_data* data) {
+    void* storage =
+        reinterpret_cast<bp::converter::rvalue_from_python_storage<Variant>*>(
+            data)
+            ->storage.bytes;
+    new (storage) Variant(typename empty_variant<Variant>::type());
+    data->convertible = storage;
+  };
+};
+
+/// Implement convertible and expected_pytype for bool, integer and float
+template <typename T, class Enable = void>
+struct NumericConvertibleImpl {};
+
+template <typename T>
+struct NumericConvertibleImpl<
+    T, typename std::enable_if<std::is_same<T, bool>::value>::type> {
+  static void* convertible(PyObject* obj) {
+    return PyBool_Check(obj) ? obj : nullptr;
+  }
+
+  static PyTypeObject const* expected_pytype() { return &PyBool_Type; }
+};
+
+template <typename T>
+struct NumericConvertibleImpl<
+    T, typename std::enable_if<!std::is_same<T, bool>::value &&
+                               std::is_integral<T>::value>::type> {
+  static void* convertible(PyObject* obj) {
+    // PyLong return true for bool type
+    return (PyLong_Check(obj) && !PyBool_Check(obj)) ? obj : nullptr;
+  }
+
+  static PyTypeObject const* expected_pytype() { return &PyLong_Type; }
+};
+
+template <typename T>
+struct NumericConvertibleImpl<
+    T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
+  static void* convertible(PyObject* obj) {
+    return PyFloat_Check(obj) ? obj : nullptr;
+  }
+
+  static PyTypeObject const* expected_pytype() { return &PyFloat_Type; }
+};
+
+/// Convert numeric type to Variant without ambiguity
+template <typename T, typename Variant>
+struct NumericConvertible {
+  static void registration() {
+    bp::converter::registry::push_back(
+        &convertible, &bp::converter::implicit<T, Variant>::construct,
+        bp::type_id<Variant>()
+#ifndef BOOST_PYTHON_NO_PY_SIGNATURES
+            ,
+        &expected_pytype
+#endif
+    );
+  }
+
+  static void* convertible(PyObject* obj) {
+    return NumericConvertibleImpl<T>::convertible(obj);
+  }
+  static PyTypeObject const* expected_pytype() {
+    return NumericConvertibleImpl<T>::expected_pytype();
+  }
+};
+
 /// Convert {boost,std}::variant<class...> alternative to a Python object.
 /// This converter copy the alternative.
 template <typename Variant>
@@ -225,7 +283,15 @@ struct VariantConvertible {
     EmptyConvertible<variant_type>::registration();
   }
 
-  template <class T, typename std::enable_if<!is_empty_variant<T>::value,
+  template <class T, typename std::enable_if<!is_empty_variant<T>::value &&
+                                                 std::is_arithmetic<T>::value,
+                                             bool>::type = true>
+  void operator()(T) {
+    NumericConvertible<T, variant_type>::registration();
+  }
+
+  template <class T, typename std::enable_if<!is_empty_variant<T>::value &&
+                                                 !std::is_arithmetic<T>::value,
                                              bool>::type = true>
   void operator()(T) {
     bp::implicitly_convertible<T, variant_type>();
diff --git a/unittest/python/test_variant.py.in b/unittest/python/test_variant.py.in
index 300b1d81..b019514c 100644
--- a/unittest/python/test_variant.py.in
+++ b/unittest/python/test_variant.py.in
@@ -4,13 +4,9 @@ variant_module = importlib.import_module("@MODNAME@")
 V1 = variant_module.V1
 V2 = variant_module.V2
 VariantHolder = variant_module.VariantHolder
-VariantNoneHolder = variant_module.VariantNoneHolder
-VariantArithmeticHolder = variant_module.VariantArithmeticHolder
-VariantBoolHolder = variant_module.VariantBoolHolder
+VariantFullHolder = variant_module.VariantFullHolder
 make_variant = variant_module.make_variant
-make_variant_none = variant_module.make_variant_none
-make_variant_arithmetic = variant_module.make_variant_arithmetic
-make_variant_bool = variant_module.make_variant_bool
+make_variant_full = variant_module.make_variant_full
 
 variant = make_variant()
 assert isinstance(variant, V1)
@@ -48,48 +44,40 @@ assert isinstance(variant_holder.variant, V2)
 assert variant_holder.variant.v == v2.v
 
 # Test variant that hold a None value
-v_none = make_variant_none()
-assert v_none is None
+v_full = make_variant_full()
+assert v_full is None
+
+variant_full_holder = VariantFullHolder()
 
-variant_none_holder = VariantNoneHolder()
-v_none = variant_none_holder.variant
+# Test None
+v_none = variant_full_holder.variant
+assert v_none is None
+variant_full_holder.variant = None
 assert v_none is None
 
+# Test V1
 v1 = V1()
-v1.v = 1
-variant_none_holder.variant = v1
-assert variant_none_holder.variant.v == 1
-v1 = variant_none_holder.variant
 v1.v = 10
-assert variant_none_holder.variant.v == 10
-variant_none_holder.variant = None
-
-
-# Test variant that hold base type
-v_arithmetic = make_variant_arithmetic()
-assert isinstance(v_arithmetic, int)
-
-variant_arithmetic_holder = VariantArithmeticHolder()
-assert isinstance(variant_arithmetic_holder.variant, int)
-variant_arithmetic_holder.variant = 2
-# Raise an exception if return_internal_postcall is called
-assert variant_arithmetic_holder.variant == 2
-
-variant_arithmetic_holder.variant = 2.0
-assert isinstance(variant_arithmetic_holder.variant, float)
-# Raise an exception if return_internal_postcall is called
-assert variant_arithmetic_holder.variant == 2.0
-
-v_bool = make_variant_bool()
-assert isinstance(v_bool, bool)
-
-variant_bool_holder = VariantBoolHolder()
-assert isinstance(variant_bool_holder.variant, bool)
-variant_bool_holder.variant = False
-# Raise an exception if return_internal_postcall is called
-assert not variant_bool_holder.variant
-
-variant_bool_holder.variant = 2.0
-assert isinstance(variant_bool_holder.variant, float)
-# Raise an exception if return_internal_postcall is called
-assert variant_bool_holder.variant == 2.0
+variant_full_holder.variant = v1
+assert variant_full_holder.variant.v == 10
+assert isinstance(variant_full_holder.variant, V1)
+# Test V1 ref
+v1 = variant_full_holder.variant
+v1.v = 100
+assert variant_full_holder.variant.v == 100
+variant_full_holder.variant = None
+
+# Test bool
+variant_full_holder.variant = True
+assert variant_full_holder.variant
+assert isinstance(variant_full_holder.variant, bool)
+
+# Test int
+variant_full_holder.variant = 3
+assert variant_full_holder.variant == 3
+assert isinstance(variant_full_holder.variant, int)
+
+# Test float
+variant_full_holder.variant = 3.14
+assert variant_full_holder.variant == 3.14
+assert isinstance(variant_full_holder.variant, float)
diff --git a/unittest/variant.cpp.in b/unittest/variant.cpp.in
index 2f8d8dbd..12f66993 100644
--- a/unittest/variant.cpp.in
+++ b/unittest/variant.cpp.in
@@ -32,35 +32,19 @@ struct MyVariantNoneHelper<std::variant<Alternatives...> > {
 };
 #endif
 
-typedef typename MyVariantNoneHelper<VARIANT<V1> >::type MyVariantNone;
-
-typedef VARIANT<int, double> MyVariantArithmetic;
-
-// There is a conversion conflict between int and bool
-typedef VARIANT<bool, double> MyVariantBool;
+typedef typename MyVariantNoneHelper<VARIANT<V1, bool, int, double> >::type
+    MyVariantFull;
 
 MyVariant make_variant() { return V1(); }
 
-MyVariantNone make_variant_none() { return MyVariantNone(); }
-
-MyVariantArithmetic make_variant_arithmetic() { return MyVariantArithmetic(); }
-
-MyVariantBool make_variant_bool() { return MyVariantBool(); }
+MyVariantFull make_variant_full() { return MyVariantFull(); }
 
 struct VariantHolder {
   MyVariant variant;
 };
 
-struct VariantNoneHolder {
-  MyVariantNone variant;
-};
-
-struct VariantArithmeticHolder {
-  MyVariantArithmetic variant;
-};
-
-struct VariantBoolHolder {
-  MyVariantBool variant;
+struct VariantFullHolder {
+  MyVariantFull variant;
 };
 
 BOOST_PYTHON_MODULE(@MODNAME@) {
@@ -82,37 +66,13 @@ BOOST_PYTHON_MODULE(@MODNAME@) {
                                     Converter::return_internal_reference()),
                     bp::make_setter(&VariantHolder::variant));
 
-  typedef eigenpy::VariantConverter<MyVariantNone> ConverterNone;
-  ConverterNone::registration();
-  bp::def("make_variant_none", make_variant_none);
+  typedef eigenpy::VariantConverter<MyVariantFull> ConverterFull;
+  ConverterFull::registration();
+  bp::def("make_variant_full", make_variant_full);
 
-  boost::python::class_<VariantNoneHolder>("VariantNoneHolder", bp::init<>())
+  boost::python::class_<VariantFullHolder>("VariantFullHolder", bp::init<>())
       .add_property("variant",
-                    bp::make_getter(&VariantNoneHolder::variant,
-                                    ConverterNone::return_internal_reference()),
-                    bp::make_setter(&VariantNoneHolder::variant));
-
-  typedef eigenpy::VariantConverter<MyVariantArithmetic> ConverterArithmetic;
-  ConverterArithmetic::registration();
-  bp::def("make_variant_arithmetic", make_variant_arithmetic);
-
-  boost::python::class_<VariantArithmeticHolder>("VariantArithmeticHolder",
-                                                 bp::init<>())
-      .add_property(
-          "variant",
-          bp::make_getter(&VariantArithmeticHolder::variant,
-                          ConverterArithmetic::return_internal_reference()),
-          bp::make_setter(&VariantArithmeticHolder::variant));
-
-  typedef eigenpy::VariantConverter<MyVariantBool> ConverterBool;
-  ConverterBool::registration();
-  bp::def("make_variant_bool", make_variant_bool);
-
-  boost::python::class_<VariantBoolHolder>("VariantBoolHolder",
-                                                 bp::init<>())
-      .add_property(
-          "variant",
-          bp::make_getter(&VariantBoolHolder::variant,
-                          ConverterBool::return_internal_reference()),
-          bp::make_setter(&VariantBoolHolder::variant));
+                    bp::make_getter(&VariantFullHolder::variant,
+                                    ConverterFull::return_internal_reference()),
+                    bp::make_setter(&VariantFullHolder::variant));
 }
-- 
GitLab