test_user_type.py 1.33 KB
Newer Older
1
import user_type
2
import numpy as np
3
#from packaging import version
4
5
6
7

rows = 10
cols = 20

8
9
def test(dtype):
  mat = np.ones((rows,cols),dtype=dtype)
10
11
12
13
  mat_copy = mat.copy()
  assert (mat == mat_copy).all()
  assert not (mat != mat_copy).all()

14
15
16
17
18
#  if version.parse(np.__version__) >= version.parse("1.21.0"): # check if it fixes for new versio of NumPy 
#    mat.fill(mat.dtype.type(20.))
#    mat_copy = mat.copy()
#    assert((mat == mat_copy).all())
#    assert(not (mat != mat_copy).all())
19

20
21
22
23
24
25
26
27
  mat_op = mat + mat
  mat_op = mat.copy(order='F') + mat.copy(order='C')
  
  mat_op = mat - mat
  mat_op = mat * mat
  mat_op = mat.dot(mat.T)
  mat_op = mat / mat

28
  mat_op = -mat
29
30
31
32
33
34

  assert (mat >= mat).all()
  assert (mat <= mat).all()
  assert not (mat > mat).all()
  assert not (mat < mat).all()

35
36
def test_cast(from_dtype,to_dtype):
  np.can_cast(from_dtype,to_dtype)
37

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
  from_mat = np.zeros((rows,cols),dtype=from_dtype)
  to_mat = from_mat.astype(dtype=to_dtype)
  
test(user_type.CustomDouble)

test_cast(user_type.CustomDouble,np.double)
test_cast(np.double,user_type.CustomDouble)

test_cast(user_type.CustomDouble,np.int64)
test_cast(np.int64,user_type.CustomDouble)

test_cast(user_type.CustomDouble,np.int32)
test_cast(np.int32,user_type.CustomDouble)

test(user_type.CustomFloat)
53
54
55
56

v = user_type.CustomDouble(1)
a = np.array(v)
assert type(v) == a.dtype.type