diff --git a/unittest/eigen_ref.cpp b/unittest/eigen_ref.cpp index 9c334d10bbb748ed0bd21f94c6580f6787b3dd3e..6934f71575c460c79adb2353b931b6603aef27a9 100644 --- a/unittest/eigen_ref.cpp +++ b/unittest/eigen_ref.cpp @@ -30,6 +30,12 @@ void setOnes(Eigen::Ref<MatType> mat) { mat.setOnes(); } +template <typename VecType> +VecType copyVectorFromConstRef(const Eigen::Ref<const VecType> vec) { + std::cout << "copyVectorFromConstRef::vec: " << vec.transpose() << std::endl; + return VecType(vec); +} + template <typename MatType> Eigen::Ref<MatType> getBlock(Eigen::Ref<MatType> mat, Eigen::DenseIndex i, Eigen::DenseIndex j, Eigen::DenseIndex n, @@ -120,6 +126,9 @@ BOOST_PYTHON_MODULE(eigen_ref) { bp::def("getBlock", &getBlock<MatrixXd>); bp::def("editBlock", &editBlock<MatrixXd>); + bp::def("copyVectorFromConstRef", ©VectorFromConstRef<VectorXd>); + bp::def("copyRowVectorFromConstRef", ©VectorFromConstRef<RowVectorXd>); + bp::class_<modify_block_wrap, boost::noncopyable>("modify_block", bp::init<>()) .def_readonly("J", &modify_block::J) diff --git a/unittest/python/test_eigen_ref.py b/unittest/python/test_eigen_ref.py index fe43e501fd28abc13c7a4fab257f7ea7ccab1faa..8d466f1d85d2d32e9a3cf9dc80aecd19a15303f1 100644 --- a/unittest/python/test_eigen_ref.py +++ b/unittest/python/test_eigen_ref.py @@ -9,6 +9,8 @@ from eigen_ref import ( editBlock, modify_block, has_ref_member, + copyVectorFromConstRef, + copyRowVectorFromConstRef, ) @@ -44,6 +46,18 @@ def test_create_ref_to_static(mat): assert np.array_equal(A_ref, A_ref2) +def test_read_block(): + data = np.array([[0, 0.2, 0.3, 0.4], [0, 1, 0, 0], [0, 0, 0, 0], [1, 0, 0, 0]]) + + data_strided = data[:, 0] + + data_strided_copy = copyVectorFromConstRef(data_strided) + assert np.all(data_strided == data_strided_copy) + + data_strided_copy = copyRowVectorFromConstRef(data_strided) + assert np.all(data_strided == data_strided_copy) + + def test_create_ref(mat): print("[asRef(mat)]") ref = asRef(mat) @@ -116,6 +130,7 @@ def do_test(mat): test_create_const_ref(mat) test_create_ref(mat) test_edit_block(rows, cols) + test_read_block() print("=" * 10)