test_user_type.py 964 Bytes
Newer Older
1
import user_type
2
import numpy as np
3
#from packaging import version
4
5
6
7
8

rows = 10
cols = 20

def test(mat):
9
  mat[:] = mat.dtype.type(10.)
10
11
12
13
  mat_copy = mat.copy()
  assert (mat == mat_copy).all()
  assert not (mat != mat_copy).all()

14
15
16
17
18
#  if version.parse(np.__version__) >= version.parse("1.21.0"): # check if it fixes for new versio of NumPy 
#    mat.fill(mat.dtype.type(20.))
#    mat_copy = mat.copy()
#    assert((mat == mat_copy).all())
#    assert(not (mat != mat_copy).all())
19

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
  mat_op = mat + mat
  mat_op = mat.copy(order='F') + mat.copy(order='C')
  
  mat_op = mat - mat
  mat_op = mat * mat
  mat_op = mat.dot(mat.T)
  mat_op = mat / mat

  mat_op = -mat;

  assert (mat >= mat).all()
  assert (mat <= mat).all()
  assert not (mat > mat).all()
  assert not (mat < mat).all()

mat = user_type.create_double(rows,cols)
test(mat)

mat = user_type.create_float(rows,cols)
test(mat)
40
41
42
43

v = user_type.CustomDouble(1)
a = np.array(v)
assert type(v) == a.dtype.type