testing.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. import contextlib
  2. import io
  3. import os
  4. import shlex
  5. import shutil
  6. import sys
  7. import tempfile
  8. import typing as t
  9. from types import TracebackType
  10. from . import _compat
  11. from . import formatting
  12. from . import termui
  13. from . import utils
  14. from ._compat import _find_binary_reader
  15. if t.TYPE_CHECKING:
  16. from .core import BaseCommand
  17. class EchoingStdin:
  18. def __init__(self, input: t.BinaryIO, output: t.BinaryIO) -> None:
  19. self._input = input
  20. self._output = output
  21. self._paused = False
  22. def __getattr__(self, x: str) -> t.Any:
  23. return getattr(self._input, x)
  24. def _echo(self, rv: bytes) -> bytes:
  25. if not self._paused:
  26. self._output.write(rv)
  27. return rv
  28. def read(self, n: int = -1) -> bytes:
  29. return self._echo(self._input.read(n))
  30. def read1(self, n: int = -1) -> bytes:
  31. return self._echo(self._input.read1(n)) # type: ignore
  32. def readline(self, n: int = -1) -> bytes:
  33. return self._echo(self._input.readline(n))
  34. def readlines(self) -> t.List[bytes]:
  35. return [self._echo(x) for x in self._input.readlines()]
  36. def __iter__(self) -> t.Iterator[bytes]:
  37. return iter(self._echo(x) for x in self._input)
  38. def __repr__(self) -> str:
  39. return repr(self._input)
  40. @contextlib.contextmanager
  41. def _pause_echo(stream: t.Optional[EchoingStdin]) -> t.Iterator[None]:
  42. if stream is None:
  43. yield
  44. else:
  45. stream._paused = True
  46. yield
  47. stream._paused = False
  48. class _NamedTextIOWrapper(io.TextIOWrapper):
  49. def __init__(
  50. self, buffer: t.BinaryIO, name: str, mode: str, **kwargs: t.Any
  51. ) -> None:
  52. super().__init__(buffer, **kwargs)
  53. self._name = name
  54. self._mode = mode
  55. @property
  56. def name(self) -> str:
  57. return self._name
  58. @property
  59. def mode(self) -> str:
  60. return self._mode
  61. def make_input_stream(
  62. input: t.Optional[t.Union[str, bytes, t.IO[t.Any]]], charset: str
  63. ) -> t.BinaryIO:
  64. # Is already an input stream.
  65. if hasattr(input, "read"):
  66. rv = _find_binary_reader(t.cast(t.IO[t.Any], input))
  67. if rv is not None:
  68. return rv
  69. raise TypeError("Could not find binary reader for input stream.")
  70. if input is None:
  71. input = b""
  72. elif isinstance(input, str):
  73. input = input.encode(charset)
  74. return io.BytesIO(input)
  75. class Result:
  76. """Holds the captured result of an invoked CLI script."""
  77. def __init__(
  78. self,
  79. runner: "CliRunner",
  80. stdout_bytes: bytes,
  81. stderr_bytes: t.Optional[bytes],
  82. return_value: t.Any,
  83. exit_code: int,
  84. exception: t.Optional[BaseException],
  85. exc_info: t.Optional[
  86. t.Tuple[t.Type[BaseException], BaseException, TracebackType]
  87. ] = None,
  88. ):
  89. #: The runner that created the result
  90. self.runner = runner
  91. #: The standard output as bytes.
  92. self.stdout_bytes = stdout_bytes
  93. #: The standard error as bytes, or None if not available
  94. self.stderr_bytes = stderr_bytes
  95. #: The value returned from the invoked command.
  96. #:
  97. #: .. versionadded:: 8.0
  98. self.return_value = return_value
  99. #: The exit code as integer.
  100. self.exit_code = exit_code
  101. #: The exception that happened if one did.
  102. self.exception = exception
  103. #: The traceback
  104. self.exc_info = exc_info
  105. @property
  106. def output(self) -> str:
  107. """The (standard) output as unicode string."""
  108. return self.stdout
  109. @property
  110. def stdout(self) -> str:
  111. """The standard output as unicode string."""
  112. return self.stdout_bytes.decode(self.runner.charset, "replace").replace(
  113. "\r\n", "\n"
  114. )
  115. @property
  116. def stderr(self) -> str:
  117. """The standard error as unicode string."""
  118. if self.stderr_bytes is None:
  119. raise ValueError("stderr not separately captured")
  120. return self.stderr_bytes.decode(self.runner.charset, "replace").replace(
  121. "\r\n", "\n"
  122. )
  123. def __repr__(self) -> str:
  124. exc_str = repr(self.exception) if self.exception else "okay"
  125. return f"<{type(self).__name__} {exc_str}>"
  126. class CliRunner:
  127. """The CLI runner provides functionality to invoke a Click command line
  128. script for unittesting purposes in a isolated environment. This only
  129. works in single-threaded systems without any concurrency as it changes the
  130. global interpreter state.
  131. :param charset: the character set for the input and output data.
  132. :param env: a dictionary with environment variables for overriding.
  133. :param echo_stdin: if this is set to `True`, then reading from stdin writes
  134. to stdout. This is useful for showing examples in
  135. some circumstances. Note that regular prompts
  136. will automatically echo the input.
  137. :param mix_stderr: if this is set to `False`, then stdout and stderr are
  138. preserved as independent streams. This is useful for
  139. Unix-philosophy apps that have predictable stdout and
  140. noisy stderr, such that each may be measured
  141. independently
  142. """
  143. def __init__(
  144. self,
  145. charset: str = "utf-8",
  146. env: t.Optional[t.Mapping[str, t.Optional[str]]] = None,
  147. echo_stdin: bool = False,
  148. mix_stderr: bool = True,
  149. ) -> None:
  150. self.charset = charset
  151. self.env: t.Mapping[str, t.Optional[str]] = env or {}
  152. self.echo_stdin = echo_stdin
  153. self.mix_stderr = mix_stderr
  154. def get_default_prog_name(self, cli: "BaseCommand") -> str:
  155. """Given a command object it will return the default program name
  156. for it. The default is the `name` attribute or ``"root"`` if not
  157. set.
  158. """
  159. return cli.name or "root"
  160. def make_env(
  161. self, overrides: t.Optional[t.Mapping[str, t.Optional[str]]] = None
  162. ) -> t.Mapping[str, t.Optional[str]]:
  163. """Returns the environment overrides for invoking a script."""
  164. rv = dict(self.env)
  165. if overrides:
  166. rv.update(overrides)
  167. return rv
  168. @contextlib.contextmanager
  169. def isolation(
  170. self,
  171. input: t.Optional[t.Union[str, bytes, t.IO[t.Any]]] = None,
  172. env: t.Optional[t.Mapping[str, t.Optional[str]]] = None,
  173. color: bool = False,
  174. ) -> t.Iterator[t.Tuple[io.BytesIO, t.Optional[io.BytesIO]]]:
  175. """A context manager that sets up the isolation for invoking of a
  176. command line tool. This sets up stdin with the given input data
  177. and `os.environ` with the overrides from the given dictionary.
  178. This also rebinds some internals in Click to be mocked (like the
  179. prompt functionality).
  180. This is automatically done in the :meth:`invoke` method.
  181. :param input: the input stream to put into sys.stdin.
  182. :param env: the environment overrides as dictionary.
  183. :param color: whether the output should contain color codes. The
  184. application can still override this explicitly.
  185. .. versionchanged:: 8.0
  186. ``stderr`` is opened with ``errors="backslashreplace"``
  187. instead of the default ``"strict"``.
  188. .. versionchanged:: 4.0
  189. Added the ``color`` parameter.
  190. """
  191. bytes_input = make_input_stream(input, self.charset)
  192. echo_input = None
  193. old_stdin = sys.stdin
  194. old_stdout = sys.stdout
  195. old_stderr = sys.stderr
  196. old_forced_width = formatting.FORCED_WIDTH
  197. formatting.FORCED_WIDTH = 80
  198. env = self.make_env(env)
  199. bytes_output = io.BytesIO()
  200. if self.echo_stdin:
  201. bytes_input = echo_input = t.cast(
  202. t.BinaryIO, EchoingStdin(bytes_input, bytes_output)
  203. )
  204. sys.stdin = text_input = _NamedTextIOWrapper(
  205. bytes_input, encoding=self.charset, name="<stdin>", mode="r"
  206. )
  207. if self.echo_stdin:
  208. # Force unbuffered reads, otherwise TextIOWrapper reads a
  209. # large chunk which is echoed early.
  210. text_input._CHUNK_SIZE = 1 # type: ignore
  211. sys.stdout = _NamedTextIOWrapper(
  212. bytes_output, encoding=self.charset, name="<stdout>", mode="w"
  213. )
  214. bytes_error = None
  215. if self.mix_stderr:
  216. sys.stderr = sys.stdout
  217. else:
  218. bytes_error = io.BytesIO()
  219. sys.stderr = _NamedTextIOWrapper(
  220. bytes_error,
  221. encoding=self.charset,
  222. name="<stderr>",
  223. mode="w",
  224. errors="backslashreplace",
  225. )
  226. @_pause_echo(echo_input) # type: ignore
  227. def visible_input(prompt: t.Optional[str] = None) -> str:
  228. sys.stdout.write(prompt or "")
  229. val = text_input.readline().rstrip("\r\n")
  230. sys.stdout.write(f"{val}\n")
  231. sys.stdout.flush()
  232. return val
  233. @_pause_echo(echo_input) # type: ignore
  234. def hidden_input(prompt: t.Optional[str] = None) -> str:
  235. sys.stdout.write(f"{prompt or ''}\n")
  236. sys.stdout.flush()
  237. return text_input.readline().rstrip("\r\n")
  238. @_pause_echo(echo_input) # type: ignore
  239. def _getchar(echo: bool) -> str:
  240. char = sys.stdin.read(1)
  241. if echo:
  242. sys.stdout.write(char)
  243. sys.stdout.flush()
  244. return char
  245. default_color = color
  246. def should_strip_ansi(
  247. stream: t.Optional[t.IO[t.Any]] = None, color: t.Optional[bool] = None
  248. ) -> bool:
  249. if color is None:
  250. return not default_color
  251. return not color
  252. old_visible_prompt_func = termui.visible_prompt_func
  253. old_hidden_prompt_func = termui.hidden_prompt_func
  254. old__getchar_func = termui._getchar
  255. old_should_strip_ansi = utils.should_strip_ansi # type: ignore
  256. old__compat_should_strip_ansi = _compat.should_strip_ansi
  257. termui.visible_prompt_func = visible_input
  258. termui.hidden_prompt_func = hidden_input
  259. termui._getchar = _getchar
  260. utils.should_strip_ansi = should_strip_ansi # type: ignore
  261. _compat.should_strip_ansi = should_strip_ansi
  262. old_env = {}
  263. try:
  264. for key, value in env.items():
  265. old_env[key] = os.environ.get(key)
  266. if value is None:
  267. try:
  268. del os.environ[key]
  269. except Exception:
  270. pass
  271. else:
  272. os.environ[key] = value
  273. yield (bytes_output, bytes_error)
  274. finally:
  275. for key, value in old_env.items():
  276. if value is None:
  277. try:
  278. del os.environ[key]
  279. except Exception:
  280. pass
  281. else:
  282. os.environ[key] = value
  283. sys.stdout = old_stdout
  284. sys.stderr = old_stderr
  285. sys.stdin = old_stdin
  286. termui.visible_prompt_func = old_visible_prompt_func
  287. termui.hidden_prompt_func = old_hidden_prompt_func
  288. termui._getchar = old__getchar_func
  289. utils.should_strip_ansi = old_should_strip_ansi # type: ignore
  290. _compat.should_strip_ansi = old__compat_should_strip_ansi
  291. formatting.FORCED_WIDTH = old_forced_width
  292. def invoke(
  293. self,
  294. cli: "BaseCommand",
  295. args: t.Optional[t.Union[str, t.Sequence[str]]] = None,
  296. input: t.Optional[t.Union[str, bytes, t.IO[t.Any]]] = None,
  297. env: t.Optional[t.Mapping[str, t.Optional[str]]] = None,
  298. catch_exceptions: bool = True,
  299. color: bool = False,
  300. **extra: t.Any,
  301. ) -> Result:
  302. """Invokes a command in an isolated environment. The arguments are
  303. forwarded directly to the command line script, the `extra` keyword
  304. arguments are passed to the :meth:`~clickpkg.Command.main` function of
  305. the command.
  306. This returns a :class:`Result` object.
  307. :param cli: the command to invoke
  308. :param args: the arguments to invoke. It may be given as an iterable
  309. or a string. When given as string it will be interpreted
  310. as a Unix shell command. More details at
  311. :func:`shlex.split`.
  312. :param input: the input data for `sys.stdin`.
  313. :param env: the environment overrides.
  314. :param catch_exceptions: Whether to catch any other exceptions than
  315. ``SystemExit``.
  316. :param extra: the keyword arguments to pass to :meth:`main`.
  317. :param color: whether the output should contain color codes. The
  318. application can still override this explicitly.
  319. .. versionchanged:: 8.0
  320. The result object has the ``return_value`` attribute with
  321. the value returned from the invoked command.
  322. .. versionchanged:: 4.0
  323. Added the ``color`` parameter.
  324. .. versionchanged:: 3.0
  325. Added the ``catch_exceptions`` parameter.
  326. .. versionchanged:: 3.0
  327. The result object has the ``exc_info`` attribute with the
  328. traceback if available.
  329. """
  330. exc_info = None
  331. with self.isolation(input=input, env=env, color=color) as outstreams:
  332. return_value = None
  333. exception: t.Optional[BaseException] = None
  334. exit_code = 0
  335. if isinstance(args, str):
  336. args = shlex.split(args)
  337. try:
  338. prog_name = extra.pop("prog_name")
  339. except KeyError:
  340. prog_name = self.get_default_prog_name(cli)
  341. try:
  342. return_value = cli.main(args=args or (), prog_name=prog_name, **extra)
  343. except SystemExit as e:
  344. exc_info = sys.exc_info()
  345. e_code = t.cast(t.Optional[t.Union[int, t.Any]], e.code)
  346. if e_code is None:
  347. e_code = 0
  348. if e_code != 0:
  349. exception = e
  350. if not isinstance(e_code, int):
  351. sys.stdout.write(str(e_code))
  352. sys.stdout.write("\n")
  353. e_code = 1
  354. exit_code = e_code
  355. except Exception as e:
  356. if not catch_exceptions:
  357. raise
  358. exception = e
  359. exit_code = 1
  360. exc_info = sys.exc_info()
  361. finally:
  362. sys.stdout.flush()
  363. stdout = outstreams[0].getvalue()
  364. if self.mix_stderr:
  365. stderr = None
  366. else:
  367. stderr = outstreams[1].getvalue() # type: ignore
  368. return Result(
  369. runner=self,
  370. stdout_bytes=stdout,
  371. stderr_bytes=stderr,
  372. return_value=return_value,
  373. exit_code=exit_code,
  374. exception=exception,
  375. exc_info=exc_info, # type: ignore
  376. )
  377. @contextlib.contextmanager
  378. def isolated_filesystem(
  379. self, temp_dir: t.Optional[t.Union[str, "os.PathLike[str]"]] = None
  380. ) -> t.Iterator[str]:
  381. """A context manager that creates a temporary directory and
  382. changes the current working directory to it. This isolates tests
  383. that affect the contents of the CWD to prevent them from
  384. interfering with each other.
  385. :param temp_dir: Create the temporary directory under this
  386. directory. If given, the created directory is not removed
  387. when exiting.
  388. .. versionchanged:: 8.0
  389. Added the ``temp_dir`` parameter.
  390. """
  391. cwd = os.getcwd()
  392. dt = tempfile.mkdtemp(dir=temp_dir)
  393. os.chdir(dt)
  394. try:
  395. yield dt
  396. finally:
  397. os.chdir(cwd)
  398. if temp_dir is None:
  399. try:
  400. shutil.rmtree(dt)
  401. except OSError:
  402. pass