Source code for caterpillar.context

# Copyright (C) MatrixEditor 2023-2024
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations

import operator
import sys

from typing import Callable, Any, Union, Self
from types import FrameType
from dataclasses import dataclass


from caterpillar.abc import _ContextLambda, _ContextLike
from caterpillar.exception import StructException
from caterpillar.registry import to_struct

CTX_PARENT = "_parent"
CTX_OBJECT = "_obj"
CTX_OFFSETS = "_offsets"
CTX_STREAM = "_io"
CTX_FIELD = "_field"
CTX_VALUE = "_value"
CTX_POS = "_pos"
CTX_INDEX = "_index"
CTX_PATH = "_path"
CTX_SEQ = "_is_seq"
CTX_ARCH = "_arch"


[docs] class Context(dict): """Represents a context object with attribute-style access.""" def __setattr__(self, key: str, value) -> None: """ Sets an attribute in the context. :param key: The attribute key. :param value: The value to be set. """ self[key] = value def __getattribute__(self, key: str): """ Retrieves an attribute from the context. :param key: The attribute key. :return: The value associated with the key. """ try: return object.__getattribute__(self, key) except AttributeError: return self.__context_getattr__(key) def __context_getattr__(self, path: str): """ Retrieves an attribute from the context. :param key: The attribute key. :return: The value associated with the key. """ nodes = path.split(".") obj = ( self[nodes[0]] if nodes[0] in self else object.__getattribute__(self, nodes[0]) ) for i in range(1, len(nodes)): obj = getattr(obj, nodes[i]) return obj def __context_setattr__(self, path: str, value: Any) -> None: nodes = path.rsplit(".", 1) if len(nodes) == 1: self[path] = value else: obj = self.__context_getattr__(nodes[0]) setattr(obj, nodes[1], value) @property def _root(self) -> _ContextLike: current = self while CTX_PARENT in current: # dict-like access is much faster parent = current[CTX_PARENT] if parent is None: break current = parent return current
[docs] class ExprMixin: """ A mixin class providing methods for creating binary and unary expressions. """ def __add__(self, other) -> ExprMixin: return BinaryExpression(operator.add, self, other) def __sub__(self, other) -> ExprMixin: return BinaryExpression(operator.sub, self, other) def __mul__(self, other) -> ExprMixin: return BinaryExpression(operator.mul, self, other) def __floordiv__(self, other) -> ExprMixin: return BinaryExpression(operator.floordiv, self, other) def __truediv__(self, other) -> ExprMixin: return BinaryExpression(operator.truediv, self, other) def __mod__(self, other) -> ExprMixin: return BinaryExpression(operator.mod, self, other) def __pow__(self, other) -> ExprMixin: return BinaryExpression(operator.pow, self, other) def __xor__(self, other) -> ExprMixin: return BinaryExpression(operator.xor, self, other) def __and__(self, other) -> ExprMixin: return BinaryExpression(operator.and_, self, other) def __or__(self, other) -> ExprMixin: return BinaryExpression(operator.or_, self, other) def __rshift__(self, other) -> ExprMixin: return BinaryExpression(operator.rshift, self, other) def __lshift__(self, other) -> ExprMixin: return BinaryExpression(operator.lshift, self, other) __div__ = __truediv__ def __radd__(self, other) -> ExprMixin: return BinaryExpression(operator.add, other, self) def __rsub__(self, other) -> ExprMixin: return BinaryExpression(operator.sub, other, self) def __rmul__(self, other) -> ExprMixin: return BinaryExpression(operator.mul, other, self) def __rfloordiv__(self, other) -> ExprMixin: return BinaryExpression(operator.floordiv, other, self) def __rtruediv__(self, other) -> ExprMixin: return BinaryExpression(operator.truediv, other, self) def __rmod__(self, other) -> ExprMixin: return BinaryExpression(operator.mod, other, self) def __rpow__(self, other) -> ExprMixin: return BinaryExpression(operator.pow, other, self) def __rxor__(self, other) -> ExprMixin: return BinaryExpression(operator.xor, other, self) def __rand__(self, other) -> ExprMixin: return BinaryExpression(operator.and_, other, self) def __ror__(self, other) -> ExprMixin: return BinaryExpression(operator.or_, other, self) def __rrshift__(self, other) -> ExprMixin: return BinaryExpression(operator.rshift, other, self) def __rlshift__(self, other) -> ExprMixin: return BinaryExpression(operator.lshift, other, self) def __neg__(self) -> ExprMixin: return UnaryExpression("neg", operator.neg, self) def __pos__(self) -> ExprMixin: return UnaryExpression("pos", operator.pos, self) def __invert__(self) -> ExprMixin: return UnaryExpression("invert", operator.not_, self) def __contains__(self, other) -> ExprMixin: return BinaryExpression(operator.contains, self, other) def __gt__(self, other) -> ExprMixin: return BinaryExpression(operator.gt, self, other) def __ge__(self, other) -> ExprMixin: return BinaryExpression(operator.ge, self, other) def __lt__(self, other) -> ExprMixin: return BinaryExpression(operator.lt, self, other) def __le__(self, other) -> ExprMixin: return BinaryExpression(operator.le, self, other) def __eq__(self, other) -> ExprMixin: return BinaryExpression(operator.eq, self, other) def __ne__(self, other) -> ExprMixin: return BinaryExpression(operator.ne, self, other)
[docs] class ConditionContext: """Class implementation of an inline condition. Use this class to automatically apply a condition to multiple field definitions. Note that this class will only work if it has access to the parent stack frame. .. code-block:: python @struct class Format: magic: b"MGK" length: uint32 with this.length > 32: # other field definitions here foo: uint8 This class will **replace** any existing fields! :param condition: a context lambda or constant boolean value :type condition: Union[_ContextLambda, bool] """ __slots__ = "func", "annotations", "namelist", "depth" def __init__(self, condition: Union[_ContextLambda, bool], depth=2): self.func = condition self.annotations = None self.namelist = None self.depth = depth def getframe(self, num: int, msg=None) -> FrameType: try: return sys._getframe(num) except AttributeError as exc: raise StructException(msg) from exc def __enter__(self) -> Self: frame = self.getframe(self.depth, "Could not enter condition context!") # keep track of all annotations try: self.annotations = frame.f_locals["__annotations__"] except AttributeError as exc: module = frame.f_locals.get("__module__") qualname = frame.f_locals.get("__qualname__") msg = f"Could not get annotations in {module} (context={qualname!r})" raise StructException(msg) from exc # store names before new fields are added self.namelist = list(self.annotations) return self def __exit__(self, *_) -> None: # pylint: disable-next=import-outside-toplevel from caterpillar.fields import Field new_names = set(self.annotations) - set(self.namelist) for name in new_names: # modify newly created fields field = self.annotations[name] if isinstance(field, Field): # field already defined/created -> check for condition if field.has_condition(): # the field's condition AND this one must be true field.condition = BinaryExpression( operator.and_, field.condition, self.func ) else: field //= self.func else: # create a field (other attributes will be modified later) # ISSUE #15: The annotation must be converted to a _StructLike # object. In case we have struct classes, the special __struct__ # attribute must be used. struct_obj = to_struct(field) if not isinstance(struct_obj, Field): struct_obj = Field(struct_obj) struct_obj.condition = self.func self.annotations[name] = struct_obj self.annotations = None self.namelist = None
[docs] @dataclass(repr=False) class BinaryExpression(ExprMixin): """ Represents a binary expression. :param operand: The binary operator function. :param left: The left operand. :param right: The right operand. """ operand: Callable[[Any, Any], Any] left: Union[Any, _ContextLambda] right: Union[Any, _ContextLambda] def __call__(self, context: Context, **kwds): lhs = self.left(context, **kwds) if callable(self.left) else self.left rhs = self.right(context, **kwds) if callable(self.right) else self.right return self.operand(lhs, rhs) def __repr__(self) -> str: return f"{self.operand.__name__}{{{self.left!r}, {self.right!r}}}" def __enter__(self): # pylint: disable-next=attribute-defined-outside-init self._cond = ConditionContext(self, depth=3) self._cond.__enter__() return self def __exit__(self, *_): self._cond.__exit__(*_)
[docs] @dataclass class UnaryExpression: """ Represents a unary expression. :param name: The name of the unary operator. :param operand: The unary operator function. :param value: The operand. """ name: str operand: Callable[[Any], Any] value: Union[Any, _ContextLambda] def __call__(self, context: Context, **kwds): value = self.value(context, **kwds) if callable(self.value) else self.value return self.operand(value) def __repr__(self) -> str: return f"{self.operand.__name__}{{{self.value!r}}}" def __enter__(self): # pylint: disable-next=attribute-defined-outside-init self._cond = ConditionContext(self, depth=3) self._cond.__enter__() return self def __exit__(self, *_): self._cond.__exit__(*_)
[docs] class ContextPath(ExprMixin): """ Represents a lambda function for retrieving a value from a Context based on a specified path. """ def __init__(self, path: str = None) -> None: """ Initializes a ContextPath instance with an optional path. :param path: The path to use when retrieving a value from a Context. """ self.path = path self._ops_ = [] self.call_kwargs = None self.getitem_args = None def __call__(self, context: _ContextLike = None, **kwds): """ Calls the lambda function to retrieve a value from a Context. :param context: The Context from which to retrieve the value. :param kwds: Additional keyword arguments. :return: The value retrieved from the Context based on the path. """ if context is None: self._ops_.append((operator.call, (), kwds)) return self value = context.__context_getattr__(self.path) for operation, args, kwargs in self._ops_: value = operation(value, *args, **kwargs) return value def __getitem__(self, key) -> Self: self._ops_.append((operator.getitem, (key,), {})) return self def __type__(self) -> type: return Any def __getattribute__(self, key: str) -> ContextPath: """ Gets an attribute from the ContextPath, creating a new instance if needed. :param key: The attribute key. :return: A new ContextPath instance with an updated path. """ try: return super().__getattribute__(key) except AttributeError: if not self.path: return ContextPath(key) return ContextPath(".".join([self.path, key])) def __repr__(self) -> str: """ Returns a string representation of the ContextPath. :return: A string representation. """ extra = [] for operation, args, kwargs in self._ops_: data = [] if len(args) > 0: data.append(*map(repr, args)) if len(kwargs) > 0: data.append(*[f"{x}={y!r}" for x, y in kwargs.items()]) extra.append(f"{operation.__name__}({', '.join(data)})") if len(extra) == 0: return f"Path({self.path!r})" return f"Path({self.path!r}, {', '.join(extra)})" def __str__(self) -> str: """ Returns a string representation of the path. :return: A string representation of the path. """ return self.path @property def parent(self) -> ContextPath: path = f"{CTX_PARENT}.{CTX_OBJECT}" if not self.path: return ContextPath(path) return ContextPath(".".join([self.path, path]))
[docs] class ContextLength(ExprMixin): def __init__(self, path: ContextPath) -> None: self.path = path def __call__(self, context: Context = None, **kwds): """ Calls the lambda function to retrieve a value from a Context. :param context: The Context from which to retrieve the value. :param kwds: Additional keyword arguments (ignored in this implementation). :return: The value retrieved from the Context based on the path. """ return len(self.path(context)) def __repr__(self) -> str: return f"len({self.path!r})"
this = ContextPath(CTX_OBJECT) ctx = ContextPath() parent = ContextPath(".".join([CTX_PARENT, CTX_OBJECT]))