Commit 4d7cb520 by Justin Carpentier

### core: handle matmul operation from NumPy

parent 99454d62
 // // Copyright (c) 2020 INRIA // Copyright (c) 2020-2021 INRIA // code aptapted from https://github.com/numpy/numpy/blob/41977b24ae011a51f64faa75cb524c7350fdedd9/numpy/core/src/umath/_rational_tests.c.src // #ifndef __eigenpy_ufunc_hpp__ #define __eigenpy_ufunc_hpp__ #include "eigenpy/register.hpp" #include "eigenpy/user-type.hpp" namespace eigenpy { ... ... @@ -18,6 +20,77 @@ namespace eigenpy #define EIGENPY_NPY_CONST_UFUNC_ARG #endif template void matrix_multiply(char **args, npy_intp const *dimensions, npy_intp const *steps) { /* pointers to data for input and output arrays */ char *ip1 = args[0]; char *ip2 = args[1]; char *op = args[2]; /* lengths of core dimensions */ npy_intp dm = dimensions[0]; npy_intp dn = dimensions[1]; npy_intp dp = dimensions[2]; /* striding over core dimensions */ npy_intp is1_m = steps[0]; npy_intp is1_n = steps[1]; npy_intp is2_n = steps[2]; npy_intp is2_p = steps[3]; npy_intp os_m = steps[4]; npy_intp os_p = steps[5]; /* core dimensions counters */ npy_intp m, p; /* calculate dot product for each row/column vector pair */ for (m = 0; m < dm; m++) { for (p = 0; p < dp; p++) { SpecialMethods::dotfunc(ip1, is1_n, ip2, is2_n, op, dn, NULL); /* advance to next column of 2nd input array and output array */ ip2 += is2_p; op += os_p; } /* reset to first column of 2nd input array and output array */ ip2 -= is2_p * p; op -= os_p * p; /* advance to next row of 1st input array and output array */ ip1 += is1_m; op += os_m; } } template void gufunc_matrix_multiply(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func)) { /* outer dimensions counter */ npy_intp N_; /* length of flattened outer dimensions */ npy_intp dN = dimensions[0]; /* striding over flattened outer dimensions for input and output arrays */ npy_intp s0 = steps[0]; npy_intp s1 = steps[1]; npy_intp s2 = steps[2]; /* * loop through outer dimensions, performing matrix multiply on * core dimensions for each loop */ for (N_ = 0; N_ < dN; N_++, args[0] += s0, args[1] += s1, args[2] += s2) { matrix_multiply(args, dimensions+1, steps+3); } } #define EIGENPY_REGISTER_BINARY_OPERATOR(name,op) \ template \ void binary_op_##name(char** args, EIGENPY_NPY_CONST_UFUNC_ARG npy_intp * dimensions, EIGENPY_NPY_CONST_UFUNC_ARG npy_intp * steps, void * /*data*/) \ ... ... @@ -127,7 +200,7 @@ namespace eigenpy template void registerCommonUfunc() { const int code = Register::getTypeCode(); const int type_code = Register::getTypeCode(); PyObject* numpy_str; #if PY_MAJOR_VERSION >= 3 ... ... @@ -140,23 +213,48 @@ namespace eigenpy Py_DECREF(numpy_str); import_ufunc(); // Matrix multiply { int types[3] = {type_code,type_code,type_code}; std::stringstream ss; ss << "return result of multiplying two matrices of "; ss << bp::type_info(typeid(Scalar)).name(); PyUFuncObject* ufunc = (PyUFuncObject*)PyObject_GetAttrString(numpy, "matmul"); if(!ufunc) { std::stringstream ss; ss << "Impossible to define matrix_multiply for given type " << bp::type_info(typeid(Scalar)).name() << std::endl; eigenpy::Exception(ss.str()); } if(PyUFunc_RegisterLoopForType((PyUFuncObject*)ufunc, type_code, &internal::gufunc_matrix_multiply, types, 0) < 0) { std::stringstream ss; ss << "Impossible to register matrix_multiply for given type " << bp::type_info(typeid(Scalar)).name() << std::endl; eigenpy::Exception(ss.str()); } Py_DECREF(ufunc); } // Binary operators EIGENPY_REGISTER_BINARY_UFUNC(add,code,Scalar,Scalar,Scalar); EIGENPY_REGISTER_BINARY_UFUNC(subtract,code,Scalar,Scalar,Scalar); EIGENPY_REGISTER_BINARY_UFUNC(multiply,code,Scalar,Scalar,Scalar); EIGENPY_REGISTER_BINARY_UFUNC(divide,code,Scalar,Scalar,Scalar); EIGENPY_REGISTER_BINARY_UFUNC(add,type_code,Scalar,Scalar,Scalar); EIGENPY_REGISTER_BINARY_UFUNC(subtract,type_code,Scalar,Scalar,Scalar); EIGENPY_REGISTER_BINARY_UFUNC(multiply,type_code,Scalar,Scalar,Scalar); EIGENPY_REGISTER_BINARY_UFUNC(divide,type_code,Scalar,Scalar,Scalar); // Comparison operators EIGENPY_REGISTER_BINARY_UFUNC(equal,code,Scalar,Scalar,bool); EIGENPY_REGISTER_BINARY_UFUNC(not_equal,code,Scalar,Scalar,bool); EIGENPY_REGISTER_BINARY_UFUNC(greater,code,Scalar,Scalar,bool); EIGENPY_REGISTER_BINARY_UFUNC(less,code,Scalar,Scalar,bool); EIGENPY_REGISTER_BINARY_UFUNC(greater_equal,code,Scalar,Scalar,bool); EIGENPY_REGISTER_BINARY_UFUNC(less_equal,code,Scalar,Scalar,bool); EIGENPY_REGISTER_BINARY_UFUNC(equal,type_code,Scalar,Scalar,bool); EIGENPY_REGISTER_BINARY_UFUNC(not_equal,type_code,Scalar,Scalar,bool); EIGENPY_REGISTER_BINARY_UFUNC(greater,type_code,Scalar,Scalar,bool); EIGENPY_REGISTER_BINARY_UFUNC(less,type_code,Scalar,Scalar,bool); EIGENPY_REGISTER_BINARY_UFUNC(greater_equal,type_code,Scalar,Scalar,bool); EIGENPY_REGISTER_BINARY_UFUNC(less_equal,type_code,Scalar,Scalar,bool); // Unary operators EIGENPY_REGISTER_UNARY_UFUNC(negative,code,Scalar,Scalar); EIGENPY_REGISTER_UNARY_UFUNC(negative,type_code,Scalar,Scalar); Py_DECREF(numpy); } ... ...
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!