mqtt-rewrite/mqtt_rewrite
2024-03-29 13:48:43 +01:00

153 lines
4.4 KiB
Plaintext
Executable file

#!/usr/bin/micropython
"""Accepts MQTT connections and rewrites the PUBLISH messages before forwarding on.
This can be used to fix the invalid packed ID sent by Solax inverters."""
import asyncio
import ssl
import socket
import sys
from errno import EINPROGRESS
try:
import asyncio.core
import asyncio.stream
IS_MICROPYTHON = True
except:
IS_MICROPYTHON = False
try:
from typing import Optional
except:
# MicroPython doesn't have `typing`.
pass
def _log(msg: str) -> None:
print(msg)
def _open_micropython_ssl_connection(host: str, port: int):
"""MicroPython 1.21 implementation of asyncio.open_connection."""
ai = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)[0]
s = socket.socket(ai[0], ai[1], ai[2])
s.setblocking(False)
try:
s.connect(ai[-1])
except OSError as er:
if er.errno != EINPROGRESS:
raise er
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
context.verify_mode = ssl.CERT_NONE
s = context.wrap_socket(s, server_hostname=host, do_handshake_on_connect=False)
s.setblocking(False)
ss = asyncio.stream.Stream(s)
yield asyncio.core._io_queue.queue_write(s)
return ss, ss
async def _open_ssl_connection(host: str, port: int):
if IS_MICROPYTHON:
return await _open_micropython_ssl_connection(host, port)
else:
return await asyncio.open_connection(host, port, ssl=True)
async def _try_get(reader, writer) -> Optional[int]:
"""Returns the next byte from `reader`, or None if closed."""
got = await reader.read(1)
if not got:
return None
writer.write(got)
return got[0]
async def _get(reader, writer) -> int:
"""Returns the next byte from `reader`, or Exception on closed."""
got = await _try_get(reader, writer)
if got is None:
raise EOFError("Unexpected EOF")
return got
async def _get_length(reader, writer) -> int:
"""Returns the encoded length, or Exception on closed or oversize."""
lsb = await _get(reader, writer)
if (lsb & 0x80) == 0:
return lsb
msb = await _get(reader, writer)
if (msb & 0x80) != 0:
# Protect the system by faulting on messages greater than ~16 KiB.
raise OSError("Packet is too long")
return (msb << 7) | (lsb & 0x7F)
def _rewrite_publish(message: bytearray) -> None:
"""Rewrites a publish message by fixing the packet offset if invalid."""
topic_length = (message[0] << 8) | message[1]
packet_id_offset = topic_length + 2
if message[packet_id_offset] == 0 and message[packet_id_offset + 1] == 0:
message[packet_id_offset] = message[0]
message[packet_id_offset + 1] = message[1]
async def _rewrite(direction: str, reader, writer) -> None:
"""Copies messages from `reader` to `writer` with rewrites."""
while True:
packet_type = await _try_get(reader, writer)
if packet_type is None:
break
length = await _get_length(reader, writer)
_log(f"{direction} type={packet_type:x} length={length}")
payload = bytearray(await reader.readexactly(length))
qos = (packet_type >> 1) & 0x03
if (packet_type & 0xF0) == 0x30 and qos > 0:
_rewrite_publish(payload)
writer.write(payload)
await writer.drain()
async def _client(reader, writer, upstream: str, upstream_port: int) -> None:
try:
ureader, uwriter = await _open_ssl_connection(upstream, upstream_port)
try:
await asyncio.gather(
asyncio.create_task(_rewrite("<", ureader, writer)),
asyncio.create_task(_rewrite(">", reader, uwriter)),
)
finally:
uwriter.close()
except Exception as ex:
_log(f"Giving up due to exception: {type(ex)} {ex}")
finally:
writer.close()
async def _serve(
listen: str, listen_port: int, upstream: str, upstream_port: int
) -> None:
async def _wrapper(reader, writer):
return await _client(reader, writer, upstream, upstream_port)
server = await asyncio.start_server(_wrapper, listen, listen_port)
async with server:
await server.wait_closed()
def main():
args = sys.argv[1:] + [None] * 4
upstream = args[0] or "localhost"
upstream_port = int(args[1] or 8883)
listen = args[2] or "0.0.0.0"
listen_port = int(args[3] or 2901)
asyncio.run(_serve(listen, listen_port, upstream, upstream_port))
if __name__ == "__main__":
main()