pythonutil.hpp 8.04 KiB
/* usadelndsoc
*
* Copyright © 2022 Pauli Virtanen
* @author Pauli Virtanen <pauli.t.virtanen@jyu.fi>
*
* SPDX-License-Identifier: AGPL-3.0-or-later
*/
#ifndef PYTHONUTIL_HPP_
#define PYTHONUTIL_HPP_
#include <stdexcept>
#include <array>
#include <pybind11/pybind11.h>
#include <numpy/ndarrayobject.h>
#include "array.hpp"
#ifdef __GNUG__
#define PY_VISIBILITY __attribute__((visibility("hidden")))
#else
#define PY_VISIBILITY
#endif
namespace array { namespace python PY_VISIBILITY {
struct shape_error : public std::runtime_error
{
shape_error(const char *msg) : std::runtime_error(msg) {}
};
template <typename Scalar, typename Shape, size_t Alignment=0>
class PyArrayView : public ArrayView<Scalar, Shape, Alignment>
{
private:
typedef ArrayView<Scalar, Shape, Alignment> BaseType;
static constexpr bool read_only = std::is_const<Scalar>::value;
pybind11::object storage_;
static std::array<size_t, Shape::ndim> get_shape(PyArrayObject *obj)
{
std::array<size_t, Shape::ndim> shape;
if (PyArray_NDIM(obj) != Shape::ndim)
throw shape_error("input array has wrong number of dimensions");
for (size_t i = 0; i < Shape::ndim; ++i) {
npy_intp dim = PyArray_DIM(obj, i);
if (dim < 0 || !(Shape::dim(i) == Dynamic || Shape::dim(i) == (size_t)dim))
throw shape_error("input array has invalid dimensions");
shape[i] = dim;
}
return shape;
}
public:
PyArrayView()
: BaseType(nullptr, Shape::size, Shape::shape())
{}
/* Only move constructor / assignment: resolve writeback only once */
PyArrayView(PyArrayView&& other)
: BaseType(other.data(), other.size(), other.shape()),
storage_(pybind11::reinterpret_steal<pybind11::object>(other.storage_.release()))
{}
PyArrayView& operator=(PyArrayView&& other)
{
BaseType::operator=(other);
storage_ = pybind11::reinterpret_steal<pybind11::object>(other.storage_.release());
return *this;
}
~PyArrayView()
{
if (storage_ && !read_only)
PyArray_ResolveWritebackIfCopy(reinterpret_cast<PyArrayObject *>(storage_.ptr()));
}
PyArrayView(pybind11::object obj)
: BaseType(reinterpret_cast<Scalar *>(PyArray_DATA(reinterpret_cast<PyArrayObject *>(obj.ptr()))),
PyArray_SIZE(reinterpret_cast<PyArrayObject *>(obj.ptr())),
get_shape(reinterpret_cast<PyArrayObject *>(obj.ptr()))),
storage_(obj)
{
/* steals reference */
}
};
namespace detail {
template <class Scalar> struct NumpyTypeMap {};
template <> struct NumpyTypeMap<std::complex<double> > { const static int numpy_type = NPY_CDOUBLE; };
template <> struct NumpyTypeMap<double> { const static int numpy_type = NPY_DOUBLE; };
template <> struct NumpyTypeMap<int> { const static int numpy_type = NPY_INT; };
template <> struct NumpyTypeMap<long> { const static int numpy_type = NPY_LONG; };
template <> struct NumpyTypeMap<short> { const static int numpy_type = NPY_SHORT; };
template <> struct NumpyTypeMap<unsigned int> { const static int numpy_type = NPY_UINT; };
template <> struct NumpyTypeMap<unsigned long> { const static int numpy_type = NPY_ULONG; };
template <> struct NumpyTypeMap<unsigned short> { const static int numpy_type = NPY_USHORT; };
template <> struct NumpyTypeMap<Mask> { const static int numpy_type = NPY_UINT8; };
template <typename Scalar, typename Shape, size_t Alignment=0>
struct array_loader {
private:
typedef PyArrayView<Scalar, Shape, Alignment> ArrayType;
static constexpr bool read_only = std::is_const<Scalar>::value;
static constexpr auto array_type_name = pybind11::detail::const_name<ArrayType>();
public:
PYBIND11_TYPE_CASTER(ArrayType, array_type_name);
bool load(pybind11::handle src, bool)
{
/* Extract PyObject from handle */
PyObject *source = src.ptr();
pybind11::object obj = pybind11::reinterpret_steal<pybind11::object>(
PyArray_FromAny(
source,
PyArray_DescrFromType(detail::NumpyTypeMap<typename std::remove_const<Scalar>::type>::numpy_type),
Shape::ndim, Shape::ndim,
NPY_ARRAY_C_CONTIGUOUS | (read_only ? 0 : (NPY_ARRAY_WRITEABLE | NPY_ARRAY_WRITEBACKIFCOPY)),
NULL));
if (!obj)
return false;
try {
value = ArrayType(obj);
} catch (const shape_error& err) {
if (!read_only)
PyArray_DiscardWritebackIfCopy(reinterpret_cast<PyArrayObject *>(obj.ptr()));
return false;
}
return true;
}
};
template <typename Scalar,
typename Shape,
typename Storage = typename Array<Scalar, Shape>::Storage,
size_t Alignment = Array<Scalar, Shape, Storage>::Alignment>
struct array_caster {
private:
typedef Array<Scalar, Shape, Storage, Alignment> ArrayType;
static constexpr bool read_only = std::is_const<Scalar>::value;
static constexpr auto array_type_name = pybind11::detail::const_name<ArrayType>();
public:
PYBIND11_TYPE_CASTER(ArrayType, array_type_name);
static pybind11::handle cast(ArrayType&& src,
pybind11::return_value_policy /* policy */,
pybind11::handle /* parent */)
{
ArrayType moved = std::move(src);
Storage *store = new Storage(std::move(moved.storage()));
npy_intp shape[Shape::ndim];
for (size_t i = 0; i < Shape::ndim; ++i)
shape[i] = moved.dim(i);
pybind11::capsule base(store, [](void *o) { delete static_cast<Storage*>(o); });
PyObject *obj = PyArray_New(&PyArray_Type, Shape::ndim, shape, NumpyTypeMap<Scalar>::numpy_type,
NULL, store->data(), 0, NPY_ARRAY_C_CONTIGUOUS | (read_only ? 0 : NPY_ARRAY_WRITEABLE), NULL);
if (!obj)
throw pybind11::error_already_set();
PyArray_SetBaseObject(reinterpret_cast<PyArrayObject *>(obj),
base.inc_ref().ptr());
return obj;
}
};
template <typename T>
using is_py_array_view = std::is_base_of<PyArrayView<typename T::Scalar, typename T::Shape, T::Alignment>, T>;
template <typename T>
using is_stored_array = std::is_base_of<Array<typename T::Scalar, typename T::Shape, typename T::Storage, T::Alignment>, T>;
}
} } // namespace array::python
namespace pybind11 { namespace detail {
template <typename Type>
struct type_caster<Type, enable_if_t<array::python::detail::is_py_array_view<Type>::value> >
: public array::python::detail::array_loader<typename Type::Scalar, typename Type::Shape, Type::Alignment>
{
};
template <typename Type>
struct type_caster<Type, enable_if_t<array::python::detail::is_stored_array<Type>::value> >
: public array::python::detail::array_caster<typename Type::Scalar, typename Type::Shape, typename Type::Storage, Type::Alignment>
{
};
} } // namespace pybind11::detail
#endif // PYTHONUTIL_HPP_