feat: parallel compiler (#2521)

This commit is contained in:
Henry Schreiner 2020-10-02 10:03:35 -04:00 committed by GitHub
parent 07b069a55b
commit b9d00273ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 131 additions and 2 deletions

View File

@ -68,6 +68,23 @@ that is supported via a ``build_ext`` command override; it will only affect
ext_modules=ext_modules ext_modules=ext_modules
) )
Since pybind11 does not require NumPy when building, a light-weight replacement
for NumPy's parallel compilation distutils tool is included. Use it like this:
from pybind11.setup_helpers import ParallelCompile
# Optional multithreaded build
ParallelCompile("NPY_NUM_BUILD_JOBS").install()
setup(...
The argument is the name of an environment variable to control the number of
threads, such as ``NPY_NUM_BUILD_JOBS`` (as used by NumPy), though you can set
something different if you want. You can also pass ``default=N`` to set the
default number of threads (0 will take the number of threads available) and
``max=N``, the maximum number of threads; if you have a large extension you may
want set this to a memory dependent number.
.. _setup_helpers-pep518: .. _setup_helpers-pep518:
PEP 518 requirements (Pip 10+ required) PEP 518 requirements (Pip 10+ required)

View File

@ -49,6 +49,7 @@ except ImportError:
from distutils.extension import Extension as _Extension from distutils.extension import Extension as _Extension
import distutils.errors import distutils.errors
import distutils.ccompiler
WIN = sys.platform.startswith("win32") WIN = sys.platform.startswith("win32")
@ -279,3 +280,108 @@ class build_ext(_build_ext): # noqa: N801
# Python 2 doesn't allow super here, since distutils uses old-style # Python 2 doesn't allow super here, since distutils uses old-style
# classes! # classes!
_build_ext.build_extensions(self) _build_ext.build_extensions(self)
# Optional parallel compile utility
# inspired by: http://stackoverflow.com/questions/11013851/speeding-up-build-process-with-distutils
# and: https://github.com/tbenthompson/cppimport/blob/stable/cppimport/build_module.py
# and NumPy's parallel distutils module:
# https://github.com/numpy/numpy/blob/master/numpy/distutils/ccompiler.py
class ParallelCompile(object):
"""
Make a parallel compile function. Inspired by
numpy.distutils.ccompiler.CCompiler_compile and cppimport.
This takes several arguments that allow you to customize the compile
function created:
envvar: Set an environment variable to control the compilation threads, like NPY_NUM_BUILD_JOBS
default: 0 will automatically multithread, or 1 will only multithread if the envvar is set.
max: The limit for automatic multithreading if non-zero
To use:
ParallelCompile("NPY_NUM_BUILD_JOBS").install()
or:
with ParallelCompile("NPY_NUM_BUILD_JOBS"):
setup(...)
"""
__slots__ = ("envvar", "default", "max", "old")
def __init__(self, envvar=None, default=0, max=0):
self.envvar = envvar
self.default = default
self.max = max
self.old = []
def function(self):
"""
Builds a function object usable as distutils.ccompiler.CCompiler.compile.
"""
def compile_function(
compiler,
sources,
output_dir=None,
macros=None,
include_dirs=None,
debug=0,
extra_preargs=None,
extra_postargs=None,
depends=None,
):
# These lines are directly from distutils.ccompiler.CCompiler
macros, objects, extra_postargs, pp_opts, build = compiler._setup_compile(
output_dir, macros, include_dirs, sources, depends, extra_postargs
)
cc_args = compiler._get_cc_args(pp_opts, debug, extra_preargs)
# The number of threads; start with default.
threads = self.default
# Determine the number of compilation threads, unless set by an environment variable.
if self.envvar is not None:
threads = int(os.environ.get(self.envvar, self.default))
def _single_compile(obj):
try:
src, ext = build[obj]
except KeyError:
return
compiler._compile(obj, src, ext, cc_args, extra_postargs, pp_opts)
try:
import multiprocessing
from multiprocessing.pool import ThreadPool
except ImportError:
threads = 1
if threads == 0:
try:
threads = multiprocessing.cpu_count()
threads = self.max if self.max and self.max < threads else threads
except NotImplementedError:
threads = 1
if threads > 1:
for _ in ThreadPool(threads).imap_unordered(_single_compile, objects):
pass
else:
for ob in objects:
_single_compile(ob)
return objects
return compile_function
def install(self):
distutils.ccompiler.CCompiler.compile = self.function()
return self
def __enter__(self):
self.old.append(distutils.ccompiler.CCompiler.compile)
return self.install()
def __exit__(self, *args):
distutils.ccompiler.CCompiler.compile = self.old.pop()

View File

@ -10,8 +10,9 @@ DIR = os.path.abspath(os.path.dirname(__file__))
MAIN_DIR = os.path.dirname(os.path.dirname(DIR)) MAIN_DIR = os.path.dirname(os.path.dirname(DIR))
@pytest.mark.parametrize("parallel", [False, True])
@pytest.mark.parametrize("std", [11, 0]) @pytest.mark.parametrize("std", [11, 0])
def test_simple_setup_py(monkeypatch, tmpdir, std): def test_simple_setup_py(monkeypatch, tmpdir, parallel, std):
monkeypatch.chdir(tmpdir) monkeypatch.chdir(tmpdir)
monkeypatch.syspath_prepend(MAIN_DIR) monkeypatch.syspath_prepend(MAIN_DIR)
@ -39,13 +40,18 @@ def test_simple_setup_py(monkeypatch, tmpdir, std):
cmdclass["build_ext"] = build_ext cmdclass["build_ext"] = build_ext
parallel = {parallel}
if parallel:
from pybind11.setup_helpers import ParallelCompile
ParallelCompile().install()
setup( setup(
name="simple_setup_package", name="simple_setup_package",
cmdclass=cmdclass, cmdclass=cmdclass,
ext_modules=ext_modules, ext_modules=ext_modules,
) )
""" """
).format(MAIN_DIR=MAIN_DIR, std=std), ).format(MAIN_DIR=MAIN_DIR, std=std, parallel=parallel),
encoding="ascii", encoding="ascii",
) )