From d12fef38ae4bbae9a7d6859d7618cd10234287f5 Mon Sep 17 00:00:00 2001 From: Ashley Whetter Date: Sun, 29 Nov 2020 20:05:56 -0800 Subject: [PATCH] Standard library enum support via an optional include --- include/pybind11/stdlib_enum.h | 161 +++++++++++++++++++++++++++++++++ 1 file changed, 161 insertions(+) create mode 100644 include/pybind11/stdlib_enum.h diff --git a/include/pybind11/stdlib_enum.h b/include/pybind11/stdlib_enum.h new file mode 100644 index 000000000..8a009fd47 --- /dev/null +++ b/include/pybind11/stdlib_enum.h @@ -0,0 +1,161 @@ +/* + pybind11/stdlib_enum.h: Declaration and conversion enums as Enum objects. + + Copyright (c) 2020 Ashley Whetter + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#pragma once + +#include "detail/common.h" +#include "pybind11.h" + +PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE) +PYBIND11_NAMESPACE_BEGIN(detail) + +template +struct enum_mapper { + handle type = {}; + std::unordered_map values = {}; + + enum_mapper(handle type, const dict& values) : type(type) { + for (auto item : values) { + this->values[item.second.cast()] = type.attr(item.first); + } + } +}; + +template +struct type_caster::value>> { + using underlying_type = typename std::underlying_type::type; + + private: + using base_caster = type_caster_base; + using shared_info_type = typename std::unordered_map; + + static enum_mapper* enum_info() { + auto shared_enum_info = reinterpret_cast( + get_shared_data("_stdlib_enum_internals") + ); + if (shared_enum_info) { + auto it = shared_enum_info->find(std::type_index(typeid(T))); + if (it != shared_enum_info->end()) { + return reinterpret_cast*>(it->second); + } + } + + return nullptr; + } + + base_caster caster; + T value; + + public: + template using cast_op_type = pybind11::detail::cast_op_type; + + operator T*() { return enum_info() ? &value: static_cast(caster); } + operator T&() { return enum_info() ? value: static_cast(caster); } + + static constexpr auto name = base_caster::name; + + static handle cast(const T& src, return_value_policy policy, handle parent) { + enum_mapper* info = enum_info(); + if (info) { + auto it = info->values.find(static_cast(src)); + if (it != info->values.end()) { + return it->second.inc_ref(); + } + } + + return base_caster::cast(src, policy, parent); + } + + bool load(handle src, bool convert) { + if (!src) { + return false; + } + + enum_mapper* info = enum_info(); + if (info) { + if (!isinstance(src, info->type)) { + return false; + } + + value = static_cast(src.attr("value").cast()); + return true; + } + + return caster.load(src, convert); + } + + static void bind(handle type, const dict& values) { + enum_mapper* info = enum_info(); + delete info; + + auto shared_enum_info = &get_or_create_shared_data("_stdlib_enum_internals"); + (*shared_enum_info)[std::type_index(typeid(T))] = reinterpret_cast( + new enum_mapper(type, values) + ); + set_shared_data("_stdlib_enum_internals", shared_enum_info); + } +}; + +PYBIND11_NAMESPACE_END(detail) + +template +class stdlib_enum { + public: + using underlying_type = typename std::underlying_type::type; + + stdlib_enum(handle scope, const char* name) + : scope(scope), name(name) + { + kwargs["value"] = cast(name); + kwargs["names"] = entries; + if (scope) { + if (hasattr(scope, "__module__")) { + kwargs["module"] = scope.attr("__module__"); + } + else if (hasattr(scope, "__name__")) { + kwargs["module"] = scope.attr("__name__"); + } +#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 3 + if (hasattr(scope, "__qualname__")) { + kwargs["qualname"] = scope.attr("__qualname__").cast() + "." + name; + } +#endif + } + } + + ~stdlib_enum() { + object ctor = module::import("enum").attr("Enum"); + object unique = module::import("enum").attr("unique"); + object type = unique(ctor(**kwargs)); + setattr(scope, name, type); + detail::type_caster::bind(type, entries); + } + + stdlib_enum& value(const char* name, T value) & { + add_entry(name, value); + return *this; + } + + stdlib_enum&& value(const char* name, T value) && { + add_entry(name, value); + return std::move(*this); + } + + private: + handle scope; + const char* name; + dict entries; + dict kwargs; + + void add_entry(const char* name, T value) { + entries[name] = cast(static_cast(value)); + } +}; + +PYBIND11_NAMESPACE_END(pybind11)