diff --git a/include/pybind11/array_iterator.h b/include/pybind11/array_iterator.h new file mode 100644 index 000000000..bbf36910c --- /dev/null +++ b/include/pybind11/array_iterator.h @@ -0,0 +1,172 @@ +/* + pybind11/array_iter.h: Array iteration support + + Copyright (c) 2016 Johan Mabille + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ +#pragma once + +#include "pybind11.h" +#include "short_vector.h" + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant +#endif + +NAMESPACE_BEGIN(pybind11) + +NAMESPACE_BEGIN(detail) + +template +class common_iterator +{ + +public: + + common_iterator() : p_ptr(0), m_strides() {} + common_iterator(void* ptr, const S& strides, const std::vector& shape) + : p_ptr(reinterpret_cast(ptr), m_strides(strides.size()) + { + using value_type = typename S::value_type; + using size_type = typename S::size_type; + + m_strides.back() = static_cast(strides.back()); + for (size_type i = m_strides.size() - 1; --i; i != 0) + { + size_type j = i - 1; + value_type s = static_cast(shape[i]); + m_strides[j] = strides[j] - ((s - 1) * strides[i] + m_strides[i]); + } + } + + void increment(size_type dim) + { + p_ptr += m_strides[dim]; + } + + void* data() const + { + return p_ptr; + } + +private: + + char* p_ptr; + S m_strides; +}; + +template +struct rebind_container_impl; + +template +struct rebind_container_impl, U> +{ + typedef std::vector type; +}; + +template +struct rebind_container_impl, U> +{ + typedef short_vector type; +}; + +template +using rebind_container = typename rebind_container_impl; + +NAMESPACE_END(detail) + +template +class multi_array_iterator +{ + + using int_container = rebind_container; + +public: + + multi_array_iterator(const std::array& buffers, + const std::vector& strides, + const std::vector& shape) + : m_shape(shape.size()), m_index(shape.size(), 0), m_common_iterator() + { + std::copy(shape.begin(), shape.end(), m_shape.begin()); + + int_container new_strides(strides.size()); + for (size_t i = 0; i < N; ++i) + { + init_common_iterator(buffers[i], strides, shape, m_common_iter[i], new_strides); + } + } + + multi_array_iterator& operator++() + { + for (size_t j = m_index.size(); --j; j != 0) + { + size_t i = j - 1; + if (++m_index[i] != m_shape[i]) + { + increment_common_iterator(m_index[i]); + break; + } + else + { + m_index[i] = 0; + } + } + return *this; + } + + template + const T& data() const + { + return *reinterpret_cast(m_common_iterator[K].data()); + } + +private: + + void init_common_iterator(const buffer_info& buffer, const std::vector& strides, const std::vector& shape, common_iter& iterator, int_container& new_strides) + { + auto buffer_shape_iter = buffer.shape.rbegin(); + auto strides_iter = strides.rbegin(); + auto shape_iter = shape.rbegin(); + auto new_strides_iter = new_strides.rbegin(); + + while (buffer_shape_iter != buffer.shape.rend()) + { + if (*shape_iter == *buffer_shape_iter) + *new_stride_iter = static_cast(*strides_iter); + else + *new_strides_iter = 0; + + ++buffer_shape_iter; + ++strides_iter; + ++shape_iter; + ++new_strides_iter; + } + + std::fill(new_strides_iter, strides.rend(), 0); + + iterator = common_iter(buffer.ptr, new_strides, shape); + } + + void increment_common_iterator(int dim) + { + std::for_each(m_common_iterator.begin(), m_common_iterator.end(), [=](const common_iter& iter) + { + iter.increment(dim); + }); + } + + S m_shape; + S m_index; + using common_iter = detail::common_iterator; + std::array m_common_iterator; +}; + +NAMESPACE_END(pybind11) + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif diff --git a/include/pybind11/short_vector.h b/include/pybind11/short_vector.h new file mode 100644 index 000000000..62d97e4af --- /dev/null +++ b/include/pybind11/short_vector.h @@ -0,0 +1,97 @@ +/* + pybind11/short_vector.h: similar to std::array but with dynamic size + + Copyright (c) 2016 Johan Mabille + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "pybind11.h" +#include +#include + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant +#endif + +NAMESPACE_BEGIN(pybind11) + +template +class short_vector +{ + +public: + + using std::array data_type; + using value_type = typename data_type::value_type; + using size_type = typename data_type::size_type; + using difference_type = typename data_type::difference_type; + using reference = typename data_type::reference; + using const_reference = typename data_type::const_reference; + using pointer = typename data_type::pointer; + using const_pointer = typename data_type::const_pointer; + using iterator = typename data_type::iterator; + using const_iterator = typename data_type::const_iterator; + using reverse_iterator = typename data_type::reverse_iterator; + using const_reverse_iterator = typename data_type::const_reverse_iterator; + + short_vector() : m_data(), m_size(0) {} + explicit short_vector(size_type size) : m_data(), m_size(size) {} + short_vector(size_type size, const_reference t) : m_data(), m_size(size) + { + std::fill(begin(), end(), t); + } + + size_type size() const noexcept { return m_size; } + constexpr size_type max_size() noexcept { return m_data.max_size(); } + bool empty() const noexcept { return m_size == 0; } + + void resize(size_type size) { m_size = size; } + void resize(size_type size, const_reference t) + { + size_type old_size = m_size; + resize(size); std::fill(begin() + old_size, end(), t); + } + + reference operator[](size_type i) { return m_data[i]; } + const_reference operator[](size_type i) const { return m_data[i]; } + + reference front() { return m_data[0]; } + const_reference front() const { return m_data[0]; } + + reference back() { return m_data[m_size - 1]; } + const_reference back() const { return m_data[m_size - 1]; } + + void fill(const_reference t) { std::fill(begin(), end(), t); } + + iterator begin() noexcept { return m_data.begin(); } + const_iterator begin() const noexcept { return m_data.begin(); } + iterator end() noexcept { return begin() + m_size; } + const_iterator end() const noexcept { return begin() + m_size(); } + + reverse_iterator rbegin() noexcept { return reverse_iterator(end()); } + const_reverse_iterator rbegin() const noexcept { return const_reverse_iterator(end()); } + reverse_iterator rend() noexcept { return reverse_iterator(begin()); } + const_reverse_iterator rend() const noexcept { return const_reverse_iterator(begin()); } + + const_iterator cbegin() const noexcept { return begin(); } + const_iterator cend() const noexcept { return end(); } + const_reverse_iterator crbegin() const noexcept { return rbegin(); } + const_reverse_iterator crend() const noexcept { return rend(); } + +private: + + data_type m_data; + size_type m_size; + +}; + +NAMESPACE_END(pybind11) + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif diff --git a/setup.py b/setup.py index ae1ca8bf2..8027088a5 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,9 @@ setup( 'include/pybind11/functional.h', 'include/pybind11/operators.h', 'include/pybind11/pytypes.h', - 'include/pybind11/typeid.h' + 'include/pybind11/typeid.h', + 'include/pybind11/short_vector.h', + 'include/pybind11/array_iterator.h' ], classifiers=[ 'Development Status :: 5 - Production/Stable',