Skip to content

Server

SusServer

This class is responsible for managing the server.

Source code in sus/server/server.py
class SusServer:
    """
    This class is responsible for managing the server.
    """
    __protocol: OnePortProtocol

    def __init__(self, addr: Address, psks: str):
        """
        Initializes the server.
        :param addr: Tuple containing the address and port to listen on
        :param psks: Hex encoded private key
        """
        self.__addr = addr
        self.__logger = logging.getLogger("gatekeeper")

        self.__psks = X25519PrivateKey.from_private_bytes(bytes.fromhex(psks))
        self.__ppks = self.__psks.public_key()

        with open("server.pub", "w") as f:
            f.write(self.__ppks.public_bytes(Encoding.Raw, PublicFormat.Raw).hex())

    @property
    def public_key(self):
        """Server public key"""
        return self.__ppks.public_bytes(Encoding.Raw, PublicFormat.Raw).hex()

    @property
    def address(self):
        """Server address"""
        return self.__addr

    async def __garbage_collector(self):
        """
        This coroutine is responsible for cleaning up inactive clients.
        """
        try:
            while not self.__protocol.closed.is_set():
                await asyncio.sleep(10)
                self.__protocol.clean()
        except asyncio.CancelledError:
            pass

    async def start(self, message_handlers: Iterable[MessageHandler] = None):
        """
        This coroutine is responsible for starting the server.
        :param message_handlers: An iterable of message handlers, which are called when a message is received.
        """
        self.__logger.info("Starting server")
        self.__logger.info(f"public key: {self.__ppks.public_bytes(Encoding.Raw, PublicFormat.Raw).hex()}")

        wallet = Wallet(ppks=self.__ppks, psks=self.__psks)

        # create a protocol instance, this will handle all incoming packets
        _, self.__protocol = await asyncio.get_running_loop().create_datagram_endpoint(
            lambda: OnePortProtocol(wallet, message_handlers if message_handlers else []),
            self.__addr)

        # start the garbage collector.
        gc_task = asyncio.create_task(self.__garbage_collector())

        # we're done here, wait for the protocol to close.
        try:
            await self.__protocol.closed.wait()
        except asyncio.CancelledError:
            self.__logger.info("Server stopped")
        finally:
            gc_task.cancel()
            self.__protocol.close()

    async def send(self, addr: Address, msg: bytes):
        """
        Sends a message to a client.
        :param addr: Client address
        :param msg: Message to send
        """
        await self.__protocol.send(msg, addr)

    async def stop(self):
        """
        Stops the server.
        """
        self.__logger.warning("Shutting down")
        self.__protocol.close()

address property

Server address

public_key property

Server public key

__garbage_collector() async

This coroutine is responsible for cleaning up inactive clients.

Source code in sus/server/server.py
async def __garbage_collector(self):
    """
    This coroutine is responsible for cleaning up inactive clients.
    """
    try:
        while not self.__protocol.closed.is_set():
            await asyncio.sleep(10)
            self.__protocol.clean()
    except asyncio.CancelledError:
        pass

__init__(addr, psks)

Initializes the server.

Parameters:

Name Type Description Default
addr Address

Tuple containing the address and port to listen on

required
psks str

Hex encoded private key

required
Source code in sus/server/server.py
def __init__(self, addr: Address, psks: str):
    """
    Initializes the server.
    :param addr: Tuple containing the address and port to listen on
    :param psks: Hex encoded private key
    """
    self.__addr = addr
    self.__logger = logging.getLogger("gatekeeper")

    self.__psks = X25519PrivateKey.from_private_bytes(bytes.fromhex(psks))
    self.__ppks = self.__psks.public_key()

    with open("server.pub", "w") as f:
        f.write(self.__ppks.public_bytes(Encoding.Raw, PublicFormat.Raw).hex())

send(addr, msg) async

Sends a message to a client.

Parameters:

Name Type Description Default
addr Address

Client address

required
msg bytes

Message to send

required
Source code in sus/server/server.py
async def send(self, addr: Address, msg: bytes):
    """
    Sends a message to a client.
    :param addr: Client address
    :param msg: Message to send
    """
    await self.__protocol.send(msg, addr)

start(message_handlers=None) async

This coroutine is responsible for starting the server.

Parameters:

Name Type Description Default
message_handlers Iterable[MessageHandler]

An iterable of message handlers, which are called when a message is received.

None
Source code in sus/server/server.py
async def start(self, message_handlers: Iterable[MessageHandler] = None):
    """
    This coroutine is responsible for starting the server.
    :param message_handlers: An iterable of message handlers, which are called when a message is received.
    """
    self.__logger.info("Starting server")
    self.__logger.info(f"public key: {self.__ppks.public_bytes(Encoding.Raw, PublicFormat.Raw).hex()}")

    wallet = Wallet(ppks=self.__ppks, psks=self.__psks)

    # create a protocol instance, this will handle all incoming packets
    _, self.__protocol = await asyncio.get_running_loop().create_datagram_endpoint(
        lambda: OnePortProtocol(wallet, message_handlers if message_handlers else []),
        self.__addr)

    # start the garbage collector.
    gc_task = asyncio.create_task(self.__garbage_collector())

    # we're done here, wait for the protocol to close.
    try:
        await self.__protocol.closed.wait()
    except asyncio.CancelledError:
        self.__logger.info("Server stopped")
    finally:
        gc_task.cancel()
        self.__protocol.close()

stop() async

Stops the server.

Source code in sus/server/server.py
async def stop(self):
    """
    Stops the server.
    """
    self.__logger.warning("Shutting down")
    self.__protocol.close()

Client handler

This class is responsible for handling clients. One instance of this class is created for each client.

Source code in sus/server/handler.py
class ClientHandler:
    """
    This class is responsible for handling clients.
    One instance of this class is created for each client.
    """
    __cl_enc: AEADDecryptionContext
    __sr_enc: AEADEncryptionContext
    __cl_mac: AEADDecryptionContext
    __sr_mac: AEADEncryptionContext

    __protocol: bytes

    def __init__(self, addr: tuple[str, int], transport: asyncio.DatagramTransport, wallet: Wallet,
                 message_handlers: Iterable[MessageHandler]):
        self.__last_seen = now()
        self.__addr = addr
        self.__transport = transport
        self.__message_handlers = set(message_handlers)
        self.__state = ConnectionState.INITIAL

        self.__logger = logging.getLogger(f"{addr[0]}:{addr[1]}")

        self.__logger.info(f"New client {addr}")

        wallet.esks = X25519PrivateKey.generate()
        wallet.epks = wallet.esks.public_key()
        wallet.ns = urandom(8)

        self.__wallet = wallet

        self.__client_message_id = 0
        self.__server_message_id = 0
        self.__client_packet_id = 0
        self.__server_packet_id = 0

        self.__mtu_estimate = 1500

    @property
    def is_alive(self):
        """
        True if the client is not in an error state and has been seen in the last 5 seconds.
        """
        return self.__state not in (
            ConnectionState.ERROR, ConnectionState.DISCONNECTED
        ) and now() - self.__last_seen < 5

    @property
    def addr(self):
        """Client address."""
        return self.__addr

    @property
    def last_seen(self):
        """Last time the client was seen, in seconds since the epoch."""
        return self.__last_seen

    @property
    def protocol(self):
        return self.__protocol

    def __key_exchange(self):
        wallet = self.__wallet
        eces = wallet.esks.exchange(wallet.epkc)
        ecps = wallet.psks.exchange(wallet.epkc)
        wallet.shared_secret = blake3(
            eces + ecps + wallet.nc + wallet.ns +
            wallet.ppks.public_bytes(Encoding.Raw, PublicFormat.Raw) +
            wallet.epks.public_bytes(Encoding.Raw, PublicFormat.Raw) +
            wallet.epkc.public_bytes(Encoding.Raw, PublicFormat.Raw)).digest()
        self.__logger.debug(f"shared_secret: {wallet.shared_secret.hex()}")
        # noinspection DuplicatedCode
        self.__cl_enc = Cipher(ChaCha20(wallet.shared_secret, b"\x00" * 8 + CLIENT_ENC_NONCE), None).decryptor()
        self.__sr_enc = Cipher(ChaCha20(wallet.shared_secret, b"\x00" * 8 + SERVER_ENC_NONCE), None).encryptor()
        self.__cl_mac = Cipher(ChaCha20(wallet.shared_secret, b"\x00" * 8 + CLIENT_MAC_NONCE), None).decryptor()
        self.__sr_mac = Cipher(ChaCha20(wallet.shared_secret, b"\x00" * 8 + SERVER_MAC_NONCE), None).encryptor()

    def __verify_and_decrypt(self, data: bytes) -> bytes | None:
        try:
            p_id = int.from_bytes(data[:8], "little")
            key = self.__cl_mac.update(b"\x00" * 32)
            payload = data[8:-16]
            tag = data[-16:]
            poly1305.Poly1305.verify_tag(key, data[:8] + payload, tag)
        except InvalidSignature:
            self.__logger.error("Invalid signature")
            return None

        # special case for first packet
        if p_id == 0:
            payload = payload[32:]
            self.__logger.debug(f"--- {trail_off(payload.hex())}")

        message_bytes = self.__cl_enc.update(payload)
        message_length = int.from_bytes(message_bytes[:4], "little")
        message = message_bytes[4:message_length + 4]
        self.__logger.info(f"Received message {p_id} ({message_length} bytes)")
        self.__client_packet_id = p_id
        return message

    def __encrypt_and_tag(self, data: bytes) -> list[bytes]:
        message_bytes = len(data).to_bytes(4, "little") + data
        packet_length = self.__mtu_estimate - 24
        padded_message_bytes = message_bytes + b"\x00" * (
                packet_length - ((len(message_bytes)) % packet_length))

        ciphertext = self.__sr_enc.update(padded_message_bytes)
        self.__logger.debug(f"--- {trail_off(ciphertext.hex())}")

        payloads = self.__split_message(ciphertext)

        packets = []
        for payload in payloads:
            key = self.__sr_mac.update(b"\x00" * 32)
            p_id = self.__client_packet_id.to_bytes(8, "little")
            frame = p_id + payload
            tag = poly1305.Poly1305.generate_tag(key, frame)
            packets.append(frame + tag)
            self.__client_packet_id += 1
        return packets

    def __split_message(self, data: bytes) -> list[bytes]:
        packet_length = self.__mtu_estimate - 24
        return [data[i:i + packet_length] for i in range(0, len(data), packet_length)]

    def __initial(self, data):
        if len(data) != 40:
            self.__logger.error("Invalid handshake packet")
            raise MalformedPacket("Invalid handshake packet")
        wallet = self.__wallet
        wallet.epkc = X25519PublicKey.from_public_bytes(data[:32])
        wallet.nc = data[32:]

        self.__state = ConnectionState.HANDSHAKE
        self.__transport.sendto((wallet.epks.public_bytes(Encoding.Raw, PublicFormat.Raw) + wallet.ns), self.__addr)

    def __handshake(self, data) -> None:
        if len(data) < 40:
            self.__logger.error("Invalid handshake packet")
            self.__state = ConnectionState.ERROR
            raise MalformedPacket("Invalid handshake packet")

        wallet = self.__wallet

        wallet.token = blake3(wallet.epkc.public_bytes(Encoding.Raw, PublicFormat.Raw) +
                              wallet.epks.public_bytes(Encoding.Raw, PublicFormat.Raw) +
                              wallet.nc + wallet.ns).digest()
        client_token = data[8:40]

        self.__logger.debug(f"token: {client_token.hex()}")

        if client_token != wallet.token:
            self.__logger.debug(f"ours : {self.__wallet.token.hex()}")
            self.__logger.error("token mismatch!")
            raise HandsakeError("Token mismatch")

        self.__logger.debug("token: OK")

        self.__key_exchange()

        message = self.__verify_and_decrypt(data)
        if message is None:
            self.__state = ConnectionState.ERROR
            raise HandsakeError("Invalid handshake packet (missing protocol)")

        self.__state = ConnectionState.CONNECTED
        self.__protocol = message
        self.__logger.info("Handshake complete")
        self.__logger.debug(f"protocol: {message.decode('utf-8')}")

    def __connected(self, data):
        message = self.__verify_and_decrypt(data)
        self.__logger.info(f">>> {trail_off(message.decode('utf-8')) if message else None}")

        if message:
            # send to all handlers
            asyncio.gather(*[
                handler(self.addr, self.__client_message_id, message)
                for handler in self.__message_handlers
            ])

    def __disconnected(self, data):
        raise NotImplementedError

    def __error(self, data):
        raise NotImplementedError

    def handle(self, data: bytes):
        """
        Handles incoming packets.
        :param data: packet data
        """
        self.__last_seen = now()

        match self.__state:
            case ConnectionState.INITIAL:
                self.__initial(data)
            case ConnectionState.HANDSHAKE:
                self.__handshake(data)
            case ConnectionState.CONNECTED:
                self.__connected(data)
            case ConnectionState.DISCONNECTED:
                self.__disconnected(data)
            case ConnectionState.ERROR:
                self.__error(data)

    async def send(self, data: bytes):
        """
        Sends a message to the client.
        :param data: data to send
        """
        if self.__state not in (ConnectionState.CONNECTED, ConnectionState.HANDSHAKE):
            return
        self.__logger.info(f"<<< {trail_off(data.decode('utf-8'))}")
        packets = self.__encrypt_and_tag(data)
        self.__logger.info(f"Sending {len(data)} bytes in {len(packets)} packets")
        for packet in packets:
            self.__transport.sendto(packet, self.__addr)

    def add_message_handler(self, handler: MessageHandler):
        """
        Adds a message handler. This handler will be called when a message is received.
        :param handler: Awaitable handler function
        """
        self.__message_handlers.add(handler)

addr property

Client address.

is_alive property

True if the client is not in an error state and has been seen in the last 5 seconds.

last_seen property

Last time the client was seen, in seconds since the epoch.

add_message_handler(handler)

Adds a message handler. This handler will be called when a message is received.

Parameters:

Name Type Description Default
handler MessageHandler

Awaitable handler function

required
Source code in sus/server/handler.py
def add_message_handler(self, handler: MessageHandler):
    """
    Adds a message handler. This handler will be called when a message is received.
    :param handler: Awaitable handler function
    """
    self.__message_handlers.add(handler)

handle(data)

Handles incoming packets.

Parameters:

Name Type Description Default
data bytes

packet data

required
Source code in sus/server/handler.py
def handle(self, data: bytes):
    """
    Handles incoming packets.
    :param data: packet data
    """
    self.__last_seen = now()

    match self.__state:
        case ConnectionState.INITIAL:
            self.__initial(data)
        case ConnectionState.HANDSHAKE:
            self.__handshake(data)
        case ConnectionState.CONNECTED:
            self.__connected(data)
        case ConnectionState.DISCONNECTED:
            self.__disconnected(data)
        case ConnectionState.ERROR:
            self.__error(data)

send(data) async

Sends a message to the client.

Parameters:

Name Type Description Default
data bytes

data to send

required
Source code in sus/server/handler.py
async def send(self, data: bytes):
    """
    Sends a message to the client.
    :param data: data to send
    """
    if self.__state not in (ConnectionState.CONNECTED, ConnectionState.HANDSHAKE):
        return
    self.__logger.info(f"<<< {trail_off(data.decode('utf-8'))}")
    packets = self.__encrypt_and_tag(data)
    self.__logger.info(f"Sending {len(data)} bytes in {len(packets)} packets")
    for packet in packets:
        self.__transport.sendto(packet, self.__addr)

Protocol

Bases: DatagramProtocol

This class is responsible for handling the UDP protocol. It matches incoming packets to clients and handles the handshake.

Source code in sus/server/protocol.py
class OnePortProtocol(asyncio.DatagramProtocol):
    """
    This class is responsible for handling the UDP protocol.
    It matches incoming packets to clients and handles the handshake.
    """
    __transport: asyncio.DatagramTransport

    def __init__(self, wallet: Wallet, message_handlers: Iterable[MessageHandler]):
        super().__init__()
        self.__wallet = wallet
        self.__message_handlers = message_handlers

        self.__clients: dict[tuple[str, int], ClientHandler] = dict()
        self.__logger = logging.getLogger(f"OnePort")

        self.closed = asyncio.Event()

    def connection_made(self, transport: asyncio.DatagramTransport):
        self.__transport = transport
        self.__logger.info(f"Listening on port {transport.get_extra_info('sockname')[1]}")

    def error_received(self, exc):
        self.__logger.exception(exc)

    def datagram_received(self, data, addr):
        if addr not in self.__clients:
            try:
                c = ClientHandler(addr, self.__transport, self.__wallet, self.__message_handlers)
            except (HandsakeError, MalformedPacket):
                self.__logger.error(f"Handshake failed with {addr}")
                return
            self.__clients[addr] = c

        handler = self.__clients[addr]

        try:
            handler.handle(data)
        except HandsakeError:
            self.__logger.error(f"Handshake failed with {addr}")
            del self.__clients[addr]
            self.close()
        except MalformedPacket:
            self.__logger.error(f"Malformed packet from {addr}")
            del self.__clients[addr]

    async def send(self, data: bytes, addr: tuple[str, int]):
        if addr not in self.__clients:
            self.__logger.error(f"Attempted to send to {addr} but they are not connected")
            return
        await self.__clients[addr].send(data)

    def add_message_handler(self, handler: MessageHandler, addr: tuple[str, int]):
        self.__clients[addr].add_message_handler(handler)

    def clean(self):
        """
        Removes inactive clients.
        """
        for addr in list(self.__clients.keys()):
            if not self.__clients[addr].is_alive:
                del self.__clients[addr]

    def close(self):
        self.__transport.close()

    def connection_lost(self, exc):
        self.__logger.info("Connection closed")
        self.closed.set()

clean()

Removes inactive clients.

Source code in sus/server/protocol.py
def clean(self):
    """
    Removes inactive clients.
    """
    for addr in list(self.__clients.keys()):
        if not self.__clients[addr].is_alive:
            del self.__clients[addr]