Skip to content
Snippets Groups Projects
Commit 5c9b7917 authored by patavirt's avatar patavirt
Browse files

array: allow multi-index Array::part(idxs...)

parent de33c2d9
No related branches found
No related tags found
No related merge requests found
......@@ -30,7 +30,7 @@ struct Shape<Dim0,Dims...>
{
static const size_t head = Dim0;
typedef Shape<Dims...> Tail;
static const size_t ndim = 1 + Tail::ndim;
static const size_t ndim = 1 + sizeof...(Dims);
static const bool fixed = (Dim0 != Dynamic) && Tail::fixed;
static constexpr size_t dim(size_t i)
......@@ -55,6 +55,22 @@ struct Shape<Dim0,Dims...>
static constexpr std::array<size_t, ndim> dims() { return {Dim0, Dims...}; }
};
template <size_t I, typename Shape>
struct TailNth;
template <typename Shape>
struct TailNth<0, Shape>
{
using type = Shape;
};
template <size_t I, typename Shape>
struct TailNth
{
using type = typename TailNth<I-1, typename Shape::Tail>::type;
};
/*! Simple data-by-reference strided array class a la Fortran
*/
template <typename Scalar, typename Shape, typename Storage = std::vector<Scalar> >
......@@ -86,20 +102,18 @@ private:
void init_strides()
{
for (size_t i = Shape::ndim; i > 0; --i) {
if (i == Shape::ndim)
if (Shape::stride(i-1) != Dynamic)
stride_[i-1] = Shape::stride(i-1);
else if (i == Shape::ndim)
stride_[i-1] = 1;
else
stride_[i-1] = shape_[i] * stride_[i];
}
for (size_t i = 0; i < Shape::ndim; ++i) {
std::cout << "stride "<< i << ": " << stride_[i] << " vs " << Shape::stride(i) << std::endl;
}
}
template <size_t axis>
static constexpr Eigen::Index eigen_shape()
{
static_assert(axis < Shape::ndim && axis < 2, "invalid axis");
size_t n = (axis == 0) ? Shape::head : Shape::Tail::head;
return (n == Dynamic) ? Eigen::Dynamic : n;
}
......@@ -115,7 +129,7 @@ public:
Array(Storage& data)
: data_(data), offset_(0)
{
static_assert(Shape::fixed, "array shape is not fixed");
static_assert(Shape::fixed, "array shape is not compile-time fixed");
shape_ = Shape::dims();
init_strides();
data_bounds_check();
......@@ -127,23 +141,6 @@ public:
template <typename... Idx>
const Scalar&& operator()(Idx... idxs) const { return data_[index(idxs...)]; }
Array<Scalar, typename Shape::Tail> part(size_t pos) const
{
std::array<size_t, Shape::ndim-1> shape;
size_t offset;
#ifndef NO_BOUNDS_CHECK
if (pos >= shape_[0])
throw std::out_of_range("index out of bounds");
#endif
offset = offset_ + stride_[0] * pos;
for (size_t i = 1; i < Shape::ndim; ++i)
shape[i-1] = shape_[i];
return {data_, shape, offset};
}
Storage& data() { return data_; }
template <size_t axis>
......@@ -195,17 +192,45 @@ public:
return idx;
}
template <typename... Idx>
Array<Scalar, typename TailNth<sizeof...(Idx), Shape>::type > part(Idx... idxs) const
{
constexpr size_t nidxs = sizeof...(idxs);
static_assert(nidxs < Shape::ndim,
"number of indices must be small than the number of dimensions");
const std::array<size_t, Shape::ndim> m{idxs...};
std::array<size_t, Shape::ndim - nidxs> shape;
size_t offset = offset_;
for (size_t i = 0; i < nidxs; ++i) {
if (Shape::stride(i) != Dynamic) {
offset += Shape::stride(i) * m[i];
} else {
offset += stride_[i] * m[i];
}
#ifndef NO_BOUNDS_CHECK
if (m[i] >= shape_[i])
throw std::out_of_range("index out of bounds");
#endif
}
for (size_t i = nidxs; i < Shape::ndim; ++i)
shape[i - nidxs] = shape_[i];
return {data_, shape, offset};
}
typedef Eigen::Matrix<Scalar, eigen_shape<0>(), eigen_shape<1>(), Eigen::RowMajor> EigenMatrix;
typedef Eigen::Map<EigenMatrix> EigenMap;
EigenMap to_matrix() const
EigenMap matrix() const
{
static_assert(Shape::ndim == 2, "matrix must be two-dimensional");
return {(Scalar *)data_.data() + offset_,
static_cast<Eigen::Index>(shape<0>()), static_cast<Eigen::Index>(shape<1>())};
}
operator EigenMap() const { return to_matrix(); }
operator EigenMap() const { return matrix(); }
};
......
......@@ -24,7 +24,7 @@ int main()
{
std::vector<ADComplex> v(2*2*2);
Array<ADComplex,Shape<Dynamic,2,2> > Q0(v, {2,2,2});
auto Q = Q0.part(1);
auto Q = Q0.part(1u);
std::cout << Q0.index(1u,0u,0u) << std::endl;
......@@ -32,7 +32,7 @@ int main()
std::cout << Q.index(0u,0u) << std::endl;
auto mat = Q.to_matrix();
auto mat = Q.matrix();
mat = mat * mat * mat * adcomplex(1j);
mat(0,0) = mat(1,1) = mat.trace();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment