diff --git a/src/pythonutil.hpp b/src/pythonutil.hpp index 53bfd248c144e2c172bbf499d4e94205b3d4026d..4d54cd87df0fd5d6c44e5922f2942479f738f9f8 100644 --- a/src/pythonutil.hpp +++ b/src/pythonutil.hpp @@ -26,6 +26,7 @@ namespace array { namespace python PY_VISIBILITY { { private: typedef ArrayView<Scalar, Shape, Alignment, BoundsCheck> BaseType; + static constexpr bool read_only = std::is_const<Scalar>::value; pybind11::object storage_; @@ -68,7 +69,7 @@ namespace array { namespace python PY_VISIBILITY { ~PyArrayView() { - if (storage_) + if (storage_ && !read_only) PyArray_ResolveWritebackIfCopy(reinterpret_cast<PyArrayObject *>(storage_.ptr())); } @@ -94,6 +95,7 @@ namespace array { namespace python PY_VISIBILITY { struct array_loader { private: typedef PyArrayView<Scalar, Shape, Alignment, BoundsCheck> ArrayType; + static constexpr bool read_only = std::is_const<Scalar>::value; static constexpr auto array_type_name = pybind11::detail::const_name<ArrayType>(); public: @@ -106,9 +108,9 @@ namespace array { namespace python PY_VISIBILITY { pybind11::object obj = pybind11::reinterpret_steal<pybind11::object>( PyArray_FromAny( source, - PyArray_DescrFromType(detail::NumpyTypeMap<Scalar>::numpy_type), + PyArray_DescrFromType(detail::NumpyTypeMap<typename std::remove_const<Scalar>::type>::numpy_type), Shape::ndim, Shape::ndim, - NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_WRITEBACKIFCOPY, + NPY_ARRAY_C_CONTIGUOUS | (read_only ? 0 : (NPY_ARRAY_WRITEABLE | NPY_ARRAY_WRITEBACKIFCOPY)), NULL)); if (!obj) return false; @@ -116,7 +118,8 @@ namespace array { namespace python PY_VISIBILITY { try { value = ArrayType(obj); } catch (const shape_error& err) { - PyArray_DiscardWritebackIfCopy(reinterpret_cast<PyArrayObject *>(obj.ptr())); + if (!read_only) + PyArray_DiscardWritebackIfCopy(reinterpret_cast<PyArrayObject *>(obj.ptr())); return false; } @@ -124,10 +127,15 @@ namespace array { namespace python PY_VISIBILITY { } }; - template <typename Scalar, typename Shape, size_t Alignment=0, bool BoundsCheck=true, typename Storage=std::vector<Scalar> > + template <typename Scalar, + typename Shape, + typename Storage = typename Array<Scalar, Shape>::Storage, + size_t Alignment = Array<Scalar, Shape, Storage>::Alignment, + bool BoundsCheck = Array<Scalar, Shape, Storage, Alignment>::BoundsCheck> struct array_caster { private: - typedef StoredArray<Scalar, Shape, Alignment, BoundsCheck, Storage> ArrayType; + typedef Array<Scalar, Shape, Storage, Alignment, BoundsCheck> ArrayType; + static constexpr bool read_only = std::is_const<Scalar>::value; static constexpr auto array_type_name = pybind11::detail::const_name<ArrayType>(); public: @@ -148,7 +156,7 @@ namespace array { namespace python PY_VISIBILITY { pybind11::capsule base(store, [](void *o) { delete static_cast<Storage*>(o); }); PyObject *obj = PyArray_New(&PyArray_Type, Shape::ndim, shape, NumpyTypeMap<Scalar>::numpy_type, - NULL, store->data(), 0, 0, NULL); + NULL, store->data(), 0, NPY_ARRAY_C_CONTIGUOUS | (read_only ? 0 : NPY_ARRAY_WRITEABLE), NULL); if (!obj) throw pybind11::error_already_set(); @@ -163,7 +171,7 @@ namespace array { namespace python PY_VISIBILITY { using is_py_array_view = std::is_base_of<PyArrayView<typename T::Scalar, typename T::Shape, T::Alignment, T::BoundsCheck>, T>; template <typename T> - using is_stored_array = std::is_base_of<StoredArray<typename T::Scalar, typename T::Shape, T::Alignment, T::BoundsCheck, typename T::Storage>, T>; + using is_stored_array = std::is_base_of<Array<typename T::Scalar, typename T::Shape, typename T::Storage, T::Alignment, T::BoundsCheck>, T>; } } } // namespace array::python @@ -176,7 +184,7 @@ namespace pybind11 { namespace detail { template <typename Type> struct type_caster<Type, enable_if_t<array::python::detail::is_stored_array<Type>::value> > - : public array::python::detail::array_caster<typename Type::Scalar, typename Type::Shape, Type::Alignment, Type::BoundsCheck, typename Type::Storage> + : public array::python::detail::array_caster<typename Type::Scalar, typename Type::Shape, typename Type::Storage, Type::Alignment, Type::BoundsCheck> { }; } } // namespace pybind11::detail