Propagate exceptions in sequence::size() (#2076)

This commit is contained in:
Nicholas Musolino 2020-01-26 11:49:32 -05:00 committed by Wenzel Jakob
parent 805c5862b6
commit 02c83dba0f
3 changed files with 28 additions and 1 deletions

View File

@ -1242,7 +1242,12 @@ private:
class sequence : public object { class sequence : public object {
public: public:
PYBIND11_OBJECT_DEFAULT(sequence, object, PySequence_Check) PYBIND11_OBJECT_DEFAULT(sequence, object, PySequence_Check)
size_t size() const { return (size_t) PySequence_Size(m_ptr); } size_t size() const {
ssize_t result = PySequence_Size(m_ptr);
if (result == -1)
throw error_already_set();
return (size_t) result;
}
bool empty() const { return size() == 0; } bool empty() const { return size() == 0; }
detail::sequence_accessor operator[](size_t index) const { return {*this, index}; } detail::sequence_accessor operator[](size_t index) const { return {*this, index}; }
detail::item_accessor operator[](handle h) const { return object::operator[](h); } detail::item_accessor operator[](handle h) const { return object::operator[](h); }

View File

@ -319,6 +319,9 @@ TEST_SUBMODULE(sequences_and_iterators, m) {
return l; return l;
}); });
// test_sequence_length: check that Python sequences can be converted to py::sequence.
m.def("sequence_length", [](py::sequence seq) { return seq.size(); });
// Make sure that py::iterator works with std algorithms // Make sure that py::iterator works with std algorithms
m.def("count_none", [](py::object o) { m.def("count_none", [](py::object o) {
return std::count_if(o.begin(), o.end(), [](py::handle h) { return h.is_none(); }); return std::count_if(o.begin(), o.end(), [](py::handle h) { return h.is_none(); });

View File

@ -100,6 +100,25 @@ def test_sequence():
assert cstats.move_assignments == 0 assert cstats.move_assignments == 0
def test_sequence_length():
"""#2076: Exception raised by len(arg) should be propagated """
class BadLen(RuntimeError):
pass
class SequenceLike():
def __getitem__(self, i):
return None
def __len__(self):
raise BadLen()
with pytest.raises(BadLen):
m.sequence_length(SequenceLike())
assert m.sequence_length([1, 2, 3]) == 3
assert m.sequence_length("hello") == 5
def test_map_iterator(): def test_map_iterator():
sm = m.StringMap({'hi': 'bye', 'black': 'white'}) sm = m.StringMap({'hi': 'bye', 'black': 'white'})
assert sm['hi'] == 'bye' assert sm['hi'] == 'bye'