/* usadelndsoc * * Copyright © 2022 Pauli Virtanen * @author Pauli Virtanen <pauli.t.virtanen@jyu.fi> * * SPDX-License-Identifier: AGPL-3.0-or-later */ #ifndef PYTHONUTIL_HPP_ #define PYTHONUTIL_HPP_ #include <stdexcept> #include <array> #include <pybind11/pybind11.h> #include <numpy/ndarrayobject.h> #include "array.hpp" #ifdef __GNUG__ #define PY_VISIBILITY __attribute__((visibility("hidden"))) #else #define PY_VISIBILITY #endif namespace array { namespace python PY_VISIBILITY { struct shape_error : public std::runtime_error { shape_error(const char *msg) : std::runtime_error(msg) {} }; template <typename Scalar, typename Shape, size_t Alignment=0> class PyArrayView : public ArrayView<Scalar, Shape, Alignment> { private: typedef ArrayView<Scalar, Shape, Alignment> BaseType; static constexpr bool read_only = std::is_const<Scalar>::value; pybind11::object storage_; static std::array<size_t, Shape::ndim> get_shape(PyArrayObject *obj) { std::array<size_t, Shape::ndim> shape; if (PyArray_NDIM(obj) != Shape::ndim) throw shape_error("input array has wrong number of dimensions"); for (size_t i = 0; i < Shape::ndim; ++i) { npy_intp dim = PyArray_DIM(obj, i); if (dim < 0 || !(Shape::dim(i) == Dynamic || Shape::dim(i) == (size_t)dim)) throw shape_error("input array has invalid dimensions"); shape[i] = dim; } return shape; } public: PyArrayView() : BaseType(nullptr, Shape::size, Shape::shape()) {} /* Only move constructor / assignment: resolve writeback only once */ PyArrayView(PyArrayView&& other) : BaseType(other.data(), other.size(), other.shape()), storage_(pybind11::reinterpret_steal<pybind11::object>(other.storage_.release())) {} PyArrayView& operator=(PyArrayView&& other) { BaseType::operator=(other); storage_ = pybind11::reinterpret_steal<pybind11::object>(other.storage_.release()); return *this; } ~PyArrayView() { if (storage_ && !read_only) PyArray_ResolveWritebackIfCopy(reinterpret_cast<PyArrayObject *>(storage_.ptr())); } PyArrayView(pybind11::object obj) : BaseType(reinterpret_cast<Scalar *>(PyArray_DATA(reinterpret_cast<PyArrayObject *>(obj.ptr()))), PyArray_SIZE(reinterpret_cast<PyArrayObject *>(obj.ptr())), get_shape(reinterpret_cast<PyArrayObject *>(obj.ptr()))), storage_(obj) { /* steals reference */ } }; namespace detail { template <class Scalar> struct NumpyTypeMap {}; template <> struct NumpyTypeMap<std::complex<double> > { const static int numpy_type = NPY_CDOUBLE; }; template <> struct NumpyTypeMap<double> { const static int numpy_type = NPY_DOUBLE; }; template <> struct NumpyTypeMap<int> { const static int numpy_type = NPY_INT; }; template <> struct NumpyTypeMap<long> { const static int numpy_type = NPY_LONG; }; template <> struct NumpyTypeMap<short> { const static int numpy_type = NPY_SHORT; }; template <> struct NumpyTypeMap<unsigned int> { const static int numpy_type = NPY_UINT; }; template <> struct NumpyTypeMap<unsigned long> { const static int numpy_type = NPY_ULONG; }; template <> struct NumpyTypeMap<unsigned short> { const static int numpy_type = NPY_USHORT; }; template <> struct NumpyTypeMap<Mask> { const static int numpy_type = NPY_UINT8; }; template <typename Scalar, typename Shape, size_t Alignment=0> struct array_loader { private: typedef PyArrayView<Scalar, Shape, Alignment> ArrayType; static constexpr bool read_only = std::is_const<Scalar>::value; static constexpr auto array_type_name = pybind11::detail::const_name<ArrayType>(); public: PYBIND11_TYPE_CASTER(ArrayType, array_type_name); bool load(pybind11::handle src, bool) { /* Extract PyObject from handle */ PyObject *source = src.ptr(); pybind11::object obj = pybind11::reinterpret_steal<pybind11::object>( PyArray_FromAny( source, PyArray_DescrFromType(detail::NumpyTypeMap<typename std::remove_const<Scalar>::type>::numpy_type), Shape::ndim, Shape::ndim, NPY_ARRAY_C_CONTIGUOUS | (read_only ? 0 : (NPY_ARRAY_WRITEABLE | NPY_ARRAY_WRITEBACKIFCOPY)), NULL)); if (!obj) return false; try { value = ArrayType(obj); } catch (const shape_error& err) { if (!read_only) PyArray_DiscardWritebackIfCopy(reinterpret_cast<PyArrayObject *>(obj.ptr())); return false; } return true; } }; template <typename Scalar, typename Shape, typename Storage = typename Array<Scalar, Shape>::Storage, size_t Alignment = Array<Scalar, Shape, Storage>::Alignment> struct array_caster { private: typedef Array<Scalar, Shape, Storage, Alignment> ArrayType; static constexpr bool read_only = std::is_const<Scalar>::value; static constexpr auto array_type_name = pybind11::detail::const_name<ArrayType>(); public: PYBIND11_TYPE_CASTER(ArrayType, array_type_name); static pybind11::handle cast(ArrayType&& src, pybind11::return_value_policy /* policy */, pybind11::handle /* parent */) { ArrayType moved = std::move(src); Storage *store = new Storage(std::move(moved.storage())); npy_intp shape[Shape::ndim]; for (size_t i = 0; i < Shape::ndim; ++i) shape[i] = moved.dim(i); pybind11::capsule base(store, [](void *o) { delete static_cast<Storage*>(o); }); PyObject *obj = PyArray_New(&PyArray_Type, Shape::ndim, shape, NumpyTypeMap<Scalar>::numpy_type, NULL, store->data(), 0, NPY_ARRAY_C_CONTIGUOUS | (read_only ? 0 : NPY_ARRAY_WRITEABLE), NULL); if (!obj) throw pybind11::error_already_set(); PyArray_SetBaseObject(reinterpret_cast<PyArrayObject *>(obj), base.inc_ref().ptr()); return obj; } }; template <typename T> using is_py_array_view = std::is_base_of<PyArrayView<typename T::Scalar, typename T::Shape, T::Alignment>, T>; template <typename T> using is_stored_array = std::is_base_of<Array<typename T::Scalar, typename T::Shape, typename T::Storage, T::Alignment>, T>; } } } // namespace array::python namespace pybind11 { namespace detail { template <typename Type> struct type_caster<Type, enable_if_t<array::python::detail::is_py_array_view<Type>::value> > : public array::python::detail::array_loader<typename Type::Scalar, typename Type::Shape, Type::Alignment> { }; template <typename Type> struct type_caster<Type, enable_if_t<array::python::detail::is_stored_array<Type>::value> > : public array::python::detail::array_caster<typename Type::Scalar, typename Type::Shape, typename Type::Storage, Type::Alignment> { }; } } // namespace pybind11::detail #endif // PYTHONUTIL_HPP_