_reloader.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. from __future__ import annotations
  2. import fnmatch
  3. import os
  4. import subprocess
  5. import sys
  6. import threading
  7. import time
  8. import typing as t
  9. from itertools import chain
  10. from pathlib import PurePath
  11. from ._internal import _log
  12. # The various system prefixes where imports are found. Base values are
  13. # different when running in a virtualenv. All reloaders will ignore the
  14. # base paths (usually the system installation). The stat reloader won't
  15. # scan the virtualenv paths, it will only include modules that are
  16. # already imported.
  17. _ignore_always = tuple({sys.base_prefix, sys.base_exec_prefix})
  18. prefix = {*_ignore_always, sys.prefix, sys.exec_prefix}
  19. if hasattr(sys, "real_prefix"):
  20. # virtualenv < 20
  21. prefix.add(sys.real_prefix)
  22. _stat_ignore_scan = tuple(prefix)
  23. del prefix
  24. _ignore_common_dirs = {
  25. "__pycache__",
  26. ".git",
  27. ".hg",
  28. ".tox",
  29. ".nox",
  30. ".pytest_cache",
  31. ".mypy_cache",
  32. }
  33. def _iter_module_paths() -> t.Iterator[str]:
  34. """Find the filesystem paths associated with imported modules."""
  35. # List is in case the value is modified by the app while updating.
  36. for module in list(sys.modules.values()):
  37. name = getattr(module, "__file__", None)
  38. if name is None or name.startswith(_ignore_always):
  39. continue
  40. while not os.path.isfile(name):
  41. # Zip file, find the base file without the module path.
  42. old = name
  43. name = os.path.dirname(name)
  44. if name == old: # skip if it was all directories somehow
  45. break
  46. else:
  47. yield name
  48. def _remove_by_pattern(paths: set[str], exclude_patterns: set[str]) -> None:
  49. for pattern in exclude_patterns:
  50. paths.difference_update(fnmatch.filter(paths, pattern))
  51. def _find_stat_paths(
  52. extra_files: set[str], exclude_patterns: set[str]
  53. ) -> t.Iterable[str]:
  54. """Find paths for the stat reloader to watch. Returns imported
  55. module files, Python files under non-system paths. Extra files and
  56. Python files under extra directories can also be scanned.
  57. System paths have to be excluded for efficiency. Non-system paths,
  58. such as a project root or ``sys.path.insert``, should be the paths
  59. of interest to the user anyway.
  60. """
  61. paths = set()
  62. for path in chain(list(sys.path), extra_files):
  63. path = os.path.abspath(path)
  64. if os.path.isfile(path):
  65. # zip file on sys.path, or extra file
  66. paths.add(path)
  67. continue
  68. parent_has_py = {os.path.dirname(path): True}
  69. for root, dirs, files in os.walk(path):
  70. # Optimizations: ignore system prefixes, __pycache__ will
  71. # have a py or pyc module at the import path, ignore some
  72. # common known dirs such as version control and tool caches.
  73. if (
  74. root.startswith(_stat_ignore_scan)
  75. or os.path.basename(root) in _ignore_common_dirs
  76. ):
  77. dirs.clear()
  78. continue
  79. has_py = False
  80. for name in files:
  81. if name.endswith((".py", ".pyc")):
  82. has_py = True
  83. paths.add(os.path.join(root, name))
  84. # Optimization: stop scanning a directory if neither it nor
  85. # its parent contained Python files.
  86. if not (has_py or parent_has_py[os.path.dirname(root)]):
  87. dirs.clear()
  88. continue
  89. parent_has_py[root] = has_py
  90. paths.update(_iter_module_paths())
  91. _remove_by_pattern(paths, exclude_patterns)
  92. return paths
  93. def _find_watchdog_paths(
  94. extra_files: set[str], exclude_patterns: set[str]
  95. ) -> t.Iterable[str]:
  96. """Find paths for the stat reloader to watch. Looks at the same
  97. sources as the stat reloader, but watches everything under
  98. directories instead of individual files.
  99. """
  100. dirs = set()
  101. for name in chain(list(sys.path), extra_files):
  102. name = os.path.abspath(name)
  103. if os.path.isfile(name):
  104. name = os.path.dirname(name)
  105. dirs.add(name)
  106. for name in _iter_module_paths():
  107. dirs.add(os.path.dirname(name))
  108. _remove_by_pattern(dirs, exclude_patterns)
  109. return _find_common_roots(dirs)
  110. def _find_common_roots(paths: t.Iterable[str]) -> t.Iterable[str]:
  111. root: dict[str, dict[str, t.Any]] = {}
  112. for chunks in sorted((PurePath(x).parts for x in paths), key=len, reverse=True):
  113. node = root
  114. for chunk in chunks:
  115. node = node.setdefault(chunk, {})
  116. node.clear()
  117. rv = set()
  118. def _walk(node: t.Mapping[str, dict[str, t.Any]], path: tuple[str, ...]) -> None:
  119. for prefix, child in node.items():
  120. _walk(child, path + (prefix,))
  121. # If there are no more nodes, and a path has been accumulated, add it.
  122. # Path may be empty if the "" entry is in sys.path.
  123. if not node and path:
  124. rv.add(os.path.join(*path))
  125. _walk(root, ())
  126. return rv
  127. def _get_args_for_reloading() -> list[str]:
  128. """Determine how the script was executed, and return the args needed
  129. to execute it again in a new process.
  130. """
  131. if sys.version_info >= (3, 10):
  132. # sys.orig_argv, added in Python 3.10, contains the exact args used to invoke
  133. # Python. Still replace argv[0] with sys.executable for accuracy.
  134. return [sys.executable, *sys.orig_argv[1:]]
  135. rv = [sys.executable]
  136. py_script = sys.argv[0]
  137. args = sys.argv[1:]
  138. # Need to look at main module to determine how it was executed.
  139. __main__ = sys.modules["__main__"]
  140. # The value of __package__ indicates how Python was called. It may
  141. # not exist if a setuptools script is installed as an egg. It may be
  142. # set incorrectly for entry points created with pip on Windows.
  143. if getattr(__main__, "__package__", None) is None or (
  144. os.name == "nt"
  145. and __main__.__package__ == ""
  146. and not os.path.exists(py_script)
  147. and os.path.exists(f"{py_script}.exe")
  148. ):
  149. # Executed a file, like "python app.py".
  150. py_script = os.path.abspath(py_script)
  151. if os.name == "nt":
  152. # Windows entry points have ".exe" extension and should be
  153. # called directly.
  154. if not os.path.exists(py_script) and os.path.exists(f"{py_script}.exe"):
  155. py_script += ".exe"
  156. if (
  157. os.path.splitext(sys.executable)[1] == ".exe"
  158. and os.path.splitext(py_script)[1] == ".exe"
  159. ):
  160. rv.pop(0)
  161. rv.append(py_script)
  162. else:
  163. # Executed a module, like "python -m werkzeug.serving".
  164. if os.path.isfile(py_script):
  165. # Rewritten by Python from "-m script" to "/path/to/script.py".
  166. py_module = t.cast(str, __main__.__package__)
  167. name = os.path.splitext(os.path.basename(py_script))[0]
  168. if name != "__main__":
  169. py_module += f".{name}"
  170. else:
  171. # Incorrectly rewritten by pydevd debugger from "-m script" to "script".
  172. py_module = py_script
  173. rv.extend(("-m", py_module.lstrip(".")))
  174. rv.extend(args)
  175. return rv
  176. class ReloaderLoop:
  177. name = ""
  178. def __init__(
  179. self,
  180. extra_files: t.Iterable[str] | None = None,
  181. exclude_patterns: t.Iterable[str] | None = None,
  182. interval: int | float = 1,
  183. ) -> None:
  184. self.extra_files: set[str] = {os.path.abspath(x) for x in extra_files or ()}
  185. self.exclude_patterns: set[str] = set(exclude_patterns or ())
  186. self.interval = interval
  187. def __enter__(self) -> ReloaderLoop:
  188. """Do any setup, then run one step of the watch to populate the
  189. initial filesystem state.
  190. """
  191. self.run_step()
  192. return self
  193. def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore
  194. """Clean up any resources associated with the reloader."""
  195. pass
  196. def run(self) -> None:
  197. """Continually run the watch step, sleeping for the configured
  198. interval after each step.
  199. """
  200. while True:
  201. self.run_step()
  202. time.sleep(self.interval)
  203. def run_step(self) -> None:
  204. """Run one step for watching the filesystem. Called once to set
  205. up initial state, then repeatedly to update it.
  206. """
  207. pass
  208. def restart_with_reloader(self) -> int:
  209. """Spawn a new Python interpreter with the same arguments as the
  210. current one, but running the reloader thread.
  211. """
  212. while True:
  213. _log("info", f" * Restarting with {self.name}")
  214. args = _get_args_for_reloading()
  215. new_environ = os.environ.copy()
  216. new_environ["WERKZEUG_RUN_MAIN"] = "true"
  217. exit_code = subprocess.call(args, env=new_environ, close_fds=False)
  218. if exit_code != 3:
  219. return exit_code
  220. def trigger_reload(self, filename: str) -> None:
  221. self.log_reload(filename)
  222. sys.exit(3)
  223. def log_reload(self, filename: str | bytes) -> None:
  224. filename = os.path.abspath(filename)
  225. _log("info", f" * Detected change in {filename!r}, reloading")
  226. class StatReloaderLoop(ReloaderLoop):
  227. name = "stat"
  228. def __enter__(self) -> ReloaderLoop:
  229. self.mtimes: dict[str, float] = {}
  230. return super().__enter__()
  231. def run_step(self) -> None:
  232. for name in _find_stat_paths(self.extra_files, self.exclude_patterns):
  233. try:
  234. mtime = os.stat(name).st_mtime
  235. except OSError:
  236. continue
  237. old_time = self.mtimes.get(name)
  238. if old_time is None:
  239. self.mtimes[name] = mtime
  240. continue
  241. if mtime > old_time:
  242. self.trigger_reload(name)
  243. class WatchdogReloaderLoop(ReloaderLoop):
  244. def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
  245. from watchdog.events import EVENT_TYPE_CLOSED
  246. from watchdog.events import EVENT_TYPE_CREATED
  247. from watchdog.events import EVENT_TYPE_DELETED
  248. from watchdog.events import EVENT_TYPE_MODIFIED
  249. from watchdog.events import EVENT_TYPE_MOVED
  250. from watchdog.events import FileModifiedEvent
  251. from watchdog.events import PatternMatchingEventHandler
  252. from watchdog.observers import Observer
  253. super().__init__(*args, **kwargs)
  254. trigger_reload = self.trigger_reload
  255. class EventHandler(PatternMatchingEventHandler):
  256. def on_any_event(self, event: FileModifiedEvent): # type: ignore
  257. if event.event_type not in {
  258. EVENT_TYPE_CLOSED,
  259. EVENT_TYPE_CREATED,
  260. EVENT_TYPE_DELETED,
  261. EVENT_TYPE_MODIFIED,
  262. EVENT_TYPE_MOVED,
  263. }:
  264. # skip events that don't involve changes to the file
  265. return
  266. trigger_reload(event.src_path)
  267. reloader_name = Observer.__name__.lower() # type: ignore[attr-defined]
  268. if reloader_name.endswith("observer"):
  269. reloader_name = reloader_name[:-8]
  270. self.name = f"watchdog ({reloader_name})"
  271. self.observer = Observer()
  272. # Extra patterns can be non-Python files, match them in addition
  273. # to all Python files in default and extra directories. Ignore
  274. # __pycache__ since a change there will always have a change to
  275. # the source file (or initial pyc file) as well. Ignore Git and
  276. # Mercurial internal changes.
  277. extra_patterns = [p for p in self.extra_files if not os.path.isdir(p)]
  278. self.event_handler = EventHandler(
  279. patterns=["*.py", "*.pyc", "*.zip", *extra_patterns],
  280. ignore_patterns=[
  281. *[f"*/{d}/*" for d in _ignore_common_dirs],
  282. *self.exclude_patterns,
  283. ],
  284. )
  285. self.should_reload = False
  286. def trigger_reload(self, filename: str | bytes) -> None:
  287. # This is called inside an event handler, which means throwing
  288. # SystemExit has no effect.
  289. # https://github.com/gorakhargosh/watchdog/issues/294
  290. self.should_reload = True
  291. self.log_reload(filename)
  292. def __enter__(self) -> ReloaderLoop:
  293. self.watches: dict[str, t.Any] = {}
  294. self.observer.start()
  295. return super().__enter__()
  296. def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore
  297. self.observer.stop()
  298. self.observer.join()
  299. def run(self) -> None:
  300. while not self.should_reload:
  301. self.run_step()
  302. time.sleep(self.interval)
  303. sys.exit(3)
  304. def run_step(self) -> None:
  305. to_delete = set(self.watches)
  306. for path in _find_watchdog_paths(self.extra_files, self.exclude_patterns):
  307. if path not in self.watches:
  308. try:
  309. self.watches[path] = self.observer.schedule(
  310. self.event_handler, path, recursive=True
  311. )
  312. except OSError:
  313. # Clear this path from list of watches We don't want
  314. # the same error message showing again in the next
  315. # iteration.
  316. self.watches[path] = None
  317. to_delete.discard(path)
  318. for path in to_delete:
  319. watch = self.watches.pop(path, None)
  320. if watch is not None:
  321. self.observer.unschedule(watch)
  322. reloader_loops: dict[str, type[ReloaderLoop]] = {
  323. "stat": StatReloaderLoop,
  324. "watchdog": WatchdogReloaderLoop,
  325. }
  326. try:
  327. __import__("watchdog.observers")
  328. except ImportError:
  329. reloader_loops["auto"] = reloader_loops["stat"]
  330. else:
  331. reloader_loops["auto"] = reloader_loops["watchdog"]
  332. def ensure_echo_on() -> None:
  333. """Ensure that echo mode is enabled. Some tools such as PDB disable
  334. it which causes usability issues after a reload."""
  335. # tcgetattr will fail if stdin isn't a tty
  336. if sys.stdin is None or not sys.stdin.isatty():
  337. return
  338. try:
  339. import termios
  340. except ImportError:
  341. return
  342. attributes = termios.tcgetattr(sys.stdin)
  343. if not attributes[3] & termios.ECHO:
  344. attributes[3] |= termios.ECHO
  345. termios.tcsetattr(sys.stdin, termios.TCSANOW, attributes)
  346. def run_with_reloader(
  347. main_func: t.Callable[[], None],
  348. extra_files: t.Iterable[str] | None = None,
  349. exclude_patterns: t.Iterable[str] | None = None,
  350. interval: int | float = 1,
  351. reloader_type: str = "auto",
  352. ) -> None:
  353. """Run the given function in an independent Python interpreter."""
  354. import signal
  355. signal.signal(signal.SIGTERM, lambda *args: sys.exit(0))
  356. reloader = reloader_loops[reloader_type](
  357. extra_files=extra_files, exclude_patterns=exclude_patterns, interval=interval
  358. )
  359. try:
  360. if os.environ.get("WERKZEUG_RUN_MAIN") == "true":
  361. ensure_echo_on()
  362. t = threading.Thread(target=main_func, args=())
  363. t.daemon = True
  364. # Enter the reloader to set up initial state, then start
  365. # the app thread and reloader update loop.
  366. with reloader:
  367. t.start()
  368. reloader.run()
  369. else:
  370. sys.exit(reloader.restart_with_reloader())
  371. except KeyboardInterrupt:
  372. pass