user-type.hpp 7.19 KB
Newer Older
1
2
3
4
5
6
7
8
9
//
// Copyright (c) 2020 INRIA
//

#ifndef __eigenpy_user_type_hpp__
#define __eigenpy_user_type_hpp__

#include "eigenpy/fwd.hpp"
#include "eigenpy/numpy-type.hpp"
Justin Carpentier's avatar
Justin Carpentier committed
10
#include "eigenpy/register.hpp"
11
12
13
14
15
16
17
18

namespace eigenpy
{
  namespace internal
  {
    template<typename T, int type_code = NumpyEquivalentType<T>::type_code>
    struct SpecialMethods
    {
Justin Carpentier's avatar
Justin Carpentier committed
19
20
21
22
23
24
25
      inline static void copyswap(void * /*dst*/, void * /*src*/, int /*swap*/, void * /*arr*/) /*{}*/;
      inline static PyObject * getitem(void * /*ip*/, void * /*ap*/) /*{ return NULL; }*/;
      inline static int setitem(PyObject * /*op*/, void * /*ov*/, void * /*ap*/) /*{ return -1; }*/;
      inline static void copyswapn(void * /*dest*/, long /*dstride*/, void * /*src*/,
                            long /*sstride*/, long /*n*/, int /*swap*/, void * /*arr*/) /*{}*/;
      inline static npy_bool nonzero(void * /*ip*/, void * /*array*/) /*{ return (npy_bool)false; }*/;
      inline static void dotfunc(void * /*ip0_*/, npy_intp /*is0*/, void * /*ip1_*/, npy_intp /*is1*/,
Justin Carpentier's avatar
Justin Carpentier committed
26
                          void * /*op*/, npy_intp /*n*/, void * /*arr*/);
27
28
29
30
31
32
//      static void cast(void * /*from*/, void * /*to*/, npy_intp /*n*/, void * /*fromarr*/, void * /*toarr*/) {};
    };
  
    template<typename T>
    struct SpecialMethods<T,NPY_USERDEF>
    {
Justin Carpentier's avatar
Justin Carpentier committed
33
      inline static void copyswap(void * dst, void * src, int swap, void * /*arr*/)
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
      {
//        std::cout << "copyswap" << std::endl;
        if (src != NULL)
        {
          T & t1 = *static_cast<T*>(dst);
          T & t2 = *static_cast<T*>(src);
          t1 = t2;
        }
          
        if(swap)
        {
          T & t1 = *static_cast<T*>(dst);
          T & t2 = *static_cast<T*>(src);
          std::swap(t1,t2);
        }
      }
      
Justin Carpentier's avatar
Justin Carpentier committed
51
      inline static PyObject * getitem(void * ip, void * ap)
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
      {
//        std::cout << "getitem" << std::endl;
        PyArrayObject * py_array = static_cast<PyArrayObject *>(ap);
        if((py_array==NULL) || PyArray_ISBEHAVED_RO(py_array))
        {
          T * elt_ptr = static_cast<T*>(ip);
          bp::object m(boost::ref(*elt_ptr));
          Py_INCREF(m.ptr());
          return m.ptr();
        }
        else
        {
          T * elt_ptr = static_cast<T*>(ip);
          bp::object m(boost::ref(*elt_ptr));
          Py_INCREF(m.ptr());
          return m.ptr();
        }
      }
      
Justin Carpentier's avatar
Justin Carpentier committed
71
      inline static int setitem(PyObject * src_obj, void * dest_ptr, void * array)
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
      {
//        std::cout << "setitem" << std::endl;
        if(array == NULL)
        {
          eigenpy::Exception("Cannot retrieve the type stored in the array.");
          return -1;
        }
        PyArrayObject * py_array = static_cast<PyArrayObject *>(array);
        PyArray_Descr * descr = PyArray_DTYPE(py_array);
        PyTypeObject * array_scalar_type = descr->typeobj;
        PyTypeObject * src_obj_type = Py_TYPE(src_obj);
        
        if(array_scalar_type != src_obj_type)
        {
          return -1;
        }
        
        bp::extract<T&> extract_src_obj(src_obj);
        if(!extract_src_obj.check())
        {
          std::stringstream ss;
          ss << "The input type is of wrong type. ";
          ss << "The expected type is " << bp::type_info(typeid(T)).name() << std::endl;
          eigenpy::Exception(ss.str());
          return -1;
        }
        
        const T & src = extract_src_obj();
        T & dest = *static_cast<T*>(dest_ptr);
        dest = src;

        return 0;
      }
      
Justin Carpentier's avatar
Justin Carpentier committed
106
107
      inline static void copyswapn(void * dst, long dstride, void * src, long sstride,
                                   long n, int swap, void * array)
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
      {
//        std::cout << "copyswapn" << std::endl;
        
        char *dstptr = static_cast<char*>(dst);
        char *srcptr = static_cast<char*>(src);
        
        PyArrayObject * py_array = static_cast<PyArrayObject *>(array);
        PyArray_CopySwapFunc * copyswap = PyArray_DESCR(py_array)->f->copyswap;
        
        for (npy_intp i = 0; i < n; i++)
        {
          copyswap(dstptr, srcptr, swap, array);
          dstptr += dstride;
          srcptr += sstride;
        }
      }
      
Justin Carpentier's avatar
Justin Carpentier committed
125
      inline static npy_bool nonzero(void * ip, void * array)
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
      {
//        std::cout << "nonzero" << std::endl;
        static const T ZeroValue = T(0);
        PyArrayObject * py_array = static_cast<PyArrayObject *>(array);
        if(py_array == NULL || PyArray_ISBEHAVED_RO(py_array))
        {
          const T & value = *static_cast<T*>(ip);
          return (npy_bool)(value != ZeroValue);
        }
        else
        {
          T tmp_value;
          PyArray_DESCR(py_array)->f->copyswap(&tmp_value, ip, PyArray_ISBYTESWAPPED(py_array),
                                               array);
          return (npy_bool)(tmp_value != ZeroValue);
        }
      }
      
Justin Carpentier's avatar
Justin Carpentier committed
144
145
      inline static void dotfunc(void * ip0_, npy_intp is0, void * ip1_, npy_intp is1,
                                 void * op, npy_intp n, void * /*arr*/)
Justin Carpentier's avatar
Justin Carpentier committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
      {
          T res = T(0);
          char *ip0 = (char*)ip0_, *ip1 = (char*)ip1_;
          npy_intp i;
          for(i = 0; i < n; i++)
          {
            
            res += *static_cast<T*>(static_cast<void*>(ip0))
            * *static_cast<T*>(static_cast<void*>(ip1));
            ip0 += is0;
            ip1 += is1;
          }
          *static_cast<T*>(op) = res;
      }
      
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
//      static void cast(void * from, void * to, npy_intp n, void * fromarr, void * toarr)
//      {
//      }

    };
  
  } // namespace internal

  template<typename Scalar>
  int registerNewType(PyTypeObject * py_type_ptr = NULL)
  {
    // Check whether the type is a Numpy native type.
    // In this case, the registration is not required.
    if(isNumpyNativeType<Scalar>())
      return NumpyEquivalentType<Scalar>::type_code;
    
    // Retrieve the registered type for the current Scalar
    if(py_type_ptr == NULL)
    { // retrive the type from Boost.Python
      py_type_ptr = Register::getPyType<Scalar>();
    }
    
    if(Register::isRegistered(py_type_ptr))
      return Register::getTypeCode(py_type_ptr); // the type is already registered
    
    PyArray_GetItemFunc * getitem = &internal::SpecialMethods<Scalar>::getitem;
    PyArray_SetItemFunc * setitem = &internal::SpecialMethods<Scalar>::setitem;
    PyArray_NonzeroFunc * nonzero = &internal::SpecialMethods<Scalar>::nonzero;
    PyArray_CopySwapFunc * copyswap = &internal::SpecialMethods<Scalar>::copyswap;
Justin Carpentier's avatar
Justin Carpentier committed
190
    PyArray_CopySwapNFunc * copyswapn = reinterpret_cast<PyArray_CopySwapNFunc*>(&internal::SpecialMethods<Scalar>::copyswapn);
Justin Carpentier's avatar
Justin Carpentier committed
191
    PyArray_DotFunc * dotfunc = &internal::SpecialMethods<Scalar>::dotfunc;
192
193
//    PyArray_CastFunc * cast = &internal::SpecialMethods<Scalar>::cast;
    
Justin Carpentier's avatar
Justin Carpentier committed
194
195
196
197
198
199
    int code = Register::registerNewType(py_type_ptr,
                                         &typeid(Scalar),
                                         sizeof(Scalar),
                                         getitem, setitem, nonzero,
                                         copyswap, copyswapn,
                                         dotfunc);
200
    
201
202
203
    call_PyArray_RegisterCanCast(call_PyArray_DescrFromType(NPY_OBJECT),
                                 code, NPY_NOSCALAR);
    
204
205
206
207
208
209
    return code;
  }
  
} // namespace eigenpy

#endif // __eigenpy_user_type_hpp__