From dddbf6d2db938e42ad9fcf8511de1952175dded1 Mon Sep 17 00:00:00 2001 From: Justin Carpentier <justin.carpentier@inria.fr> Date: Wed, 7 Dec 2022 19:29:23 +0100 Subject: [PATCH] test/ref: enforce testing of vector blocks --- unittest/eigen_ref.cpp | 9 +++++++++ unittest/python/test_eigen_ref.py | 15 +++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/unittest/eigen_ref.cpp b/unittest/eigen_ref.cpp index 9c334d10..6934f715 100644 --- a/unittest/eigen_ref.cpp +++ b/unittest/eigen_ref.cpp @@ -30,6 +30,12 @@ void setOnes(Eigen::Ref<MatType> mat) { mat.setOnes(); } +template <typename VecType> +VecType copyVectorFromConstRef(const Eigen::Ref<const VecType> vec) { + std::cout << "copyVectorFromConstRef::vec: " << vec.transpose() << std::endl; + return VecType(vec); +} + template <typename MatType> Eigen::Ref<MatType> getBlock(Eigen::Ref<MatType> mat, Eigen::DenseIndex i, Eigen::DenseIndex j, Eigen::DenseIndex n, @@ -120,6 +126,9 @@ BOOST_PYTHON_MODULE(eigen_ref) { bp::def("getBlock", &getBlock<MatrixXd>); bp::def("editBlock", &editBlock<MatrixXd>); + bp::def("copyVectorFromConstRef", ©VectorFromConstRef<VectorXd>); + bp::def("copyRowVectorFromConstRef", ©VectorFromConstRef<RowVectorXd>); + bp::class_<modify_block_wrap, boost::noncopyable>("modify_block", bp::init<>()) .def_readonly("J", &modify_block::J) diff --git a/unittest/python/test_eigen_ref.py b/unittest/python/test_eigen_ref.py index fe43e501..8d466f1d 100644 --- a/unittest/python/test_eigen_ref.py +++ b/unittest/python/test_eigen_ref.py @@ -9,6 +9,8 @@ from eigen_ref import ( editBlock, modify_block, has_ref_member, + copyVectorFromConstRef, + copyRowVectorFromConstRef, ) @@ -44,6 +46,18 @@ def test_create_ref_to_static(mat): assert np.array_equal(A_ref, A_ref2) +def test_read_block(): + data = np.array([[0, 0.2, 0.3, 0.4], [0, 1, 0, 0], [0, 0, 0, 0], [1, 0, 0, 0]]) + + data_strided = data[:, 0] + + data_strided_copy = copyVectorFromConstRef(data_strided) + assert np.all(data_strided == data_strided_copy) + + data_strided_copy = copyRowVectorFromConstRef(data_strided) + assert np.all(data_strided == data_strided_copy) + + def test_create_ref(mat): print("[asRef(mat)]") ref = asRef(mat) @@ -116,6 +130,7 @@ def do_test(mat): test_create_const_ref(mat) test_create_ref(mat) test_edit_block(rows, cols) + test_read_block() print("=" * 10) -- GitLab