Skip to content
Snippets Groups Projects
Commit ca7e3898 authored by patavirt's avatar patavirt
Browse files

pythonutil: support returning stored arrays (move semantics)

parent df5075f0
No related branches found
No related tags found
No related merge requests found
......@@ -11,14 +11,18 @@ using namespace array::python;
typedef pybind11::detail::type_caster<array::ArrayView<double, array::Shape<4> > > xcaster;
double add(const array::python::PythonArray<double, Shape<4> > array, PythonArray<double, Shape<1> > out)
array::StoredArray<double, Shape<1> >
add(array::python::PyArrayView<double, Shape<4> > array)
{
StoredArray<double, Shape<1> > out;
double sum = 0;
for (size_t i = 0; i < 4; ++i)
sum += array(i);
out(0u) = -sum;
return sum;
out(0u) = sum;
return out;
}
PYBIND11_MODULE(_core, m) {
......
......@@ -22,7 +22,7 @@ namespace array { namespace PYTHON_NAMESPACE {
};
template <typename Scalar, typename Shape, size_t Alignment=0, bool BoundsCheck=true>
class PythonArray : public ArrayView<Scalar, Shape, Alignment, BoundsCheck>
class PyArrayView : public ArrayView<Scalar, Shape, Alignment, BoundsCheck>
{
private:
typedef ArrayView<Scalar, Shape, Alignment, BoundsCheck> BaseType;
......@@ -48,31 +48,31 @@ namespace array { namespace PYTHON_NAMESPACE {
}
public:
PythonArray()
PyArrayView()
: BaseType(nullptr, Shape::size, Shape::shape())
{}
/* Only move constructor / assignment: resolve writeback only once */
PythonArray(PythonArray&& other)
PyArrayView(PyArrayView&& other)
: BaseType(other.data(), other.size(), other.shape()),
storage_(pybind11::reinterpret_steal<pybind11::object>(other.storage_.release()))
{}
PythonArray& operator=(PythonArray&& other)
PyArrayView& operator=(PyArrayView&& other)
{
BaseType::operator=(other);
storage_ = pybind11::reinterpret_steal<pybind11::object>(other.storage_.release());
return *this;
}
~PythonArray()
~PyArrayView()
{
if (storage_)
PyArray_ResolveWritebackIfCopy(reinterpret_cast<PyArrayObject *>(storage_.ptr()));
}
PythonArray(pybind11::object obj)
PyArrayView(pybind11::object obj)
: BaseType(reinterpret_cast<Scalar *>(PyArray_DATA(reinterpret_cast<PyArrayObject *>(obj.ptr()))),
PyArray_SIZE(reinterpret_cast<PyArrayObject *>(obj.ptr())),
get_shape(reinterpret_cast<PyArrayObject *>(obj.ptr()))),
......@@ -80,8 +80,6 @@ namespace array { namespace PYTHON_NAMESPACE {
{
/* steals reference */
}
static constexpr auto name = pybind11::detail::const_name<BaseType>();
};
namespace detail {
......@@ -93,12 +91,13 @@ namespace array { namespace PYTHON_NAMESPACE {
template <> struct NumpyTypeMap<short> { const static int numpy_type = NPY_SHORT; };
template <typename Scalar, typename Shape, size_t Alignment=0, bool BoundsCheck=true>
struct array_caster {
struct array_loader {
private:
typedef PythonArray<Scalar, Shape, Alignment, BoundsCheck> ArrayType;
typedef PyArrayView<Scalar, Shape, Alignment, BoundsCheck> ArrayType;
static constexpr auto array_type_name = pybind11::detail::const_name<ArrayType>();
public:
PYBIND11_TYPE_CASTER(ArrayType, ArrayType::name);
PYBIND11_TYPE_CASTER(ArrayType, array_type_name);
bool load(pybind11::handle src, bool)
{
......@@ -125,14 +124,60 @@ namespace array { namespace PYTHON_NAMESPACE {
}
};
template <typename Scalar, typename Shape, size_t Alignment=0, bool BoundsCheck=true, typename StorageType=Scalar>
struct array_caster {
private:
typedef StoredArray<Scalar, Shape, Alignment, BoundsCheck> ArrayType;
static constexpr auto array_type_name = pybind11::detail::const_name<ArrayType>();
public:
PYBIND11_TYPE_CASTER(ArrayType, array_type_name);
static pybind11::handle cast(ArrayType&& src,
pybind11::return_value_policy /* policy */,
pybind11::handle /* parent */)
{
typedef std::vector<StorageType> Storage;
ArrayType moved = std::move(src);
Storage *store = new Storage(std::move(moved.storage()));
npy_intp shape[Shape::ndim];
for (size_t i = 0; i < Shape::ndim; ++i)
shape[i] = moved.dim(i);
pybind11::capsule base(store, [](void *o) { delete static_cast<Storage*>(o); });
PyObject *obj = PyArray_New(&PyArray_Type, Shape::ndim, shape, NumpyTypeMap<Scalar>::numpy_type,
NULL, store->data(), 0, 0, NULL);
if (!obj)
throw pybind11::error_already_set();
PyArray_SetBaseObject(reinterpret_cast<PyArrayObject *>(obj),
base.inc_ref().ptr());
return obj;
}
};
template <typename T>
using is_py_array = std::is_base_of<PythonArray<typename T::Scalar, typename T::Shape, T::Alignment, T::BoundsCheck>, T>;
using is_py_array_view = std::is_base_of<PyArrayView<typename T::Scalar, typename T::Shape, T::Alignment, T::BoundsCheck>, T>;
template <typename T>
using is_stored_array = std::is_base_of<StoredArray<typename T::Scalar, typename T::Shape, T::Alignment, T::BoundsCheck>, T>;
}
} } // namespace array::python
namespace pybind11 { namespace detail {
template <typename Type>
struct type_caster<Type, enable_if_t<array::python::detail::is_py_array<Type>::value> >
struct type_caster<Type, enable_if_t<array::python::detail::is_py_array_view<Type>::value> >
: public array::python::detail::array_loader<typename Type::Scalar, typename Type::Shape, Type::Alignment, Type::BoundsCheck>
{
};
template <typename Type>
struct type_caster<Type, enable_if_t<array::python::detail::is_stored_array<Type>::value> >
: public array::python::detail::array_caster<typename Type::Scalar, typename Type::Shape, Type::Alignment, Type::BoundsCheck>
{
};
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment