Source code for icspacket.proto.tpkt

# This file is part of icspacket.
# Copyright (C) 2025-present  MatrixEditor @ github
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
# pyright: reportInvalidTypeForm=false
import socket
import logging
import queue

from typing_extensions import override

from caterpillar.shortcuts import pack, struct, BigEndian, this, unpack
from caterpillar.fields import Bytes, uint8, uint16

from icspacket.core.logger import TRACE


logger = logging.getLogger(__name__)


# [RFC 1006] - ISO Transport Service on top of the TCP
[docs] @struct(order=BigEndian) class TPKT: """TPKT header structure as defined in [RFC 1006] section 6. This class models the ISO transport service packetization layer on top of TCP, which introduces a simple 4-byte header in front of each TPDU. """ # fmt: off vrsn : uint8 = 3 """ Version number of the TPKT protocol. This value is fixed to ``3``. If any other value is received, the packet should be considered invalid. """ reserved : uint8 = 0 """Reserved for future use""" length : uint16 = 0 """Total length of the TPKT in octets, including the 4-byte header.""" tpdu : Bytes(this.length - 4) = b"" """ The encapsulated TPDU bytes. The size is determined by ``length - 4``. """ # fmt: on
[docs] @staticmethod def from_octets(octets: bytes) -> "TPKT": """Deserialize a TPKT from a raw byte sequence. :param octets: Byte buffer containing at least one full TPKT. :type octets: bytes :raises ValueError: If the ``length`` field does not match the actual number of received octets. :return: Parsed TPKT instance. :rtype: TPKT """ obj = unpack(TPKT, octets) if obj.length != len(obj.tpdu) + 4: raise ValueError( f"Invalid length: expected {obj.length}, got {len(obj.tpdu) + 4}. " + "This error could indicate a buffer size being too small." ) return obj
[docs] def build(self) -> bytes: """Serialize the TPKT into its octet representation. Updates the ``length`` field based on the TPDU size and produces a byte sequence suitable for transmission over TCP. :return: Encoded TPKT packet. :rtype: bytes """ # contains the length of entire packet in octets, including # packet-header self.length = len(self.tpdu) + 4 return pack(self, TPKT)
#: Convenience constant for decoding 16-bit unsigned integers in #: big-endian byte order. Used internally for parsing TPKT headers. #: #: .. versionadded:: 0.2.4 _U16_BE = BigEndian + uint16
[docs] class tpktsock(socket.socket): """Socket wrapper that transparently applies TPKT encapsulation. This class extends :class:`socket.socket` to provide automatic encoding and decoding of **ISO 8073 (TPKT)** headers for connection-oriented transport protocols. Applications can use :class:`tpktsock` as a drop-in replacement for raw sockets when working with TPKT-based communication. **Enhancements since 0.2.4:** - An internal :class:`queue.Queue` (``in_queue``) is now used to buffer partially received or multiple consecutive TPKT PDUs. - Improved handling of cases where more than one PDU arrives in a single TCP segment. Excess packets are queued for later retrieval. - Extended validation ensures incomplete headers are safely discarded and logged. .. versionchanged:: 0.2.4 Added internal buffering and support for multiple PDUs per TCP segment. """ #: Internal queue for buffered TPKT PDUs awaiting delivery. #: #: .. versionadded:: 0.2.4 in_queue: queue.Queue[bytes] def __init__( self, family: int = -1, type: int = -1, proto: int = -1, fileno: int | None = None, ) -> None: super().__init__(family, type, proto, fileno) self.in_queue = queue.Queue() def __del__(self): if not self.in_queue.empty(): logger.warning("Leaking %d TPKTs", self.in_queue.qsize())
[docs] def unpack_tpkt(self, octets: bytes) -> bytes: """Unpack a TPKT-encapsulated buffer. :param octets: Raw bytes received from the socket. :type octets: bytes :return: Extracted TPDU payload. :rtype: bytes """ if not octets: return b"" logger.log(TRACE, "Received %d bytes from socket", len(octets)) tpkt = TPKT.from_octets(octets) if tpkt.length < len(octets): logger.warning( "Received more than one TPKT: %s < %s. Second packet will be discarded", tpkt.length, len(octets), ) logger.log(TRACE, "Header complete (message size = %d)", len(tpkt.tpdu)) return tpkt.tpdu
[docs] @override def recv(self, bufsize: int, flags: int = 0, /) -> bytes: """ Receive a TPKT-encapsulated payload. If multiple PDUs are present in a single TCP segment, the first is returned immediately and subsequent ones are stored in :attr:`in_queue` for later retrieval. :param bufsize: Maximum number of bytes to read from the socket. :type bufsize: int :param flags: Optional flags passed through to the underlying :func:`socket.socket.recv`. :type flags: int :return: The payload of a single decoded TPKT PDU. :rtype: bytes :raises ValueError: If the received header length is inconsistent with the actual payload size. .. versionchanged:: 0.2.4 Now returns buffered PDUs if available and supports handling multiple PDUs per TCP segment. """ if not self.in_queue.empty(): return self.in_queue.get() data = super().recv(bufsize, flags) if not data: return b"" logger.log(TRACE, "Received %d bytes from socket", len(data)) tpkt = TPKT.from_octets(data) self.in_queue.put(tpkt.tpdu) logger.log(TRACE, "Header complete (message size = %d)", len(tpkt.tpdu)) if tpkt.length < len(data): logger.log(TRACE, "Received more than one TPKT") remaining = data[tpkt.length :] if len(remaining) < 4: logger.warning( "Received more than one TPKT: %s < %s. Second packet will be discarded", tpkt.length, len(data), ) return self.in_queue.get() actual_size = unpack(_U16_BE, remaining[2:]) size = actual_size - len(remaining) if size <= 0: next_pkt = super().recv(size, flags) else: next_pkt = b"" tpkt = TPKT.from_octets(remaining + next_pkt) self.in_queue.put(tpkt.tpdu) logger.log( TRACE, "Header complete (2nd message size = %d)", len(tpkt.tpdu) ) return self.recv(bufsize, flags)
[docs] @override def recvfrom(self, bufsize: int, flags: int = 0, /) -> tuple[bytes, tuple]: """Receive a TPKT packet along with the sender's address. :param bufsize: Maximum number of bytes to receive. :type bufsize: int :param flags: Optional socket flags. :type flags: int, optional :return: A tuple of (TPDU bytes, sender address). :rtype: tuple[bytes, tuple] """ data, address = super().recvfrom(bufsize, flags) return (self.unpack_tpkt(data), address)
[docs] @override def send(self, data: bytes, flags: int = 0, /) -> int: """Send TPDU bytes, encapsulated into a TPKT. :param data: The raw TPDU bytes to send. :type data: bytes :param flags: Optional socket flags. :type flags: int, optional :return: Number of bytes sent. :rtype: int """ pkt = TPKT(tpdu=data) return super().send(pkt.build(), flags)
[docs] @override def sendall(self, data: bytes, flags: int = 0, /) -> None: """Send all TPDU bytes, encapsulated into a TPKT. Ensures the entire buffer is transmitted, as in the standard :func:`socket.socket.sendall`. :param data: The raw TPDU bytes to send. :type data: bytes :param flags: Optional socket flags. :type flags: int, optional """ pkt = TPKT(tpdu=data) logger.log(TRACE, "Sending %d+4 bytes to socket", len(data)) return super().sendall(pkt.build(), flags)
[docs] @override def sendto(self, data: bytes, address: tuple, /) -> int: """Send TPDU bytes, encapsulated into a TPKT, to a specific address. :param data: The raw TPDU bytes to send. :type data: bytes :param address: Destination address (host, port). :type address: tuple :return: Number of bytes sent. :rtype: int """ pkt = TPKT(tpdu=data) logger.log(TRACE, "Sending %d+4 bytes to socket", len(data)) return super().sendto(pkt.build(), address)