#!/usr/bin/micropython """Accepts MQTT connections and rewrites the PUBLISH messages before forwarding on. This can be used to fix the invalid packed ID sent by Solax inverters.""" import asyncio import ssl import socket import sys from errno import EINPROGRESS try: import asyncio.core import asyncio.stream IS_MICROPYTHON = True except: IS_MICROPYTHON = False try: from typing import Optional except: # MicroPython doesn't have `typing`. pass def _log(msg: str) -> None: print(msg) def _open_micropython_ssl_connection(host: str, port: int): """MicroPython 1.21 implementation of asyncio.open_connection.""" ai = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)[0] s = socket.socket(ai[0], ai[1], ai[2]) s.setblocking(False) try: s.connect(ai[-1]) except OSError as er: if er.errno != EINPROGRESS: raise er context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) context.verify_mode = ssl.CERT_NONE s = context.wrap_socket(s, server_hostname=host, do_handshake_on_connect=False) s.setblocking(False) ss = asyncio.stream.Stream(s) yield asyncio.core._io_queue.queue_write(s) return ss, ss async def _open_ssl_connection(host: str, port: int): if IS_MICROPYTHON: return await _open_micropython_ssl_connection(host, port) else: return await asyncio.open_connection(host, port, ssl=True) async def _try_get(reader, writer) -> Optional[int]: """Returns the next byte from `reader`, or None if closed.""" got = await reader.read(1) if not got: return None writer.write(got) return got[0] async def _get(reader, writer) -> int: """Returns the next byte from `reader`, or Exception on closed.""" got = await _try_get(reader, writer) if got is None: raise EOFError("Unexpected EOF") return got async def _get_length(reader, writer) -> int: """Returns the encoded length, or Exception on closed or oversize.""" lsb = await _get(reader, writer) if (lsb & 0x80) == 0: return lsb msb = await _get(reader, writer) if (msb & 0x80) != 0: # Protect the system by faulting on messages greater than ~16 KiB. raise OSError("Packet is too long") return (msb << 7) | (lsb & 0x7F) def _rewrite_publish(message: bytearray) -> None: """Rewrites a publish message by fixing the packet offset if invalid.""" topic_length = (message[0] << 8) | message[1] packet_id_offset = topic_length + 2 if message[packet_id_offset] == 0 and message[packet_id_offset + 1] == 0: message[packet_id_offset] = message[0] message[packet_id_offset + 1] = message[1] async def _rewrite(direction: str, reader, writer) -> None: """Copies messages from `reader` to `writer` with rewrites.""" while True: packet_type = await _try_get(reader, writer) if packet_type is None: break length = await _get_length(reader, writer) _log(f"{direction} type={packet_type:x} length={length}") payload = bytearray(await reader.readexactly(length)) qos = (packet_type >> 1) & 0x03 if (packet_type & 0xF0) == 0x30 and qos > 0: _rewrite_publish(payload) writer.write(payload) await writer.drain() async def _client(reader, writer, upstream: str, upstream_port: int) -> None: try: ureader, uwriter = await _open_ssl_connection(upstream, upstream_port) try: await asyncio.gather( asyncio.create_task(_rewrite("<", ureader, writer)), asyncio.create_task(_rewrite(">", reader, uwriter)), ) finally: uwriter.close() except Exception as ex: _log(f"Giving up due to exception: {type(ex)} {ex}") finally: writer.close() async def _serve( listen: str, listen_port: int, upstream: str, upstream_port: int ) -> None: async def _wrapper(reader, writer): return await _client(reader, writer, upstream, upstream_port) server = await asyncio.start_server(_wrapper, listen, listen_port) async with server: await server.wait_closed() def main(): args = sys.argv[1:] + [None] * 4 upstream = args[0] or "localhost" upstream_port = int(args[1] or 8883) listen = args[2] or "0.0.0.0" listen_port = int(args[3] or 2901) asyncio.run(_serve(listen, listen_port, upstream, upstream_port)) if __name__ == "__main__": main()