Add support for __await__, __aiter__, and __anext__ protocols (#1842)

This commit is contained in:
Jeremy Maitin-Shepard 2019-07-18 00:02:35 -07:00 committed by Wenzel Jakob
parent 9b3fb05326
commit a3f4a0e8ab
5 changed files with 65 additions and 0 deletions

View File

@ -586,6 +586,9 @@ inline PyObject* make_new_python_type(const type_record &rec) {
type->tp_as_number = &heap_type->as_number;
type->tp_as_sequence = &heap_type->as_sequence;
type->tp_as_mapping = &heap_type->as_mapping;
#if PY_VERSION_HEX >= 0x03050000
type->tp_as_async = &heap_type->as_async;
#endif
/* Flags */
type->tp_flags |= Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;

View File

@ -26,6 +26,7 @@ endif()
# Full set of test files (you can override these; see below)
set(PYBIND11_TEST_FILES
test_async.cpp
test_buffers.cpp
test_builtin_casters.cpp
test_call_policies.cpp
@ -71,6 +72,13 @@ if (PYBIND11_TEST_OVERRIDE)
set(PYBIND11_TEST_FILES ${PYBIND11_TEST_OVERRIDE})
endif()
# Skip test_async for Python < 3.5
list(FIND PYBIND11_TEST_FILES test_async.cpp PYBIND11_TEST_FILES_ASYNC_I)
if((PYBIND11_TEST_FILES_ASYNC_I GREATER -1) AND ("${PYTHON_VERSION_MAJOR}.${PYTHON_VERSION_MINOR}" VERSION_LESS 3.5))
message(STATUS "Skipping test_async because Python version ${PYTHON_VERSION_MAJOR}.${PYTHON_VERSION_MINOR} < 3.5")
list(REMOVE_AT PYBIND11_TEST_FILES ${PYBIND11_TEST_FILES_ASYNC_I})
endif()
string(REPLACE ".cpp" ".py" PYBIND11_PYTEST_FILES "${PYBIND11_TEST_FILES}")
# Contains the set of test files that require pybind11_cross_module_tests to be

View File

@ -17,6 +17,11 @@ _unicode_marker = re.compile(r'u(\'[^\']*\')')
_long_marker = re.compile(r'([0-9])L')
_hexadecimal = re.compile(r'0x[0-9a-fA-F]+')
# test_async.py requires support for async and await
collect_ignore = []
if sys.version_info[:2] < (3, 5):
collect_ignore.append("test_async.py")
def _strip_and_dedent(s):
"""For triple-quote strings"""

26
tests/test_async.cpp Normal file
View File

@ -0,0 +1,26 @@
/*
tests/test_async.cpp -- __await__ support
Copyright (c) 2019 Google Inc.
All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/
#include "pybind11_tests.h"
TEST_SUBMODULE(async_module, m) {
struct DoesNotSupportAsync {};
py::class_<DoesNotSupportAsync>(m, "DoesNotSupportAsync")
.def(py::init<>());
struct SupportsAsync {};
py::class_<SupportsAsync>(m, "SupportsAsync")
.def(py::init<>())
.def("__await__", [](const SupportsAsync& self) -> py::object {
static_cast<void>(self);
py::object loop = py::module::import("asyncio.events").attr("get_event_loop")();
py::object f = loop.attr("create_future")();
f.attr("set_result")(5);
return f.attr("__await__")();
});
}

23
tests/test_async.py Normal file
View File

@ -0,0 +1,23 @@
import asyncio
import pytest
from pybind11_tests import async_module as m
@pytest.fixture
def event_loop():
loop = asyncio.new_event_loop()
yield loop
loop.close()
async def get_await_result(x):
return await x
def test_await(event_loop):
assert 5 == event_loop.run_until_complete(get_await_result(m.SupportsAsync()))
def test_await_missing(event_loop):
with pytest.raises(TypeError):
event_loop.run_until_complete(get_await_result(m.DoesNotSupportAsync()))