soc: xtensa: partial fix of socket misuse and refine the code

1. Improve the firmware transfer reliability by fixing the misuse
of the socket. Fix the most frequent occurence of the common `recv()`
bug described here:

https://docs.python.org/3/howto/sockets.html#using-a-socket

The longer term fix is to switch to a higher level API like Python
Remote Objects.

2. Not rely on the client's command to disconnect. Previously we
rely on the SIGINT to send stop_command to the server, but it does
not work well in some environments. Refine the whole logic and the
sever disconnect service by checking if the client is alive or not.

These changes make the client-server-based cavstool more stable.

Fixes #46864

Signed-off-by: Enjia Mai <enjia.mai@intel.com>
This commit is contained in:
Enjia Mai 2022-06-26 03:51:36 +08:00 committed by Anas Nashif
commit faff3f7ecc
2 changed files with 135 additions and 76 deletions

View file

@ -14,6 +14,7 @@ import argparse
import socketserver
import threading
import netifaces
import hashlib
# Global variable use to sync between log and request services.
# When it is true, the adsp is able to start running.
@ -24,9 +25,17 @@ HOST = None
PORT_LOG = 9999
PORT_REQ = PORT_LOG + 1
BUF_SIZE = 4096
# Define the command and the max size
CMD_LOG_START = "start_log"
CMD_LOG_STOP = "stop_log"
CMD_DOWNLOAD = "download"
MAX_CMD_SZ = 16
# Define the header format and size for
# transmiting the firmware
PACKET_HEADER_FORMAT_FW = 'I 42s 32s'
HEADER_SZ = 78
logging.basicConfig()
log = logging.getLogger("cavs-fw")
@ -625,6 +634,7 @@ def ipc_command(data, ext_data):
async def _main(server):
#TODO this bit me, remove the globals, write a little FirmwareLoader class or something to contain.
global hda, sd, dsp, hda_ostream_id, hda_streams
global start_output
try:
(hda, sd, dsp, hda_ostream_id) = map_regs()
except Exception as e:
@ -659,67 +669,81 @@ async def _main(server):
if dsp.HIPCTDR & 0x80000000:
ipc_command(dsp.HIPCTDR & ~0x80000000, dsp.HIPCTDD)
if server:
# Check if the client connection is alive.
if not is_connection_alive(server):
lock.acquire()
start_output = False
lock.release()
class adsp_request_handler(socketserver.BaseRequestHandler):
"""
The request handler class for control the actions of server.
"""
def receive_fw(self, filename):
try:
with open(fw_file,'wb') as f:
cnt = 0
log.info("Receiving...")
def receive_fw(self):
log.info("Receiving...")
# Receive the header first
d = self.request.recv(HEADER_SZ)
while True:
l = self.request.recv(BUF_SIZE)
ll = len(l)
cnt = cnt + ll
if not l:
break
else:
f.write(l)
# Unpacked the header data
# Include size(4), filename(42) and MD5(32)
header = d[:HEADER_SZ]
total = d[HEADER_SZ:]
s = struct.Struct(PACKET_HEADER_FORMAT_FW)
fsize, fname, md5_tx_b = s.unpack(header)
log.info(f'size:{fsize}, filename:{fname}, MD5:{md5_tx_b}')
# Receive the firmware. We only receive the specified amount of bytes.
while len(total) < fsize:
data = self.request.recv(min(BUF_SIZE, fsize - len(total)))
if not data:
raise EOFError("truncated firmware file")
total += data
log.info(f"Done Receiving {len(total)}.")
try:
with open(fname,'wb') as f:
f.write(total)
except Exception as e:
log.error(f"Get exception {e} during FW transfer.")
return 1
return None
log.info(f"Done Receiving {cnt}.")
# Check the MD5 of the firmware
md5_rx = hashlib.md5(total).hexdigest()
md5_tx = md5_tx_b.decode('utf-8')
if md5_tx != md5_rx:
log.error(f'MD5 mismatch: {md5_tx} vs. {md5_rx}')
return None
return fname
def handle(self):
global start_output, fw_file
cmd = self.request.recv(BUF_SIZE)
cmd = self.request.recv(MAX_CMD_SZ)
log.info(f"{self.client_address[0]} wrote: {cmd}")
action = cmd.decode("utf-8")
log.debug(f'load {action}')
if action == CMD_DOWNLOAD:
self.request.sendall(cmd)
recv_fn = self.request.recv(BUF_SIZE)
log.info(f"{self.client_address[0]} wrote: {recv_fn}")
recv_file = self.receive_fw()
try:
tmp_file = recv_fn.decode("utf-8")
except UnicodeDecodeError:
tmp_file = "zephyr.ri.decode_error"
log.info(f'did not receive a correct filename')
if recv_file:
self.request.sendall("success".encode('utf-8'))
log.info("Firmware well received. Ready to download.")
else:
self.request.sendall("failed".encode('utf-8'))
log.error("Receive firmware failed.")
lock.acquire()
fw_file = tmp_file
ret = self.receive_fw(fw_file)
if not ret:
start_output = True
fw_file = recv_file
start_output = True
lock.release()
log.debug(f'{recv_fn}, {fw_file}, {start_output}')
elif action == CMD_LOG_STOP:
self.request.sendall(cmd)
lock.acquire()
start_output = False
if fw_file:
os.remove(fw_file)
fw_file = None
lock.release()
else:
log.error("incorrect load communitcation!")
@ -732,35 +756,50 @@ class adsp_log_handler(socketserver.BaseRequestHandler):
self.loop.run_until_complete(_main(self))
def handle(self):
global start_output, fw_file
cmd = self.request.recv(BUF_SIZE)
cmd = self.request.recv(MAX_CMD_SZ)
log.info(f"{self.client_address[0]} wrote: {cmd}")
action = cmd.decode("utf-8")
log.debug(f'monitor {action}')
if action == CMD_LOG_START:
global start_output, fw_file
self.request.sendall(cmd)
log.info(f"Waiting for instruction...")
while start_output is False:
time.sleep(1)
if not is_connection_alive(self):
break
log.info(f"Loaded FW {fw_file} and running...")
if os.path.exists(fw_file):
self.run_adsp()
self.request.sendall("service complete.".encode())
log.info("service complete.")
else:
log.error("cannot find the FW file")
if fw_file:
log.info(f"Loaded FW {fw_file} and running...")
if os.path.exists(fw_file):
self.run_adsp()
log.info("service complete.")
else:
log.error("Cannot find the FW file.")
lock.acquire()
fw_file = None
start_output = False
if fw_file:
os.remove(fw_file)
fw_file = None
lock.release()
else:
log.error("incorrect monitor communitcation!")
log.info("Wait for next service...")
def is_connection_alive(server):
try:
server.request.sendall(b' ')
except (BrokenPipeError, ConnectionResetError):
log.info("Client is disconnect.")
return False
return True
def adsp_log(output, server):
if server:
@ -794,7 +833,7 @@ ap.add_argument("-l", "--log-only", action="store_true",
ap.add_argument("-n", "--no-history", action="store_true",
help="No current log buffer at start, just new output")
ap.add_argument("-s", "--server-addr",
help="No current log buffer at start, just new output")
help="Specify the IP address that the server to active")
ap.add_argument("fw_file", nargs="?", help="Firmware file")
args = ap.parse_args()