register.cpp 3.75 KB
Newer Older
1
/*
2
 * Copyright 2020-2021 INRIA
3
4
 */

Justin Carpentier's avatar
Justin Carpentier committed
5
#include "eigenpy/register.hpp"
6
7
8
9

namespace eigenpy
{

10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
  PyArray_Descr * Register::getPyArrayDescr(PyTypeObject * py_type_ptr)
  {
    MapDescr::iterator it = instance().py_array_descr_bindings.find(py_type_ptr);
    if(it != instance().py_array_descr_bindings.end())
      return it->second;
    else
      return NULL;
  }
  
  bool Register::isRegistered(PyTypeObject * py_type_ptr)
  {
    if(getPyArrayDescr(py_type_ptr) != NULL)
      return true;
    else
      return false;
  }
  
  int Register::getTypeCode(PyTypeObject * py_type_ptr)
  {
    MapCode::iterator it = instance().py_array_code_bindings.find(py_type_ptr);
    if(it != instance().py_array_code_bindings.end())
      return it->second;
    else
      return PyArray_TypeNum(py_type_ptr);
  }

  int Register::registerNewType(PyTypeObject * py_type_ptr,
                                const std::type_info * type_info_ptr,
                                const int type_size,
39
                                const int alignement,
40
41
42
43
44
                                PyArray_GetItemFunc * getitem,
                                PyArray_SetItemFunc * setitem,
                                PyArray_NonzeroFunc * nonzero,
                                PyArray_CopySwapFunc * copyswap,
                                PyArray_CopySwapNFunc * copyswapn,
45
                                PyArray_DotFunc * dotfunc,
Justin Carpentier's avatar
Justin Carpentier committed
46
                                PyArray_FillFunc * fill,
47
                                PyArray_FillWithScalarFunc * fillwithscalar)
48
  {
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    namespace bp = boost::python;
    bp::list bases(bp::handle<>(bp::borrowed(py_type_ptr->tp_bases)));
    bases.append((bp::handle<>(bp::borrowed(&PyGenericArrType_Type))));

    bp::tuple tp_bases_extended(bases);
    Py_INCREF(tp_bases_extended.ptr());
    py_type_ptr->tp_bases = tp_bases_extended.ptr();

    py_type_ptr->tp_flags &= ~Py_TPFLAGS_READY; // to force the rebuild
    if(PyType_Ready(py_type_ptr) < 0) // Force rebuilding of the __bases__ and mro
    {
      throw std::invalid_argument("PyType_Ready fails to initialize input type.");
    }

63
64
65
66
67
    PyArray_Descr * descr_ptr = new PyArray_Descr(*call_PyArray_DescrFromType(NPY_OBJECT));
    PyArray_Descr & descr = *descr_ptr;
    descr.typeobj = py_type_ptr;
    descr.kind = 'V';
    descr.byteorder = '=';
68
    descr.type = 'r';
69
    descr.elsize = type_size;
70
71
72
73
    descr.flags = NPY_NEEDS_PYAPI | NPY_USE_GETITEM | NPY_USE_SETITEM;
    descr.type_num = 0;
    descr.names = 0;
    descr.fields = 0;
74
    descr.alignment = alignement; //call_PyArray_DescrFromType(NPY_OBJECT)->alignment;
75
76
77
78
79
80
81
82
83
84
85
    
    PyArray_ArrFuncs * funcs_ptr = new PyArray_ArrFuncs;
    PyArray_ArrFuncs & funcs = *funcs_ptr;
    descr.f = funcs_ptr;
    call_PyArray_InitArrFuncs(funcs_ptr);
    funcs.getitem = getitem;
    funcs.setitem = setitem;
    funcs.nonzero = nonzero;
    funcs.copyswap = copyswap;
    funcs.copyswapn = copyswapn;
    funcs.dotfunc = dotfunc;
Justin Carpentier's avatar
Justin Carpentier committed
86
    funcs.fill = fill;
87
    funcs.fillwithscalar = fillwithscalar;
88
89
90
91
92
    //      f->cast = cast;
    
    const int code = call_PyArray_RegisterDataType(descr_ptr);
    assert(code >= 0 && "The return code should be positive");
    PyArray_Descr * new_descr = call_PyArray_DescrFromType(code);
93
94
95
96
97

    if(PyDict_SetItemString(py_type_ptr->tp_dict,"dtype",(PyObject*)descr_ptr) < 0)
    {
      throw std::invalid_argument("PyDict_SetItemString fails.");
    }
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    
    instance().type_to_py_type_bindings.insert(std::make_pair(type_info_ptr,py_type_ptr));
    instance().py_array_descr_bindings[py_type_ptr] = new_descr;
    instance().py_array_code_bindings[py_type_ptr] = code;
    
    //      PyArray_RegisterCanCast(descr,NPY_OBJECT,NPY_NOSCALAR);
    return code;
  }

  Register & Register::instance()
  {
    static Register self;
    return self;
  }
112
113

} // namespace eigenpy