Skip to content
Snippets Groups Projects
pythonutil.hpp 8.04 KiB
/* usadelndsoc
 *
 * Copyright © 2022 Pauli Virtanen
 *    @author Pauli Virtanen <pauli.t.virtanen@jyu.fi>
 *
 * SPDX-License-Identifier: AGPL-3.0-or-later
 */
#ifndef PYTHONUTIL_HPP_
#define PYTHONUTIL_HPP_

#include <stdexcept>
#include <array>

#include <pybind11/pybind11.h>
#include <numpy/ndarrayobject.h>

#include "array.hpp"

#ifdef __GNUG__
#define PY_VISIBILITY __attribute__((visibility("hidden")))
#else
#define PY_VISIBILITY
#endif

namespace array { namespace python PY_VISIBILITY {
    struct shape_error : public std::runtime_error
    {
        shape_error(const char *msg) : std::runtime_error(msg) {}
    };

    template <typename Scalar, typename Shape, size_t Alignment=0>
    class PyArrayView : public ArrayView<Scalar, Shape, Alignment>
    {
    private:
        typedef ArrayView<Scalar, Shape, Alignment> BaseType;
        static constexpr bool read_only = std::is_const<Scalar>::value;

        pybind11::object storage_;

        static std::array<size_t, Shape::ndim> get_shape(PyArrayObject *obj)
            {
                std::array<size_t, Shape::ndim> shape;

                if (PyArray_NDIM(obj) != Shape::ndim)
                    throw shape_error("input array has wrong number of dimensions");

                for (size_t i = 0; i < Shape::ndim; ++i) {
                    npy_intp dim = PyArray_DIM(obj, i);

                    if (dim < 0 || !(Shape::dim(i) == Dynamic || Shape::dim(i) == (size_t)dim))
                        throw shape_error("input array has invalid dimensions");
                    shape[i] = dim;
                }

                return shape;
            }

    public:
        PyArrayView()
            : BaseType(nullptr, Shape::size, Shape::shape())
            {}

        /* Only move constructor / assignment: resolve writeback only once */

        PyArrayView(PyArrayView&& other)
            : BaseType(other.data(), other.size(), other.shape()),
              storage_(pybind11::reinterpret_steal<pybind11::object>(other.storage_.release()))
            {}

        PyArrayView& operator=(PyArrayView&& other)
            {
                BaseType::operator=(other);
                storage_ = pybind11::reinterpret_steal<pybind11::object>(other.storage_.release());
                return *this;
            }

        ~PyArrayView()
            {
                if (storage_ && !read_only)
                    PyArray_ResolveWritebackIfCopy(reinterpret_cast<PyArrayObject *>(storage_.ptr()));
            }

        PyArrayView(pybind11::object obj)
            : BaseType(reinterpret_cast<Scalar *>(PyArray_DATA(reinterpret_cast<PyArrayObject *>(obj.ptr()))),
                       PyArray_SIZE(reinterpret_cast<PyArrayObject *>(obj.ptr())),
                       get_shape(reinterpret_cast<PyArrayObject *>(obj.ptr()))),
              storage_(obj)
            {
                /* steals reference */
            }
    };

    namespace detail {
        template <class Scalar> struct NumpyTypeMap {};
        template <> struct NumpyTypeMap<std::complex<double> > { const static int numpy_type = NPY_CDOUBLE; };
        template <> struct NumpyTypeMap<double> { const static int numpy_type = NPY_DOUBLE; };
        template <> struct NumpyTypeMap<int> { const static int numpy_type = NPY_INT; };
        template <> struct NumpyTypeMap<long> { const static int numpy_type = NPY_LONG; };
        template <> struct NumpyTypeMap<short> { const static int numpy_type = NPY_SHORT; };
        template <> struct NumpyTypeMap<unsigned int> { const static int numpy_type = NPY_UINT; };
        template <> struct NumpyTypeMap<unsigned long> { const static int numpy_type = NPY_ULONG; };
        template <> struct NumpyTypeMap<unsigned short> { const static int numpy_type = NPY_USHORT; };
        template <> struct NumpyTypeMap<Mask> { const static int numpy_type = NPY_UINT8; };

        template <typename Scalar, typename Shape, size_t Alignment=0>
        struct array_loader {
        private:
            typedef PyArrayView<Scalar, Shape, Alignment> ArrayType;
            static constexpr bool read_only = std::is_const<Scalar>::value;
            static constexpr auto array_type_name = pybind11::detail::const_name<ArrayType>();

        public:
            PYBIND11_TYPE_CASTER(ArrayType, array_type_name);

            bool load(pybind11::handle src, bool)
                {
                    /* Extract PyObject from handle */
                    PyObject *source = src.ptr();
                    pybind11::object obj = pybind11::reinterpret_steal<pybind11::object>(
                        PyArray_FromAny(
                            source,
                            PyArray_DescrFromType(detail::NumpyTypeMap<typename std::remove_const<Scalar>::type>::numpy_type),
                            Shape::ndim, Shape::ndim,
                            NPY_ARRAY_C_CONTIGUOUS | (read_only ? 0 : (NPY_ARRAY_WRITEABLE | NPY_ARRAY_WRITEBACKIFCOPY)),
                            NULL));
                    if (!obj)
                        return false;

                    try {
                        value = ArrayType(obj);
                    } catch (const shape_error& err) {
                        if (!read_only)
                            PyArray_DiscardWritebackIfCopy(reinterpret_cast<PyArrayObject *>(obj.ptr()));
                        return false;
                    }

                    return true;
                }
        };
        template <typename Scalar,
                  typename Shape,
                  typename Storage = typename Array<Scalar, Shape>::Storage,
                  size_t Alignment = Array<Scalar, Shape, Storage>::Alignment>
        struct array_caster {
        private:
            typedef Array<Scalar, Shape, Storage, Alignment> ArrayType;
            static constexpr bool read_only = std::is_const<Scalar>::value;
            static constexpr auto array_type_name = pybind11::detail::const_name<ArrayType>();

        public:
            PYBIND11_TYPE_CASTER(ArrayType, array_type_name);

            static pybind11::handle cast(ArrayType&& src,
                                         pybind11::return_value_policy /* policy */,
                                         pybind11::handle /* parent */)
                {
                    ArrayType moved = std::move(src);
                    Storage *store = new Storage(std::move(moved.storage()));

                    npy_intp shape[Shape::ndim];

                    for (size_t i = 0; i < Shape::ndim; ++i)
                        shape[i] = moved.dim(i);

                    pybind11::capsule base(store, [](void *o) { delete static_cast<Storage*>(o); });

                    PyObject *obj = PyArray_New(&PyArray_Type, Shape::ndim, shape, NumpyTypeMap<Scalar>::numpy_type,
                                                NULL, store->data(), 0, NPY_ARRAY_C_CONTIGUOUS | (read_only ? 0 : NPY_ARRAY_WRITEABLE), NULL);
                    if (!obj)
                        throw pybind11::error_already_set();

                    PyArray_SetBaseObject(reinterpret_cast<PyArrayObject *>(obj),
                                          base.inc_ref().ptr());

                    return obj;
                }
        };

        template <typename T>
        using is_py_array_view = std::is_base_of<PyArrayView<typename T::Scalar, typename T::Shape, T::Alignment>, T>;

        template <typename T>
        using is_stored_array = std::is_base_of<Array<typename T::Scalar, typename T::Shape, typename T::Storage, T::Alignment>, T>;
    }
} } // namespace array::python

namespace pybind11 { namespace detail {
    template <typename Type>
    struct type_caster<Type, enable_if_t<array::python::detail::is_py_array_view<Type>::value> >
        : public array::python::detail::array_loader<typename Type::Scalar, typename Type::Shape, Type::Alignment>
    {
    };

    template <typename Type>
    struct type_caster<Type, enable_if_t<array::python::detail::is_stored_array<Type>::value> >
        : public array::python::detail::array_caster<typename Type::Scalar, typename Type::Shape, typename Type::Storage, Type::Alignment>
    {
    };
} } // namespace pybind11::detail

#endif // PYTHONUTIL_HPP_