153 lines
4.4 KiB
Plaintext
Executable file
153 lines
4.4 KiB
Plaintext
Executable file
#!/usr/bin/micropython
|
|
"""Accepts MQTT connections and rewrites the PUBLISH messages before forwarding on.
|
|
|
|
This can be used to fix the invalid packed ID sent by Solax inverters."""
|
|
import asyncio
|
|
import ssl
|
|
import socket
|
|
import sys
|
|
from errno import EINPROGRESS
|
|
|
|
try:
|
|
import asyncio.core
|
|
import asyncio.stream
|
|
|
|
IS_MICROPYTHON = True
|
|
except:
|
|
IS_MICROPYTHON = False
|
|
|
|
try:
|
|
from typing import Optional
|
|
except:
|
|
# MicroPython doesn't have `typing`.
|
|
pass
|
|
|
|
|
|
def _log(msg: str) -> None:
|
|
print(msg)
|
|
|
|
|
|
def _open_micropython_ssl_connection(host: str, port: int):
|
|
"""MicroPython 1.21 implementation of asyncio.open_connection."""
|
|
ai = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)[0]
|
|
s = socket.socket(ai[0], ai[1], ai[2])
|
|
s.setblocking(False)
|
|
try:
|
|
s.connect(ai[-1])
|
|
except OSError as er:
|
|
if er.errno != EINPROGRESS:
|
|
raise er
|
|
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
context.verify_mode = ssl.CERT_NONE
|
|
s = context.wrap_socket(s, server_hostname=host, do_handshake_on_connect=False)
|
|
s.setblocking(False)
|
|
ss = asyncio.stream.Stream(s)
|
|
yield asyncio.core._io_queue.queue_write(s)
|
|
return ss, ss
|
|
|
|
|
|
async def _open_ssl_connection(host: str, port: int):
|
|
if IS_MICROPYTHON:
|
|
return await _open_micropython_ssl_connection(host, port)
|
|
else:
|
|
return await asyncio.open_connection(host, port, ssl=True)
|
|
|
|
|
|
async def _try_get(reader, writer) -> Optional[int]:
|
|
"""Returns the next byte from `reader`, or None if closed."""
|
|
got = await reader.read(1)
|
|
if not got:
|
|
return None
|
|
writer.write(got)
|
|
return got[0]
|
|
|
|
|
|
async def _get(reader, writer) -> int:
|
|
"""Returns the next byte from `reader`, or Exception on closed."""
|
|
got = await _try_get(reader, writer)
|
|
if got is None:
|
|
raise EOFError("Unexpected EOF")
|
|
return got
|
|
|
|
|
|
async def _get_length(reader, writer) -> int:
|
|
"""Returns the encoded length, or Exception on closed or oversize."""
|
|
lsb = await _get(reader, writer)
|
|
if (lsb & 0x80) == 0:
|
|
return lsb
|
|
|
|
msb = await _get(reader, writer)
|
|
if (msb & 0x80) != 0:
|
|
# Protect the system by faulting on messages greater than ~16 KiB.
|
|
raise OSError("Packet is too long")
|
|
return (msb << 7) | (lsb & 0x7F)
|
|
|
|
|
|
def _rewrite_publish(message: bytearray) -> None:
|
|
"""Rewrites a publish message by fixing the packet offset if invalid."""
|
|
topic_length = (message[0] << 8) | message[1]
|
|
packet_id_offset = topic_length + 2
|
|
|
|
if message[packet_id_offset] == 0 and message[packet_id_offset + 1] == 0:
|
|
message[packet_id_offset] = message[0]
|
|
message[packet_id_offset + 1] = message[1]
|
|
|
|
|
|
async def _rewrite(direction: str, reader, writer) -> None:
|
|
"""Copies messages from `reader` to `writer` with rewrites."""
|
|
while True:
|
|
packet_type = await _try_get(reader, writer)
|
|
if packet_type is None:
|
|
break
|
|
|
|
length = await _get_length(reader, writer)
|
|
_log(f"{direction} type={packet_type:x} length={length}")
|
|
payload = bytearray(await reader.readexactly(length))
|
|
qos = (packet_type >> 1) & 0x03
|
|
if (packet_type & 0xF0) == 0x30 and qos > 0:
|
|
_rewrite_publish(payload)
|
|
|
|
writer.write(payload)
|
|
await writer.drain()
|
|
|
|
|
|
async def _client(reader, writer, upstream: str, upstream_port: int) -> None:
|
|
try:
|
|
ureader, uwriter = await _open_ssl_connection(upstream, upstream_port)
|
|
try:
|
|
await asyncio.gather(
|
|
asyncio.create_task(_rewrite("<", ureader, writer)),
|
|
asyncio.create_task(_rewrite(">", reader, uwriter)),
|
|
)
|
|
finally:
|
|
uwriter.close()
|
|
except Exception as ex:
|
|
_log(f"Giving up due to exception: {type(ex)} {ex}")
|
|
finally:
|
|
writer.close()
|
|
|
|
|
|
async def _serve(
|
|
listen: str, listen_port: int, upstream: str, upstream_port: int
|
|
) -> None:
|
|
async def _wrapper(reader, writer):
|
|
return await _client(reader, writer, upstream, upstream_port)
|
|
|
|
server = await asyncio.start_server(_wrapper, listen, listen_port)
|
|
async with server:
|
|
await server.wait_closed()
|
|
|
|
|
|
def main():
|
|
args = sys.argv[1:] + [None] * 4
|
|
upstream = args[0] or "localhost"
|
|
upstream_port = int(args[1] or 8883)
|
|
listen = args[2] or "0.0.0.0"
|
|
listen_port = int(args[3] or 2901)
|
|
|
|
asyncio.run(_serve(listen, listen_port, upstream, upstream_port))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|