diff --git a/src/ucs/sys/sock.c b/src/ucs/sys/sock.c index 2810a6c4861..014de33ea6d 100644 --- a/src/ucs/sys/sock.c +++ b/src/ucs/sys/sock.c @@ -106,7 +106,7 @@ ucs_status_t ucs_socket_setopt(int fd, int level, int optname, return UCS_OK; } -static const char *ucs_socket_getname_str(int fd, char *str, size_t max_size) +const char *ucs_socket_getname_str(int fd, char *str, size_t max_size) { struct sockaddr_storage sock_addr = {0}; /* Suppress Clang false-positive */ socklen_t addr_size; diff --git a/src/ucs/sys/sock.h b/src/ucs/sys/sock.h index e8bf35af4d3..a1165030765 100644 --- a/src/ucs/sys/sock.h +++ b/src/ucs/sys/sock.h @@ -336,6 +336,19 @@ const char* ucs_sockaddr_str(const struct sockaddr *sock_addr, char *str, size_t max_size); +/** + * Extract the IP address from a given socket fd and return it as a string. + * + * @param [in] fd Socket fd. + * @param [out] str A string filled with the IP address. + * @param [in] max_size Size of a string (considering '\0'-terminated symbol) + * + * @return ip_str if the sock_addr has a valid IP address or 'Invalid address' + * otherwise. + */ +const char *ucs_socket_getname_str(int fd, char *str, size_t max_size); + + /** * Return a value indicating the relationships between passed sockaddr structures. * diff --git a/src/uct/tcp/tcp.h b/src/uct/tcp/tcp.h index c2eb7e78e5b..7f8816c89cc 100644 --- a/src/uct/tcp/tcp.h +++ b/src/uct/tcp/tcp.h @@ -21,6 +21,9 @@ #define UCT_TCP_CONFIG_PREFIX "TCP_" +/* Magic number that is used by TCP to identify its peers */ +#define UCT_TCP_MAGIC_NUMBER 0xCAFEBABE12345678lu + /* Maximum number of events to wait on event set */ #define UCT_TCP_MAX_EVENTS 16 @@ -86,9 +89,13 @@ typedef enum uct_tcp_ep_conn_state { * After it is done, it sends `UCT_TCP_CM_CONN_REQ` to the peer. * All AM operations return `UCS_ERR_NO_RESOURCE` error to a caller. */ UCT_TCP_EP_CONN_STATE_CONNECTING, + /* EP is receiving the magic number in order to verify a peer. EP is moved + * to this state after accept() completed. */ + UCT_TCP_EP_CONN_STATE_RECV_MAGIC_NUMBER, /* EP is accepting connection from a peer, i.e. accept() returns socket fd * on which a connection was accepted, this EP was created using this socket - * and now it is waiting for `UCT_TCP_CM_CONN_REQ` from a peer. */ + * fd and the magic number was received and verified by EP and now it is + * waiting for `UCT_TCP_CM_CONN_REQ` from a peer. */ UCT_TCP_EP_CONN_STATE_ACCEPTING, /* EP is waiting for `UCT_TCP_CM_CONN_ACK` message from a peer after sending * `UCT_TCP_CM_CONN_REQ`. @@ -144,6 +151,7 @@ KHASH_INIT(uct_tcp_cm_eps, struct sockaddr_in, ucs_list_link_t*, typedef struct uct_tcp_cm_state { const char *name; /* CM state name */ uct_tcp_ep_progress_t tx_progress; /* TX progress function */ + uct_tcp_ep_progress_t rx_progress; /* RX progress function */ } uct_tcp_cm_state_t; @@ -386,10 +394,6 @@ void uct_tcp_iface_remove_ep(uct_tcp_ep_t *ep); ucs_status_t uct_tcp_ep_handle_dropped_connect(uct_tcp_ep_t *ep, int io_errno); -unsigned uct_tcp_ep_progress_am_rx(uct_tcp_ep_t *ep); - -unsigned uct_tcp_ep_progress_put_rx(uct_tcp_ep_t *ep); - ucs_status_t uct_tcp_ep_init(uct_tcp_iface_t *iface, int fd, const struct sockaddr_in *dest_addr, uct_tcp_ep_t **ep_p); @@ -480,20 +484,6 @@ ucs_status_t uct_tcp_cm_handle_incoming_conn(uct_tcp_iface_t *iface, ucs_status_t uct_tcp_cm_conn_start(uct_tcp_ep_t *ep); -static inline unsigned uct_tcp_ep_progress_tx(uct_tcp_ep_t *ep) -{ - return uct_tcp_ep_cm_state[ep->conn_state].tx_progress(ep); -} - -static inline unsigned uct_tcp_ep_progress_rx(uct_tcp_ep_t *ep) -{ - if (!(ep->ctx_caps & UCS_BIT(UCT_TCP_EP_CTX_TYPE_PUT_RX))) { - return uct_tcp_ep_progress_am_rx(ep); - } else { - return uct_tcp_ep_progress_put_rx(ep); - } -} - static inline void uct_tcp_iface_outstanding_inc(uct_tcp_iface_t *iface) { iface->outstanding++; diff --git a/src/uct/tcp/tcp_cm.c b/src/uct/tcp/tcp_cm.c index 428c5a2a5ee..45e357d0f33 100644 --- a/src/uct/tcp/tcp_cm.c +++ b/src/uct/tcp/tcp_cm.c @@ -24,8 +24,6 @@ void uct_tcp_cm_change_conn_state(uct_tcp_ep_t *ep, switch(ep->conn_state) { case UCT_TCP_EP_CONN_STATE_CONNECTING: - ucs_assertv(iface->config.conn_nb, "ep=%p", ep); - /* Fall through */ case UCT_TCP_EP_CONN_STATE_WAITING_ACK: if (old_conn_state == UCT_TCP_EP_CONN_STATE_CLOSED) { uct_tcp_iface_outstanding_inc(iface); @@ -61,13 +59,15 @@ void uct_tcp_cm_change_conn_state(uct_tcp_ep_t *ep, (old_conn_state == UCT_TCP_EP_CONN_STATE_WAITING_ACK) || (old_conn_state == UCT_TCP_EP_CONN_STATE_WAITING_REQ)) { uct_tcp_iface_outstanding_dec(iface); - } else if (old_conn_state == UCT_TCP_EP_CONN_STATE_ACCEPTING) { + } else if ((old_conn_state == UCT_TCP_EP_CONN_STATE_ACCEPTING) || + (old_conn_state == UCT_TCP_EP_CONN_STATE_RECV_MAGIC_NUMBER)) { /* Since ep::peer_addr is 0'ed, we have to print w/o peer's address */ full_log = 0; } break; default: - ucs_assert(ep->conn_state == UCT_TCP_EP_CONN_STATE_ACCEPTING); + ucs_assert((ep->conn_state == UCT_TCP_EP_CONN_STATE_ACCEPTING) || + (ep->conn_state == UCT_TCP_EP_CONN_STATE_RECV_MAGIC_NUMBER)); /* Since ep::peer_addr is 0'ed and client's * has already been logged, print w/o peer's address */ full_log = 0; @@ -143,8 +143,9 @@ static void uct_tcp_cm_trace_conn_pkt(const uct_tcp_ep_t *ep, ucs_status_t uct_tcp_cm_send_event(uct_tcp_ep_t *ep, uct_tcp_cm_conn_event_t event) { - uct_tcp_iface_t *iface = ucs_derived_of(ep->super.super.iface, - uct_tcp_iface_t); + uct_tcp_iface_t *iface = ucs_derived_of(ep->super.super.iface, + uct_tcp_iface_t); + size_t magic_number_length = 0; void *pkt_buf; size_t pkt_length, cm_pkt_length; uct_tcp_cm_conn_req_pkt_t *conn_pkt; @@ -160,20 +161,30 @@ ucs_status_t uct_tcp_cm_send_event(uct_tcp_ep_t *ep, uct_tcp_cm_conn_event_t eve (ep->conn_state != UCT_TCP_EP_CONN_STATE_CONNECTED), "ep=%p", ep); - pkt_length = sizeof(*pkt_hdr); + pkt_length = sizeof(*pkt_hdr); if (event == UCT_TCP_CM_CONN_REQ) { - cm_pkt_length = sizeof(*conn_pkt); + cm_pkt_length = sizeof(*conn_pkt); + + if (ep->conn_state == UCT_TCP_EP_CONN_STATE_CONNECTING) { + magic_number_length = sizeof(uint64_t); + } } else { - cm_pkt_length = sizeof(event); + cm_pkt_length = sizeof(event); } - pkt_length += cm_pkt_length; - pkt_buf = ucs_alloca(pkt_length); - pkt_hdr = (uct_tcp_am_hdr_t*)pkt_buf; + pkt_length += cm_pkt_length + magic_number_length; + pkt_buf = ucs_alloca(pkt_length); + pkt_hdr = (uct_tcp_am_hdr_t*)(UCS_PTR_BYTE_OFFSET(pkt_buf, + magic_number_length)); pkt_hdr->am_id = UCT_AM_ID_MAX; pkt_hdr->length = cm_pkt_length; if (event == UCT_TCP_CM_CONN_REQ) { + if (ep->conn_state == UCT_TCP_EP_CONN_STATE_CONNECTING) { + ucs_assert(magic_number_length == sizeof(uint64_t)); + *(uint64_t*)pkt_buf = UCT_TCP_MAGIC_NUMBER; + } + conn_pkt = (uct_tcp_cm_conn_req_pkt_t*)(pkt_hdr + 1); conn_pkt->event = UCT_TCP_CM_CONN_REQ; conn_pkt->iface_addr = iface->config.ifaddr; @@ -500,19 +511,14 @@ unsigned uct_tcp_cm_handle_conn_pkt(uct_tcp_ep_t **ep_p, void *pkt, uint32_t len return 0; } -unsigned uct_tcp_cm_conn_progress(uct_tcp_ep_t *ep) +static ucs_status_t uct_tcp_cm_conn_complete(uct_tcp_ep_t *ep, + unsigned *progress_count_p) { ucs_status_t status; - if (!ucs_socket_is_connected(ep->fd)) { - ucs_error("tcp_ep %p: connection establishment for " - "socket fd %d was unsuccessful", ep, ep->fd); - goto err; - } - status = uct_tcp_cm_send_event(ep, UCT_TCP_CM_CONN_REQ); if (status != UCS_OK) { - return 0; + goto out; } uct_tcp_cm_change_conn_state(ep, UCT_TCP_EP_CONN_STATE_WAITING_ACK); @@ -520,7 +526,25 @@ unsigned uct_tcp_cm_conn_progress(uct_tcp_ep_t *ep) ucs_assertv((ep->tx.length == 0) && (ep->tx.offset == 0) && (ep->tx.buf == NULL), "ep=%p", ep); - return 1; +out: + if (progress_count_p != NULL) { + *progress_count_p = (status == UCS_OK); + } + return status; +} + +unsigned uct_tcp_cm_conn_progress(uct_tcp_ep_t *ep) +{ + unsigned progress_count; + + if (!ucs_socket_is_connected(ep->fd)) { + ucs_error("tcp_ep %p: connection establishment for " + "socket fd %d was unsuccessful", ep, ep->fd); + goto err; + } + + uct_tcp_cm_conn_complete(ep, &progress_count); + return progress_count; err: uct_tcp_ep_set_failed(ep); @@ -539,13 +563,13 @@ ucs_status_t uct_tcp_cm_conn_start(uct_tcp_ep_t *ep) return UCS_ERR_TIMED_OUT; } + uct_tcp_cm_change_conn_state(ep, UCT_TCP_EP_CONN_STATE_CONNECTING); + status = ucs_socket_connect(ep->fd, (const struct sockaddr*)&ep->peer_addr); if (UCS_STATUS_IS_ERR(status)) { return status; } else if (status == UCS_INPROGRESS) { - uct_tcp_cm_change_conn_state(ep, UCT_TCP_EP_CONN_STATE_CONNECTING); uct_tcp_ep_mod_events(ep, UCS_EVENT_SET_EVWRITE, 0); - return UCS_OK; } @@ -558,15 +582,7 @@ ucs_status_t uct_tcp_cm_conn_start(uct_tcp_ep_t *ep) } } - status = uct_tcp_cm_send_event(ep, UCT_TCP_CM_CONN_REQ); - if (status != UCS_OK) { - return status; - } - - uct_tcp_cm_change_conn_state(ep, UCT_TCP_EP_CONN_STATE_WAITING_ACK); - uct_tcp_ep_mod_events(ep, UCS_EVENT_SET_EVREAD, 0); - - return UCS_OK; + return uct_tcp_cm_conn_complete(ep, NULL); } /* This function is called from async thread */ @@ -594,7 +610,7 @@ ucs_status_t uct_tcp_cm_handle_incoming_conn(uct_tcp_iface_t *iface, return status; } - uct_tcp_cm_change_conn_state(ep, UCT_TCP_EP_CONN_STATE_ACCEPTING); + uct_tcp_cm_change_conn_state(ep, UCT_TCP_EP_CONN_STATE_RECV_MAGIC_NUMBER); uct_tcp_ep_mod_events(ep, UCS_EVENT_SET_EVREAD, 0); ucs_debug("tcp_iface %p: accepted connection from " diff --git a/src/uct/tcp/tcp_ep.c b/src/uct/tcp/tcp_ep.c index 08f9b4b21af..f143412b196 100644 --- a/src/uct/tcp/tcp_ep.c +++ b/src/uct/tcp/tcp_ep.c @@ -12,33 +12,46 @@ #include -/* Forward declaration */ +/* Forward declarations */ static unsigned uct_tcp_ep_progress_data_tx(uct_tcp_ep_t *ep); +static unsigned uct_tcp_ep_progress_data_rx(uct_tcp_ep_t *ep); +static unsigned uct_tcp_ep_progress_magic_number_rx(uct_tcp_ep_t *ep); const uct_tcp_cm_state_t uct_tcp_ep_cm_state[] = { [UCT_TCP_EP_CONN_STATE_CLOSED] = { .name = "CLOSED", - .tx_progress = (uct_tcp_ep_progress_t)ucs_empty_function_return_zero + .tx_progress = (uct_tcp_ep_progress_t)ucs_empty_function_return_zero, + .rx_progress = (uct_tcp_ep_progress_t)ucs_empty_function_return_zero }, [UCT_TCP_EP_CONN_STATE_CONNECTING] = { .name = "CONNECTING", - .tx_progress = uct_tcp_cm_conn_progress + .tx_progress = uct_tcp_cm_conn_progress, + .rx_progress = uct_tcp_ep_progress_data_rx }, [UCT_TCP_EP_CONN_STATE_WAITING_ACK] = { .name = "WAITING_ACK", - .tx_progress = (uct_tcp_ep_progress_t)ucs_empty_function_return_zero + .tx_progress = (uct_tcp_ep_progress_t)ucs_empty_function_return_zero, + .rx_progress = uct_tcp_ep_progress_data_rx + }, + [UCT_TCP_EP_CONN_STATE_RECV_MAGIC_NUMBER] = { + .name = "RECV_MAGIC_NUMBER", + .tx_progress = (uct_tcp_ep_progress_t)ucs_empty_function_return_zero, + .rx_progress = uct_tcp_ep_progress_magic_number_rx }, [UCT_TCP_EP_CONN_STATE_ACCEPTING] = { .name = "ACCEPTING", - .tx_progress = (uct_tcp_ep_progress_t)ucs_empty_function_return_zero + .tx_progress = (uct_tcp_ep_progress_t)ucs_empty_function_return_zero, + .rx_progress = uct_tcp_ep_progress_data_rx }, [UCT_TCP_EP_CONN_STATE_WAITING_REQ] = { .name = "WAITING_REQ", - .tx_progress = (uct_tcp_ep_progress_t)ucs_empty_function_return_zero + .tx_progress = (uct_tcp_ep_progress_t)ucs_empty_function_return_zero, + .rx_progress = uct_tcp_ep_progress_data_rx }, [UCT_TCP_EP_CONN_STATE_CONNECTED] = { .name = "CONNECTED", - .tx_progress = uct_tcp_ep_progress_data_tx + .tx_progress = uct_tcp_ep_progress_data_tx, + .rx_progress = uct_tcp_ep_progress_data_rx } }; @@ -135,11 +148,11 @@ static void uct_tcp_ep_cleanup(uct_tcp_ep_t *ep) { uct_tcp_ep_addr_cleanup(&ep->peer_addr); - if (ep->tx.buf) { + if (ep->tx.buf != NULL) { uct_tcp_ep_ctx_reset(&ep->tx); } - if (ep->rx.buf) { + if (ep->rx.buf != NULL) { uct_tcp_ep_ctx_reset(&ep->rx); } @@ -552,8 +565,9 @@ static void uct_tcp_ep_handle_disconnected(uct_tcp_ep_t *ep, uct_tcp_ep_mod_events(ep, 0, ep->events); uct_tcp_ep_close_fd(&ep->fd); - } else if (ep->ctx_caps & UCS_BIT(UCT_TCP_EP_CTX_TYPE_RX)) { - /* If the EP supports RX only, destroy it */ + } else if ((ep->ctx_caps == 0) || + (ep->ctx_caps & UCS_BIT(UCT_TCP_EP_CTX_TYPE_RX))) { + /* If the EP supports RX only or no capabilities set, destroy it */ uct_tcp_ep_destroy_internal(&ep->super.super); } } @@ -676,8 +690,9 @@ static ucs_status_t uct_tcp_ep_io_err_handler_cb(void *arg, int io_errno) char str_remote_addr[UCS_SOCKADDR_STRING_LEN]; if ((io_errno == ECONNRESET) && - (ep->conn_state == UCT_TCP_EP_CONN_STATE_CONNECTED) && - (ep->ctx_caps == UCS_BIT(UCT_TCP_EP_CTX_TYPE_RX)) /* only RX cap */) { + ((ep->conn_state == UCT_TCP_EP_CONN_STATE_ACCEPTING) || + ((ep->conn_state == UCT_TCP_EP_CONN_STATE_CONNECTED) && + (ep->ctx_caps == UCS_BIT(UCT_TCP_EP_CTX_TYPE_RX)) /* only RX cap */))) { ucs_debug("tcp_ep %p: detected %d (%s) error, the [%s <-> %s] " "connection was dropped by the peer", ep, io_errno, strerror(io_errno), @@ -698,7 +713,7 @@ static inline void uct_tcp_ep_handle_recv_err(uct_tcp_ep_t *ep, /* If no data were read to the allocated buffer, * we can safely reset it for futher re-use and to * avoid overwriting this buffer, because `rx::length == 0` */ - if (!ep->rx.length) { + if (ep->rx.length == 0) { uct_tcp_ep_ctx_reset(&ep->rx); } } else { @@ -714,7 +729,8 @@ static inline unsigned uct_tcp_ep_recv(uct_tcp_ep_t *ep, size_t recv_length) ucs_assertv(recv_length, "ep=%p", ep); - status = ucs_socket_recv_nb(ep->fd, UCS_PTR_BYTE_OFFSET(ep->rx.buf, ep->rx.length), + status = ucs_socket_recv_nb(ep->fd, UCS_PTR_BYTE_OFFSET(ep->rx.buf, + ep->rx.length), &recv_length, uct_tcp_ep_io_err_handler_cb, ep); if (ucs_unlikely(status != UCS_OK)) { uct_tcp_ep_handle_recv_err(ep, status); @@ -840,7 +856,7 @@ static inline void uct_tcp_ep_handle_put_req(uct_tcp_ep_t *ep, ep->ctx_caps |= UCS_BIT(UCT_TCP_EP_CTX_TYPE_PUT_RX); } -unsigned uct_tcp_ep_progress_am_rx(uct_tcp_ep_t *ep) +static unsigned uct_tcp_ep_progress_am_rx(uct_tcp_ep_t *ep) { uct_tcp_iface_t *iface = ucs_derived_of(ep->super.super.iface, uct_tcp_iface_t); @@ -852,7 +868,7 @@ unsigned uct_tcp_ep_progress_am_rx(uct_tcp_ep_t *ep) ucs_trace_func("ep=%p", ep); if (!uct_tcp_ep_ctx_buf_need_progress(&ep->rx)) { - ucs_assert(!ep->rx.buf); + ucs_assert(ep->rx.buf == NULL); ep->rx.buf = ucs_mpool_get_inline(&iface->rx_mpool); if (ucs_unlikely(ep->rx.buf == NULL)) { ucs_warn("tcp_ep %p: unable to get a buffer from RX memory pool", ep); @@ -978,7 +994,7 @@ uct_tcp_ep_am_prepare(uct_tcp_iface_t *iface, uct_tcp_ep_t *ep, return UCS_ERR_NO_RESOURCE; } -unsigned uct_tcp_ep_progress_put_rx(uct_tcp_ep_t *ep) +static unsigned uct_tcp_ep_progress_put_rx(uct_tcp_ep_t *ep) { uct_tcp_ep_put_req_hdr_t *put_req; size_t recv_length; @@ -1001,6 +1017,65 @@ unsigned uct_tcp_ep_progress_put_rx(uct_tcp_ep_t *ep) return 1; } +static unsigned uct_tcp_ep_progress_data_rx(uct_tcp_ep_t *ep) +{ + if (!(ep->ctx_caps & UCS_BIT(UCT_TCP_EP_CTX_TYPE_PUT_RX))) { + return uct_tcp_ep_progress_am_rx(ep); + } else { + return uct_tcp_ep_progress_put_rx(ep); + } +} + +static unsigned uct_tcp_ep_progress_magic_number_rx(uct_tcp_ep_t *ep) +{ + uct_tcp_iface_t *iface = ucs_derived_of(ep->super.super.iface, + uct_tcp_iface_t); + char str_local_addr[UCS_SOCKADDR_STRING_LEN]; + char str_remote_addr[UCS_SOCKADDR_STRING_LEN]; + size_t recv_length, prev_length; + uint64_t magic_number; + + if (ep->rx.buf == NULL) { + ep->rx.buf = ucs_mpool_get_inline(&iface->rx_mpool); + if (ucs_unlikely(ep->rx.buf == NULL)) { + ucs_warn("tcp_ep %p: unable to get a buffer from RX memory pool", ep); + return 0; + } + } + + prev_length = ep->rx.length; + recv_length = sizeof(magic_number) - ep->rx.length; + + if (!uct_tcp_ep_recv(ep, recv_length) || + (ep->rx.length < sizeof(magic_number))) { + return ((ep->rx.length - prev_length) > 0); + } + + magic_number = *(uint64_t*)ep->rx.buf; + + if (magic_number != UCT_TCP_MAGIC_NUMBER) { + /* Silently close this connection and destroy its EP */ + ucs_debug("tcp_iface %p (%s): received wrong magic number (expected: " + "%zu, received: %zu) for ep=%p (fd=%d) from %s", iface, + ucs_sockaddr_str((const struct sockaddr*)&iface->config.ifaddr, + str_local_addr, UCS_SOCKADDR_STRING_LEN), + UCT_TCP_MAGIC_NUMBER, magic_number, ep, + ep->fd, ucs_socket_getname_str(ep->fd, str_remote_addr, + UCS_SOCKADDR_STRING_LEN)); + goto err; + } + + uct_tcp_ep_ctx_reset(&ep->rx); + + uct_tcp_cm_change_conn_state(ep, UCT_TCP_EP_CONN_STATE_ACCEPTING); + + return 1; + +err: + uct_tcp_ep_destroy_internal(&ep->super.super); + return 0; +} + static inline void uct_tcp_ep_set_outstanding_zcopy(uct_tcp_iface_t *iface, uct_tcp_ep_t *ep, uct_tcp_ep_zcopy_tx_t *ctx, const void *header, diff --git a/src/uct/tcp/tcp_iface.c b/src/uct/tcp/tcp_iface.c index 7b6ceaf3d44..075ce0b6c0c 100644 --- a/src/uct/tcp/tcp_iface.c +++ b/src/uct/tcp/tcp_iface.c @@ -191,10 +191,10 @@ static void uct_tcp_iface_handle_events(void *callback_data, ucs_assertv(ep->conn_state != UCT_TCP_EP_CONN_STATE_CLOSED, "ep=%p", ep); if (events & UCS_EVENT_SET_EVREAD) { - *count += uct_tcp_ep_progress_rx(ep); + *count += uct_tcp_ep_cm_state[ep->conn_state].rx_progress(ep); } if (events & UCS_EVENT_SET_EVWRITE) { - *count += uct_tcp_ep_progress_tx(ep); + *count += uct_tcp_ep_cm_state[ep->conn_state].tx_progress(ep); } } diff --git a/test/gtest/Makefile.am b/test/gtest/Makefile.am index ecdb38cc35d..71eff0a0e1d 100644 --- a/test/gtest/Makefile.am +++ b/test/gtest/Makefile.am @@ -111,6 +111,7 @@ gtest_SOURCES = \ ucs/test_stats_filter.cc \ uct/test_peer_failure.cc \ uct/test_tag.cc \ + uct/tcp/test_tcp.cc \ \ ucp/test_ucp_am.cc \ ucp/test_ucp_stream.cc \ diff --git a/test/gtest/uct/tcp/test_tcp.cc b/test/gtest/uct/tcp/test_tcp.cc new file mode 100644 index 00000000000..c7db74679dd --- /dev/null +++ b/test/gtest/uct/tcp/test_tcp.cc @@ -0,0 +1,258 @@ +/** + * Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ + +#include +#include + +extern "C" { +#include +#include +} + +class test_uct_tcp : public uct_test { +public: + void init() { + if (RUNNING_ON_VALGRIND) { + modify_config("TX_SEG_SIZE", "1kb"); + modify_config("RX_SEG_SIZE", "1kb"); + } + + uct_test::init(); + m_ent = uct_test::create_entity(0); + m_entities.push_back(m_ent); + m_tcp_iface = (uct_tcp_iface*)m_ent->iface(); + } + + size_t get_accepted_conn_num(entity& ent) { + size_t num = 0; + uct_tcp_ep_t *ep; + + ucs_list_for_each(ep, &m_tcp_iface->ep_list, list) { + num += (ep->conn_state == UCT_TCP_EP_CONN_STATE_RECV_MAGIC_NUMBER); + } + + return num; + } + + ucs_status_t post_recv(int fd, bool nb = false) { + uint8_t msg; + size_t msg_size = sizeof(msg); + ucs_status_t status; + + scoped_log_handler slh(wrap_errors_logger); + if (nb) { + status = ucs_socket_recv_nb(fd, &msg, &msg_size, NULL, NULL); + } else { + status = ucs_socket_recv(fd, &msg, msg_size, NULL, NULL); + } + + return status; + } + + void post_send(int fd, const std::vector &buf) { + scoped_log_handler slh(wrap_errors_logger); + ucs_status_t status = ucs_socket_send(fd, &buf[0], + buf.size(), NULL, NULL); + // send can be OK or fail when a connection was closed by a peer + // before all data were sent + ASSERT_TRUE((status == UCS_OK) || + (status == UCS_ERR_IO_ERROR)); + } + + void detect_conn_reset(int fd) { + // Try to receive something on this socket fd - it has to be failed + ucs_status_t status = post_recv(fd); + ASSERT_TRUE((status == UCS_ERR_IO_ERROR) || + (status == UCS_ERR_CANCELED)); + EXPECT_EQ(0, ucs_socket_is_connected(fd)); + } + + void test_listener_flood(entity& test_entity, size_t max_conn, + size_t msg_size) { + std::vector fds; + std::vector buf; + + if (msg_size > 0) { + buf.resize(msg_size + sizeof(uct_tcp_am_hdr_t)); + std::fill(buf.begin(), buf.end(), 0); + init_data(&buf[0], buf.size()); + } + + setup_conns_to_entity(test_entity, max_conn, fds); + + size_t handled = 0; + for (std::vector::const_iterator iter = fds.begin(); + iter != fds.end(); ++iter) { + size_t sent_length = 0; + do { + if (msg_size > 0) { + post_send(*iter, buf); + sent_length += buf.size(); + } else { + close(*iter); + } + + // If it was sent >= the length of the magic number or sending + // is not required by the current test, wait until connection + // is destroyed. Otherwise, need to send more data + if ((msg_size == 0) || (sent_length >= sizeof(uint64_t))) { + handled++; + + while (get_accepted_conn_num(test_entity) != (max_conn - handled)) { + sched_yield(); + progress(); + } + } else { + // Peers still have to be connected + ucs_status_t status = post_recv(*iter, true); + EXPECT_TRUE((status == UCS_OK) || + (status == UCS_ERR_NO_PROGRESS)); + EXPECT_EQ(1, ucs_socket_is_connected(*iter)); + } + } while ((msg_size != 0) && (sent_length < sizeof(uint64_t))); + } + + // give a chance to close all connections + while (!ucs_list_is_empty(&m_tcp_iface->ep_list)) { + sched_yield(); + progress(); + } + + // TCP has to reject all connections and forget EPs that were + // created after accept(): + // - EP list has to be empty + EXPECT_EQ(1, ucs_list_is_empty(&m_tcp_iface->ep_list)); + // - all connections have to be destroyed (if wasn't closed + // yet by the clients) + if (msg_size > 0) { + // if we sent data during the test, close socket fd here + while (!fds.empty()) { + int fd = fds.back(); + fds.pop_back(); + detect_conn_reset(fd); + close(fd); + } + } + } + + void setup_conns_to_entity(entity& to, size_t max_conn, + std::vector &fds) { + for (size_t i = 0; i < max_conn; i++) { + int fd = setup_conn_to_entity(to, i + 1lu); + fds.push_back(fd); + + // give a chance to finish all connections + while (get_accepted_conn_num(to) != (i + 1lu)) { + sched_yield(); + progress(); + } + + EXPECT_EQ(1, ucs_socket_is_connected(fd)); + } + } + +private: + void init_data(void *buf, size_t msg_size) { + uct_tcp_am_hdr_t *tcp_am_hdr; + ASSERT_TRUE(msg_size >= sizeof(*tcp_am_hdr)); + tcp_am_hdr = static_cast(buf); + tcp_am_hdr->am_id = std::numeric_limits::max(); + tcp_am_hdr->length = msg_size; + } + + int connect_to_entity(entity& to) { + uct_device_addr_t *dev_addr; + uct_iface_addr_t *iface_addr; + ucs_status_t status; + + dev_addr = (uct_device_addr_t*)malloc(to.iface_attr().device_addr_len); + iface_addr = (uct_iface_addr_t*)malloc(to.iface_attr().iface_addr_len); + + status = uct_iface_get_device_address(to.iface(), dev_addr); + ASSERT_UCS_OK(status); + + status = uct_iface_get_address(to.iface(), iface_addr); + ASSERT_UCS_OK(status); + + struct sockaddr_in dest_addr; + dest_addr.sin_family = AF_INET; + dest_addr.sin_port = *(in_port_t*)iface_addr; + dest_addr.sin_addr = *(struct in_addr*)dev_addr; + + int fd; + status = ucs_socket_create(AF_INET, SOCK_STREAM, &fd); + ASSERT_UCS_OK(status); + + status = ucs_socket_connect(fd, (const struct sockaddr*)&dest_addr); + ASSERT_UCS_OK(status); + + status = ucs_sys_fcntl_modfl(fd, O_NONBLOCK, 0); + ASSERT_UCS_OK(status); + + free(iface_addr); + free(dev_addr); + + return fd; + } + + int setup_conn_to_entity(entity &to, size_t sn = 1) { + int fd = -1; + + do { + if (fd != -1) { + close(fd); + } + + fd = connect_to_entity(to); + EXPECT_NE(-1, fd); + + // give a chance to finish the connection + while (get_accepted_conn_num(to) != sn) { + sched_yield(); + progress(); + + ucs_status_t status = post_recv(fd, true); + if ((status != UCS_OK) && + (status != UCS_ERR_NO_PROGRESS)) { + break; + } + } + } while (!ucs_socket_is_connected(fd)); + + EXPECT_EQ(1, ucs_socket_is_connected(fd)); + + return fd; + } + +protected: + uct_tcp_iface *m_tcp_iface; + entity *m_ent; +}; + +UCS_TEST_P(test_uct_tcp, listener_flood_connect_and_send_large) { + const size_t max_conn = + ucs_min(static_cast(max_connections()), 128lu) / + ucs::test_time_multiplier(); + const size_t msg_size = m_tcp_iface->config.rx_seg_size * 4; + test_listener_flood(*m_ent, max_conn, msg_size); +} + +UCS_TEST_P(test_uct_tcp, listener_flood_connect_and_send_small) { + const size_t max_conn = + ucs_min(static_cast(max_connections()), 128lu) / + ucs::test_time_multiplier(); + // It should be less than length of the expected magic number by TCP + const size_t msg_size = 1; + test_listener_flood(*m_ent, max_conn, msg_size); +} + +UCS_TEST_P(test_uct_tcp, listener_flood_connect_and_close) { + const size_t max_conn = + ucs_min(static_cast(max_connections()), 128lu) / + ucs::test_time_multiplier(); + test_listener_flood(*m_ent, max_conn, 0); +} + +_UCT_INSTANTIATE_TEST_CASE(test_uct_tcp, tcp)