mirror of
https://github.com/pybind/pybind11.git
synced 2025-02-24 01:19:23 +00:00
Multi iterator
This commit is contained in:
parent
f8584b630b
commit
7fc85f1d7c
172
include/pybind11/array_iterator.h
Normal file
172
include/pybind11/array_iterator.h
Normal file
@ -0,0 +1,172 @@
|
||||
/*
|
||||
pybind11/array_iter.h: Array iteration support
|
||||
|
||||
Copyright (c) 2016 Johan Mabille
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "pybind11.h"
|
||||
#include "short_vector.h"
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
|
||||
#endif
|
||||
|
||||
NAMESPACE_BEGIN(pybind11)
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
||||
template <class S>
|
||||
class common_iterator
|
||||
{
|
||||
|
||||
public:
|
||||
|
||||
common_iterator() : p_ptr(0), m_strides() {}
|
||||
common_iterator(void* ptr, const S& strides, const std::vector<size_t>& shape)
|
||||
: p_ptr(reinterpret_cast<char*>(ptr), m_strides(strides.size())
|
||||
{
|
||||
using value_type = typename S::value_type;
|
||||
using size_type = typename S::size_type;
|
||||
|
||||
m_strides.back() = static_cast<value_type>(strides.back());
|
||||
for (size_type i = m_strides.size() - 1; --i; i != 0)
|
||||
{
|
||||
size_type j = i - 1;
|
||||
value_type s = static_cast<value_type>(shape[i]);
|
||||
m_strides[j] = strides[j] - ((s - 1) * strides[i] + m_strides[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void increment(size_type dim)
|
||||
{
|
||||
p_ptr += m_strides[dim];
|
||||
}
|
||||
|
||||
void* data() const
|
||||
{
|
||||
return p_ptr;
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
char* p_ptr;
|
||||
S m_strides;
|
||||
};
|
||||
|
||||
template <class C, class U>
|
||||
struct rebind_container_impl;
|
||||
|
||||
template <class T, class A, class U>
|
||||
struct rebind_container_impl<std::vector<T, A>, U>
|
||||
{
|
||||
typedef std::vector<U, A> type;
|
||||
};
|
||||
|
||||
template <class T, size_t N, class U>
|
||||
struct rebind_container_impl<short_vector<T, N>, U>
|
||||
{
|
||||
typedef short_vector<U, N> type;
|
||||
};
|
||||
|
||||
template <class C, class U>
|
||||
using rebind_container = typename rebind_container_impl<C, U>;
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
template <class S, size_t N>
|
||||
class multi_array_iterator
|
||||
{
|
||||
|
||||
using int_container = rebind_container<S, int>;
|
||||
|
||||
public:
|
||||
|
||||
multi_array_iterator(const std::array<buffer_info, N>& buffers,
|
||||
const std::vector<size_t>& strides,
|
||||
const std::vector<size_t>& shape)
|
||||
: m_shape(shape.size()), m_index(shape.size(), 0), m_common_iterator()
|
||||
{
|
||||
std::copy(shape.begin(), shape.end(), m_shape.begin());
|
||||
|
||||
int_container new_strides(strides.size());
|
||||
for (size_t i = 0; i < N; ++i)
|
||||
{
|
||||
init_common_iterator(buffers[i], strides, shape, m_common_iter[i], new_strides);
|
||||
}
|
||||
}
|
||||
|
||||
multi_array_iterator& operator++()
|
||||
{
|
||||
for (size_t j = m_index.size(); --j; j != 0)
|
||||
{
|
||||
size_t i = j - 1;
|
||||
if (++m_index[i] != m_shape[i])
|
||||
{
|
||||
increment_common_iterator(m_index[i]);
|
||||
break;
|
||||
}
|
||||
else
|
||||
{
|
||||
m_index[i] = 0;
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <size_t K, class T>
|
||||
const T& data() const
|
||||
{
|
||||
return *reinterpret_cast<T*>(m_common_iterator[K].data());
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
void init_common_iterator(const buffer_info& buffer, const std::vector<size_t>& strides, const std::vector<size_t>& shape, common_iter& iterator, int_container& new_strides)
|
||||
{
|
||||
auto buffer_shape_iter = buffer.shape.rbegin();
|
||||
auto strides_iter = strides.rbegin();
|
||||
auto shape_iter = shape.rbegin();
|
||||
auto new_strides_iter = new_strides.rbegin();
|
||||
|
||||
while (buffer_shape_iter != buffer.shape.rend())
|
||||
{
|
||||
if (*shape_iter == *buffer_shape_iter)
|
||||
*new_stride_iter = static_cast<int>(*strides_iter);
|
||||
else
|
||||
*new_strides_iter = 0;
|
||||
|
||||
++buffer_shape_iter;
|
||||
++strides_iter;
|
||||
++shape_iter;
|
||||
++new_strides_iter;
|
||||
}
|
||||
|
||||
std::fill(new_strides_iter, strides.rend(), 0);
|
||||
|
||||
iterator = common_iter(buffer.ptr, new_strides, shape);
|
||||
}
|
||||
|
||||
void increment_common_iterator(int dim)
|
||||
{
|
||||
std::for_each(m_common_iterator.begin(), m_common_iterator.end(), [=](const common_iter& iter)
|
||||
{
|
||||
iter.increment(dim);
|
||||
});
|
||||
}
|
||||
|
||||
S m_shape;
|
||||
S m_index;
|
||||
using common_iter = detail::common_iterator<int_container>;
|
||||
std::array<common_iter, N> m_common_iterator;
|
||||
};
|
||||
|
||||
NAMESPACE_END(pybind11)
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(pop)
|
||||
#endif
|
97
include/pybind11/short_vector.h
Normal file
97
include/pybind11/short_vector.h
Normal file
@ -0,0 +1,97 @@
|
||||
/*
|
||||
pybind11/short_vector.h: similar to std::array but with dynamic size
|
||||
|
||||
Copyright (c) 2016 Johan Mabille
|
||||
|
||||
All rights reserved. Use of this source code is governed by a
|
||||
BSD-style license that can be found in the LICENSE file.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "pybind11.h"
|
||||
#include <array>
|
||||
#include <algorithm>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
|
||||
#endif
|
||||
|
||||
NAMESPACE_BEGIN(pybind11)
|
||||
|
||||
template <class T, size_t N = 3>
|
||||
class short_vector
|
||||
{
|
||||
|
||||
public:
|
||||
|
||||
using std::array<T, N> data_type;
|
||||
using value_type = typename data_type::value_type;
|
||||
using size_type = typename data_type::size_type;
|
||||
using difference_type = typename data_type::difference_type;
|
||||
using reference = typename data_type::reference;
|
||||
using const_reference = typename data_type::const_reference;
|
||||
using pointer = typename data_type::pointer;
|
||||
using const_pointer = typename data_type::const_pointer;
|
||||
using iterator = typename data_type::iterator;
|
||||
using const_iterator = typename data_type::const_iterator;
|
||||
using reverse_iterator = typename data_type::reverse_iterator;
|
||||
using const_reverse_iterator = typename data_type::const_reverse_iterator;
|
||||
|
||||
short_vector() : m_data(), m_size(0) {}
|
||||
explicit short_vector(size_type size) : m_data(), m_size(size) {}
|
||||
short_vector(size_type size, const_reference t) : m_data(), m_size(size)
|
||||
{
|
||||
std::fill(begin(), end(), t);
|
||||
}
|
||||
|
||||
size_type size() const noexcept { return m_size; }
|
||||
constexpr size_type max_size() noexcept { return m_data.max_size(); }
|
||||
bool empty() const noexcept { return m_size == 0; }
|
||||
|
||||
void resize(size_type size) { m_size = size; }
|
||||
void resize(size_type size, const_reference t)
|
||||
{
|
||||
size_type old_size = m_size;
|
||||
resize(size); std::fill(begin() + old_size, end(), t);
|
||||
}
|
||||
|
||||
reference operator[](size_type i) { return m_data[i]; }
|
||||
const_reference operator[](size_type i) const { return m_data[i]; }
|
||||
|
||||
reference front() { return m_data[0]; }
|
||||
const_reference front() const { return m_data[0]; }
|
||||
|
||||
reference back() { return m_data[m_size - 1]; }
|
||||
const_reference back() const { return m_data[m_size - 1]; }
|
||||
|
||||
void fill(const_reference t) { std::fill(begin(), end(), t); }
|
||||
|
||||
iterator begin() noexcept { return m_data.begin(); }
|
||||
const_iterator begin() const noexcept { return m_data.begin(); }
|
||||
iterator end() noexcept { return begin() + m_size; }
|
||||
const_iterator end() const noexcept { return begin() + m_size(); }
|
||||
|
||||
reverse_iterator rbegin() noexcept { return reverse_iterator(end()); }
|
||||
const_reverse_iterator rbegin() const noexcept { return const_reverse_iterator(end()); }
|
||||
reverse_iterator rend() noexcept { return reverse_iterator(begin()); }
|
||||
const_reverse_iterator rend() const noexcept { return const_reverse_iterator(begin()); }
|
||||
|
||||
const_iterator cbegin() const noexcept { return begin(); }
|
||||
const_iterator cend() const noexcept { return end(); }
|
||||
const_reverse_iterator crbegin() const noexcept { return rbegin(); }
|
||||
const_reverse_iterator crend() const noexcept { return rend(); }
|
||||
|
||||
private:
|
||||
|
||||
data_type m_data;
|
||||
size_type m_size;
|
||||
|
||||
};
|
||||
|
||||
NAMESPACE_END(pybind11)
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(pop)
|
||||
#endif
|
4
setup.py
4
setup.py
@ -27,7 +27,9 @@ setup(
|
||||
'include/pybind11/functional.h',
|
||||
'include/pybind11/operators.h',
|
||||
'include/pybind11/pytypes.h',
|
||||
'include/pybind11/typeid.h'
|
||||
'include/pybind11/typeid.h',
|
||||
'include/pybind11/short_vector.h',
|
||||
'include/pybind11/array_iterator.h'
|
||||
],
|
||||
classifiers=[
|
||||
'Development Status :: 5 - Production/Stable',
|
||||
|
Loading…
Reference in New Issue
Block a user