# Copyright (C) MatrixEditor 2023-2026
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
# pyright: reportPrivateUsage=false
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
from typing_extensions import override
from caterpillar.exception import UnsupportedOperation
from caterpillar.exception import InvalidValueError
from caterpillar.context import CTX_STREAM
from caterpillar.abc import (
_ContextLambda,
_ContextLike,
_GreedyType,
_StructLike,
_IT,
_ArgType,
)
from .common import Memory, Bytes
from ._mixin import get_args, get_kwargs
if TYPE_CHECKING:
from cryptography.hazmat.primitives.ciphers import (
modes,
CipherAlgorithm,
CipherContext,
)
from cryptography.hazmat.primitives.padding import PaddingContext
@runtime_checkable
class Padding(Protocol): # pylint: disable=missing-class-docstring
def unpadder(self) -> "PaddingContext":
"""Abstract method to get an unpadder for padding."""
...
def padder(self) -> "PaddingContext":
"""Abstract method to get a padder for padding."""
...
[docs]
class Encrypted(Memory):
"""Struct that is able to encrypt/decrypt blocks of memory.
:param length: Length of the encrypted data.
:type length: Union[int, _GreedyType, _ContextLambda]
:param algorithm: Encryption algorithm.
:type algorithm: Type[algorithms.CipherAlgorithm]
:param mode: Encryption mode.
:type mode: Union[Type[modes.Mode], modes.Mode]
:param padding: Padding scheme for encryption.
:type padding: Union[Padding, Type[Padding]], optional
:param algo_args: Additional arguments for the encryption algorithm.
:type algo_args: Optional[Iterable[_ArgType]], optional
:param mode_args: Additional arguments for the encryption mode.
:type mode_args: Optional[Iterable[_ArgType]], optional
:param padding_args: Additional arguments for the padding scheme.
:type padding_args: Optional[Iterable[_ArgType]], optional
:param post: Post-processing structure.
"""
# REVISIT: this constructor looks ugly
def __init__(
self,
length: int | _GreedyType | _ContextLambda[int],
algorithm: type["CipherAlgorithm"],
mode: "type[modes.Mode] | modes.Mode",
padding: Padding | type[Padding] | None = None,
algo_args: Iterable[_ArgType] | None = None,
mode_args: Iterable[_ArgType] | None = None,
padding_args: Iterable[_ArgType] | None = None,
post: _StructLike | None = None,
) -> None:
try:
from cryptography.hazmat.primitives.ciphers import Cipher
except ImportError:
raise UnsupportedOperation(
(
"To use encryption with this framework, the module 'cryptography' "
"is required! You can install it via pip or use the packaging "
"extra 'crypto' that is available with this library."
)
)
super().__init__(length)
self._algo: "type[CipherAlgorithm]" = algorithm
self._algo_args: Iterable[_ArgType] | None = algo_args
self._mode: "type[modes.Mode] | modes.Mode" = mode
self._mode_args: Iterable[_ArgType] | None = mode_args
self._padding: Padding | type[Padding] | None = padding
self._padding_args: Iterable[_ArgType] | None = padding_args
self.post: _StructLike | None = post
[docs]
def algorithm(self, context: _ContextLike) -> "CipherAlgorithm":
"""
Get the encryption algorithm instance.
:param context: The current operation context.
:type context: _ContextLike
:return: An instance of the encryption algorithm.
:rtype: algorithms.CipherAlgorithm
"""
# fmt: off
from cryptography.hazmat.primitives.ciphers import CipherAlgorithm
return self.get_instance(CipherAlgorithm, self._algo, self._algo_args, context) # pyright: ignore[reportReturnType]
[docs]
def mode(self, context: _ContextLike) -> "modes.Mode | None":
"""
Get the encryption mode instance.
:param context: The current operation context.
:type context: _ContextLike
:return: An instance of the encryption mode.
:rtype: modes.Mode
"""
from cryptography.hazmat.primitives.ciphers import modes
return self.get_instance(modes.Mode, self._mode, self._mode_args, context)
[docs]
def padding(self, context: _ContextLike) -> Padding | None:
"""
Get the padding scheme instance.
:param context: The current operation context.
:type context: _ContextLike
:return: An instance of the padding scheme.
:rtype: Padding
"""
return self.get_instance(Padding, self._padding, self._padding_args, context)
[docs]
def get_instance(
self,
type_: type[_IT],
field: _IT | Any | None,
args: dict[str, _ArgType] | Iterable[_ArgType] | None,
context: _ContextLike,
) -> _IT | None:
"""
Get an instance of a specified type.
:param type_: The desired type of the instance.
:type type_: type
:param field: The field or instance.
:type field: Any
:param args: Additional arguments for the instance.
:type args: Any
:param context: The current operation context.
:type context: _ContextLambda
:return: An instance of the specified type.
:rtype: Any
"""
if isinstance(field, type_) or not field:
return field
if isinstance(args, dict):
args, kwargs = (), get_kwargs(args, context)
else:
args, kwargs = get_args(args, context), {}
return field(*args, **kwargs) # pyright: ignore[reportCallIssue]
[docs]
@override
def pack_single(
self, obj: bytes | memoryview | bytearray, context: _ContextLike
) -> None:
"""
Pack a single element.
:param obj: The element to pack.
:type obj: Any
:param context: The current operation context.
:type context: _ContextLike
"""
# fmt: off
from cryptography.hazmat.primitives.ciphers import Cipher
cipher = Cipher(self.algorithm(context), self.mode(context))
padding = self.padding(context)
data = obj if isinstance(obj, bytes) else bytes(obj)
if padding:
padder = padding.padder()
data = padder.update(data) + padder.finalize()
encryptor: "CipherContext" = cipher.encryptor() # pyright: ignore[reportAttributeAccessIssue]
super().pack_single(encryptor.update(data) + encryptor.finalize(), context)
[docs]
@override
def unpack_single(self, context: _ContextLike) -> memoryview:
"""
Unpack a single element.
:param context: The current operation context.
:type context: _ContextLike
:return: The unpacked element as a memoryview.
:rtype: memoryview
"""
# fmt: off
from cryptography.hazmat.primitives.ciphers import Cipher
value = super().unpack_single(context)
cipher = Cipher(self.algorithm(context), self.mode(context))
decryptor: "CipherContext" = cipher.decryptor() # pyright: ignore[reportAttributeAccessIssue]
data: bytes = decryptor.update(bytes(value)) + decryptor.finalize()
padding = self.padding(context)
if padding:
unpadder = padding.unpadder()
data = unpadder.update(data) + unpadder.finalize()
return memoryview(data)
_KeyType = int | str | bytes
class KeyCipher(Bytes):
# key: bytes
# """The key that should be applied.
# It will be converted automatically to bytes if not given.
# """
# key_length: int
# """Internal attribute to keep track of the key's length"""
__slots__: tuple[str, ...] = "key", "key_length", "key_fn"
def __init__(
self,
key: _KeyType | _ContextLambda[_KeyType],
length: _GreedyType | _ContextLambda[int] | int | None = None,
) -> None:
super().__init__(length or ...)
self.key: bytes = b""
self.key_fn: _ContextLambda[_KeyType] | None = None
self.key_length: int = -1
self.set_key(key)
def is_lazy(self) -> bool:
return self.key_fn is not None
def set_key(
self,
key: _KeyType | _ContextLambda[_KeyType],
context: _ContextLike | None = None,
) -> None:
if callable(key) and context is None:
# context lambda indicates the key will be computed at runtime
self.key_fn = key
self.key_length = -1
return
self.key_fn = None
match key:
case str():
self.key = key.encode()
case int():
self.key = bytes([key])
case bytes():
self.key = key
case _:
raise InvalidValueError(
f"Expected a valid key type, got {key!r}", context
)
self.key_length = len(self.key)
def process(self, obj: bytes, context: _ContextLike) -> bytes:
length = len(obj)
data = bytearray(length)
if self.key_fn:
self.set_key(self.key_fn(context), context)
self._do_process(obj, data)
return bytes(data)
def _do_process(self, src: bytes, dest: bytearray) -> None:
raise NotImplementedError
@override
def pack_single(self, obj: bytes, context: _ContextLike) -> None:
context[CTX_STREAM].write(self.process(obj, context))
@override
def unpack_single(self, context: _ContextLike) -> bytes:
obj: bytes = super().unpack_single(context)
return self.process(obj, context)
class Xor(KeyCipher):
__slots__: tuple[()] = ()
@override
def _do_process(self, src: bytes, dest: bytearray) -> None:
for i, e in enumerate(src):
dest[i] = e ^ self.key[i % self.key_length]
class Or(KeyCipher):
__slots__: tuple[()] = ()
@override
def _do_process(self, src: bytes, dest: bytearray) -> None:
for i, e in enumerate(src):
dest[i] = e | self.key[i % self.key_length]
class And(KeyCipher):
__slots__: tuple[()] = ()
@override
def _do_process(self, src: bytes, dest: bytearray) -> None:
for i, e in enumerate(src):
dest[i] = e & self.key[i % self.key_length]