soc: xtensa: partial fix of socket misuse and refine the code
1. Improve the firmware transfer reliability by fixing the misuse of the socket. Fix the most frequent occurence of the common `recv()` bug described here: https://docs.python.org/3/howto/sockets.html#using-a-socket The longer term fix is to switch to a higher level API like Python Remote Objects. 2. Not rely on the client's command to disconnect. Previously we rely on the SIGINT to send stop_command to the server, but it does not work well in some environments. Refine the whole logic and the sever disconnect service by checking if the client is alive or not. These changes make the client-server-based cavstool more stable. Fixes #46864 Signed-off-by: Enjia Mai <enjia.mai@intel.com>
This commit is contained in:
parent
fe9526313b
commit
faff3f7ecc
2 changed files with 135 additions and 76 deletions
|
@ -14,6 +14,7 @@ import argparse
|
||||||
import socketserver
|
import socketserver
|
||||||
import threading
|
import threading
|
||||||
import netifaces
|
import netifaces
|
||||||
|
import hashlib
|
||||||
|
|
||||||
# Global variable use to sync between log and request services.
|
# Global variable use to sync between log and request services.
|
||||||
# When it is true, the adsp is able to start running.
|
# When it is true, the adsp is able to start running.
|
||||||
|
@ -24,9 +25,17 @@ HOST = None
|
||||||
PORT_LOG = 9999
|
PORT_LOG = 9999
|
||||||
PORT_REQ = PORT_LOG + 1
|
PORT_REQ = PORT_LOG + 1
|
||||||
BUF_SIZE = 4096
|
BUF_SIZE = 4096
|
||||||
|
|
||||||
|
# Define the command and the max size
|
||||||
CMD_LOG_START = "start_log"
|
CMD_LOG_START = "start_log"
|
||||||
CMD_LOG_STOP = "stop_log"
|
|
||||||
CMD_DOWNLOAD = "download"
|
CMD_DOWNLOAD = "download"
|
||||||
|
MAX_CMD_SZ = 16
|
||||||
|
|
||||||
|
# Define the header format and size for
|
||||||
|
# transmiting the firmware
|
||||||
|
PACKET_HEADER_FORMAT_FW = 'I 42s 32s'
|
||||||
|
HEADER_SZ = 78
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig()
|
logging.basicConfig()
|
||||||
log = logging.getLogger("cavs-fw")
|
log = logging.getLogger("cavs-fw")
|
||||||
|
@ -625,6 +634,7 @@ def ipc_command(data, ext_data):
|
||||||
async def _main(server):
|
async def _main(server):
|
||||||
#TODO this bit me, remove the globals, write a little FirmwareLoader class or something to contain.
|
#TODO this bit me, remove the globals, write a little FirmwareLoader class or something to contain.
|
||||||
global hda, sd, dsp, hda_ostream_id, hda_streams
|
global hda, sd, dsp, hda_ostream_id, hda_streams
|
||||||
|
global start_output
|
||||||
try:
|
try:
|
||||||
(hda, sd, dsp, hda_ostream_id) = map_regs()
|
(hda, sd, dsp, hda_ostream_id) = map_regs()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -659,67 +669,81 @@ async def _main(server):
|
||||||
if dsp.HIPCTDR & 0x80000000:
|
if dsp.HIPCTDR & 0x80000000:
|
||||||
ipc_command(dsp.HIPCTDR & ~0x80000000, dsp.HIPCTDD)
|
ipc_command(dsp.HIPCTDR & ~0x80000000, dsp.HIPCTDD)
|
||||||
|
|
||||||
|
if server:
|
||||||
|
# Check if the client connection is alive.
|
||||||
|
if not is_connection_alive(server):
|
||||||
|
lock.acquire()
|
||||||
|
start_output = False
|
||||||
|
lock.release()
|
||||||
|
|
||||||
class adsp_request_handler(socketserver.BaseRequestHandler):
|
class adsp_request_handler(socketserver.BaseRequestHandler):
|
||||||
"""
|
"""
|
||||||
The request handler class for control the actions of server.
|
The request handler class for control the actions of server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def receive_fw(self, filename):
|
def receive_fw(self):
|
||||||
try:
|
log.info("Receiving...")
|
||||||
with open(fw_file,'wb') as f:
|
# Receive the header first
|
||||||
cnt = 0
|
d = self.request.recv(HEADER_SZ)
|
||||||
log.info("Receiving...")
|
|
||||||
|
|
||||||
while True:
|
# Unpacked the header data
|
||||||
l = self.request.recv(BUF_SIZE)
|
# Include size(4), filename(42) and MD5(32)
|
||||||
ll = len(l)
|
header = d[:HEADER_SZ]
|
||||||
cnt = cnt + ll
|
total = d[HEADER_SZ:]
|
||||||
if not l:
|
s = struct.Struct(PACKET_HEADER_FORMAT_FW)
|
||||||
break
|
fsize, fname, md5_tx_b = s.unpack(header)
|
||||||
else:
|
log.info(f'size:{fsize}, filename:{fname}, MD5:{md5_tx_b}')
|
||||||
f.write(l)
|
|
||||||
|
# Receive the firmware. We only receive the specified amount of bytes.
|
||||||
|
while len(total) < fsize:
|
||||||
|
data = self.request.recv(min(BUF_SIZE, fsize - len(total)))
|
||||||
|
if not data:
|
||||||
|
raise EOFError("truncated firmware file")
|
||||||
|
total += data
|
||||||
|
|
||||||
|
log.info(f"Done Receiving {len(total)}.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(fname,'wb') as f:
|
||||||
|
f.write(total)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Get exception {e} during FW transfer.")
|
log.error(f"Get exception {e} during FW transfer.")
|
||||||
return 1
|
return None
|
||||||
|
|
||||||
log.info(f"Done Receiving {cnt}.")
|
# Check the MD5 of the firmware
|
||||||
|
md5_rx = hashlib.md5(total).hexdigest()
|
||||||
|
md5_tx = md5_tx_b.decode('utf-8')
|
||||||
|
|
||||||
|
if md5_tx != md5_rx:
|
||||||
|
log.error(f'MD5 mismatch: {md5_tx} vs. {md5_rx}')
|
||||||
|
return None
|
||||||
|
|
||||||
|
return fname
|
||||||
|
|
||||||
def handle(self):
|
def handle(self):
|
||||||
global start_output, fw_file
|
global start_output, fw_file
|
||||||
|
|
||||||
cmd = self.request.recv(BUF_SIZE)
|
cmd = self.request.recv(MAX_CMD_SZ)
|
||||||
log.info(f"{self.client_address[0]} wrote: {cmd}")
|
log.info(f"{self.client_address[0]} wrote: {cmd}")
|
||||||
action = cmd.decode("utf-8")
|
action = cmd.decode("utf-8")
|
||||||
log.debug(f'load {action}')
|
log.debug(f'load {action}')
|
||||||
|
|
||||||
if action == CMD_DOWNLOAD:
|
if action == CMD_DOWNLOAD:
|
||||||
self.request.sendall(cmd)
|
self.request.sendall(cmd)
|
||||||
recv_fn = self.request.recv(BUF_SIZE)
|
recv_file = self.receive_fw()
|
||||||
log.info(f"{self.client_address[0]} wrote: {recv_fn}")
|
|
||||||
|
|
||||||
try:
|
if recv_file:
|
||||||
tmp_file = recv_fn.decode("utf-8")
|
self.request.sendall("success".encode('utf-8'))
|
||||||
except UnicodeDecodeError:
|
log.info("Firmware well received. Ready to download.")
|
||||||
tmp_file = "zephyr.ri.decode_error"
|
else:
|
||||||
log.info(f'did not receive a correct filename')
|
self.request.sendall("failed".encode('utf-8'))
|
||||||
|
log.error("Receive firmware failed.")
|
||||||
|
|
||||||
lock.acquire()
|
lock.acquire()
|
||||||
fw_file = tmp_file
|
fw_file = recv_file
|
||||||
ret = self.receive_fw(fw_file)
|
start_output = True
|
||||||
if not ret:
|
|
||||||
start_output = True
|
|
||||||
lock.release()
|
lock.release()
|
||||||
|
|
||||||
log.debug(f'{recv_fn}, {fw_file}, {start_output}')
|
|
||||||
|
|
||||||
elif action == CMD_LOG_STOP:
|
|
||||||
self.request.sendall(cmd)
|
|
||||||
lock.acquire()
|
|
||||||
start_output = False
|
|
||||||
if fw_file:
|
|
||||||
os.remove(fw_file)
|
|
||||||
fw_file = None
|
|
||||||
lock.release()
|
|
||||||
else:
|
else:
|
||||||
log.error("incorrect load communitcation!")
|
log.error("incorrect load communitcation!")
|
||||||
|
|
||||||
|
@ -732,35 +756,50 @@ class adsp_log_handler(socketserver.BaseRequestHandler):
|
||||||
self.loop.run_until_complete(_main(self))
|
self.loop.run_until_complete(_main(self))
|
||||||
|
|
||||||
def handle(self):
|
def handle(self):
|
||||||
global start_output, fw_file
|
cmd = self.request.recv(MAX_CMD_SZ)
|
||||||
|
|
||||||
cmd = self.request.recv(BUF_SIZE)
|
|
||||||
log.info(f"{self.client_address[0]} wrote: {cmd}")
|
log.info(f"{self.client_address[0]} wrote: {cmd}")
|
||||||
action = cmd.decode("utf-8")
|
action = cmd.decode("utf-8")
|
||||||
log.debug(f'monitor {action}')
|
log.debug(f'monitor {action}')
|
||||||
|
|
||||||
if action == CMD_LOG_START:
|
if action == CMD_LOG_START:
|
||||||
|
global start_output, fw_file
|
||||||
|
|
||||||
self.request.sendall(cmd)
|
self.request.sendall(cmd)
|
||||||
|
|
||||||
log.info(f"Waiting for instruction...")
|
log.info(f"Waiting for instruction...")
|
||||||
while start_output is False:
|
while start_output is False:
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
if not is_connection_alive(self):
|
||||||
|
break
|
||||||
|
|
||||||
log.info(f"Loaded FW {fw_file} and running...")
|
if fw_file:
|
||||||
if os.path.exists(fw_file):
|
log.info(f"Loaded FW {fw_file} and running...")
|
||||||
self.run_adsp()
|
if os.path.exists(fw_file):
|
||||||
self.request.sendall("service complete.".encode())
|
self.run_adsp()
|
||||||
log.info("service complete.")
|
log.info("service complete.")
|
||||||
else:
|
else:
|
||||||
log.error("cannot find the FW file")
|
log.error("Cannot find the FW file.")
|
||||||
|
|
||||||
lock.acquire()
|
lock.acquire()
|
||||||
fw_file = None
|
|
||||||
start_output = False
|
start_output = False
|
||||||
|
if fw_file:
|
||||||
|
os.remove(fw_file)
|
||||||
|
fw_file = None
|
||||||
lock.release()
|
lock.release()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
log.error("incorrect monitor communitcation!")
|
log.error("incorrect monitor communitcation!")
|
||||||
|
|
||||||
|
log.info("Wait for next service...")
|
||||||
|
|
||||||
|
def is_connection_alive(server):
|
||||||
|
try:
|
||||||
|
server.request.sendall(b' ')
|
||||||
|
except (BrokenPipeError, ConnectionResetError):
|
||||||
|
log.info("Client is disconnect.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def adsp_log(output, server):
|
def adsp_log(output, server):
|
||||||
if server:
|
if server:
|
||||||
|
@ -794,7 +833,7 @@ ap.add_argument("-l", "--log-only", action="store_true",
|
||||||
ap.add_argument("-n", "--no-history", action="store_true",
|
ap.add_argument("-n", "--no-history", action="store_true",
|
||||||
help="No current log buffer at start, just new output")
|
help="No current log buffer at start, just new output")
|
||||||
ap.add_argument("-s", "--server-addr",
|
ap.add_argument("-s", "--server-addr",
|
||||||
help="No current log buffer at start, just new output")
|
help="Specify the IP address that the server to active")
|
||||||
ap.add_argument("fw_file", nargs="?", help="Firmware file")
|
ap.add_argument("fw_file", nargs="?", help="Firmware file")
|
||||||
|
|
||||||
args = ap.parse_args()
|
args = ap.parse_args()
|
||||||
|
|
|
@ -2,20 +2,29 @@
|
||||||
# Copyright(c) 2022 Intel Corporation. All rights reserved.
|
# Copyright(c) 2022 Intel Corporation. All rights reserved.
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import argparse
|
import argparse
|
||||||
import socket
|
import socket
|
||||||
import signal
|
import struct
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
RET = 0
|
||||||
HOST = None
|
HOST = None
|
||||||
PORT_LOG = 9999
|
PORT_LOG = 9999
|
||||||
PORT_REQ = PORT_LOG + 1
|
PORT_REQ = PORT_LOG + 1
|
||||||
BUF_SIZE = 4096
|
BUF_SIZE = 4096
|
||||||
|
|
||||||
|
# Define the command and its
|
||||||
|
# possible max size
|
||||||
CMD_LOG_START = "start_log"
|
CMD_LOG_START = "start_log"
|
||||||
CMD_LOG_STOP = "stop_log"
|
|
||||||
CMD_DOWNLOAD = "download"
|
CMD_DOWNLOAD = "download"
|
||||||
|
MAX_CMD_SZ = 16
|
||||||
|
|
||||||
|
# Define the header format and size for
|
||||||
|
# transmiting the firmware
|
||||||
|
PACKET_HEADER_FORMAT_FW = 'I 42s 32s'
|
||||||
|
|
||||||
logging.basicConfig()
|
logging.basicConfig()
|
||||||
log = logging.getLogger("cavs-client")
|
log = logging.getLogger("cavs-client")
|
||||||
|
@ -36,32 +45,52 @@ class cavstool_client():
|
||||||
self.sock.connect((self.host, self.port))
|
self.sock.connect((self.host, self.port))
|
||||||
self.sock.sendall(cmd.encode("utf-8"))
|
self.sock.sendall(cmd.encode("utf-8"))
|
||||||
log.info(f"Sent: {cmd}")
|
log.info(f"Sent: {cmd}")
|
||||||
ack = str(self.sock.recv(BUF_SIZE), "utf-8")
|
ack = str(self.sock.recv(MAX_CMD_SZ), "utf-8")
|
||||||
log.info(f"Receive: {ack}")
|
log.info(f"Receive: {ack}")
|
||||||
|
|
||||||
if ack == CMD_LOG_START:
|
if ack == CMD_LOG_START:
|
||||||
self.monitor_log()
|
self.monitor_log()
|
||||||
elif ack == CMD_LOG_STOP:
|
|
||||||
log.info(f"Stop output.")
|
|
||||||
elif ack == CMD_DOWNLOAD:
|
elif ack == CMD_DOWNLOAD:
|
||||||
self.run()
|
self.run()
|
||||||
else:
|
else:
|
||||||
log.error(f"Receive incorrect msg:{ack} expect:{cmd}")
|
log.error(f"Receive incorrect msg:{ack} expect:{cmd}")
|
||||||
|
|
||||||
def download(self, filename):
|
def uploading(self, filename):
|
||||||
# Send the FW to server
|
# Send the FW to server
|
||||||
|
fname = os.path.basename(filename)
|
||||||
|
fsize = os.path.getsize(filename)
|
||||||
|
|
||||||
|
md5_tx = hashlib.md5(open(filename,'rb').read()).hexdigest()
|
||||||
|
|
||||||
|
# Pack the header and the expecting packed size is 78 bytes.
|
||||||
|
# The header by convention includes:
|
||||||
|
# size(4), filename(42), MD5(32)
|
||||||
|
values = (fsize, fname.encode('utf-8'), md5_tx.encode('utf-8'))
|
||||||
|
log.info(f'filename:{fname}, size:{fsize}, md5:{md5_tx}')
|
||||||
|
|
||||||
|
s = struct.Struct(PACKET_HEADER_FORMAT_FW)
|
||||||
|
header_data = s.pack(*values)
|
||||||
|
header_size = s.size
|
||||||
|
log.info(f'header size: {header_size}')
|
||||||
|
|
||||||
with open(filename,'rb') as f:
|
with open(filename,'rb') as f:
|
||||||
log.info('Sending...')
|
log.info(f'Sending...')
|
||||||
ret = self.sock.sendfile(f)
|
|
||||||
log.info(f"Done Sending ({ret}).")
|
total = self.sock.send(header_data)
|
||||||
|
total += self.sock.sendfile(f)
|
||||||
|
|
||||||
|
log.info(f"Done Sending ({total}).")
|
||||||
|
|
||||||
|
rck = self.sock.recv(MAX_CMD_SZ).decode("utf-8")
|
||||||
|
log.info(f"RCK ({rck}).")
|
||||||
|
if not rck == "success":
|
||||||
|
global RET
|
||||||
|
RET = -1
|
||||||
|
log.error(f"Firmware uploading failed")
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
filename = str(self.args.fw_file)
|
filename = str(self.args.fw_file)
|
||||||
send_fn = os.path.basename(filename)
|
self.uploading(filename)
|
||||||
self.sock.sendall(send_fn.encode("utf-8"))
|
|
||||||
log.info(f"Sent fw: {send_fn}, {filename}")
|
|
||||||
|
|
||||||
self.download(filename)
|
|
||||||
|
|
||||||
def monitor_log(self):
|
def monitor_log(self):
|
||||||
log.info(f"Start to monitor log output...")
|
log.info(f"Start to monitor log output...")
|
||||||
|
@ -75,28 +104,19 @@ class cavstool_client():
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.sock.close()
|
self.sock.close()
|
||||||
|
|
||||||
def cleanup():
|
|
||||||
client = cavstool_client(HOST, PORT_REQ, args)
|
|
||||||
client.send_cmd(CMD_LOG_STOP)
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
if args.log_only:
|
if args.log_only:
|
||||||
log.info("Monitor process")
|
log.info("Monitor process")
|
||||||
signal.signal(signal.SIGTERM, cleanup)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
client = cavstool_client(HOST, PORT_LOG, args)
|
client = cavstool_client(HOST, PORT_LOG, args)
|
||||||
client.send_cmd(CMD_LOG_START)
|
client.send_cmd(CMD_LOG_START)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
finally:
|
|
||||||
cleanup()
|
|
||||||
|
|
||||||
elif args.kill_cmd:
|
|
||||||
log.info("Stop monitor log")
|
|
||||||
cleanup()
|
|
||||||
else:
|
else:
|
||||||
log.info("Download process")
|
log.info("Uploading process")
|
||||||
client = cavstool_client(HOST, PORT_REQ, args)
|
client = cavstool_client(HOST, PORT_REQ, args)
|
||||||
client.send_cmd(CMD_DOWNLOAD)
|
client.send_cmd(CMD_DOWNLOAD)
|
||||||
|
|
||||||
|
@ -107,8 +127,6 @@ ap.add_argument("-l", "--log-only", action="store_true",
|
||||||
help="Don't load firmware, just show log output")
|
help="Don't load firmware, just show log output")
|
||||||
ap.add_argument("-s", "--server-addr", default="localhost",
|
ap.add_argument("-s", "--server-addr", default="localhost",
|
||||||
help="Specify the adsp server address")
|
help="Specify the adsp server address")
|
||||||
ap.add_argument("-k", "--kill-cmd", action="store_true",
|
|
||||||
help="No current log buffer at start, just new output")
|
|
||||||
ap.add_argument("fw_file", nargs="?", help="Firmware file")
|
ap.add_argument("fw_file", nargs="?", help="Firmware file")
|
||||||
args = ap.parse_args()
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
@ -119,3 +137,5 @@ HOST = args.server_addr
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
sys.exit(RET)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue