diff --git a/drivers/wifi/esp_at/esp.c b/drivers/wifi/esp_at/esp.c index bfc8b448356..6b63d3b3b15 100644 --- a/drivers/wifi/esp_at/esp.c +++ b/drivers/wifi/esp_at/esp.c @@ -996,11 +996,71 @@ out: esp_flags_clear(dev, EDF_STA_CONNECTING); } +static int esp_conn_cmd_append(struct esp_data *data, size_t *off, + const char *chunk, size_t chunk_len) +{ + char *str_end = &data->conn_cmd[sizeof(data->conn_cmd)]; + char *str = &data->conn_cmd[*off]; + const char *chunk_end = chunk + chunk_len; + + for (; chunk < chunk_end; chunk++) { + if (str_end - str < 1) { + return -ENOSPC; + } + + *str = *chunk; + str++; + } + + *off = str - data->conn_cmd; + + return 0; +} + +#define esp_conn_cmd_append_literal(data, off, chunk) \ + esp_conn_cmd_append(data, off, chunk, sizeof(chunk) - 1) + +static int esp_conn_cmd_escape_and_append(struct esp_data *data, size_t *off, + const char *chunk, size_t chunk_len) +{ + char *str_end = &data->conn_cmd[sizeof(data->conn_cmd)]; + char *str = &data->conn_cmd[*off]; + const char *chunk_end = chunk + chunk_len; + + for (; chunk < chunk_end; chunk++) { + switch (*chunk) { + case ',': + case '\\': + case '"': + if (str_end - str < 2) { + return -ENOSPC; + } + + *str = '\\'; + str++; + + break; + } + + if (str_end - str < 1) { + return -ENOSPC; + } + + *str = *chunk; + str++; + } + + *off = str - data->conn_cmd; + + return 0; +} + static int esp_mgmt_connect(const struct device *dev, struct wifi_connect_req_params *params) { struct esp_data *data = dev->data; - int len; + size_t off = 0; + int err; if (!net_if_is_carrier_ok(data->net_iface) || !net_if_is_admin_up(data->net_iface)) { @@ -1013,21 +1073,34 @@ static int esp_mgmt_connect(const struct device *dev, esp_flags_set(data, EDF_STA_CONNECTING); - len = snprintk(data->conn_cmd, sizeof(data->conn_cmd), - "AT+"_CWJAP"=\""); - memcpy(&data->conn_cmd[len], params->ssid, params->ssid_length); - len += params->ssid_length; - - len += snprintk(&data->conn_cmd[len], - sizeof(data->conn_cmd) - len, "\",\""); - - if (params->security == WIFI_SECURITY_TYPE_PSK) { - memcpy(&data->conn_cmd[len], params->psk, params->psk_length); - len += params->psk_length; + err = esp_conn_cmd_append_literal(data, &off, "AT+"_CWJAP"=\""); + if (err) { + return err; } - len += snprintk(&data->conn_cmd[len], sizeof(data->conn_cmd) - len, - "\""); + err = esp_conn_cmd_escape_and_append(data, &off, + params->ssid, params->ssid_length); + if (err) { + return err; + } + + err = esp_conn_cmd_append_literal(data, &off, "\",\""); + if (err) { + return err; + } + + if (params->security == WIFI_SECURITY_TYPE_PSK) { + err = esp_conn_cmd_escape_and_append(data, &off, + params->psk, params->psk_length); + if (err) { + return err; + } + } + + err = esp_conn_cmd_append_literal(data, &off, "\""); + if (err) { + return err; + } k_work_submit_to_queue(&data->workq, &data->connect_work); diff --git a/drivers/wifi/esp_at/esp.h b/drivers/wifi/esp_at/esp.h index bda67cb7ebf..c775bf1f14a 100644 --- a/drivers/wifi/esp_at/esp.h +++ b/drivers/wifi/esp_at/esp.h @@ -77,7 +77,7 @@ extern "C" { STRINGIFY(_UART_BAUD)",8,1,0,"_FLOW_CONTROL #define CONN_CMD_MAX_LEN (sizeof("AT+"_CWJAP"=\"\",\"\"") + \ - WIFI_SSID_MAX_LEN + WIFI_PSK_MAX_LEN) + WIFI_SSID_MAX_LEN * 2 + WIFI_PSK_MAX_LEN * 2) #if defined(CONFIG_WIFI_ESP_AT_DNS_USE) #define ESP_MAX_DNS MIN(3, CONFIG_DNS_RESOLVER_MAX_SERVERS)