姚凯文 姜嘉琪
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

166 lines
5.5 KiB

  1. import contextlib
  2. import os
  3. import platform
  4. import shutil
  5. import sysconfig
  6. from pathlib import Path
  7. from typing import List
  8. import setuptools
  9. from setuptools.command import build_ext
  10. PYTHON_INCLUDE_PATH_PLACEHOLDER = "<PYTHON_INCLUDE_PATH>"
  11. IS_WINDOWS = platform.system() == "Windows"
  12. IS_MAC = platform.system() == "Darwin"
  13. def _get_long_description(fp: str) -> str:
  14. with open(fp, "r", encoding="utf-8") as f:
  15. return f.read()
  16. def _get_version(fp: str) -> str:
  17. """Parse a version string from a file."""
  18. with open(fp, "r") as f:
  19. for line in f:
  20. if "__version__" in line:
  21. delim = '"'
  22. return line.split(delim)[1]
  23. raise RuntimeError(f"could not find a version string in file {fp!r}.")
  24. def _parse_requirements(fp: str) -> List[str]:
  25. with open(fp) as requirements:
  26. return [
  27. line.rstrip()
  28. for line in requirements
  29. if not (line.isspace() or line.startswith("#"))
  30. ]
  31. @contextlib.contextmanager
  32. def temp_fill_include_path(fp: str):
  33. """Temporarily set the Python include path in a file."""
  34. with open(fp, "r+") as f:
  35. try:
  36. content = f.read()
  37. replaced = content.replace(
  38. PYTHON_INCLUDE_PATH_PLACEHOLDER,
  39. Path(sysconfig.get_paths()['include']).as_posix(),
  40. )
  41. f.seek(0)
  42. f.write(replaced)
  43. f.truncate()
  44. yield
  45. finally:
  46. # revert to the original content after exit
  47. f.seek(0)
  48. f.write(content)
  49. f.truncate()
  50. class BazelExtension(setuptools.Extension):
  51. """A C/C++ extension that is defined as a Bazel BUILD target."""
  52. def __init__(self, name: str, bazel_target: str):
  53. super().__init__(name=name, sources=[])
  54. self.bazel_target = bazel_target
  55. stripped_target = bazel_target.split("//")[-1]
  56. self.relpath, self.target_name = stripped_target.split(":")
  57. class BuildBazelExtension(build_ext.build_ext):
  58. """A command that runs Bazel to build a C/C++ extension."""
  59. def run(self):
  60. for ext in self.extensions:
  61. self.bazel_build(ext)
  62. build_ext.build_ext.run(self)
  63. def bazel_build(self, ext: BazelExtension):
  64. """Runs the bazel build to create the package."""
  65. with temp_fill_include_path("WORKSPACE"):
  66. temp_path = Path(self.build_temp)
  67. bazel_argv = [
  68. "bazel",
  69. "build",
  70. ext.bazel_target,
  71. f"--symlink_prefix={temp_path / 'bazel-'}",
  72. f"--compilation_mode={'dbg' if self.debug else 'opt'}",
  73. # C++17 is required by nanobind
  74. f"--cxxopt={'/std:c++17' if IS_WINDOWS else '-std=c++17'}",
  75. ]
  76. if IS_WINDOWS:
  77. # Link with python*.lib.
  78. for library_dir in self.library_dirs:
  79. bazel_argv.append("--linkopt=/LIBPATH:" + library_dir)
  80. elif IS_MAC:
  81. if platform.machine() == "x86_64":
  82. # C++17 needs macOS 10.14 at minimum
  83. bazel_argv.append("--macos_minimum_os=10.14")
  84. # cross-compilation for Mac ARM64 on GitHub Mac x86 runners.
  85. # ARCHFLAGS is set by cibuildwheel before macOS wheel builds.
  86. archflags = os.getenv("ARCHFLAGS", "")
  87. if "arm64" in archflags:
  88. bazel_argv.append("--cpu=darwin_arm64")
  89. bazel_argv.append("--macos_cpus=arm64")
  90. elif platform.machine() == "arm64":
  91. bazel_argv.append("--macos_minimum_os=11.0")
  92. self.spawn(bazel_argv)
  93. shared_lib_suffix = '.dll' if IS_WINDOWS else '.so'
  94. ext_name = ext.target_name + shared_lib_suffix
  95. ext_bazel_bin_path = temp_path / 'bazel-bin' / ext.relpath / ext_name
  96. ext_dest_path = Path(self.get_ext_fullpath(ext.name))
  97. shutil.copyfile(ext_bazel_bin_path, ext_dest_path)
  98. # explicitly call `bazel shutdown` for graceful exit
  99. self.spawn(["bazel", "shutdown"])
  100. setuptools.setup(
  101. name="google_benchmark",
  102. version=_get_version("bindings/python/google_benchmark/__init__.py"),
  103. url="https://github.com/google/benchmark",
  104. description="A library to benchmark code snippets.",
  105. long_description=_get_long_description("README.md"),
  106. long_description_content_type="text/markdown",
  107. author="Google",
  108. author_email="benchmark-py@google.com",
  109. # Contained modules and scripts.
  110. package_dir={"": "bindings/python"},
  111. packages=setuptools.find_packages("bindings/python"),
  112. install_requires=_parse_requirements("bindings/python/requirements.txt"),
  113. cmdclass=dict(build_ext=BuildBazelExtension),
  114. ext_modules=[
  115. BazelExtension(
  116. "google_benchmark._benchmark",
  117. "//bindings/python/google_benchmark:_benchmark",
  118. )
  119. ],
  120. zip_safe=False,
  121. # PyPI package information.
  122. classifiers=[
  123. "Development Status :: 4 - Beta",
  124. "Intended Audience :: Developers",
  125. "Intended Audience :: Science/Research",
  126. "License :: OSI Approved :: Apache Software License",
  127. "Programming Language :: Python :: 3.8",
  128. "Programming Language :: Python :: 3.9",
  129. "Programming Language :: Python :: 3.10",
  130. "Programming Language :: Python :: 3.11",
  131. "Topic :: Software Development :: Testing",
  132. "Topic :: System :: Benchmark",
  133. ],
  134. license="Apache 2.0",
  135. keywords="benchmark",
  136. )