mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-29 08:32:02 +00:00
array_t: make c_style/f_style work for array creation
Currently if you construct an `array_t<T, array::f_style>` with a shape but not strides you get a C-style array; the only way to get F-style strides was to calculate the strides manually. This commit fixes that by adding logic to use f_style strides when the flag is set. This also simplifies the existing c_style stride logic.
This commit is contained in:
parent
129a7256a9
commit
41f8da4a95
@ -530,7 +530,7 @@ public:
|
|||||||
const void *ptr = nullptr, handle base = handle()) {
|
const void *ptr = nullptr, handle base = handle()) {
|
||||||
|
|
||||||
if (strides->empty())
|
if (strides->empty())
|
||||||
*strides = default_strides(*shape, dt.itemsize());
|
*strides = c_strides(*shape, dt.itemsize());
|
||||||
|
|
||||||
auto ndim = shape->size();
|
auto ndim = shape->size();
|
||||||
if (ndim != strides->size())
|
if (ndim != strides->size())
|
||||||
@ -758,15 +758,21 @@ protected:
|
|||||||
throw std::domain_error("array is not writeable");
|
throw std::domain_error("array is not writeable");
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<ssize_t> default_strides(const std::vector<ssize_t>& shape, ssize_t itemsize) {
|
// Default, C-style strides
|
||||||
|
static std::vector<ssize_t> c_strides(const std::vector<ssize_t> &shape, ssize_t itemsize) {
|
||||||
auto ndim = shape.size();
|
auto ndim = shape.size();
|
||||||
std::vector<ssize_t> strides(ndim);
|
std::vector<ssize_t> strides(ndim, itemsize);
|
||||||
if (ndim) {
|
for (size_t i = ndim - 1; i > 0; --i)
|
||||||
std::fill(strides.begin(), strides.end(), itemsize);
|
strides[i - 1] = strides[i] * shape[i];
|
||||||
for (size_t i = 0; i < ndim - 1; i++)
|
return strides;
|
||||||
for (size_t j = 0; j < ndim - 1 - i; j++)
|
}
|
||||||
strides[j] *= shape[ndim - 1 - i];
|
|
||||||
}
|
// F-style strides; default when constructing an array_t with `ExtraFlags & f_style`
|
||||||
|
static std::vector<ssize_t> f_strides(const std::vector<ssize_t> &shape, ssize_t itemsize) {
|
||||||
|
auto ndim = shape.size();
|
||||||
|
std::vector<ssize_t> strides(ndim, itemsize);
|
||||||
|
for (size_t i = 1; i < ndim; ++i)
|
||||||
|
strides[i] = strides[i - 1] * shape[i - 1];
|
||||||
return strides;
|
return strides;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -797,6 +803,11 @@ protected:
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
|
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
|
||||||
|
private:
|
||||||
|
struct private_ctor {};
|
||||||
|
// Delegating constructor needed when both moving and accessing in the same constructor
|
||||||
|
array_t(private_ctor, ShapeContainer &&shape, StridesContainer &&strides, const T *ptr, handle base)
|
||||||
|
: array(std::move(shape), std::move(strides), ptr, base) {}
|
||||||
public:
|
public:
|
||||||
static_assert(!detail::array_info<T>::is_array, "Array types cannot be used with array_t");
|
static_assert(!detail::array_info<T>::is_array, "Array types cannot be used with array_t");
|
||||||
|
|
||||||
@ -822,7 +833,9 @@ public:
|
|||||||
: array(std::move(shape), std::move(strides), ptr, base) { }
|
: array(std::move(shape), std::move(strides), ptr, base) { }
|
||||||
|
|
||||||
explicit array_t(ShapeContainer shape, const T *ptr = nullptr, handle base = handle())
|
explicit array_t(ShapeContainer shape, const T *ptr = nullptr, handle base = handle())
|
||||||
: array(std::move(shape), ptr, base) { }
|
: array_t(private_ctor{}, std::move(shape),
|
||||||
|
ExtraFlags & f_style ? f_strides(*shape, itemsize()) : c_strides(*shape, itemsize()),
|
||||||
|
ptr, base) { }
|
||||||
|
|
||||||
explicit array_t(size_t count, const T *ptr = nullptr, handle base = handle())
|
explicit array_t(size_t count, const T *ptr = nullptr, handle base = handle())
|
||||||
: array({count}, {}, ptr, base) { }
|
: array({count}, {}, ptr, base) { }
|
||||||
|
Loading…
Reference in New Issue
Block a user