From 4d7cb520fdc37cb7db8b5eaffa6d41c3275a775b Mon Sep 17 00:00:00 2001
From: Justin Carpentier <justin.carpentier@inria.fr>
Date: Tue, 3 Aug 2021 15:44:44 +0200
Subject: [PATCH] core: handle matmul operation from NumPy

---
 include/eigenpy/ufunc.hpp | 124 ++++++++++++++++++++++++++++++++++----
 1 file changed, 111 insertions(+), 13 deletions(-)

diff --git a/include/eigenpy/ufunc.hpp b/include/eigenpy/ufunc.hpp
index 2647f8cf..1dbe1966 100644
--- a/include/eigenpy/ufunc.hpp
+++ b/include/eigenpy/ufunc.hpp
@@ -1,11 +1,13 @@
 //
-// Copyright (c) 2020 INRIA
+// Copyright (c) 2020-2021 INRIA
+// code aptapted from https://github.com/numpy/numpy/blob/41977b24ae011a51f64faa75cb524c7350fdedd9/numpy/core/src/umath/_rational_tests.c.src
 //
 
 #ifndef __eigenpy_ufunc_hpp__
 #define __eigenpy_ufunc_hpp__
 
 #include "eigenpy/register.hpp"
+#include "eigenpy/user-type.hpp"
 
 namespace eigenpy
 {
@@ -18,6 +20,77 @@ namespace eigenpy
   #define EIGENPY_NPY_CONST_UFUNC_ARG
 #endif
   
+  template<typename T>
+  void matrix_multiply(char **args, npy_intp const *dimensions, npy_intp const *steps)
+  {
+    /* pointers to data for input and output arrays */
+    char *ip1 = args[0];
+    char *ip2 = args[1];
+    char *op = args[2];
+    
+    /* lengths of core dimensions */
+    npy_intp dm = dimensions[0];
+    npy_intp dn = dimensions[1];
+    npy_intp dp = dimensions[2];
+    
+    /* striding over core dimensions */
+    npy_intp is1_m = steps[0];
+    npy_intp is1_n = steps[1];
+    npy_intp is2_n = steps[2];
+    npy_intp is2_p = steps[3];
+    npy_intp os_m = steps[4];
+    npy_intp os_p = steps[5];
+    
+    /* core dimensions counters */
+    npy_intp m, p;
+    
+    /* calculate dot product for each row/column vector pair */
+    for (m = 0; m < dm; m++)
+    {
+      for (p = 0; p < dp; p++)
+      {
+        SpecialMethods<T>::dotfunc(ip1, is1_n, ip2, is2_n, op, dn, NULL);
+        
+        /* advance to next column of 2nd input array and output array */
+        ip2 += is2_p;
+        op  +=  os_p;
+      }
+      
+      /* reset to first column of 2nd input array and output array */
+      ip2 -= is2_p * p;
+      op -= os_p * p;
+      
+      /* advance to next row of 1st input array and output array */
+      ip1 += is1_m;
+      op += os_m;
+    }
+  }
+  
+  template<typename T>
+  void gufunc_matrix_multiply(char **args, npy_intp const *dimensions,
+                              npy_intp const *steps, void *NPY_UNUSED(func))
+  {
+    /* outer dimensions counter */
+    npy_intp N_;
+    
+    /* length of flattened outer dimensions */
+    npy_intp dN = dimensions[0];
+    
+    /* striding over flattened outer dimensions for input and output arrays */
+    npy_intp s0 = steps[0];
+    npy_intp s1 = steps[1];
+    npy_intp s2 = steps[2];
+    
+    /*
+     * loop through outer dimensions, performing matrix multiply on
+     * core dimensions for each loop
+     */
+    for (N_ = 0; N_ < dN; N_++, args[0] += s0, args[1] += s1, args[2] += s2)
+    {
+      matrix_multiply<T>(args, dimensions+1, steps+3);
+    }
+  }
+  
 #define EIGENPY_REGISTER_BINARY_OPERATOR(name,op) \
     template<typename T1, typename T2, typename R> \
     void binary_op_##name(char** args, EIGENPY_NPY_CONST_UFUNC_ARG npy_intp * dimensions, EIGENPY_NPY_CONST_UFUNC_ARG npy_intp * steps, void * /*data*/) \
@@ -127,7 +200,7 @@ namespace eigenpy
   template<typename Scalar>
   void registerCommonUfunc()
   {
-    const int code = Register::getTypeCode<Scalar>();
+    const int type_code = Register::getTypeCode<Scalar>();
   
     PyObject* numpy_str;
 #if PY_MAJOR_VERSION >= 3
@@ -140,23 +213,48 @@ namespace eigenpy
     Py_DECREF(numpy_str);
     
     import_ufunc();
+    
+    // Matrix multiply
+    {
+      int types[3] = {type_code,type_code,type_code};
+
+      std::stringstream ss;
+      ss << "return result of multiplying two matrices of ";
+      ss << bp::type_info(typeid(Scalar)).name();
+      PyUFuncObject* ufunc = (PyUFuncObject*)PyObject_GetAttrString(numpy, "matmul");
+      if(!ufunc)
+      {
+        std::stringstream ss;
+        ss << "Impossible to define matrix_multiply for given type " << bp::type_info(typeid(Scalar)).name() << std::endl;
+        eigenpy::Exception(ss.str());
+      }
+      if(PyUFunc_RegisterLoopForType((PyUFuncObject*)ufunc, type_code,
+                                     &internal::gufunc_matrix_multiply<Scalar>, types, 0) < 0)
+      {
+        std::stringstream ss;
+        ss << "Impossible to register matrix_multiply for given type " << bp::type_info(typeid(Scalar)).name() << std::endl;
+        eigenpy::Exception(ss.str());
+      }
+
+      Py_DECREF(ufunc);
+    }
 
     // Binary operators
-    EIGENPY_REGISTER_BINARY_UFUNC(add,code,Scalar,Scalar,Scalar);
-    EIGENPY_REGISTER_BINARY_UFUNC(subtract,code,Scalar,Scalar,Scalar);
-    EIGENPY_REGISTER_BINARY_UFUNC(multiply,code,Scalar,Scalar,Scalar);
-    EIGENPY_REGISTER_BINARY_UFUNC(divide,code,Scalar,Scalar,Scalar);
+    EIGENPY_REGISTER_BINARY_UFUNC(add,type_code,Scalar,Scalar,Scalar);
+    EIGENPY_REGISTER_BINARY_UFUNC(subtract,type_code,Scalar,Scalar,Scalar);
+    EIGENPY_REGISTER_BINARY_UFUNC(multiply,type_code,Scalar,Scalar,Scalar);
+    EIGENPY_REGISTER_BINARY_UFUNC(divide,type_code,Scalar,Scalar,Scalar);
   
     // Comparison operators
-    EIGENPY_REGISTER_BINARY_UFUNC(equal,code,Scalar,Scalar,bool);
-    EIGENPY_REGISTER_BINARY_UFUNC(not_equal,code,Scalar,Scalar,bool);
-    EIGENPY_REGISTER_BINARY_UFUNC(greater,code,Scalar,Scalar,bool);
-    EIGENPY_REGISTER_BINARY_UFUNC(less,code,Scalar,Scalar,bool);
-    EIGENPY_REGISTER_BINARY_UFUNC(greater_equal,code,Scalar,Scalar,bool);
-    EIGENPY_REGISTER_BINARY_UFUNC(less_equal,code,Scalar,Scalar,bool);
+    EIGENPY_REGISTER_BINARY_UFUNC(equal,type_code,Scalar,Scalar,bool);
+    EIGENPY_REGISTER_BINARY_UFUNC(not_equal,type_code,Scalar,Scalar,bool);
+    EIGENPY_REGISTER_BINARY_UFUNC(greater,type_code,Scalar,Scalar,bool);
+    EIGENPY_REGISTER_BINARY_UFUNC(less,type_code,Scalar,Scalar,bool);
+    EIGENPY_REGISTER_BINARY_UFUNC(greater_equal,type_code,Scalar,Scalar,bool);
+    EIGENPY_REGISTER_BINARY_UFUNC(less_equal,type_code,Scalar,Scalar,bool);
   
     // Unary operators
-    EIGENPY_REGISTER_UNARY_UFUNC(negative,code,Scalar,Scalar);
+    EIGENPY_REGISTER_UNARY_UFUNC(negative,type_code,Scalar,Scalar);
 
     Py_DECREF(numpy);
   }
-- 
GitLab