Skip to content
Snippets Groups Projects
Commit 760c46e8 authored by patavirt's avatar patavirt
Browse files

array: better error checking etc

parent c572cc90
No related branches found
No related tags found
No related merge requests found
...@@ -10,7 +10,7 @@ inline Scalar S_2(Array<Scalar,2> Q) ...@@ -10,7 +10,7 @@ inline Scalar S_2(Array<Scalar,2> Q)
{ {
Q(0u,0u) = 123; Q(0u,0u) = 123;
Q(1u,1u) = 123; Q(1u,1u) = 123;
return 0; return Scalar(0);
} }
......
#ifndef ARRAY_HPP_ #ifndef ARRAY_HPP_
#define ARRAY_HPP_ #define ARRAY_HPP_
#include <vector>
#include <array>
#include <type_traits>
#include <stdexcept>
/*! Simple data-by-reference strided array class /*! Simple data-by-reference strided array class
*/ */
template <typename Scalar, int NDim> template <typename Scalar, int NDim, typename Storage = std::vector<Scalar> >
class Array class Array
{ {
private: private:
std::vector<Scalar>& data_; Storage& data_;
std::array<size_t, NDim> shape_; std::array<size_t, NDim> shape_;
std::array<size_t, NDim> stride_; std::array<size_t, NDim> stride_;
size_t offset_ = 0; size_t offset_ = 0;
...@@ -17,11 +23,11 @@ public: ...@@ -17,11 +23,11 @@ public:
: data_(array.data_), shape_(array.shape_), stride_(array.stride_), offset_(array.offset_) : data_(array.data_), shape_(array.shape_), stride_(array.stride_), offset_(array.offset_)
{} {}
Array(std::vector<Scalar>& data, const std::array<size_t, NDim> shape, const std::array<size_t, NDim> stride, const size_t offset=0) Array(Storage& data, const std::array<size_t, NDim> shape, const std::array<size_t, NDim> stride, const size_t offset=0)
: data_(data), shape_(shape), stride_(stride), offset_(offset) : data_(data), shape_(shape), stride_(stride), offset_(offset)
{} {}
Array(std::vector<Scalar>& data, const std::array<size_t, NDim> shape) Array(Storage& data, const std::array<size_t, NDim> shape)
: data_(data), shape_(shape), offset_(0) : data_(data), shape_(shape), offset_(0)
{ {
// row-major strides // row-major strides
...@@ -42,6 +48,8 @@ public: ...@@ -42,6 +48,8 @@ public:
template <size_t axis> template <size_t axis>
Array<Scalar, NDim-1> slice(size_t pos=0) const Array<Scalar, NDim-1> slice(size_t pos=0) const
{ {
static_assert(axis < NDim, "invalid slice index");
std::array<size_t, NDim-1> shape; std::array<size_t, NDim-1> shape;
std::array<size_t, NDim-1> stride; std::array<size_t, NDim-1> stride;
size_t offset; size_t offset;
...@@ -56,12 +64,14 @@ public: ...@@ -56,12 +64,14 @@ public:
++j; ++j;
} }
return Array<Scalar,NDim-1>(data_, shape, stride, offset); return {data_, shape, stride, offset};
} }
template <size_t axis> template <size_t axis>
Array<Scalar, NDim> slice(size_t begin, size_t end) const Array<Scalar, NDim> slice(size_t begin, size_t end) const
{ {
static_assert(axis < NDim, "invalid slice index");
std::array<size_t, NDim> shape; std::array<size_t, NDim> shape;
std::array<size_t, NDim> stride; std::array<size_t, NDim> stride;
size_t offset; size_t offset;
...@@ -76,10 +86,10 @@ public: ...@@ -76,10 +86,10 @@ public:
stride[i] = stride_[i]; stride[i] = stride_[i];
} }
return Array<Scalar,NDim>(data_, shape, stride, offset); return {data_, shape, stride, offset};
} }
std::vector<Scalar>& data() { return data_; } Storage& data() { return data_; }
template <size_t axis> template <size_t axis>
const size_t shape() const { return shape_[axis]; } const size_t shape() const { return shape_[axis]; }
...@@ -92,11 +102,19 @@ public: ...@@ -92,11 +102,19 @@ public:
template <typename... Idx> template <typename... Idx>
size_t index(Idx... idxs) const size_t index(Idx... idxs) const
{ {
static_assert(sizeof...(idxs) == NDim,
"number of indices must equal the number of dimensions");
const std::array<size_t, NDim> m{idxs...}; const std::array<size_t, NDim> m{idxs...};
size_t idx = offset_; size_t idx = offset_;
for (size_t i = 0; i < NDim; ++i) for (size_t i = 0; i < NDim; ++i) {
idx += stride_[i] * m[i]; idx += stride_[i] * m[i];
#ifndef NO_BOUNDS_CHECK
if (m[i] >= shape_[i])
throw std::out_of_range("index out of bounds");
#endif
}
return idx; return idx;
} }
......
#ifndef COMMON_HPP_ #ifndef COMMON_HPP_
#define COMMON_HPP_ #define COMMON_HPP_
#include <array>
#include <tuple>
#include <iostream> #include <iostream>
#include <cppad/cppad.hpp> #include <cppad/cppad.hpp>
......
...@@ -4,10 +4,10 @@ ...@@ -4,10 +4,10 @@
int main() int main()
{ {
std::vector<double> v(2*2*2); std::vector<ADComplex> v(2*2*2);
Array<double,3> Q(v,{2,2,2}); Array<ADComplex,3> Q(v,{2,2,2});
S_2(Q.slice<3>()); S_2(Q.slice<2>());
for (size_t i = 0; i < Q.shape<0>(); ++i) { for (size_t i = 0; i < Q.shape<0>(); ++i) {
for (size_t j = 0; j < Q.shape<1>(); ++j) { for (size_t j = 0; j < Q.shape<1>(); ++j) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment