Source code for umbrella.trailing_objects

# Copyright (c) 2023 MatrixEditor
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import typing as t
import inspect

import construct as cs
import dataclasses as dc
import construct_dataclasses as csd

from umbrella.runtime import sizeof

# Construct or dataclass type
T = t.TypeVar("T", bound=cs.Construct)

# simple solution to accept Token[...] values within a method call
Token = t.Type
TrailingTys = t.TypeVarTuple("TrailingTys")

[docs] class TrailingObjects(t.Generic[*TrailingTys]): """Python implementation of TrailingObjects found in the Swift ABI. This class uses typing to inspect annotated types and create trailer information based on them. You can think of "trailing objects" when using this class, for instance: >>> @dataclass ... class Foo: ... a: int = csfield(Int32ul) ... Now, we want to have additional trailing objects based on runtime conditions at the end of each ``Foo`` object. We now create another class that stores our new "trailing" information: >>> @dataclass ... class Bar: ... x: int = csfield(Int32ul) ... By annotating the previously defined class *Foo* with this class, one or multiple *Bar* objects will be placed as trailing objects at the end of each *Foo* instance virtually. >>> @dataclass ... class Foo(TrailingObjects[Bar]): ... a: int = csfield(Int32ul) ... # in order to get the amount of trailing Bar objects, a function ... # must be defined, otherwise 1 is used ... def _num_bar(self) -> int: ... return self.a Each trailing object can be retrieved by calling ``getTrailingObject(s)``. The return value will be a tuple or single value based on the provided count. >>> foo = Foo(a=2) # retrieve Foo instance >>> foo.getTrailingObjects(Bar) (Bar(x=1), Bar(x=2)) >>> foo.getTrailingObject(Bar) Bar(x=1) The amount of trailing objects will be computed per type. Each type may have a method mapped to it. The naming convention of each method is as follows: >>> nameof(Foo) '_num_foo' >>> nameof(FooBar) '_num_foo_bar' """ # The internal representation of this structure that stores all # trailing structs # __trailing__: t.Tuple[*TrailingTys] # A list of length values that must be added to the base address of # this object # __sizes__: t.List[int] def __new__(cls, **kwargs) -> t.Self: super_new = super().__new__ obj = super_new(cls) super(TrailingObjects, obj).__init__() assert dc.is_dataclass(cls), "Only dataclasses are supported!" # First, setup additional attributes that are fixed: fields = list(map(lambda f: f.name, dc.fields(obj))) for name, value in kwargs.items(): if name not in fields: raise AttributeError(f"Invalid attr '{name}' - not found") setattr(obj, name, value) # In order to retrieve our trailing object types, we have to # inspect the type arguments of our class bases = list(obj.__orig_bases__) # If there are more than one base class, ensure we use the # right one base = None for base_type in bases: # The target type is generic, so we have to use the __origin__ # attribute origin = t.get_origin(base_type) if origin and origin == TrailingObjects: base = base_type break # See TrailingGenericContextObjects for more details if getattr(origin, "__generic_root__", False): base = base_type break if base is None: raise ValueError("Could not determine base class!") trailing = [] # Get all trailing objects and validate we have # 1. no duplicate types # 2. only instances of the DataclassStruct class for argument in t.get_args(base): if not isinstance(argument, csd.DataclassStruct): # 1 # BUT: we can create structs if the arguments are # dataclasses if dc.is_dataclass(argument): argument = csd.DataclassStruct(argument) else: raise TypeError( f"Only constructs are approved as parameters - got {argument}" ) if argument in trailing: # 2 raise ValueError(f"Duplicate trailing type found: {argument}") trailing.append(argument) # Create a new internal reference setattr(obj, "__trailing__", trailing) setattr(obj, "__sizes__", []) # Intended for future use setattr(obj, "__struct__", csd.DataclassStruct(obj.__class__)) setattr(obj, "__cached__", [None] * len(trailing)) return obj def _index_of(self, token: Token[T]) -> int: base_ty = token # Try to search for the given token. the default return # value shoule be -1 to mitigate possible exceptions for i, ty in enumerate(self.__trailing__): if ty == base_ty: return i if isinstance(ty, csd.DataclassStruct): # Case: our dataclass has been turned into a DataclassStruct # using csd.DataclassStruct. We can use the model field to # identify the used type. if ty.model == base_ty: return i return -1 def _get_address(self, until: int) -> int: if until == -1: until = len(self.__trailing__) # REVISIT: maybe hardcode the sizeof value somehow struct = self.__struct__ base_address = self._address + sizeof(struct) for i in range(until): ty_struct: csd.DataclassStruct = self.__trailing__[i] count = self._count(ty_struct) # Using Construct.sizeof(...) to calculate size length = count * sizeof(ty_struct) base_address += length return base_address def _count(self, ty_struct) -> int: # Try to resolve possible dynamic-sized elements by calling # an internal function with the following signature: # - _num_<class-name>(self, fp) -> int if not isinstance(ty_struct, csd.DataclassStruct): name = self._map_name(ty_struct.__class__.__name__) else: name = self._map_name(ty_struct.model.__name__) fn_name = f"_num_{name}" if not hasattr(self, fn_name): # We automatically assume 1 as default count count = 1 else: # The attribute may also be a fixed integer fn = getattr(self, fn_name) if isinstance(fn, int): count = fn elif isinstance(fn, t.Callable): # REVISIT: maybe use try: ... except: ... here # Check the number of arguments first args, *_ = inspect.getargs(fn.__code__) assert len(args) <= 2, f"Invalid arg-count on funciton: '{fn}'" if len(args) == 1: # only 'self' argument count = fn() elif len(args) == 2: # self + fp argument count = fn(self._fp) else: raise TypeError(f"Invalid count type - {type(fn)}") return count def _map_name(self, name: str) -> str: result = "" for ch in name: result += ch if ch.islower() else f"_{ch}" return result.removeprefix("_").lower()
[docs] def getTrailingObject(self, token: Token[T]) -> t.Optional[T]: """Retrieves a single trailing object based on the provided token. :param token: the token annotated with the trailing type :type token: Token[T] :return: the parsed trailing object :rtype: t.Optional[T] """ value = self.getTrailingObjects(token) if value is None: return None # Extract single value and discard all other values rvalue, *_ = value return rvalue
[docs] def getTrailingObjects(self, token: Token[T]) -> t.Optional[t.Tuple[T]]: """Retrieves one or multiple trailing objects. If a trailing object is optional, the count should be set to zero, which in result leads to a null value returned by this method. :param token: the token annotated with the trailing type :type token: Token[T] :raises TypeError: if the provided type is not present in the defined trailing objects :return: the parsed objects as tuple (always as tuple) :rtype: t.Optional[t.Tuple[T]] """ # We shall not support any invalid type references. DataclassStructs # will be detected automatically. index = self._index_of(token) if index < 0: raise TypeError(f"Could not find trailing object of type '{token}'") struct = self.__trailing__[index] cached = self.__cached__[index] if cached is not None: if isinstance(cached, t.Iterable): return tuple(cached) return (cached,) address = self._get_address(index) count = self._count(struct) if count == 1: value = self._fp.read_struct(struct, address) self.__cached__[index] = value return (value,) elif count == 0: # Default return value for trailing objects that are # configured to be optional return None size = sizeof(struct) values = [] for _ in range(count): # REVISIT: more read_struct calls, but we have the # struct's memoy address values.append(self._fp.read_struct(struct, address)) address += size self.__cached__[index] = values return tuple(values)
[docs] def sizeof(self, **contextkw) -> int: """Method that calculates the size of this struct""" end_address = self._get_address(-1) return end_address - self._address