idtracking.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. import typing as t
  2. from . import nodes
  3. from .visitor import NodeVisitor
  4. if t.TYPE_CHECKING:
  5. import typing_extensions as te
  6. VAR_LOAD_PARAMETER = "param"
  7. VAR_LOAD_RESOLVE = "resolve"
  8. VAR_LOAD_ALIAS = "alias"
  9. VAR_LOAD_UNDEFINED = "undefined"
  10. def find_symbols(
  11. nodes: t.Iterable[nodes.Node], parent_symbols: t.Optional["Symbols"] = None
  12. ) -> "Symbols":
  13. sym = Symbols(parent=parent_symbols)
  14. visitor = FrameSymbolVisitor(sym)
  15. for node in nodes:
  16. visitor.visit(node)
  17. return sym
  18. def symbols_for_node(
  19. node: nodes.Node, parent_symbols: t.Optional["Symbols"] = None
  20. ) -> "Symbols":
  21. sym = Symbols(parent=parent_symbols)
  22. sym.analyze_node(node)
  23. return sym
  24. class Symbols:
  25. def __init__(
  26. self, parent: t.Optional["Symbols"] = None, level: t.Optional[int] = None
  27. ) -> None:
  28. if level is None:
  29. if parent is None:
  30. level = 0
  31. else:
  32. level = parent.level + 1
  33. self.level: int = level
  34. self.parent = parent
  35. self.refs: t.Dict[str, str] = {}
  36. self.loads: t.Dict[str, t.Any] = {}
  37. self.stores: t.Set[str] = set()
  38. def analyze_node(self, node: nodes.Node, **kwargs: t.Any) -> None:
  39. visitor = RootVisitor(self)
  40. visitor.visit(node, **kwargs)
  41. def _define_ref(
  42. self, name: str, load: t.Optional[t.Tuple[str, t.Optional[str]]] = None
  43. ) -> str:
  44. ident = f"l_{self.level}_{name}"
  45. self.refs[name] = ident
  46. if load is not None:
  47. self.loads[ident] = load
  48. return ident
  49. def find_load(self, target: str) -> t.Optional[t.Any]:
  50. if target in self.loads:
  51. return self.loads[target]
  52. if self.parent is not None:
  53. return self.parent.find_load(target)
  54. return None
  55. def find_ref(self, name: str) -> t.Optional[str]:
  56. if name in self.refs:
  57. return self.refs[name]
  58. if self.parent is not None:
  59. return self.parent.find_ref(name)
  60. return None
  61. def ref(self, name: str) -> str:
  62. rv = self.find_ref(name)
  63. if rv is None:
  64. raise AssertionError(
  65. "Tried to resolve a name to a reference that was"
  66. f" unknown to the frame ({name!r})"
  67. )
  68. return rv
  69. def copy(self) -> "te.Self":
  70. rv = object.__new__(self.__class__)
  71. rv.__dict__.update(self.__dict__)
  72. rv.refs = self.refs.copy()
  73. rv.loads = self.loads.copy()
  74. rv.stores = self.stores.copy()
  75. return rv
  76. def store(self, name: str) -> None:
  77. self.stores.add(name)
  78. # If we have not see the name referenced yet, we need to figure
  79. # out what to set it to.
  80. if name not in self.refs:
  81. # If there is a parent scope we check if the name has a
  82. # reference there. If it does it means we might have to alias
  83. # to a variable there.
  84. if self.parent is not None:
  85. outer_ref = self.parent.find_ref(name)
  86. if outer_ref is not None:
  87. self._define_ref(name, load=(VAR_LOAD_ALIAS, outer_ref))
  88. return
  89. # Otherwise we can just set it to undefined.
  90. self._define_ref(name, load=(VAR_LOAD_UNDEFINED, None))
  91. def declare_parameter(self, name: str) -> str:
  92. self.stores.add(name)
  93. return self._define_ref(name, load=(VAR_LOAD_PARAMETER, None))
  94. def load(self, name: str) -> None:
  95. if self.find_ref(name) is None:
  96. self._define_ref(name, load=(VAR_LOAD_RESOLVE, name))
  97. def branch_update(self, branch_symbols: t.Sequence["Symbols"]) -> None:
  98. stores: t.Set[str] = set()
  99. for branch in branch_symbols:
  100. stores.update(branch.stores)
  101. stores.difference_update(self.stores)
  102. for sym in branch_symbols:
  103. self.refs.update(sym.refs)
  104. self.loads.update(sym.loads)
  105. self.stores.update(sym.stores)
  106. for name in stores:
  107. target = self.find_ref(name)
  108. assert target is not None, "should not happen"
  109. if self.parent is not None:
  110. outer_target = self.parent.find_ref(name)
  111. if outer_target is not None:
  112. self.loads[target] = (VAR_LOAD_ALIAS, outer_target)
  113. continue
  114. self.loads[target] = (VAR_LOAD_RESOLVE, name)
  115. def dump_stores(self) -> t.Dict[str, str]:
  116. rv: t.Dict[str, str] = {}
  117. node: t.Optional[Symbols] = self
  118. while node is not None:
  119. for name in sorted(node.stores):
  120. if name not in rv:
  121. rv[name] = self.find_ref(name) # type: ignore
  122. node = node.parent
  123. return rv
  124. def dump_param_targets(self) -> t.Set[str]:
  125. rv = set()
  126. node: t.Optional[Symbols] = self
  127. while node is not None:
  128. for target, (instr, _) in self.loads.items():
  129. if instr == VAR_LOAD_PARAMETER:
  130. rv.add(target)
  131. node = node.parent
  132. return rv
  133. class RootVisitor(NodeVisitor):
  134. def __init__(self, symbols: "Symbols") -> None:
  135. self.sym_visitor = FrameSymbolVisitor(symbols)
  136. def _simple_visit(self, node: nodes.Node, **kwargs: t.Any) -> None:
  137. for child in node.iter_child_nodes():
  138. self.sym_visitor.visit(child)
  139. visit_Template = _simple_visit
  140. visit_Block = _simple_visit
  141. visit_Macro = _simple_visit
  142. visit_FilterBlock = _simple_visit
  143. visit_Scope = _simple_visit
  144. visit_If = _simple_visit
  145. visit_ScopedEvalContextModifier = _simple_visit
  146. def visit_AssignBlock(self, node: nodes.AssignBlock, **kwargs: t.Any) -> None:
  147. for child in node.body:
  148. self.sym_visitor.visit(child)
  149. def visit_CallBlock(self, node: nodes.CallBlock, **kwargs: t.Any) -> None:
  150. for child in node.iter_child_nodes(exclude=("call",)):
  151. self.sym_visitor.visit(child)
  152. def visit_OverlayScope(self, node: nodes.OverlayScope, **kwargs: t.Any) -> None:
  153. for child in node.body:
  154. self.sym_visitor.visit(child)
  155. def visit_For(
  156. self, node: nodes.For, for_branch: str = "body", **kwargs: t.Any
  157. ) -> None:
  158. if for_branch == "body":
  159. self.sym_visitor.visit(node.target, store_as_param=True)
  160. branch = node.body
  161. elif for_branch == "else":
  162. branch = node.else_
  163. elif for_branch == "test":
  164. self.sym_visitor.visit(node.target, store_as_param=True)
  165. if node.test is not None:
  166. self.sym_visitor.visit(node.test)
  167. return
  168. else:
  169. raise RuntimeError("Unknown for branch")
  170. if branch:
  171. for item in branch:
  172. self.sym_visitor.visit(item)
  173. def visit_With(self, node: nodes.With, **kwargs: t.Any) -> None:
  174. for target in node.targets:
  175. self.sym_visitor.visit(target)
  176. for child in node.body:
  177. self.sym_visitor.visit(child)
  178. def generic_visit(self, node: nodes.Node, *args: t.Any, **kwargs: t.Any) -> None:
  179. raise NotImplementedError(f"Cannot find symbols for {type(node).__name__!r}")
  180. class FrameSymbolVisitor(NodeVisitor):
  181. """A visitor for `Frame.inspect`."""
  182. def __init__(self, symbols: "Symbols") -> None:
  183. self.symbols = symbols
  184. def visit_Name(
  185. self, node: nodes.Name, store_as_param: bool = False, **kwargs: t.Any
  186. ) -> None:
  187. """All assignments to names go through this function."""
  188. if store_as_param or node.ctx == "param":
  189. self.symbols.declare_parameter(node.name)
  190. elif node.ctx == "store":
  191. self.symbols.store(node.name)
  192. elif node.ctx == "load":
  193. self.symbols.load(node.name)
  194. def visit_NSRef(self, node: nodes.NSRef, **kwargs: t.Any) -> None:
  195. self.symbols.load(node.name)
  196. def visit_If(self, node: nodes.If, **kwargs: t.Any) -> None:
  197. self.visit(node.test, **kwargs)
  198. original_symbols = self.symbols
  199. def inner_visit(nodes: t.Iterable[nodes.Node]) -> "Symbols":
  200. self.symbols = rv = original_symbols.copy()
  201. for subnode in nodes:
  202. self.visit(subnode, **kwargs)
  203. self.symbols = original_symbols
  204. return rv
  205. body_symbols = inner_visit(node.body)
  206. elif_symbols = inner_visit(node.elif_)
  207. else_symbols = inner_visit(node.else_ or ())
  208. self.symbols.branch_update([body_symbols, elif_symbols, else_symbols])
  209. def visit_Macro(self, node: nodes.Macro, **kwargs: t.Any) -> None:
  210. self.symbols.store(node.name)
  211. def visit_Import(self, node: nodes.Import, **kwargs: t.Any) -> None:
  212. self.generic_visit(node, **kwargs)
  213. self.symbols.store(node.target)
  214. def visit_FromImport(self, node: nodes.FromImport, **kwargs: t.Any) -> None:
  215. self.generic_visit(node, **kwargs)
  216. for name in node.names:
  217. if isinstance(name, tuple):
  218. self.symbols.store(name[1])
  219. else:
  220. self.symbols.store(name)
  221. def visit_Assign(self, node: nodes.Assign, **kwargs: t.Any) -> None:
  222. """Visit assignments in the correct order."""
  223. self.visit(node.node, **kwargs)
  224. self.visit(node.target, **kwargs)
  225. def visit_For(self, node: nodes.For, **kwargs: t.Any) -> None:
  226. """Visiting stops at for blocks. However the block sequence
  227. is visited as part of the outer scope.
  228. """
  229. self.visit(node.iter, **kwargs)
  230. def visit_CallBlock(self, node: nodes.CallBlock, **kwargs: t.Any) -> None:
  231. self.visit(node.call, **kwargs)
  232. def visit_FilterBlock(self, node: nodes.FilterBlock, **kwargs: t.Any) -> None:
  233. self.visit(node.filter, **kwargs)
  234. def visit_With(self, node: nodes.With, **kwargs: t.Any) -> None:
  235. for target in node.values:
  236. self.visit(target)
  237. def visit_AssignBlock(self, node: nodes.AssignBlock, **kwargs: t.Any) -> None:
  238. """Stop visiting at block assigns."""
  239. self.visit(node.target, **kwargs)
  240. def visit_Scope(self, node: nodes.Scope, **kwargs: t.Any) -> None:
  241. """Stop visiting at scopes."""
  242. def visit_Block(self, node: nodes.Block, **kwargs: t.Any) -> None:
  243. """Stop visiting at blocks."""
  244. def visit_OverlayScope(self, node: nodes.OverlayScope, **kwargs: t.Any) -> None:
  245. """Do not visit into overlay scopes."""