array_t: make c_style/f_style work for array creation

Currently if you construct an `array_t<T, array::f_style>` with a shape
but not strides you get a C-style array; the only way to get F-style
strides was to calculate the strides manually.  This commit fixes that
by adding logic to use f_style strides when the flag is set.

This also simplifies the existing c_style stride logic.
This commit is contained in:
Jason Rhinelander 2017-03-27 11:03:16 -03:00
parent 129a7256a9
commit 41f8da4a95

View File

@ -530,7 +530,7 @@ public:
const void *ptr = nullptr, handle base = handle()) {
if (strides->empty())
*strides = default_strides(*shape, dt.itemsize());
*strides = c_strides(*shape, dt.itemsize());
auto ndim = shape->size();
if (ndim != strides->size())
@ -758,15 +758,21 @@ protected:
throw std::domain_error("array is not writeable");
}
static std::vector<ssize_t> default_strides(const std::vector<ssize_t>& shape, ssize_t itemsize) {
// Default, C-style strides
static std::vector<ssize_t> c_strides(const std::vector<ssize_t> &shape, ssize_t itemsize) {
auto ndim = shape.size();
std::vector<ssize_t> strides(ndim);
if (ndim) {
std::fill(strides.begin(), strides.end(), itemsize);
for (size_t i = 0; i < ndim - 1; i++)
for (size_t j = 0; j < ndim - 1 - i; j++)
strides[j] *= shape[ndim - 1 - i];
}
std::vector<ssize_t> strides(ndim, itemsize);
for (size_t i = ndim - 1; i > 0; --i)
strides[i - 1] = strides[i] * shape[i];
return strides;
}
// F-style strides; default when constructing an array_t with `ExtraFlags & f_style`
static std::vector<ssize_t> f_strides(const std::vector<ssize_t> &shape, ssize_t itemsize) {
auto ndim = shape.size();
std::vector<ssize_t> strides(ndim, itemsize);
for (size_t i = 1; i < ndim; ++i)
strides[i] = strides[i - 1] * shape[i - 1];
return strides;
}
@ -797,6 +803,11 @@ protected:
};
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
private:
struct private_ctor {};
// Delegating constructor needed when both moving and accessing in the same constructor
array_t(private_ctor, ShapeContainer &&shape, StridesContainer &&strides, const T *ptr, handle base)
: array(std::move(shape), std::move(strides), ptr, base) {}
public:
static_assert(!detail::array_info<T>::is_array, "Array types cannot be used with array_t");
@ -822,7 +833,9 @@ public:
: array(std::move(shape), std::move(strides), ptr, base) { }
explicit array_t(ShapeContainer shape, const T *ptr = nullptr, handle base = handle())
: array(std::move(shape), ptr, base) { }
: array_t(private_ctor{}, std::move(shape),
ExtraFlags & f_style ? f_strides(*shape, itemsize()) : c_strides(*shape, itemsize()),
ptr, base) { }
explicit array_t(size_t count, const T *ptr = nullptr, handle base = handle())
: array({count}, {}, ptr, base) { }