soc: xtensa: partial fix of socket misuse and refine the code

1. Improve the firmware transfer reliability by fixing the misuse
of the socket. Fix the most frequent occurence of the common `recv()`
bug described here:

https://docs.python.org/3/howto/sockets.html#using-a-socket

The longer term fix is to switch to a higher level API like Python
Remote Objects.

2. Not rely on the client's command to disconnect. Previously we
rely on the SIGINT to send stop_command to the server, but it does
not work well in some environments. Refine the whole logic and the
sever disconnect service by checking if the client is alive or not.

These changes make the client-server-based cavstool more stable.

Fixes #46864

Signed-off-by: Enjia Mai <enjia.mai@intel.com>
This commit is contained in:
Enjia Mai 2022-06-26 03:51:36 +08:00 committed by Anas Nashif
commit faff3f7ecc
2 changed files with 135 additions and 76 deletions

View file

@ -14,6 +14,7 @@ import argparse
import socketserver import socketserver
import threading import threading
import netifaces import netifaces
import hashlib
# Global variable use to sync between log and request services. # Global variable use to sync between log and request services.
# When it is true, the adsp is able to start running. # When it is true, the adsp is able to start running.
@ -24,9 +25,17 @@ HOST = None
PORT_LOG = 9999 PORT_LOG = 9999
PORT_REQ = PORT_LOG + 1 PORT_REQ = PORT_LOG + 1
BUF_SIZE = 4096 BUF_SIZE = 4096
# Define the command and the max size
CMD_LOG_START = "start_log" CMD_LOG_START = "start_log"
CMD_LOG_STOP = "stop_log"
CMD_DOWNLOAD = "download" CMD_DOWNLOAD = "download"
MAX_CMD_SZ = 16
# Define the header format and size for
# transmiting the firmware
PACKET_HEADER_FORMAT_FW = 'I 42s 32s'
HEADER_SZ = 78
logging.basicConfig() logging.basicConfig()
log = logging.getLogger("cavs-fw") log = logging.getLogger("cavs-fw")
@ -625,6 +634,7 @@ def ipc_command(data, ext_data):
async def _main(server): async def _main(server):
#TODO this bit me, remove the globals, write a little FirmwareLoader class or something to contain. #TODO this bit me, remove the globals, write a little FirmwareLoader class or something to contain.
global hda, sd, dsp, hda_ostream_id, hda_streams global hda, sd, dsp, hda_ostream_id, hda_streams
global start_output
try: try:
(hda, sd, dsp, hda_ostream_id) = map_regs() (hda, sd, dsp, hda_ostream_id) = map_regs()
except Exception as e: except Exception as e:
@ -659,67 +669,81 @@ async def _main(server):
if dsp.HIPCTDR & 0x80000000: if dsp.HIPCTDR & 0x80000000:
ipc_command(dsp.HIPCTDR & ~0x80000000, dsp.HIPCTDD) ipc_command(dsp.HIPCTDR & ~0x80000000, dsp.HIPCTDD)
if server:
# Check if the client connection is alive.
if not is_connection_alive(server):
lock.acquire()
start_output = False
lock.release()
class adsp_request_handler(socketserver.BaseRequestHandler): class adsp_request_handler(socketserver.BaseRequestHandler):
""" """
The request handler class for control the actions of server. The request handler class for control the actions of server.
""" """
def receive_fw(self, filename): def receive_fw(self):
try:
with open(fw_file,'wb') as f:
cnt = 0
log.info("Receiving...") log.info("Receiving...")
# Receive the header first
d = self.request.recv(HEADER_SZ)
while True: # Unpacked the header data
l = self.request.recv(BUF_SIZE) # Include size(4), filename(42) and MD5(32)
ll = len(l) header = d[:HEADER_SZ]
cnt = cnt + ll total = d[HEADER_SZ:]
if not l: s = struct.Struct(PACKET_HEADER_FORMAT_FW)
break fsize, fname, md5_tx_b = s.unpack(header)
else: log.info(f'size:{fsize}, filename:{fname}, MD5:{md5_tx_b}')
f.write(l)
# Receive the firmware. We only receive the specified amount of bytes.
while len(total) < fsize:
data = self.request.recv(min(BUF_SIZE, fsize - len(total)))
if not data:
raise EOFError("truncated firmware file")
total += data
log.info(f"Done Receiving {len(total)}.")
try:
with open(fname,'wb') as f:
f.write(total)
except Exception as e: except Exception as e:
log.error(f"Get exception {e} during FW transfer.") log.error(f"Get exception {e} during FW transfer.")
return 1 return None
log.info(f"Done Receiving {cnt}.") # Check the MD5 of the firmware
md5_rx = hashlib.md5(total).hexdigest()
md5_tx = md5_tx_b.decode('utf-8')
if md5_tx != md5_rx:
log.error(f'MD5 mismatch: {md5_tx} vs. {md5_rx}')
return None
return fname
def handle(self): def handle(self):
global start_output, fw_file global start_output, fw_file
cmd = self.request.recv(BUF_SIZE) cmd = self.request.recv(MAX_CMD_SZ)
log.info(f"{self.client_address[0]} wrote: {cmd}") log.info(f"{self.client_address[0]} wrote: {cmd}")
action = cmd.decode("utf-8") action = cmd.decode("utf-8")
log.debug(f'load {action}') log.debug(f'load {action}')
if action == CMD_DOWNLOAD: if action == CMD_DOWNLOAD:
self.request.sendall(cmd) self.request.sendall(cmd)
recv_fn = self.request.recv(BUF_SIZE) recv_file = self.receive_fw()
log.info(f"{self.client_address[0]} wrote: {recv_fn}")
try: if recv_file:
tmp_file = recv_fn.decode("utf-8") self.request.sendall("success".encode('utf-8'))
except UnicodeDecodeError: log.info("Firmware well received. Ready to download.")
tmp_file = "zephyr.ri.decode_error" else:
log.info(f'did not receive a correct filename') self.request.sendall("failed".encode('utf-8'))
log.error("Receive firmware failed.")
lock.acquire() lock.acquire()
fw_file = tmp_file fw_file = recv_file
ret = self.receive_fw(fw_file)
if not ret:
start_output = True start_output = True
lock.release() lock.release()
log.debug(f'{recv_fn}, {fw_file}, {start_output}')
elif action == CMD_LOG_STOP:
self.request.sendall(cmd)
lock.acquire()
start_output = False
if fw_file:
os.remove(fw_file)
fw_file = None
lock.release()
else: else:
log.error("incorrect load communitcation!") log.error("incorrect load communitcation!")
@ -732,35 +756,50 @@ class adsp_log_handler(socketserver.BaseRequestHandler):
self.loop.run_until_complete(_main(self)) self.loop.run_until_complete(_main(self))
def handle(self): def handle(self):
global start_output, fw_file cmd = self.request.recv(MAX_CMD_SZ)
cmd = self.request.recv(BUF_SIZE)
log.info(f"{self.client_address[0]} wrote: {cmd}") log.info(f"{self.client_address[0]} wrote: {cmd}")
action = cmd.decode("utf-8") action = cmd.decode("utf-8")
log.debug(f'monitor {action}') log.debug(f'monitor {action}')
if action == CMD_LOG_START: if action == CMD_LOG_START:
global start_output, fw_file
self.request.sendall(cmd) self.request.sendall(cmd)
log.info(f"Waiting for instruction...") log.info(f"Waiting for instruction...")
while start_output is False: while start_output is False:
time.sleep(1) time.sleep(1)
if not is_connection_alive(self):
break
if fw_file:
log.info(f"Loaded FW {fw_file} and running...") log.info(f"Loaded FW {fw_file} and running...")
if os.path.exists(fw_file): if os.path.exists(fw_file):
self.run_adsp() self.run_adsp()
self.request.sendall("service complete.".encode())
log.info("service complete.") log.info("service complete.")
else: else:
log.error("cannot find the FW file") log.error("Cannot find the FW file.")
lock.acquire() lock.acquire()
fw_file = None
start_output = False start_output = False
if fw_file:
os.remove(fw_file)
fw_file = None
lock.release() lock.release()
else: else:
log.error("incorrect monitor communitcation!") log.error("incorrect monitor communitcation!")
log.info("Wait for next service...")
def is_connection_alive(server):
try:
server.request.sendall(b' ')
except (BrokenPipeError, ConnectionResetError):
log.info("Client is disconnect.")
return False
return True
def adsp_log(output, server): def adsp_log(output, server):
if server: if server:
@ -794,7 +833,7 @@ ap.add_argument("-l", "--log-only", action="store_true",
ap.add_argument("-n", "--no-history", action="store_true", ap.add_argument("-n", "--no-history", action="store_true",
help="No current log buffer at start, just new output") help="No current log buffer at start, just new output")
ap.add_argument("-s", "--server-addr", ap.add_argument("-s", "--server-addr",
help="No current log buffer at start, just new output") help="Specify the IP address that the server to active")
ap.add_argument("fw_file", nargs="?", help="Firmware file") ap.add_argument("fw_file", nargs="?", help="Firmware file")
args = ap.parse_args() args = ap.parse_args()

View file

@ -2,20 +2,29 @@
# Copyright(c) 2022 Intel Corporation. All rights reserved. # Copyright(c) 2022 Intel Corporation. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os import os
import sys
import logging import logging
import time import time
import argparse import argparse
import socket import socket
import signal import struct
import hashlib
RET = 0
HOST = None HOST = None
PORT_LOG = 9999 PORT_LOG = 9999
PORT_REQ = PORT_LOG + 1 PORT_REQ = PORT_LOG + 1
BUF_SIZE = 4096 BUF_SIZE = 4096
# Define the command and its
# possible max size
CMD_LOG_START = "start_log" CMD_LOG_START = "start_log"
CMD_LOG_STOP = "stop_log"
CMD_DOWNLOAD = "download" CMD_DOWNLOAD = "download"
MAX_CMD_SZ = 16
# Define the header format and size for
# transmiting the firmware
PACKET_HEADER_FORMAT_FW = 'I 42s 32s'
logging.basicConfig() logging.basicConfig()
log = logging.getLogger("cavs-client") log = logging.getLogger("cavs-client")
@ -36,32 +45,52 @@ class cavstool_client():
self.sock.connect((self.host, self.port)) self.sock.connect((self.host, self.port))
self.sock.sendall(cmd.encode("utf-8")) self.sock.sendall(cmd.encode("utf-8"))
log.info(f"Sent: {cmd}") log.info(f"Sent: {cmd}")
ack = str(self.sock.recv(BUF_SIZE), "utf-8") ack = str(self.sock.recv(MAX_CMD_SZ), "utf-8")
log.info(f"Receive: {ack}") log.info(f"Receive: {ack}")
if ack == CMD_LOG_START: if ack == CMD_LOG_START:
self.monitor_log() self.monitor_log()
elif ack == CMD_LOG_STOP:
log.info(f"Stop output.")
elif ack == CMD_DOWNLOAD: elif ack == CMD_DOWNLOAD:
self.run() self.run()
else: else:
log.error(f"Receive incorrect msg:{ack} expect:{cmd}") log.error(f"Receive incorrect msg:{ack} expect:{cmd}")
def download(self, filename): def uploading(self, filename):
# Send the FW to server # Send the FW to server
fname = os.path.basename(filename)
fsize = os.path.getsize(filename)
md5_tx = hashlib.md5(open(filename,'rb').read()).hexdigest()
# Pack the header and the expecting packed size is 78 bytes.
# The header by convention includes:
# size(4), filename(42), MD5(32)
values = (fsize, fname.encode('utf-8'), md5_tx.encode('utf-8'))
log.info(f'filename:{fname}, size:{fsize}, md5:{md5_tx}')
s = struct.Struct(PACKET_HEADER_FORMAT_FW)
header_data = s.pack(*values)
header_size = s.size
log.info(f'header size: {header_size}')
with open(filename,'rb') as f: with open(filename,'rb') as f:
log.info('Sending...') log.info(f'Sending...')
ret = self.sock.sendfile(f)
log.info(f"Done Sending ({ret}).") total = self.sock.send(header_data)
total += self.sock.sendfile(f)
log.info(f"Done Sending ({total}).")
rck = self.sock.recv(MAX_CMD_SZ).decode("utf-8")
log.info(f"RCK ({rck}).")
if not rck == "success":
global RET
RET = -1
log.error(f"Firmware uploading failed")
def run(self): def run(self):
filename = str(self.args.fw_file) filename = str(self.args.fw_file)
send_fn = os.path.basename(filename) self.uploading(filename)
self.sock.sendall(send_fn.encode("utf-8"))
log.info(f"Sent fw: {send_fn}, {filename}")
self.download(filename)
def monitor_log(self): def monitor_log(self):
log.info(f"Start to monitor log output...") log.info(f"Start to monitor log output...")
@ -75,28 +104,19 @@ class cavstool_client():
def __del__(self): def __del__(self):
self.sock.close() self.sock.close()
def cleanup():
client = cavstool_client(HOST, PORT_REQ, args)
client.send_cmd(CMD_LOG_STOP)
def main(): def main():
if args.log_only: if args.log_only:
log.info("Monitor process") log.info("Monitor process")
signal.signal(signal.SIGTERM, cleanup)
try: try:
client = cavstool_client(HOST, PORT_LOG, args) client = cavstool_client(HOST, PORT_LOG, args)
client.send_cmd(CMD_LOG_START) client.send_cmd(CMD_LOG_START)
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
finally:
cleanup()
elif args.kill_cmd:
log.info("Stop monitor log")
cleanup()
else: else:
log.info("Download process") log.info("Uploading process")
client = cavstool_client(HOST, PORT_REQ, args) client = cavstool_client(HOST, PORT_REQ, args)
client.send_cmd(CMD_DOWNLOAD) client.send_cmd(CMD_DOWNLOAD)
@ -107,8 +127,6 @@ ap.add_argument("-l", "--log-only", action="store_true",
help="Don't load firmware, just show log output") help="Don't load firmware, just show log output")
ap.add_argument("-s", "--server-addr", default="localhost", ap.add_argument("-s", "--server-addr", default="localhost",
help="Specify the adsp server address") help="Specify the adsp server address")
ap.add_argument("-k", "--kill-cmd", action="store_true",
help="No current log buffer at start, just new output")
ap.add_argument("fw_file", nargs="?", help="Firmware file") ap.add_argument("fw_file", nargs="?", help="Firmware file")
args = ap.parse_args() args = ap.parse_args()
@ -119,3 +137,5 @@ HOST = args.server_addr
if __name__ == "__main__": if __name__ == "__main__":
main() main()
sys.exit(RET)