net: socket: Implement SO_BINDTODEVICE socket option

Implement SO_BINDTODEVICE socket option which allows to bind an open
socket to a particular network interface. Once bound, the socket will
only send and receive packets through that interface.

For the TX path, simply avoid overwriting the interface pointer by
net_context_bind() in case it's already bound to an interface with an
option. For the RX path, drop the packet in case the connection handler
detects that the net_context associated with that connection is bound to
a different interface that the packet origin interface.

Signed-off-by: Robert Lubos <robert.lubos@nordicsemi.no>
This commit is contained in:
Robert Lubos 2021-03-24 09:46:24 +01:00 committed by Anas Nashif
commit 814fb71bf3
12 changed files with 124 additions and 19 deletions

View file

@ -65,6 +65,9 @@ enum net_context_state {
/** Is the socket closing / closed */ /** Is the socket closing / closed */
#define NET_CONTEXT_CLOSING_SOCK BIT(10) #define NET_CONTEXT_CLOSING_SOCK BIT(10)
/* Context is bound to a specific interface */
#define NET_CONTEXT_BOUND_TO_IFACE BIT(11)
struct net_context; struct net_context;
/** /**
@ -336,6 +339,13 @@ static inline bool net_context_is_used(struct net_context *context)
return context->flags & NET_CONTEXT_IN_USE; return context->flags & NET_CONTEXT_IN_USE;
} }
static inline bool net_context_is_bound_to_iface(struct net_context *context)
{
NET_ASSERT(context);
return context->flags & NET_CONTEXT_BOUND_TO_IFACE;
}
/** /**
* @brief Is this context is accepting data now. * @brief Is this context is accepting data now.
* *

View file

@ -803,6 +803,13 @@ static inline char *inet_ntop(sa_family_t family, const void *src, char *dst,
#define EAI_FAMILY DNS_EAI_FAMILY #define EAI_FAMILY DNS_EAI_FAMILY
#endif /* defined(CONFIG_NET_SOCKETS_POSIX_NAMES) */ #endif /* defined(CONFIG_NET_SOCKETS_POSIX_NAMES) */
#define IFNAMSIZ Z_DEVICE_MAX_NAME_LEN
/** Interface description structure */
struct ifreq {
char ifr_name[IFNAMSIZ]; /* Interface name */
};
/** sockopt: Socket-level option */ /** sockopt: Socket-level option */
#define SOL_SOCKET 1 #define SOL_SOCKET 1
@ -822,6 +829,9 @@ static inline char *inet_ntop(sa_family_t family, const void *src, char *dst,
/** sockopt: Send timeout */ /** sockopt: Send timeout */
#define SO_SNDTIMEO 21 #define SO_SNDTIMEO 21
/** sockopt: Bind a socket to an interface */
#define SO_BINDTODEVICE 25
/** sockopt: Timestamp TX packets */ /** sockopt: Timestamp TX packets */
#define SO_TIMESTAMPING 37 #define SO_TIMESTAMPING 37
/** sockopt: Protocol used with the socket */ /** sockopt: Protocol used with the socket */

View file

@ -245,6 +245,7 @@ int net_conn_register(uint16_t proto, uint8_t family,
const struct sockaddr *local_addr, const struct sockaddr *local_addr,
uint16_t remote_port, uint16_t remote_port,
uint16_t local_port, uint16_t local_port,
struct net_context *context,
net_conn_cb_t cb, net_conn_cb_t cb,
void *user_data, void *user_data,
struct net_conn_handle **handle) struct net_conn_handle **handle)
@ -348,6 +349,7 @@ int net_conn_register(uint16_t proto, uint8_t family,
conn->flags = flags; conn->flags = flags;
conn->proto = proto; conn->proto = proto;
conn->family = family; conn->family = family;
conn->context = context;
if (handle) { if (handle) {
*handle = (struct net_conn_handle *)conn; *handle = (struct net_conn_handle *)conn;
@ -622,6 +624,12 @@ enum net_verdict net_conn_input(struct net_pkt *pkt,
} }
SYS_SLIST_FOR_EACH_CONTAINER(&conn_used, conn, node) { SYS_SLIST_FOR_EACH_CONTAINER(&conn_used, conn, node) {
if (conn->context != NULL &&
net_context_is_bound_to_iface(conn->context) &&
net_pkt_iface(pkt) != net_context_get_iface(conn->context)) {
continue;
}
/* For packet socket data, the proto is set to ETH_P_ALL or IPPROTO_RAW /* For packet socket data, the proto is set to ETH_P_ALL or IPPROTO_RAW
* but the listener might have a specific protocol set. This is ok * but the listener might have a specific protocol set. This is ok
* and let the packet pass this check in this case. * and let the packet pass this check in this case.

View file

@ -17,6 +17,7 @@
#include <sys/util.h> #include <sys/util.h>
#include <net/net_context.h>
#include <net/net_core.h> #include <net/net_core.h>
#include <net/net_ip.h> #include <net/net_ip.h>
#include <net/net_pkt.h> #include <net/net_pkt.h>
@ -59,6 +60,11 @@ struct net_conn {
/** Callback to be called when matching UDP packet is received */ /** Callback to be called when matching UDP packet is received */
net_conn_cb_t cb; net_conn_cb_t cb;
/** A pointer to the net_context corresponding to the connection.
* Can be NULL if no net_context is associated.
*/
struct net_context *context;
/** Possible user to pass to the callback */ /** Possible user to pass to the callback */
void *user_data; void *user_data;
@ -83,6 +89,7 @@ struct net_conn {
* @param remote_port Remote port of the connection end point. * @param remote_port Remote port of the connection end point.
* @param local_port Local port of the connection end point. * @param local_port Local port of the connection end point.
* @param cb Callback to be called * @param cb Callback to be called
* @param context net_context structure related to the connection.
* @param user_data User data supplied by caller. * @param user_data User data supplied by caller.
* @param handle Connection handle that can be used when unregistering * @param handle Connection handle that can be used when unregistering
* *
@ -94,6 +101,7 @@ int net_conn_register(uint16_t proto, uint8_t family,
const struct sockaddr *local_addr, const struct sockaddr *local_addr,
uint16_t remote_port, uint16_t remote_port,
uint16_t local_port, uint16_t local_port,
struct net_context *context,
net_conn_cb_t cb, net_conn_cb_t cb,
void *user_data, void *user_data,
struct net_conn_handle **handle); struct net_conn_handle **handle);
@ -103,6 +111,7 @@ static inline int net_conn_register(uint16_t proto, uint8_t family,
const struct sockaddr *local_addr, const struct sockaddr *local_addr,
uint16_t remote_port, uint16_t remote_port,
uint16_t local_port, uint16_t local_port,
struct net_context *context,
net_conn_cb_t cb, net_conn_cb_t cb,
void *user_data, void *user_data,
struct net_conn_handle **handle) struct net_conn_handle **handle)
@ -114,6 +123,7 @@ static inline int net_conn_register(uint16_t proto, uint8_t family,
ARG_UNUSED(remote_port); ARG_UNUSED(remote_port);
ARG_UNUSED(local_port); ARG_UNUSED(local_port);
ARG_UNUSED(cb); ARG_UNUSED(cb);
ARG_UNUSED(context);
ARG_UNUSED(user_data); ARG_UNUSED(user_data);
ARG_UNUSED(handle); ARG_UNUSED(handle);

View file

@ -1265,7 +1265,7 @@ int net_dhcpv4_init(void)
ret = net_udp_register(AF_INET, NULL, &local_addr, ret = net_udp_register(AF_INET, NULL, &local_addr,
DHCPV4_SERVER_PORT, DHCPV4_SERVER_PORT,
DHCPV4_CLIENT_PORT, DHCPV4_CLIENT_PORT,
net_dhcpv4_input, NULL, NULL); NULL, net_dhcpv4_input, NULL, NULL);
if (ret < 0) { if (ret < 0) {
NET_DBG("UDP callback registration failed"); NET_DBG("UDP callback registration failed");
return ret; return ret;

View file

@ -528,6 +528,10 @@ int net_context_bind(struct net_context *context, const struct sockaddr *addr,
return -EINVAL; return -EINVAL;
} }
if (net_context_is_bound_to_iface(context)) {
iface = net_context_get_iface(context);
}
if (net_ipv6_is_addr_mcast(&addr6->sin6_addr)) { if (net_ipv6_is_addr_mcast(&addr6->sin6_addr)) {
struct net_if_mcast_addr *maddr; struct net_if_mcast_addr *maddr;
@ -540,15 +544,18 @@ int net_context_bind(struct net_context *context, const struct sockaddr *addr,
ptr = &maddr->address.in6_addr; ptr = &maddr->address.in6_addr;
} else if (net_ipv6_is_addr_unspecified(&addr6->sin6_addr)) { } else if (net_ipv6_is_addr_unspecified(&addr6->sin6_addr)) {
iface = net_if_ipv6_select_src_iface( if (iface == NULL) {
&net_sin6(&context->remote)->sin6_addr); iface = net_if_ipv6_select_src_iface(
&net_sin6(&context->remote)->sin6_addr);
}
ptr = (struct in6_addr *)net_ipv6_unspecified_address(); ptr = (struct in6_addr *)net_ipv6_unspecified_address();
} else { } else {
struct net_if_addr *ifaddr; struct net_if_addr *ifaddr;
ifaddr = net_if_ipv6_addr_lookup(&addr6->sin6_addr, ifaddr = net_if_ipv6_addr_lookup(
&iface); &addr6->sin6_addr,
iface == NULL ? &iface : NULL);
if (!ifaddr) { if (!ifaddr) {
return -ENOENT; return -ENOENT;
} }
@ -622,6 +629,10 @@ int net_context_bind(struct net_context *context, const struct sockaddr *addr,
return -EINVAL; return -EINVAL;
} }
if (net_context_is_bound_to_iface(context)) {
iface = net_context_get_iface(context);
}
if (net_ipv4_is_addr_mcast(&addr4->sin_addr)) { if (net_ipv4_is_addr_mcast(&addr4->sin_addr)) {
struct net_if_mcast_addr *maddr; struct net_if_mcast_addr *maddr;
@ -634,13 +645,16 @@ int net_context_bind(struct net_context *context, const struct sockaddr *addr,
ptr = &maddr->address.in_addr; ptr = &maddr->address.in_addr;
} else if (addr4->sin_addr.s_addr == INADDR_ANY) { } else if (addr4->sin_addr.s_addr == INADDR_ANY) {
iface = net_if_ipv4_select_src_iface( if (iface == NULL) {
&net_sin(&context->remote)->sin_addr); iface = net_if_ipv4_select_src_iface(
&net_sin(&context->remote)->sin_addr);
}
ptr = (struct in_addr *)net_ipv4_unspecified_address(); ptr = (struct in_addr *)net_ipv4_unspecified_address();
} else { } else {
ifaddr = net_if_ipv4_addr_lookup(&addr4->sin_addr, ifaddr = net_if_ipv4_addr_lookup(
&iface); &addr4->sin_addr,
iface == NULL ? &iface : NULL);
if (!ifaddr) { if (!ifaddr) {
return -ENOENT; return -ENOENT;
} }
@ -1473,7 +1487,8 @@ static int context_sendto(struct net_context *context,
* second or later network interface. * second or later network interface.
*/ */
if (net_ipv6_is_addr_unspecified( if (net_ipv6_is_addr_unspecified(
&net_sin6(&context->remote)->sin6_addr)) { &net_sin6(&context->remote)->sin6_addr) &&
!net_context_is_bound_to_iface(context)) {
iface = net_if_ipv6_select_src_iface(&addr6->sin6_addr); iface = net_if_ipv6_select_src_iface(&addr6->sin6_addr);
net_context_set_iface(context, iface); net_context_set_iface(context, iface);
} }
@ -1512,7 +1527,8 @@ static int context_sendto(struct net_context *context,
* network interfaces and we are trying to send data to * network interfaces and we are trying to send data to
* second or later network interface. * second or later network interface.
*/ */
if (net_sin(&context->remote)->sin_addr.s_addr == 0U) { if (net_sin(&context->remote)->sin_addr.s_addr == 0U &&
!net_context_is_bound_to_iface(context)) {
iface = net_if_ipv4_select_src_iface(&addr4->sin_addr); iface = net_if_ipv4_select_src_iface(&addr4->sin_addr);
net_context_set_iface(context, iface); net_context_set_iface(context, iface);
} }
@ -1962,6 +1978,7 @@ static int recv_udp(struct net_context *context,
laddr, laddr,
ntohs(net_sin(&context->remote)->sin_port), ntohs(net_sin(&context->remote)->sin_port),
ntohs(lport), ntohs(lport),
context,
net_context_packet_received, net_context_packet_received,
user_data, user_data,
&context->conn_handler); &context->conn_handler);
@ -2029,6 +2046,7 @@ static int recv_raw(struct net_context *context,
ret = net_conn_register(net_context_get_ip_proto(context), ret = net_conn_register(net_context_get_ip_proto(context),
net_context_get_family(context), net_context_get_family(context),
NULL, local_addr, 0, 0, NULL, local_addr, 0, 0,
context,
net_context_raw_packet_received, net_context_raw_packet_received,
user_data, user_data,
&context->conn_handler); &context->conn_handler);

View file

@ -1476,7 +1476,7 @@ static struct tcp *tcp_conn_new(struct net_pkt *pkt)
&context->remote, &local_addr, &context->remote, &local_addr,
ntohs(conn->dst.sin.sin_port),/* local port */ ntohs(conn->dst.sin.sin_port),/* local port */
ntohs(conn->src.sin.sin_port),/* remote port */ ntohs(conn->src.sin.sin_port),/* remote port */
tcp_recv, context, context, tcp_recv, context,
&context->conn_handler); &context->conn_handler);
if (ret < 0) { if (ret < 0) {
NET_ERR("net_conn_register(): %d", ret); NET_ERR("net_conn_register(): %d", ret);
@ -2235,7 +2235,7 @@ int net_tcp_connect(struct net_context *context,
net_context_get_family(context), net_context_get_family(context),
remote_addr, local_addr, remote_addr, local_addr,
ntohs(remote_port), ntohs(local_port), ntohs(remote_port), ntohs(local_port),
tcp_recv, context, context, tcp_recv, context,
&context->conn_handler); &context->conn_handler);
if (ret < 0) { if (ret < 0) {
goto out; goto out;
@ -2335,7 +2335,7 @@ int net_tcp_accept(struct net_context *context, net_tcp_accept_cb_t cb,
&context->remote : NULL, &context->remote : NULL,
&local_addr, &local_addr,
remote_port, local_port, remote_port, local_port,
tcp_recv, context, context, tcp_recv, context,
&context->conn_handler); &context->conn_handler);
} }
@ -2636,6 +2636,7 @@ static void test_cb_register(sa_family_t family, uint8_t proto, uint16_t remote_
&addr, /* local address */ &addr, /* local address */
local_port, local_port,
remote_port, remote_port,
NULL,
cb, cb,
NULL, /* user_data */ NULL, /* user_data */
&conn_handle); &conn_handle);

View file

@ -129,13 +129,14 @@ int net_udp_register(uint8_t family,
const struct sockaddr *local_addr, const struct sockaddr *local_addr,
uint16_t remote_port, uint16_t remote_port,
uint16_t local_port, uint16_t local_port,
struct net_context *context,
net_conn_cb_t cb, net_conn_cb_t cb,
void *user_data, void *user_data,
struct net_conn_handle **handle) struct net_conn_handle **handle)
{ {
return net_conn_register(IPPROTO_UDP, family, remote_addr, local_addr, return net_conn_register(IPPROTO_UDP, family, remote_addr, local_addr,
remote_port, local_port, cb, user_data, remote_port, local_port, context, cb,
handle); user_data, handle);
} }
int net_udp_unregister(struct net_conn_handle *handle) int net_udp_unregister(struct net_conn_handle *handle)

View file

@ -105,6 +105,7 @@ struct net_udp_hdr *net_udp_input(struct net_pkt *pkt,
* @param local_addr Local address of the connection end point. * @param local_addr Local address of the connection end point.
* @param remote_port Remote port of the connection end point. * @param remote_port Remote port of the connection end point.
* @param local_port Local port of the connection end point. * @param local_port Local port of the connection end point.
* @param context net_context structure related to the connection.
* @param cb Callback to be called * @param cb Callback to be called
* @param user_data User data supplied by caller. * @param user_data User data supplied by caller.
* @param handle UDP handle that can be used when unregistering * @param handle UDP handle that can be used when unregistering
@ -116,6 +117,7 @@ int net_udp_register(uint8_t family,
const struct sockaddr *local_addr, const struct sockaddr *local_addr,
uint16_t remote_port, uint16_t remote_port,
uint16_t local_port, uint16_t local_port,
struct net_context *context,
net_conn_cb_t cb, net_conn_cb_t cb,
void *user_data, void *user_data,
struct net_conn_handle **handle); struct net_conn_handle **handle);

View file

@ -1764,6 +1764,50 @@ int zsock_setsockopt_ctx(struct net_context *ctx, int level, int optname,
} }
break; break;
case SO_BINDTODEVICE: {
struct net_if *iface;
const struct device *dev;
struct ifreq *ifreq = (struct ifreq *)optval;
if (net_context_get_family(ctx) != AF_INET &&
net_context_get_family(ctx) != AF_INET6) {
errno = EAFNOSUPPORT;
return -1;
}
/* optlen equal to 0 or empty interface name should
* remove the binding.
*/
if ((optlen == 0) || (ifreq != NULL &&
strlen(ifreq->ifr_name) == 0)) {
ctx->flags &= ~NET_CONTEXT_BOUND_TO_IFACE;
return 0;
}
if ((ifreq == NULL) || (optlen != sizeof(*ifreq))) {
errno = EINVAL;
return -1;
}
dev = device_get_binding(ifreq->ifr_name);
if (dev == NULL) {
errno = ENODEV;
return -1;
}
iface = net_if_lookup_by_dev(dev);
if (iface == NULL) {
errno = ENODEV;
return -1;
}
net_context_set_iface(ctx, iface);
ctx->flags |= NET_CONTEXT_BOUND_TO_IFACE;
return 0;
}
} }
break; break;

View file

@ -1429,7 +1429,7 @@ static void setup_udp_handler(const struct in6_addr *raddr,
remote_addr.sa_family = AF_INET6; remote_addr.sa_family = AF_INET6;
ret = net_udp_register(AF_INET6, &remote_addr, &local_addr, ret = net_udp_register(AF_INET6, &remote_addr, &local_addr,
remote_port, local_port, udp_data_received, remote_port, local_port, NULL, udp_data_received,
NULL, &handle); NULL, &handle);
zassert_equal(ret, 0, "Cannot register UDP handler"); zassert_equal(ret, 0, "Cannot register UDP handler");
} }

View file

@ -509,7 +509,7 @@ void test_udp(void)
(struct sockaddr *)raddr, \ (struct sockaddr *)raddr, \
(struct sockaddr *)laddr, \ (struct sockaddr *)laddr, \
rport, lport, \ rport, lport, \
test_ok, &user_data, \ NULL, test_ok, &user_data, \
&handlers[i]); \ &handlers[i]); \
if (ret) { \ if (ret) { \
printk("UDP register %s failed (%d)\n", \ printk("UDP register %s failed (%d)\n", \
@ -525,7 +525,8 @@ void test_udp(void)
(struct sockaddr *)raddr, \ (struct sockaddr *)raddr, \
(struct sockaddr *)laddr, \ (struct sockaddr *)laddr, \
rport, lport, \ rport, lport, \
test_fail, INT_TO_POINTER(0), NULL); \ NULL, test_fail, INT_TO_POINTER(0), \
NULL); \
if (!ret) { \ if (!ret) { \
printk("UDP register invalid match %s failed\n", \ printk("UDP register invalid match %s failed\n", \
"DST="#raddr"-SRC="#laddr"-RP="#rport"-LP="#lport); \ "DST="#raddr"-SRC="#laddr"-RP="#rport"-LP="#lport); \