mirror of
https://github.com/pybind/pybind11.git
synced 2024-11-21 20:55:11 +00:00
enum_: move most functionality to a non-template implementation
This commit addresses an inefficiency in how enums are created in pybind11. Most of the enum_<> implementation is completely generic -- however, being a template class, it ended up instantiating vast amounts of essentially identical code in larger projects with many enums. This commit introduces a generic non-templated helper class that is compatible with any kind of enumeration. enum_ then becomes a thin wrapper around this new class. The new enum_<> API is designed to be 100% compatible with the old one.
This commit is contained in:
parent
b4b2292488
commit
f4245181ae
@ -22,6 +22,11 @@ v2.3.0 (Not yet released)
|
||||
* Added support for write only properties.
|
||||
`#1144 <https://github.com/pybind/pybind11/pull/1144>`_.
|
||||
|
||||
* Python type wrappers (``py::handle``, ``py::object``, etc.)
|
||||
now support map Python's number protocol onto C++ arithmetic
|
||||
operators such as ``operator+``, ``operator/=``, etc.
|
||||
`#1511 <https://github.com/pybind/pybind11/pull/1511>`_.
|
||||
|
||||
* A number of improvements related to enumerations:
|
||||
|
||||
1. The ``enum_`` implementation was rewritten from scratch to reduce
|
||||
|
@ -1360,6 +1360,146 @@ detail::initimpl::pickle_factory<GetState, SetState> pickle(GetState &&g, SetSta
|
||||
return {std::forward<GetState>(g), std::forward<SetState>(s)};
|
||||
}
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
struct enum_base {
|
||||
enum_base(handle base, handle parent) : m_base(base), m_parent(parent) { }
|
||||
|
||||
PYBIND11_NOINLINE void init(bool is_arithmetic, bool is_convertible) {
|
||||
m_base.attr("__entries") = dict();
|
||||
auto property = handle((PyObject *) &PyProperty_Type);
|
||||
auto static_property = handle((PyObject *) get_internals().static_property_type);
|
||||
|
||||
m_base.attr("__repr__") = cpp_function(
|
||||
[](handle arg) -> str {
|
||||
handle type = arg.get_type();
|
||||
object type_name = type.attr("__name__");
|
||||
dict entries = type.attr("__entries");
|
||||
for (const auto &kv : entries) {
|
||||
object other = kv.second[int_(0)];
|
||||
if (other.equal(arg))
|
||||
return pybind11::str("{}.{}").format(type_name, kv.first);
|
||||
}
|
||||
return pybind11::str("{}.???").format(type_name);
|
||||
}, is_method(m_base)
|
||||
);
|
||||
|
||||
m_base.attr("name") = property(cpp_function(
|
||||
[](handle arg) -> str {
|
||||
dict entries = arg.get_type().attr("__entries");
|
||||
for (const auto &kv : entries) {
|
||||
if (handle(kv.second[int_(0)]).equal(arg))
|
||||
return pybind11::str(kv.first);
|
||||
}
|
||||
return "???";
|
||||
}, is_method(m_base)
|
||||
));
|
||||
|
||||
m_base.attr("__doc__") = static_property(cpp_function(
|
||||
[](handle arg) -> std::string {
|
||||
std::string docstring;
|
||||
dict entries = arg.attr("__entries");
|
||||
if (((PyTypeObject *) arg.ptr())->tp_doc)
|
||||
docstring += std::string(((PyTypeObject *) arg.ptr())->tp_doc) + "\n\n";
|
||||
docstring += "Members:";
|
||||
for (const auto &kv : entries) {
|
||||
auto key = std::string(pybind11::str(kv.first));
|
||||
auto comment = kv.second[int_(1)];
|
||||
docstring += "\n\n " + key;
|
||||
if (!comment.is_none())
|
||||
docstring += " : " + (std::string) pybind11::str(comment);
|
||||
}
|
||||
return docstring;
|
||||
}
|
||||
), none(), none(), "");
|
||||
|
||||
m_base.attr("__members__") = static_property(cpp_function(
|
||||
[](handle arg) -> dict {
|
||||
dict entries = arg.attr("__entries"), m;
|
||||
for (const auto &kv : entries)
|
||||
m[kv.first] = kv.second[int_(0)];
|
||||
return m;
|
||||
}), none(), none(), ""
|
||||
);
|
||||
|
||||
#define PYBIND11_ENUM_OP_STRICT(op, expr) \
|
||||
m_base.attr(op) = cpp_function( \
|
||||
[](object a, object b) { \
|
||||
if (!a.get_type().is(b.get_type())) \
|
||||
throw type_error("Expected an enumeration of matching type!"); \
|
||||
return expr; \
|
||||
}, \
|
||||
is_method(m_base))
|
||||
|
||||
#define PYBIND11_ENUM_OP_CONV(op, expr) \
|
||||
m_base.attr(op) = cpp_function( \
|
||||
[](object a_, object b_) { \
|
||||
int_ a(a_), b(b_); \
|
||||
return expr; \
|
||||
}, \
|
||||
is_method(m_base))
|
||||
|
||||
if (is_convertible) {
|
||||
PYBIND11_ENUM_OP_CONV("__eq__", !b.is_none() && a.equal(b));
|
||||
PYBIND11_ENUM_OP_CONV("__ne__", b.is_none() || !a.equal(b));
|
||||
|
||||
if (is_arithmetic) {
|
||||
PYBIND11_ENUM_OP_CONV("__lt__", a < b);
|
||||
PYBIND11_ENUM_OP_CONV("__gt__", a > b);
|
||||
PYBIND11_ENUM_OP_CONV("__le__", a <= b);
|
||||
PYBIND11_ENUM_OP_CONV("__ge__", a >= b);
|
||||
PYBIND11_ENUM_OP_CONV("__and__", a & b);
|
||||
PYBIND11_ENUM_OP_CONV("__rand__", a & b);
|
||||
PYBIND11_ENUM_OP_CONV("__or__", a | b);
|
||||
PYBIND11_ENUM_OP_CONV("__ror__", a | b);
|
||||
PYBIND11_ENUM_OP_CONV("__xor__", a ^ b);
|
||||
PYBIND11_ENUM_OP_CONV("__rxor__", a ^ b);
|
||||
}
|
||||
} else {
|
||||
PYBIND11_ENUM_OP_STRICT("__eq__", int_(a).equal(int_(b)));
|
||||
PYBIND11_ENUM_OP_STRICT("__ne__", !int_(a).equal(int_(b)));
|
||||
|
||||
if (is_arithmetic) {
|
||||
PYBIND11_ENUM_OP_STRICT("__lt__", int_(a) < int_(b));
|
||||
PYBIND11_ENUM_OP_STRICT("__gt__", int_(a) > int_(b));
|
||||
PYBIND11_ENUM_OP_STRICT("__le__", int_(a) <= int_(b));
|
||||
PYBIND11_ENUM_OP_STRICT("__ge__", int_(a) >= int_(b));
|
||||
}
|
||||
}
|
||||
|
||||
#undef PYBIND11_ENUM_OP_CONV
|
||||
#undef PYBIND11_ENUM_OP_STRICT
|
||||
|
||||
object getstate = cpp_function(
|
||||
[](object arg) { return int_(arg); }, is_method(m_base));
|
||||
|
||||
m_base.attr("__getstate__") = getstate;
|
||||
m_base.attr("__hash__") = getstate;
|
||||
}
|
||||
|
||||
PYBIND11_NOINLINE void value(char const* name_, object value, const char *doc = nullptr) {
|
||||
dict entries = m_base.attr("__entries");
|
||||
str name(name_);
|
||||
if (entries.contains(name)) {
|
||||
std::string type_name = (std::string) str(m_base.attr("__name__"));
|
||||
throw value_error(type_name + ": element \"" + std::string(name_) + "\" already exists!");
|
||||
}
|
||||
|
||||
entries[name] = std::make_pair(value, doc);
|
||||
m_base.attr(name) = value;
|
||||
}
|
||||
|
||||
PYBIND11_NOINLINE void export_values() {
|
||||
dict entries = m_base.attr("__entries");
|
||||
for (const auto &kv : entries)
|
||||
m_parent.attr(kv.first) = kv.second[int_(0)];
|
||||
}
|
||||
|
||||
handle m_base;
|
||||
handle m_parent;
|
||||
};
|
||||
|
||||
NAMESPACE_END(detail)
|
||||
|
||||
/// Binds C++ enumerations and enumeration classes to Python
|
||||
template <typename Type> class enum_ : public class_<Type> {
|
||||
public:
|
||||
@ -1370,109 +1510,33 @@ public:
|
||||
|
||||
template <typename... Extra>
|
||||
enum_(const handle &scope, const char *name, const Extra&... extra)
|
||||
: class_<Type>(scope, name, extra...), m_entries(), m_parent(scope) {
|
||||
|
||||
: class_<Type>(scope, name, extra...), m_base(*this, scope) {
|
||||
constexpr bool is_arithmetic = detail::any_of<std::is_same<arithmetic, Extra>...>::value;
|
||||
constexpr bool is_convertible = std::is_convertible<Type, Scalar>::value;
|
||||
m_base.init(is_arithmetic, is_convertible);
|
||||
|
||||
auto m_entries_ptr = m_entries.inc_ref().ptr();
|
||||
def("__repr__", [name, m_entries_ptr](Type value) -> pybind11::str {
|
||||
for (const auto &kv : reinterpret_borrow<dict>(m_entries_ptr)) {
|
||||
if (pybind11::cast<Type>(kv.second[int_(0)]) == value)
|
||||
return pybind11::str("{}.{}").format(name, kv.first);
|
||||
}
|
||||
return pybind11::str("{}.???").format(name);
|
||||
});
|
||||
def_property_readonly("name", [m_entries_ptr](Type value) -> pybind11::str {
|
||||
for (const auto &kv : reinterpret_borrow<dict>(m_entries_ptr)) {
|
||||
if (pybind11::cast<Type>(kv.second[int_(0)]) == value)
|
||||
return pybind11::str(kv.first);
|
||||
}
|
||||
return pybind11::str("???");
|
||||
});
|
||||
def_property_readonly_static("__doc__", [m_entries_ptr](handle self_) {
|
||||
std::string docstring;
|
||||
const char *tp_doc = ((PyTypeObject *) self_.ptr())->tp_doc;
|
||||
if (tp_doc)
|
||||
docstring += std::string(tp_doc) + "\n\n";
|
||||
docstring += "Members:";
|
||||
for (const auto &kv : reinterpret_borrow<dict>(m_entries_ptr)) {
|
||||
auto key = std::string(pybind11::str(kv.first));
|
||||
auto comment = kv.second[int_(1)];
|
||||
docstring += "\n\n " + key;
|
||||
if (!comment.is_none())
|
||||
docstring += " : " + (std::string) pybind11::str(comment);
|
||||
}
|
||||
return docstring;
|
||||
});
|
||||
def_property_readonly_static("__members__", [m_entries_ptr](handle /* self_ */) {
|
||||
dict m;
|
||||
for (const auto &kv : reinterpret_borrow<dict>(m_entries_ptr))
|
||||
m[kv.first] = kv.second[int_(0)];
|
||||
return m;
|
||||
}, return_value_policy::copy);
|
||||
def(init([](Scalar i) { return static_cast<Type>(i); }));
|
||||
def("__int__", [](Type value) { return (Scalar) value; });
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
def("__long__", [](Type value) { return (Scalar) value; });
|
||||
#endif
|
||||
def("__eq__", [](const Type &value, Type *value2) { return value2 && value == *value2; });
|
||||
def("__ne__", [](const Type &value, Type *value2) { return !value2 || value != *value2; });
|
||||
if (is_arithmetic) {
|
||||
def("__lt__", [](const Type &value, Type *value2) { return value2 && value < *value2; });
|
||||
def("__gt__", [](const Type &value, Type *value2) { return value2 && value > *value2; });
|
||||
def("__le__", [](const Type &value, Type *value2) { return value2 && value <= *value2; });
|
||||
def("__ge__", [](const Type &value, Type *value2) { return value2 && value >= *value2; });
|
||||
}
|
||||
if (std::is_convertible<Type, Scalar>::value) {
|
||||
// Don't provide comparison with the underlying type if the enum isn't convertible,
|
||||
// i.e. if Type is a scoped enum, mirroring the C++ behaviour. (NB: we explicitly
|
||||
// convert Type to Scalar below anyway because this needs to compile).
|
||||
def("__eq__", [](const Type &value, Scalar value2) { return (Scalar) value == value2; });
|
||||
def("__ne__", [](const Type &value, Scalar value2) { return (Scalar) value != value2; });
|
||||
if (is_arithmetic) {
|
||||
def("__lt__", [](const Type &value, Scalar value2) { return (Scalar) value < value2; });
|
||||
def("__gt__", [](const Type &value, Scalar value2) { return (Scalar) value > value2; });
|
||||
def("__le__", [](const Type &value, Scalar value2) { return (Scalar) value <= value2; });
|
||||
def("__ge__", [](const Type &value, Scalar value2) { return (Scalar) value >= value2; });
|
||||
def("__invert__", [](const Type &value) { return ~((Scalar) value); });
|
||||
def("__and__", [](const Type &value, Scalar value2) { return (Scalar) value & value2; });
|
||||
def("__or__", [](const Type &value, Scalar value2) { return (Scalar) value | value2; });
|
||||
def("__xor__", [](const Type &value, Scalar value2) { return (Scalar) value ^ value2; });
|
||||
def("__rand__", [](const Type &value, Scalar value2) { return (Scalar) value & value2; });
|
||||
def("__ror__", [](const Type &value, Scalar value2) { return (Scalar) value | value2; });
|
||||
def("__rxor__", [](const Type &value, Scalar value2) { return (Scalar) value ^ value2; });
|
||||
def("__and__", [](const Type &value, const Type &value2) { return (Scalar) value & (Scalar) value2; });
|
||||
def("__or__", [](const Type &value, const Type &value2) { return (Scalar) value | (Scalar) value2; });
|
||||
def("__xor__", [](const Type &value, const Type &value2) { return (Scalar) value ^ (Scalar) value2; });
|
||||
}
|
||||
}
|
||||
def("__hash__", [](const Type &value) { return (Scalar) value; });
|
||||
// Pickling and unpickling -- needed for use with the 'multiprocessing' module
|
||||
def(pickle([](const Type &value) { return pybind11::make_tuple((Scalar) value); },
|
||||
[](tuple t) { return static_cast<Type>(t[0].cast<Scalar>()); }));
|
||||
def("__setstate__", [](Type &value, Scalar arg) { value = static_cast<Type>(arg); });
|
||||
}
|
||||
|
||||
/// Export enumeration entries into the parent scope
|
||||
enum_& export_values() {
|
||||
for (const auto &kv : m_entries)
|
||||
m_parent.attr(kv.first) = kv.second[int_(0)];
|
||||
m_base.export_values();
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Add an enumeration entry
|
||||
enum_& value(char const* name, Type value, const char *doc = nullptr) {
|
||||
auto v = pybind11::cast(value, return_value_policy::copy);
|
||||
this->attr(name) = v;
|
||||
auto name_converted = pybind11::str(name);
|
||||
if (m_entries.contains(name_converted))
|
||||
throw value_error("Enum error - element with name: " + std::string(name) + " already exists");
|
||||
m_entries[name_converted] = std::make_pair(v, doc);
|
||||
m_base.value(name, pybind11::cast(value, return_value_policy::copy), doc);
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
dict m_entries;
|
||||
handle m_parent;
|
||||
detail::enum_base m_base;
|
||||
};
|
||||
|
||||
NAMESPACE_BEGIN(detail)
|
||||
|
@ -153,4 +153,4 @@ def test_enum_to_int():
|
||||
def test_duplicate_enum_name():
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
m.register_bad_enum()
|
||||
assert str(excinfo.value) == "Enum error - element with name: ONE already exists"
|
||||
assert str(excinfo.value) == 'SimpleEnum: element "ONE" already exists!'
|
||||
|
@ -34,3 +34,9 @@ def test_roundtrip_with_dict(cls_name):
|
||||
assert p2.value == p.value
|
||||
assert p2.extra == p.extra
|
||||
assert p2.dynamic == p.dynamic
|
||||
|
||||
|
||||
def test_enum_pickle():
|
||||
from pybind11_tests import enums as e
|
||||
data = pickle.dumps(e.EOne, 2)
|
||||
assert e.EOne == pickle.loads(data)
|
||||
|
Loading…
Reference in New Issue
Block a user