From 14b4732faddcaaecda434ad31b14c95fa5e64fbe Mon Sep 17 00:00:00 2001
From: Pauli Virtanen <pauli.t.virtanen@jyu.fi>
Date: Mon, 8 Aug 2022 15:19:30 +0300
Subject: [PATCH] pythonutil: deal with const + adjust to changes

---
 src/pythonutil.hpp | 26 +++++++++++++++++---------
 1 file changed, 17 insertions(+), 9 deletions(-)

diff --git a/src/pythonutil.hpp b/src/pythonutil.hpp
index 53bfd24..4d54cd8 100644
--- a/src/pythonutil.hpp
+++ b/src/pythonutil.hpp
@@ -26,6 +26,7 @@ namespace array { namespace python PY_VISIBILITY {
     {
     private:
         typedef ArrayView<Scalar, Shape, Alignment, BoundsCheck> BaseType;
+        static constexpr bool read_only = std::is_const<Scalar>::value;
 
         pybind11::object storage_;
 
@@ -68,7 +69,7 @@ namespace array { namespace python PY_VISIBILITY {
 
         ~PyArrayView()
             {
-                if (storage_)
+                if (storage_ && !read_only)
                     PyArray_ResolveWritebackIfCopy(reinterpret_cast<PyArrayObject *>(storage_.ptr()));
             }
 
@@ -94,6 +95,7 @@ namespace array { namespace python PY_VISIBILITY {
         struct array_loader {
         private:
             typedef PyArrayView<Scalar, Shape, Alignment, BoundsCheck> ArrayType;
+            static constexpr bool read_only = std::is_const<Scalar>::value;
             static constexpr auto array_type_name = pybind11::detail::const_name<ArrayType>();
 
         public:
@@ -106,9 +108,9 @@ namespace array { namespace python PY_VISIBILITY {
                     pybind11::object obj = pybind11::reinterpret_steal<pybind11::object>(
                         PyArray_FromAny(
                             source,
-                            PyArray_DescrFromType(detail::NumpyTypeMap<Scalar>::numpy_type),
+                            PyArray_DescrFromType(detail::NumpyTypeMap<typename std::remove_const<Scalar>::type>::numpy_type),
                             Shape::ndim, Shape::ndim,
-                            NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_WRITEBACKIFCOPY,
+                            NPY_ARRAY_C_CONTIGUOUS | (read_only ? 0 : (NPY_ARRAY_WRITEABLE | NPY_ARRAY_WRITEBACKIFCOPY)),
                             NULL));
                     if (!obj)
                         return false;
@@ -116,7 +118,8 @@ namespace array { namespace python PY_VISIBILITY {
                     try {
                         value = ArrayType(obj);
                     } catch (const shape_error& err) {
-                        PyArray_DiscardWritebackIfCopy(reinterpret_cast<PyArrayObject *>(obj.ptr()));
+                        if (!read_only)
+                            PyArray_DiscardWritebackIfCopy(reinterpret_cast<PyArrayObject *>(obj.ptr()));
                         return false;
                     }
 
@@ -124,10 +127,15 @@ namespace array { namespace python PY_VISIBILITY {
                 }
         };
 
-        template <typename Scalar, typename Shape, size_t Alignment=0, bool BoundsCheck=true, typename Storage=std::vector<Scalar> >
+        template <typename Scalar,
+                  typename Shape,
+                  typename Storage = typename Array<Scalar, Shape>::Storage,
+                  size_t Alignment = Array<Scalar, Shape, Storage>::Alignment,
+                  bool BoundsCheck = Array<Scalar, Shape, Storage, Alignment>::BoundsCheck>
         struct array_caster {
         private:
-            typedef StoredArray<Scalar, Shape, Alignment, BoundsCheck, Storage> ArrayType;
+            typedef Array<Scalar, Shape, Storage, Alignment, BoundsCheck> ArrayType;
+            static constexpr bool read_only = std::is_const<Scalar>::value;
             static constexpr auto array_type_name = pybind11::detail::const_name<ArrayType>();
 
         public:
@@ -148,7 +156,7 @@ namespace array { namespace python PY_VISIBILITY {
                     pybind11::capsule base(store, [](void *o) { delete static_cast<Storage*>(o); });
 
                     PyObject *obj = PyArray_New(&PyArray_Type, Shape::ndim, shape, NumpyTypeMap<Scalar>::numpy_type,
-                                                     NULL, store->data(), 0, 0, NULL);
+                                                NULL, store->data(), 0, NPY_ARRAY_C_CONTIGUOUS | (read_only ? 0 : NPY_ARRAY_WRITEABLE), NULL);
                     if (!obj)
                         throw pybind11::error_already_set();
 
@@ -163,7 +171,7 @@ namespace array { namespace python PY_VISIBILITY {
         using is_py_array_view = std::is_base_of<PyArrayView<typename T::Scalar, typename T::Shape, T::Alignment, T::BoundsCheck>, T>;
 
         template <typename T>
-        using is_stored_array = std::is_base_of<StoredArray<typename T::Scalar, typename T::Shape, T::Alignment, T::BoundsCheck, typename T::Storage>, T>;
+        using is_stored_array = std::is_base_of<Array<typename T::Scalar, typename T::Shape, typename T::Storage, T::Alignment, T::BoundsCheck>, T>;
     }
 } } // namespace array::python
 
@@ -176,7 +184,7 @@ namespace pybind11 { namespace detail {
 
     template <typename Type>
     struct type_caster<Type, enable_if_t<array::python::detail::is_stored_array<Type>::value> >
-        : public array::python::detail::array_caster<typename Type::Scalar, typename Type::Shape, Type::Alignment, Type::BoundsCheck, typename Type::Storage>
+        : public array::python::detail::array_caster<typename Type::Scalar, typename Type::Shape, typename Type::Storage, Type::Alignment, Type::BoundsCheck>
     {
     };
 } } // namespace pybind11::detail
-- 
GitLab