From 2dbf0297050533e697d065f4b92283e0ecd37581 Mon Sep 17 00:00:00 2001 From: Ivan Smirnov Date: Mon, 31 Oct 2016 14:11:10 +0000 Subject: [PATCH] Add public shared_data API NumPy internals are stored under "_numpy_internals" key. --- include/pybind11/common.h | 31 ++++++++++++++++++++++++++++++- include/pybind11/numpy.h | 14 +++++--------- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/include/pybind11/common.h b/include/pybind11/common.h index 27cd47bef..62198c341 100644 --- a/include/pybind11/common.h +++ b/include/pybind11/common.h @@ -323,7 +323,7 @@ struct internals { std::unordered_set, overload_hash> inactive_overload_cache; std::unordered_map> direct_conversions; std::forward_list registered_exception_translators; - std::unordered_map shared_data; + std::unordered_map shared_data; // Custom data to be shared across extensions #if defined(WITH_THREAD) decltype(PyThread_create_key()) tstate = 0; // Usually an int but a long on Cygwin64 with Python 3.x PyInterpreterState *istate = nullptr; @@ -428,6 +428,35 @@ inline void ignore_unused(const int *) { } NAMESPACE_END(detail) +/// Returns a named pointer that is shared among all extension modules (using the same +/// pybind11 version) running in the current interpreter. Names starting with underscores +/// are reserved for internal usage. Returns `nullptr` if no matching entry was found. +inline PYBIND11_NOINLINE void* get_shared_data(const std::string& name) { + auto& internals = detail::get_internals(); + auto it = internals.shared_data.find(name); + return it != internals.shared_data.end() ? it->second : nullptr; +} + +/// Set the shared data that can be later recovered by `get_shared_data()`. +inline PYBIND11_NOINLINE void *set_shared_data(const std::string& name, void *data) { + detail::get_internals().shared_data[name] = data; + return data; +} + +/// Returns a typed reference to a shared data entry (by using `get_shared_data()`) if +/// such entry exists. Otherwise, a new object of default-constructible type `T` is +/// added to the shared data under the given name and a reference to it is returned. +template T& get_or_create_shared_data(const std::string& name) { + auto& internals = detail::get_internals(); + auto it = internals.shared_data.find(name); + T* ptr = (T*) (it != internals.shared_data.end() ? it->second : nullptr); + if (!ptr) { + ptr = new T(); + internals.shared_data[name] = ptr; + } + return *ptr; +} + /// Fetch and hold an error which was already set in Python class error_already_set : public std::runtime_error { public: diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 19bff6359..b180cb296 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -91,18 +91,14 @@ struct numpy_internals { } }; -inline PYBIND11_NOINLINE numpy_internals* load_numpy_internals() { - auto& shared_data = detail::get_internals().shared_data; - auto it = shared_data.find("numpy_internals"); - if (it != shared_data.end()) - return (numpy_internals *)it->second; - auto ptr = new numpy_internals(); - shared_data["numpy_internals"] = ptr; - return ptr; +inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) { + ptr = &get_or_create_shared_data("_numpy_internals"); } inline numpy_internals& get_numpy_internals() { - static numpy_internals* ptr = load_numpy_internals(); + static numpy_internals* ptr = nullptr; + if (!ptr) + load_numpy_internals(ptr); return *ptr; }