diff --git a/samples/net/sockets/big_http_download/Kconfig b/samples/net/sockets/big_http_download/Kconfig index 4f7f5fdbce2..b229074ecda 100644 --- a/samples/net/sockets/big_http_download/Kconfig +++ b/samples/net/sockets/big_http_download/Kconfig @@ -11,6 +11,10 @@ config SAMPLE_BIG_HTTP_DL_URL help URL to download. +config SAMPLE_BIG_HTTP_DL_MAX_URL_LENGTH + int "Maximum length of the download URL" + default 256 + config SAMPLE_BIG_HTTP_DL_NUM_ITER int "Limit number of iterations" default 0 diff --git a/samples/net/sockets/big_http_download/src/big_http_download.c b/samples/net/sockets/big_http_download/src/big_http_download.c index c9ea70ea578..4203de9d2c3 100644 --- a/samples/net/sockets/big_http_download/src/big_http_download.c +++ b/samples/net/sockets/big_http_download/src/big_http_download.c @@ -8,6 +8,7 @@ #include #include #include +#include #include "mbedtls/md.h" @@ -36,8 +37,14 @@ #define bytes2KiB(Bytes) (Bytes / (1024u)) #define bytes2MiB(Bytes) (Bytes / (1024u * 1024u)) +#if defined(CONFIG_SAMPLE_BIG_HTTP_DL_MAX_URL_LENGTH) +#define MAX_URL_LENGTH CONFIG_SAMPLE_BIG_HTTP_DL_MAX_URL_LENGTH +#else +#define MAX_URL_LENGTH 256 +#endif + /* This URL is parsed in-place, so buffer must be non-const. */ -static char download_url[] = +static char download_url[MAX_URL_LENGTH] = #if defined(CONFIG_SAMPLE_BIG_HTTP_DL_URL) CONFIG_SAMPLE_BIG_HTTP_DL_URL; #else @@ -99,30 +106,132 @@ ssize_t sendall(int sock, const void *buf, size_t len) return 0; } -int skip_headers(int sock) +static int parse_status(bool *redirect) +{ + char *ptr; + int code; + + ptr = strstr(response, "HTTP"); + if (ptr == NULL) { + return -1; + } + + ptr = strstr(response, " "); + if (ptr == NULL) { + return -1; + } + + ptr++; + + code = atoi(ptr); + if (code >= 300 && code < 400) { + *redirect = true; + } + + return 0; +} + +static int parse_header(bool *location_found) +{ + char *ptr; + + ptr = strstr(response, ":"); + if (ptr == NULL) { + return 0; + } + + *ptr = '\0'; + ptr = response; + + while (*ptr != '\0') { + *ptr = tolower(*ptr); + ptr++; + } + + if (strcmp(response, "location") != 0) { + return 0; + } + + /* Skip whitespace */ + while (*(++ptr) == ' ') { + ; + } + + strncpy(download_url, ptr, sizeof(download_url)); + download_url[sizeof(download_url) - 1] = '\0'; + + /* Trim LF. */ + ptr = strstr(download_url, "\n"); + if (ptr == NULL) { + printf("Redirect URL too long or malformed\n"); + return -1; + } + + *ptr = '\0'; + + /* Trim CR if present. */ + ptr = strstr(download_url, "\r"); + if (ptr != NULL) { + *ptr = '\0'; + } + + *location_found = true; + + return 0; +} + +int skip_headers(int sock, bool *redirect) { int state = 0; + int i = 0; + bool status_line = true; + bool redirect_code = false; + bool location_found = false; while (1) { - char c; int st; - st = recv(sock, &c, 1, 0); + st = recv(sock, response + i, 1, 0); if (st <= 0) { return st; } - if (state == 0 && c == '\r') { + if (state == 0 && response[i] == '\r') { state++; - } else if (state == 1 && c == '\n') { + } else if ((state == 0 || state == 1) && response[i] == '\n') { + state = 2; + response[i + 1] = '\0'; + i = 0; + + if (status_line) { + if (parse_status(&redirect_code) < 0) { + return -1; + } + + status_line = false; + } else { + if (parse_header(&location_found) < 0) { + return -1; + } + } + + continue; + } else if (state == 2 && response[i] == '\r') { state++; - } else if (state == 2 && c == '\r') { - state++; - } else if (state == 3 && c == '\n') { + } else if ((state == 2 || state == 3) && response[i] == '\n') { break; } else { state = 0; } + + i++; + if (i >= sizeof(response) - 1) { + i = 0; + } + } + + if (redirect_code && location_found) { + *redirect = true; } return 1; @@ -135,7 +244,7 @@ void print_hex(const unsigned char *p, int len) } } -void download(struct addrinfo *ai, bool is_tls) +bool download(struct addrinfo *ai, bool is_tls, bool *redirect) { int sock; struct timeval timeout = { @@ -143,13 +252,14 @@ void download(struct addrinfo *ai, bool is_tls) }; cur_bytes = 0U; + *redirect = false; if (is_tls) { #if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) sock = socket(ai->ai_family, ai->ai_socktype, IPPROTO_TLS_1_2); # else printf("TLS not supported\n"); - return; + return false; #endif } else { sock = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol); @@ -184,11 +294,16 @@ void download(struct addrinfo *ai, bool is_tls) sendall(sock, host, strlen(host)); sendall(sock, "\r\n\r\n", SSTRLEN("\r\n\r\n")); - if (skip_headers(sock) <= 0) { + if (skip_headers(sock, redirect) <= 0) { printf("EOF or error in response headers\n"); goto error; } + if (*redirect) { + printf("Server requested redirection to %s\n", download_url); + goto error; + } + mbedtls_md_starts(&hash_ctx); while (1) { @@ -233,6 +348,8 @@ void download(struct addrinfo *ai, bool is_tls) error: (void)close(sock); + + return redirect; } void main(void) @@ -245,6 +362,7 @@ void main(void) int resolve_attempts = 10; bool is_tls = false; unsigned int num_iterations = CONFIG_SAMPLE_BIG_HTTP_DL_NUM_ITER; + bool redirect = false; #if defined(CONFIG_NET_SOCKETS_SOCKOPT_TLS) tls_credential_add(CA_CERTIFICATE_TAG, TLS_CREDENTIAL_CA_CERTIFICATE, @@ -253,6 +371,7 @@ void main(void) setbuf(stdout, NULL); +redirect: if (strncmp(download_url, "http://", SSTRLEN("http://")) == 0) { port = "80"; p = download_url + SSTRLEN("http://"); @@ -335,7 +454,11 @@ void main(void) printf("\nIteration %u of %u:\n", current_iteration, total_iterations); } - download(res, is_tls); + + download(res, is_tls, &redirect); + if (redirect) { + goto redirect; + } total_bytes += cur_bytes; printf("Total downloaded so far: %u MiB\n",