# Copyright (c) 2023 MatrixEditor
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
__doc__ = """Generic (lazy) iterator classes."""
import abc
import typing as t
import construct as cs
from construct_dataclasses import DataclassStruct
from umbrella.runtime import Runtime
# All types in this file will be using this type var for an element type
E = t.TypeVar("E")
[docs]
class CachingIterator(abc.ABC, t.Generic[E]):
"""Base class for all iterators that cache their parsed values.
When iterating over elements, the following method order is applied:
- ``__iter__`` to start the iteration
- ``__next__`` to prepare the next element
- ``_load`` performs the actual parsing of the next element
Note that this iterator also supports list-like access. You can reference
each element with its index position. The method ordering is as follows
- ``__getitem__`` gets called on item access
- (*) ``__next__`` if the iterator has to parse additional values
- (*) ``_load`` if additional elements has to be parsed
Examples:
>>> iterator = ...
>>> # Get all elements of an iterator (parsed ones)
>>> elements = iterator.elements
>>> # Get all elements (including elements that have to be parsed)
>>> elements = iterator.all()
>>> # Get an element by its index
>>> element = iterator[4]
>>> # Use slicing to get mulitple elements
>>> elements = iterator[2:5]
"""
#: Additional field to mark the reset value for this iterator
RESET = -1
def __init__(self) -> None:
super().__init__()
# The current position. Note the -1 here
self.__pos = self.RESET
# cached elements
self.__elements = []
@property
def pos(self) -> int:
"""The current position as an integer value
:return: the current position (may be -1)
:rtype: int
"""
return self.__pos
@pos.setter
def pos(self, value: int) -> None:
"""Sets the current position
:param value: the new position
:type value: int
"""
self.__pos = value
@property
def elements(self) -> t.List[E]:
"""Returns all cached elements of this iterator.
Note that this property will return all elements that have been
parsed so far. Use ``all()`` for a list of all elements including
the ones to be parsed.
:return: all stored elements
:rtype: t.List[E]
"""
return self.__elements
[docs]
def all(self) -> t.List[E]:
"""Returns a list of all elements including the ones to be parsed.
this call is equivalent to ``list(...)``.
:return: a list of all elements.
:rtype: t.List[E]
"""
try:
return list(self)
except StopIteration:
# As this iterator may be at the end already,
# we should return all cached element directly.
return self.__elements
[docs]
@abc.abstractmethod
def _load(self, pos: int) -> E:
"""Parses an element at the current index position.
:param pos: the index position
:type pos: int
:return: the parsed element
:rtype: E
"""
pass
[docs]
def __len__(self) -> int:
"""Returns the length of this iterator (if present)
.. hint::
You can raise an ``IndexError`` if your iterator does not have a
fixed size.
:raises NotImplementedError: by default, raises an error
:return: the size of this iterator
:rtype: int
"""
raise NotImplementedError(f"len() not applicable for {type(self)}")
def __iter__(self) -> t.Generator[E, t.Any, None]:
while True:
try:
# The returned element won't be null
yield next(self)
except StopIteration:
break
def __next__(self) -> E:
# Stops this iterator if it is are positioned at the end
try:
if self.__pos >= len(self) - 1:
raise StopIteration
except IndexError:
# This error assumes that an iterator is defined to be
# generic and will raise StopIterator accordingly.
pass
_pos = self.pos + 1 # because we start from 0
element = self._load(_pos)
self.elements.append(element)
self.pos = self.pos + 1
return element
def __getitem__(self, key: t.Union[int, slice]) -> t.Union[t.Tuple[E], E]:
if isinstance(key, slice):
# Assign start, end and step sized based on the given slice
# NOTE: negative step size is supported as well
start, end, step = (
key.start or 0,
key.stop or len(self) - 1,
key.step or 1,
)
else:
# Otherwise all sizes are the same
start = end = key
step = 1
if start < 0 or end > len(self):
raise IndexError(f"index {key} out of bounds!")
if start == end:
if self.pos >= end:
# NOTE: a tuple is returned for simplicity
return self.__elements[end]
else:
if start < self.pos and self.pos > end:
return self.__elements[start:end:step]
try:
while self.pos < end:
# We have to iterate/parse all remaining elements until
# the position matched the given end
next(self)
except StopIteration:
pass
return (
self.__elements[end]
if start == end
else tuple(self.__elements[start:end:step])
)
[docs]
class LazyIterator(CachingIterator[E]):
"""
Partial implementation of a :class:`CachingIterator` to integrate a
runtime object.
This class uses an internal context to store any additional variables
used by the iterator. See :class:`ReflectionSectionIterator` for more
details on possible usage.
In addition, this class introduces a *length* field, which is used to
determine the length of this iterator. Subclasses must specify the
legth attribute with a string assigned to reference a context variable.
>>> class Foo(LazyIterator):
... length = "foo_length"
Here, the referenced length variable ``"foo_length"`` must be set in
the internal context within the ``_preload_context`` method.
"""
# A field used by sub-classes to reference a context variable
length: str
def __init__(self, runtime: Runtime, **kwds) -> None:
super().__init__()
self._runtime = runtime
self._context = cs.Container()
# All values must be loaded before we ca actually operate
# on the iterator instance
self._preload_context(**kwds)
[docs]
def _preload_context(self, **kwds) -> None:
"""Prepares the internal context."""
@property
def runtime(self) -> Runtime:
"""Returns the associated runtime
:return: the runtime object
:rtype: Runtime
"""
return self._runtime
@property
def context(self) -> cs.Container:
"""Returns the internal context
:return: the internal context with all relevant values
:rtype: Container
"""
return self._context
def __len__(self) -> int:
# Small verification no the length field
assert hasattr(self, "length"), "Missing 'length' field!"
# This access will fail if we specify an invalid length reference
value = self.context[self.length]
if isinstance(value, int):
return value
# It is also possible to reference lists or dictionaries
return len(value)
[docs]
class ReflectionSectionIterator(LazyIterator[E]):
"""Base iterator class used to iterate over structs in a reflection section.
Using python's type inspection only the section's name has to be provided.
Everything else is configured automatically. For example,
>>> class Foo: ... # assume this class is a dataclass
>>> class FooIterator(ReflectionSectionIterator[Foo]):
... kind = "__foo_section"
...
>>> it = FooIterator(runtime)
>>> foo = it[0] # get next element
The method ordering is a little bit different compared to the initial iterator
class. When iterating, the following methods will be executed:
1. ``_load`` to parse the current element
2. ``_address_of`` retrieves the current load address
3. ``load_at`` looks for cached elements first before parsing
4. (*) ``_load_at`` parses the next element if not already cached
"""
# Used to determine the annotated generic class
__root__ = "ReflectionSectionIterator"
length = "addresses" # length field
kind: str # section name
struct: t.Type[E] # the struct's type
def __init__(
self, runtime: Runtime, pointer_ty: cs.Construct = None, **kwds
) -> None:
self._pointer_type = pointer_ty or cs.Int64ul
super().__init__(runtime)
if getattr(self, "struct", None):
# Skip if struct has already been set
return
# The struct's tyoe can be retrieved by inspecting the type
# arguments
bases = list(self.__orig_bases__)
# If there are more than one base class, ensure we use the
# right one
base = None
for base_type in bases:
origin = t.get_origin(base_type)
if origin and origin.__name__ == self.__root__:
base = base_type
break
assert (
base is not None
), "Could not locate base class of 'ReflectionSectionIterator'"
# Create the struct instance directly
(struct,) = t.get_args(base)
self.struct = DataclassStruct(struct)
def _preload_context(self, **kwds) -> None:
# Loads a pointer section named by 'kind'
assert self.kind is not None, "Invalid kind for a reflection section"
ptrs = self.runtime.read_ptr_section(self.kind, struct=self._pointer_type)
self.context.addresses = list(ptrs.keys())
self.context.ptrs = list(ptrs.values())
# Additional attribute to map addresses to parsed elements
self.context.ptr2type = {}
def _load_at(self, address: int, parent_address=0) -> E:
# Just uses the runtime to parse the struct
return self.runtime.read_struct(self.struct, address, fix=True)
def _address_of(self, pos: int) -> int:
# NOTE: this function uses assumes absolute pointers by default
return self.context.ptrs[pos]
def _load(self, pos: int) -> E:
# Implementation of the load function (simplified)
address = self._address_of(pos)
return self._load_at(address)
[docs]
def load_at(self, vaddress: int, parent_address=0) -> t.Optional[E]:
"""Loads a new struct at the given virtual address.
:param vaddress: the virtual memory address
:type vaddress: int
:return: the parsed struct or nothing on an invalid address
:rtype: t.Optional[E]
"""
ptr2type = self.context.ptr2type
obj = ptr2type.get(vaddress)
if obj is None:
obj = self._load_at(vaddress, parent_address=parent_address)
ptr2type[vaddress] = obj
return obj