net: socket: Add locking to prevent concurrent access

The BSD API calls were not thread safe. Add locking to fix this.

Fixes #27032

Signed-off-by: Jukka Rissanen <jukka.rissanen@linux.intel.com>
This commit is contained in:
Jukka Rissanen 2020-07-23 13:49:06 +03:00 committed by Kumar Gala
commit dde03c6770
5 changed files with 201 additions and 36 deletions

View file

@ -274,6 +274,13 @@ __net_socket struct net_context {
struct k_fifo accept_q;
};
struct {
/** Condition variable used when receiving data */
struct k_condvar recv;
/** Mutex used by condition variable */
struct k_mutex *lock;
} cond;
#endif /* CONFIG_NET_SOCKETS */
#if defined(CONFIG_NET_OFFLOAD)

View file

@ -96,10 +96,15 @@ void *z_get_fd_obj(int fd, const struct fd_op_vtable *vtable, int err);
*
* @param fd File descriptor previously returned by z_reserve_fd()
* @param vtable A pointer to a pointer variable to store the vtable
* @param lock An optional pointer to a pointer variable to store the mutex
* preventing concurrent descriptor access. The lock is not taken,
* it is just returned for the caller to use if necessary. Pass NULL
* if the lock is not needed by the caller.
*
* @return Object pointer or NULL, with errno set
*/
void *z_get_fd_obj_and_vtable(int fd, const struct fd_op_vtable **vtable);
void *z_get_fd_obj_and_vtable(int fd, const struct fd_op_vtable **vtable,
struct k_mutex **lock);
/**
* @brief Call ioctl vmethod on an object using varargs.
@ -141,6 +146,7 @@ enum {
ZFD_IOCTL_POLL_PREPARE,
ZFD_IOCTL_POLL_UPDATE,
ZFD_IOCTL_POLL_OFFLOAD,
ZFD_IOCTL_SET_LOCK,
};
#ifdef __cplusplus

View file

@ -25,6 +25,7 @@ struct fd_entry {
void *obj;
const struct fd_op_vtable *vtable;
atomic_t refcount;
struct k_mutex lock;
};
#ifdef CONFIG_POSIX_API
@ -136,7 +137,8 @@ void *z_get_fd_obj(int fd, const struct fd_op_vtable *vtable, int err)
return entry->obj;
}
void *z_get_fd_obj_and_vtable(int fd, const struct fd_op_vtable **vtable)
void *z_get_fd_obj_and_vtable(int fd, const struct fd_op_vtable **vtable,
struct k_mutex **lock)
{
struct fd_entry *entry;
@ -147,6 +149,10 @@ void *z_get_fd_obj_and_vtable(int fd, const struct fd_op_vtable **vtable)
entry = &fdtable[fd];
*vtable = entry->vtable;
if (lock) {
*lock = &entry->lock;
}
return entry->obj;
}
@ -162,6 +168,7 @@ int z_reserve_fd(void)
(void)z_fd_ref(fd);
fdtable[fd].obj = NULL;
fdtable[fd].vtable = NULL;
k_mutex_init(&fdtable[fd].lock);
}
k_mutex_unlock(&fdtable_lock);
@ -184,6 +191,15 @@ void z_finalize_fd(int fd, void *obj, const struct fd_op_vtable *vtable)
#endif
fdtable[fd].obj = obj;
fdtable[fd].vtable = vtable;
/* Let the object know about the lock just in case it needs it
* for something. For BSD sockets, the lock is used with condition
* variables to avoid keeping the lock for a long period of time.
*/
if (vtable && vtable->ioctl) {
(void)z_fdtable_call_ioctl(vtable, obj, ZFD_IOCTL_SET_LOCK,
&fdtable[fd].lock);
}
}
void z_free_fd(int fd)

View file

@ -32,26 +32,39 @@ LOG_MODULE_REGISTER(net_sock, CONFIG_NET_SOCKETS_LOG_LEVEL);
#define SET_ERRNO(x) \
{ int _err = x; if (_err < 0) { errno = -_err; return -1; } }
#define VTABLE_CALL(fn, sock, ...) \
do { \
const struct socket_op_vtable *vtable; \
void *ctx = get_sock_vtable(sock, &vtable); \
if (ctx == NULL || vtable->fn == NULL) { \
errno = EBADF; \
return -1; \
} \
return vtable->fn(ctx, __VA_ARGS__); \
#define VTABLE_CALL(fn, sock, ...) \
do { \
const struct socket_op_vtable *vtable; \
struct k_mutex *lock; \
void *obj; \
int ret; \
\
obj = get_sock_vtable(sock, &vtable, &lock); \
if (obj == NULL || vtable->fn == NULL) { \
errno = EBADF; \
return -1; \
} \
\
(void)k_mutex_lock(lock, K_FOREVER); \
\
ret = vtable->fn(obj, __VA_ARGS__); \
\
k_mutex_unlock(lock); \
\
return ret; \
} while (0)
const struct socket_op_vtable sock_fd_op_vtable;
static inline void *get_sock_vtable(
int sock, const struct socket_op_vtable **vtable)
static inline void *get_sock_vtable(int sock,
const struct socket_op_vtable **vtable,
struct k_mutex **lock)
{
void *ctx;
ctx = z_get_fd_obj_and_vtable(sock,
(const struct fd_op_vtable **)vtable);
(const struct fd_op_vtable **)vtable,
lock);
#ifdef CONFIG_USERSPACE
if (ctx != NULL && z_is_in_user_syscall()) {
@ -84,7 +97,7 @@ void *z_impl_zsock_get_context_object(int sock)
{
const struct socket_op_vtable *ignored;
return get_sock_vtable(sock, &ignored);
return get_sock_vtable(sock, &ignored, NULL);
}
#ifdef CONFIG_USERSPACE
@ -171,6 +184,11 @@ int zsock_socket_internal(int family, int type, int proto)
/* recv_q and accept_q are in union */
k_fifo_init(&ctx->recv_q);
/* Condition variable is used to avoid keeping lock for a long time
* when waiting data to be received
*/
k_condvar_init(&ctx->cond.recv);
/* TCP context is effectively owned by both application
* and the stack: stack may detect that peer closed/aborted
* connection, but it must not dispose of the context behind
@ -248,19 +266,25 @@ int zsock_close_ctx(struct net_context *ctx)
int z_impl_zsock_close(int sock)
{
const struct socket_op_vtable *vtable;
void *ctx = get_sock_vtable(sock, &vtable);
struct k_mutex *lock;
void *ctx;
int ret;
ctx = get_sock_vtable(sock, &vtable, &lock);
if (ctx == NULL) {
errno = EBADF;
return -1;
}
(void)k_mutex_lock(lock, K_FOREVER);
NET_DBG("close: ctx=%p, fd=%d", ctx, sock);
z_free_fd(sock);
ret = vtable->fd_vtable.close(ctx);
z_free_fd(sock);
k_mutex_unlock(lock);
return ret;
}
@ -307,6 +331,7 @@ static void zsock_accepted_cb(struct net_context *new_ctx,
(void)net_context_recv(new_ctx, zsock_received_cb, K_NO_WAIT,
NULL);
k_fifo_init(&new_ctx->recv_q);
k_condvar_init(&new_ctx->cond.recv);
k_fifo_put(&parent->accept_q, new_ctx);
}
@ -338,6 +363,8 @@ static void zsock_received_cb(struct net_context *ctx,
net_pkt_set_eof(last_pkt, true);
NET_DBG("Set EOF flag on pkt %p", last_pkt);
}
(void)k_condvar_signal(&ctx->cond.recv);
return;
}
@ -351,6 +378,9 @@ static void zsock_received_cb(struct net_context *ctx,
net_pkt_set_rx_stats_tick(pkt, k_cycle_get_32());
k_fifo_put(&ctx->recv_q, pkt);
/* Let reader to wake if it was sleeping */
(void)k_condvar_signal(&ctx->cond.recv);
}
int zsock_bind_ctx(struct net_context *ctx, const struct sockaddr *addr,
@ -925,6 +955,29 @@ void net_socket_update_tc_rx_time(struct net_pkt *pkt, uint32_t end_tick)
}
}
static int wait_data(struct net_context *ctx, k_timeout_t *timeout)
{
if (ctx->cond.lock == NULL) {
/* For some reason the lock pointer is not set properly
* when called by fdtable.c:z_finalize_fd()
* It is not practical to try to figure out the fdtable
* lock at this point so skip it.
*/
NET_WARN("No lock pointer set for context %p", ctx);
} else if (!k_fifo_peek_head(&ctx->recv_q)) {
int ret;
/* Wait for the data to arrive but without holding a lock */
ret = k_condvar_wait(&ctx->cond.recv, ctx->cond.lock, *timeout);
if (ret < 0) {
return ret;
}
}
return 0;
}
static inline ssize_t zsock_recv_dgram(struct net_context *ctx,
void *buf,
size_t max_len,
@ -941,7 +994,15 @@ static inline ssize_t zsock_recv_dgram(struct net_context *ctx,
if ((flags & ZSOCK_MSG_DONTWAIT) || sock_is_nonblock(ctx)) {
timeout = K_NO_WAIT;
} else {
int ret;
net_context_get_option(ctx, NET_OPT_RCVTIMEO, &timeout, NULL);
ret = wait_data(ctx, &timeout);
if (ret < 0) {
errno = -ret;
return -1;
}
}
if (flags & ZSOCK_MSG_PEEK) {
@ -1060,8 +1121,14 @@ static inline ssize_t zsock_recv_stream(struct net_context *ctx,
if ((flags & ZSOCK_MSG_DONTWAIT) || sock_is_nonblock(ctx)) {
timeout = K_NO_WAIT;
} else {
} else if (!sock_is_eof(ctx)) {
net_context_get_option(ctx, NET_OPT_RCVTIMEO, &timeout, NULL);
res = wait_data(ctx, &timeout);
if (res < 0) {
errno = -res;
return -1;
}
}
end = sys_clock_timeout_end_calc(timeout);
@ -1222,16 +1289,24 @@ ssize_t z_vrfy_zsock_recvfrom(int sock, void *buf, size_t max_len, int flags,
int z_impl_zsock_fcntl(int sock, int cmd, int flags)
{
const struct socket_op_vtable *vtable;
struct k_mutex *lock;
void *obj;
int ret;
obj = get_sock_vtable(sock, &vtable);
obj = get_sock_vtable(sock, &vtable, &lock);
if (obj == NULL) {
errno = EBADF;
return -1;
}
return z_fdtable_call_ioctl((const struct fd_op_vtable *)vtable,
obj, cmd, flags);
(void)k_mutex_lock(lock, K_FOREVER);
ret = z_fdtable_call_ioctl((const struct fd_op_vtable *)vtable,
obj, cmd, flags);
k_mutex_unlock(lock);
return ret;
}
#ifdef CONFIG_USERSPACE
@ -1311,6 +1386,7 @@ int z_impl_zsock_poll(struct zsock_pollfd *fds, int nfds, int poll_timeout)
struct k_poll_event *pev;
struct k_poll_event *pev_end = poll_events + ARRAY_SIZE(poll_events);
const struct fd_op_vtable *vtable;
struct k_mutex *lock;
k_timeout_t timeout;
uint64_t end;
bool offload = false;
@ -1337,12 +1413,15 @@ int z_impl_zsock_poll(struct zsock_pollfd *fds, int nfds, int poll_timeout)
}
ctx = get_sock_vtable(pfd->fd,
(const struct socket_op_vtable **)&vtable);
(const struct socket_op_vtable **)&vtable,
&lock);
if (ctx == NULL) {
/* Will set POLLNVAL in return loop */
continue;
}
(void)k_mutex_lock(lock, K_FOREVER);
result = z_fdtable_call_ioctl(vtable, ctx,
ZFD_IOCTL_POLL_PREPARE,
pfd, &pev, pev_end);
@ -1353,7 +1432,7 @@ int z_impl_zsock_poll(struct zsock_pollfd *fds, int nfds, int poll_timeout)
* as many events as possible, but without any wait.
*/
timeout = K_NO_WAIT;
continue;
result = 0;
} else if (result == -EXDEV) {
/* If POLL_PREPARE returned EXDEV, it means
* it detected an offloaded socket.
@ -1368,8 +1447,13 @@ int z_impl_zsock_poll(struct zsock_pollfd *fds, int nfds, int poll_timeout)
offl_vtable = vtable;
offl_ctx = ctx;
}
continue;
} else if (result != 0) {
result = 0;
}
k_mutex_unlock(lock);
if (result < 0) {
errno = -result;
return -1;
}
@ -1414,17 +1498,23 @@ int z_impl_zsock_poll(struct zsock_pollfd *fds, int nfds, int poll_timeout)
continue;
}
ctx = get_sock_vtable(pfd->fd,
(const struct socket_op_vtable **)&vtable);
ctx = get_sock_vtable(
pfd->fd,
(const struct socket_op_vtable **)&vtable,
&lock);
if (ctx == NULL) {
pfd->revents = ZSOCK_POLLNVAL;
ret++;
continue;
}
(void)k_mutex_lock(lock, K_FOREVER);
result = z_fdtable_call_ioctl(vtable, ctx,
ZFD_IOCTL_POLL_UPDATE,
pfd, &pev);
k_mutex_unlock(lock);
if (result == -EAGAIN) {
retry = true;
continue;
@ -1907,8 +1997,11 @@ int z_impl_zsock_getsockname(int sock, struct sockaddr *addr,
socklen_t *addrlen)
{
const struct socket_op_vtable *vtable;
void *ctx = get_sock_vtable(sock, &vtable);
struct k_mutex *lock;
void *ctx;
int ret;
ctx = get_sock_vtable(sock, &vtable, &lock);
if (ctx == NULL) {
errno = EBADF;
return -1;
@ -1916,7 +2009,13 @@ int z_impl_zsock_getsockname(int sock, struct sockaddr *addr,
NET_DBG("getsockname: ctx=%p, fd=%d", ctx, sock);
return vtable->getsockname(ctx, addr, addrlen);
(void)k_mutex_lock(lock, K_FOREVER);
ret = vtable->getsockname(ctx, addr, addrlen);
k_mutex_unlock(lock);
return ret;
}
#ifdef CONFIG_USERSPACE
@ -1959,6 +2058,11 @@ static ssize_t sock_write_vmeth(void *obj, const void *buffer, size_t count)
return zsock_sendto_ctx(obj, buffer, count, 0, NULL, 0);
}
static void zsock_ctx_set_lock(struct net_context *ctx, struct k_mutex *lock)
{
ctx->cond.lock = lock;
}
static int sock_ioctl_vmeth(void *obj, unsigned int request, va_list args)
{
switch (request) {
@ -2007,6 +2111,15 @@ static int sock_ioctl_vmeth(void *obj, unsigned int request, va_list args)
return zsock_poll_update_ctx(obj, pfd, pev);
}
case ZFD_IOCTL_SET_LOCK: {
struct k_mutex *lock;
lock = va_arg(args, struct k_mutex *);
zsock_ctx_set_lock(obj, lock);
return 0;
}
default:
errno = EOPNOTSUPP;
return -1;

View file

@ -1957,6 +1957,7 @@ static int ztls_poll_prepare_ctx(struct tls_context *ctx,
struct k_poll_event *pev_end)
{
const struct fd_op_vtable *vtable;
struct k_mutex *lock;
void *obj;
int ret;
short events = pfd->events;
@ -1980,12 +1981,14 @@ static int ztls_poll_prepare_ctx(struct tls_context *ctx,
}
obj = z_get_fd_obj_and_vtable(
ctx->sock, (const struct fd_op_vtable **)&vtable);
ctx->sock, (const struct fd_op_vtable **)&vtable, &lock);
if (obj == NULL) {
ret = -EBADF;
goto exit;
}
(void)k_mutex_lock(lock, K_FOREVER);
ret = z_fdtable_call_ioctl(vtable, obj, ZFD_IOCTL_POLL_PREPARE,
pfd, pev, pev_end);
if (ret != 0) {
@ -1999,6 +2002,9 @@ static int ztls_poll_prepare_ctx(struct tls_context *ctx,
exit:
/* Restore original events. */
pfd->events = events;
k_mutex_unlock(lock);
return ret;
}
@ -2056,16 +2062,19 @@ static int ztls_poll_update_ctx(struct tls_context *ctx,
struct k_poll_event **pev)
{
const struct fd_op_vtable *vtable;
struct k_mutex *lock;
void *obj;
int ret;
short events = pfd->events;
obj = z_get_fd_obj_and_vtable(
ctx->sock, (const struct fd_op_vtable **)&vtable);
ctx->sock, (const struct fd_op_vtable **)&vtable, &lock);
if (obj == NULL) {
return -EBADF;
}
(void)k_mutex_lock(lock, K_FOREVER);
/* Check if the socket was waiting for the handshake to complete. */
if ((pfd->events & ZSOCK_POLLIN) &&
((*pev)->obj == &ctx->tls_established)) {
@ -2077,14 +2086,15 @@ static int ztls_poll_update_ctx(struct tls_context *ctx,
ZFD_IOCTL_POLL_PREPARE,
pfd, pev, *pev + 1);
if (ret != 0 && ret != -EALREADY) {
return ret;
goto out;
}
/* Return -EAGAIN to signal to poll() that it should
* make another iteration with the event reconfigured
* above (if needed).
*/
return -EAGAIN;
ret = -EAGAIN;
goto out;
}
/* Handshake still not ready - skip ZSOCK_POLLIN verification
@ -2110,6 +2120,10 @@ static int ztls_poll_update_ctx(struct tls_context *ctx,
exit:
/* Restore original events. */
pfd->events = events;
out:
k_mutex_unlock(lock);
return ret;
}
@ -2152,7 +2166,8 @@ static inline int ztls_poll_offload(struct zsock_pollfd *fds, int nfds, int time
/* Get offloaded sockets vtable. */
ctx = z_get_fd_obj_and_vtable(fds[0].fd,
(const struct fd_op_vtable **)&vtable);
(const struct fd_op_vtable **)&vtable,
NULL);
if (ctx == NULL) {
errno = EINVAL;
goto exit;
@ -2366,17 +2381,25 @@ static int tls_sock_ioctl_vmeth(void *obj, unsigned int request, va_list args)
case F_GETFL:
case F_SETFL: {
const struct fd_op_vtable *vtable;
struct k_mutex *lock;
void *obj;
int ret;
obj = z_get_fd_obj_and_vtable(ctx->sock,
(const struct fd_op_vtable **)&vtable);
(const struct fd_op_vtable **)&vtable, &lock);
if (obj == NULL) {
errno = EBADF;
return -1;
}
(void)k_mutex_lock(lock, K_FOREVER);
/* Pass the call to the core socket implementation. */
return vtable->ioctl(obj, request, args);
ret = vtable->ioctl(obj, request, args);
k_mutex_unlock(lock);
return ret;
}
case ZFD_IOCTL_POLL_PREPARE: {