Skip to content
Snippets Groups Projects
Commit c572cc90 authored by patavirt's avatar patavirt
Browse files

Slicing

parent 9d6dc870
No related branches found
No related tags found
No related merge requests found
...@@ -2,12 +2,15 @@ ...@@ -2,12 +2,15 @@
#define ACTION_HPP_ #define ACTION_HPP_
#include "common.hpp" #include "common.hpp"
#include "array.hpp"
inline ADComplex S_2(Array<ADComplex,3> Q) template <typename Scalar>
inline Scalar S_2(Array<Scalar,2> Q)
{ {
size_t i = 2; Q(0u,0u) = 123;
return Q(0u, i, 2u); Q(1u,1u) = 123;
return 0;
} }
......
#ifndef ARRAY_HPP_
#define ARRAY_HPP_
/*! Simple data-by-reference strided array class
*/
template <typename Scalar, int NDim>
class Array
{
private:
std::vector<Scalar>& data_;
std::array<size_t, NDim> shape_;
std::array<size_t, NDim> stride_;
size_t offset_ = 0;
public:
Array(Array<Scalar,NDim>&& array)
: data_(array.data_), shape_(array.shape_), stride_(array.stride_), offset_(array.offset_)
{}
Array(std::vector<Scalar>& data, const std::array<size_t, NDim> shape, const std::array<size_t, NDim> stride, const size_t offset=0)
: data_(data), shape_(shape), stride_(stride), offset_(offset)
{}
Array(std::vector<Scalar>& data, const std::array<size_t, NDim> shape)
: data_(data), shape_(shape), offset_(0)
{
// row-major strides
for (size_t i = NDim; i > 0; --i) {
if (i == NDim)
stride_[i-1] = 1;
else
stride_[i-1] = shape[i] * stride_[i];
}
}
template <typename... Idx>
Scalar& operator()(Idx... idxs) { return data_[index(idxs...)]; }
template <typename... Idx>
const Scalar&& operator()(Idx... idxs) const { return data_[index(idxs...)]; }
template <size_t axis>
Array<Scalar, NDim-1> slice(size_t pos=0) const
{
std::array<size_t, NDim-1> shape;
std::array<size_t, NDim-1> stride;
size_t offset;
offset = stride[axis] * pos;
for (size_t i = 0, j = 0; i < NDim; ++i) {
if (i == axis)
continue;
shape[j] = shape_[i];
stride[j] = stride_[i];
++j;
}
return Array<Scalar,NDim-1>(data_, shape, stride, offset);
}
template <size_t axis>
Array<Scalar, NDim> slice(size_t begin, size_t end) const
{
std::array<size_t, NDim> shape;
std::array<size_t, NDim> stride;
size_t offset;
offset = stride[axis] * begin;
for (size_t i = 0; i < NDim; ++i) {
if (i == axis)
shape[i] = end - begin;
else
shape[i] = shape_[i];
stride[i] = stride_[i];
}
return Array<Scalar,NDim>(data_, shape, stride, offset);
}
std::vector<Scalar>& data() { return data_; }
template <size_t axis>
const size_t shape() const { return shape_[axis]; }
template <size_t axis>
const size_t stride() const { return stride_[axis]; }
size_t offset() const { return offset_; }
template <typename... Idx>
size_t index(Idx... idxs) const
{
const std::array<size_t, NDim> m{idxs...};
size_t idx = offset_;
for (size_t i = 0; i < NDim; ++i)
idx += stride_[i] * m[i];
return idx;
}
};
#endif
#ifndef COMMON_H_ #ifndef COMMON_HPP_
#define COMMON_H_ #define COMMON_HPP_
#include <array> #include <array>
#include <tuple> #include <tuple>
...@@ -13,50 +13,4 @@ ...@@ -13,50 +13,4 @@
typedef CppAD::AD<double> ADDouble; typedef CppAD::AD<double> ADDouble;
typedef std::complex<ADDouble> ADComplex; typedef std::complex<ADDouble> ADComplex;
template <typename Scalar, int NDim>
class Array
{
private:
std::vector<Scalar>& data_;
std::array<size_t, NDim> shape_;
std::array<size_t, NDim> stride_;
protected:
template <typename... Idx>
size_t index(Idx... idxs) const
{
const std::array<size_t,NDim> m{idxs...};
size_t idx = 0;
for (size_t i = 0; i < NDim; ++i)
idx += stride_[i] * m[i];
return idx;
}
public:
Array(std::vector<Scalar>& data, std::array<size_t, NDim> shape) : data_(data), shape_(shape)
{
if (NDim > 0) {
size_t i = NDim - 1;
stride_[i] = 1;
while (i-- > 0)
stride_[i] = shape[i+1] * stride_[i+1];
}
}
template <typename... Idx>
Scalar& operator()(Idx... idxs)
{
return data_[index(idxs...)];
}
template <typename... Idx>
const Scalar&& operator()(Idx... idxs) const
{
return data_[index(idxs...)];
}
};
#endif #endif
...@@ -4,20 +4,14 @@ ...@@ -4,20 +4,14 @@
int main() int main()
{ {
std::vector<double> v(3*3*3); std::vector<double> v(2*2*2);
Array<double,3> Q(v,{3,3,3}); Array<double,3> Q(v,{2,2,2});
for (size_t i = 0; i < 3; ++i) { S_2(Q.slice<3>());
for (size_t j = 0; i < 3; ++i) {
for (size_t k = 0; i < 3; ++i) {
Q(i,j,k) = i+j+k;
}
}
}
for (size_t i = 0; i < 3; ++i) { for (size_t i = 0; i < Q.shape<0>(); ++i) {
for (size_t j = 0; i < 3; ++i) { for (size_t j = 0; j < Q.shape<1>(); ++j) {
for (size_t k = 0; i < 3; ++i) { for (size_t k = 0; k < Q.shape<2>(); ++k) {
std::cout << Q(i,j,k) << std::endl; std::cout << Q(i,j,k) << std::endl;
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment