diff --git a/src/Client/Connection.hpp b/src/Client/Connection.hpp index d9626f0ba..b2c6b8fe2 100644 --- a/src/Client/Connection.hpp +++ b/src/Client/Connection.hpp @@ -79,8 +79,24 @@ struct ConnectionImpl void ref(); void unref(); + typename NetProvider::Stream_t &get_strm() { return strm; } + const typename NetProvider::Stream_t &get_strm() const { return strm; } + + void setError(const std::string &msg, int errno_ = 0); + bool hasError() const; + + size_t getFutureCount() const; + + BUFFER &getInBuf(); + BUFFER &getOutBuf(); + + void prepare_auth(std::string_view user, std::string_view passwd); + void commit_auth(std::string_view user, std::string_view passwd); + Connector &connector; BUFFER inBuf; + static constexpr size_t GC_STEP_CNT = 100; + size_t gc_step = 0; BUFFER outBuf; RequestEncoder enc; ResponseDecoder dec; @@ -111,8 +127,8 @@ template ConnectionImpl::~ConnectionImpl() { assert(refs == 0); - if (!strm.has_status(SS_DEAD)) { - connector.close(*this); + if (strm.is_open()) { + connector.close(this); } } @@ -133,6 +149,56 @@ ConnectionImpl::unref() delete this; } +template +void +ConnectionImpl::setError(const std::string &msg, int errno_) +{ + error.emplace(msg, errno_); +} + +template +bool +ConnectionImpl::hasError() const +{ + return error.has_value(); +} + +template +size_t +ConnectionImpl::getFutureCount() const +{ + return futures.size(); +} + +template +BUFFER & +ConnectionImpl::getInBuf() +{ + return inBuf; +} + +template +BUFFER & +ConnectionImpl::getOutBuf() +{ + return outBuf; +} + +template +void +ConnectionImpl::prepare_auth(std::string_view user, std::string_view passwd) +{ + enc.encodeAuth(user, passwd, greeting); +} + +template +void +ConnectionImpl::commit_auth(std::string_view user, std::string_view passwd) +{ + enc.reencodeAuth(user, passwd, greeting); + connector.readyToSend(this); +} + /** Each connection is supposed to be bound to a single socket. */ template class Connection @@ -212,44 +278,11 @@ class Connection BUFFER& getInBuf(); BUFFER& getOutBuf(); - template - friend - void hasSentBytes(Connection &conn, size_t bytes); - - template - friend - void hasNotRecvBytes(Connection &conn, size_t bytes); - - template - friend - bool hasDataToSend(Connection &conn); - - template - friend - bool hasDataToDecode(Connection &conn); - - template - friend - enum DecodeStatus processResponse(Connection &conn, int req_sync, Response *result); - - template - friend - void inputBufGC(Connection &conn); - - template - friend - int decodeGreeting(Connection &conn); - - rid_t prepare_auth(std::string_view user, - std::string_view passwd); - - rid_t commit_auth(std::string_view user, - std::string_view passwd); + void prepare_auth(std::string_view user, std::string_view passwd); + void commit_auth(std::string_view user, std::string_view passwd); private: ConnectionImpl *impl; - static constexpr size_t GC_STEP_CNT = 100; - size_t gc_step = 0; template rid_t insert(const T &tuple, uint32_t space_id); @@ -435,21 +468,21 @@ template size_t Connection::getFutureCount() const { - return impl->futures.size(); + return impl->getFutureCount(); } template void Connection::setError(const std::string &msg, int errno_) { - impl->error.emplace(msg, errno_); + impl->setError(msg, errno_); } template bool Connection::hasError() const { - return impl->error.has_value(); + return impl->hasError(); } template @@ -471,73 +504,79 @@ template BUFFER& Connection::getInBuf() { - return impl->inBuf; + return impl->getInBuf(); } template BUFFER& Connection::getOutBuf() { - return impl->outBuf; + return impl->getOutBuf(); } -template +template void -hasSentBytes(Connection &conn, size_t bytes) +hasSentBytes(ConnectionImpl *conn, size_t bytes) { //dropBack()/dropFront() interfaces require number of bytes be greater //than zero so let's check it first. if (bytes > 0) - conn.impl->outBuf.dropFront(bytes); + conn->getOutBuf().dropFront(bytes); } -template +template void -hasNotRecvBytes(Connection &conn, size_t bytes) +hasNotRecvBytes(ConnectionImpl *conn, size_t bytes) { if (bytes > 0) - conn.impl->inBuf.dropBack(bytes); + conn->getInBuf().dropBack(bytes); } -template +template bool -hasDataToSend(Connection &conn) +hasDataToSend(ConnectionImpl *conn) { //We drop content of input buffer once it has been sent. So to detect //if there's any data to send it's enough to check buffer's emptiness. - return !conn.impl->outBuf.empty(); + return !conn->getOutBuf().empty(); } -template +template bool hasDataToDecode(Connection &conn) { - assert(conn.impl->endDecoded < conn.impl->inBuf.end() || - conn.impl->endDecoded == conn.impl->inBuf.end()); - return conn.impl->endDecoded != conn.impl->inBuf.end(); + return hasDataToDecode(conn.getImpl()); } -template +template +bool +hasDataToDecode(ConnectionImpl *conn) +{ + assert(conn->endDecoded < conn->getInBuf().end() || conn->endDecoded == conn->getInBuf().end()); + return conn->endDecoded != conn->getInBuf().end(); +} + +template static void -inputBufGC(Connection &conn) +inputBufGC(ConnectionImpl *conn) { - if ((conn.gc_step++ % Connection::GC_STEP_CNT) == 0) { - LOG_DEBUG("Flushed input buffer of the connection %p", &conn); - conn.impl->inBuf.flush(); + if (conn->gc_step++ % ConnectionImpl::GC_STEP_CNT == 0) { + LOG_DEBUG("Flushed input buffer of the connection %p", conn); + conn->getInBuf().flush(); } } -template +template DecodeStatus -processResponse(Connection &conn, int req_sync, Response *result) +processResponse(ConnectionImpl *conn, int req_sync, Response *result) { //Decode response. In case of success - fill in feature map //and adjust end-of-decoded data pointer. Call GC if needed. - if (! conn.impl->inBuf.has(conn.impl->endDecoded, MP_RESPONSE_SIZE)) + if (!conn->getInBuf().has(conn->endDecoded, MP_RESPONSE_SIZE)) return DECODE_NEEDMORE; Response response; - response.size = conn.impl->dec.decodeResponseSize(); + response.size = conn->dec.decodeResponseSize(); if (response.size < 0) { LOG_ERROR("Failed to decode response size"); //In case of corrupted response size all other data in the buffer @@ -548,15 +587,15 @@ processResponse(Connection &conn, int req_sync, ResponseinBuf.has(conn.impl->endDecoded, response.size)) { + if (!conn->getInBuf().has(conn->endDecoded, response.size)) { //Response was received only partially. Reset decoder position //to the start of response to make this function re-entered. - conn.impl->dec.reset(conn.impl->endDecoded); + conn->dec.reset(conn->endDecoded); return DECODE_NEEDMORE; } - if (conn.impl->dec.decodeResponse(response) != 0) { - conn.setError("Failed to decode response, skipping bytes.."); - conn.impl->endDecoded += response.size; + if (conn->dec.decodeResponse(response) != 0) { + conn->setError("Failed to decode response, skipping bytes.."); + conn->endDecoded += response.size; return DECODE_ERR; } LOG_DEBUG("Header: sync=", response.header.sync, ", code=", @@ -564,39 +603,37 @@ processResponse(Connection &conn, int req_sync, Responsefutures.insert({response.header.sync, - std::move(response)}); + conn->futures.insert({response.header.sync, std::move(response)}); } - conn.impl->endDecoded += response.size; + conn->endDecoded += response.size; inputBufGC(conn); return DECODE_SUCC; } -template +template int -decodeGreeting(Connection &conn) +decodeGreeting(ConnectionImpl *conn) { //TODO: that's not zero-copy, should be rewritten in that pattern. - assert(conn.getInBuf().has(conn.impl->endDecoded, Iproto::GREETING_SIZE)); + assert(conn->getInBuf().has(conn->endDecoded, Iproto::GREETING_SIZE)); char greeting_buf[Iproto::GREETING_SIZE]; - conn.impl->endDecoded.read({greeting_buf, sizeof(greeting_buf)}); - conn.impl->dec.reset(conn.impl->endDecoded); - if (parseGreeting(std::string_view{greeting_buf, Iproto::GREETING_SIZE}, - conn.impl->greeting) != 0) + conn->endDecoded.read({greeting_buf, sizeof(greeting_buf)}); + conn->dec.reset(conn->endDecoded); + if (parseGreeting(std::string_view {greeting_buf, Iproto::GREETING_SIZE}, conn->greeting) != 0) return -1; - conn.impl->is_greeting_received = true; - LOG_DEBUG("Version: ", conn.impl->greeting.version_id); + conn->is_greeting_received = true; + LOG_DEBUG("Version: ", conn->greeting.version_id); #ifndef NDEBUG //print salt in hex format. char hex_salt[Iproto::MAX_SALT_SIZE * 2 + 1]; const char *hex = "0123456789abcdef"; - for (size_t i = 0; i < conn.impl->greeting.salt_size; i++) { - uint8_t u = conn.impl->greeting.salt[i]; + for (size_t i = 0; i < conn->greeting.salt_size; i++) { + uint8_t u = conn->greeting.salt[i]; hex_salt[i * 2] = hex[u / 16]; hex_salt[i * 2 + 1] = hex[u % 16]; } - hex_salt[conn.impl->greeting.salt_size * 2] = 0; + hex_salt[conn->greeting.salt_size * 2] = 0; LOG_DEBUG("Salt: ", hex_salt); #endif return 0; @@ -717,21 +754,16 @@ Connection::select(const T &key, uint32_t space_id, return impl->enc.getSync(); } -template -rid_t -Connection::prepare_auth(std::string_view user, - std::string_view passwd) +template +void +Connection::prepare_auth(std::string_view user, std::string_view passwd) { - impl->enc.encodeAuth(user, passwd, impl->greeting); - return 0; + impl->prepare_auth(user, passwd); } -template -rid_t -Connection::commit_auth(std::string_view user, - std::string_view passwd) +template +void +Connection::commit_auth(std::string_view user, std::string_view passwd) { - impl->enc.reencodeAuth(user, passwd, impl->greeting);; - impl->connector.readyToSend(*this); - return 0; + impl->commit_auth(user, passwd); } diff --git a/src/Client/Connector.hpp b/src/Client/Connector.hpp index f0d52ded1..f2b9b6033 100644 --- a/src/Client/Connector.hpp +++ b/src/Client/Connector.hpp @@ -82,13 +82,14 @@ class Connector size_t feature_count, int timeout = -1); std::optional> waitAny(int timeout = -1); ////////////////////////////Service interfaces////////////////////////// - void readyToDecode(const Connection &conn); - void readyToSend(const Connection &conn); - void finishSend(const Connection &conn); + void readyToDecode(ConnectionImpl *conn); + void readyToSend(ConnectionImpl *conn); + void readyToSend(Connection &conn); + void finishSend(ConnectionImpl *conn); - std::set> m_ReadyToSend; + std::set *> m_ReadyToSend; void close(Connection &conn); - void close(ConnectionImpl &conn); + void close(ConnectionImpl *conn); private: /** @@ -98,6 +99,8 @@ class Connector * `req_sync` sync. If `result` is `nullptr` - `req_sync` is ignored. * Returns -1 in the case of any error, 0 on success. */ + int connectionDecodeResponses(ConnectionImpl *conn, int req_sync = -1, + Response *result = nullptr); int connectionDecodeResponses(Connection &conn, int req_sync = -1, Response *result = nullptr); /** @@ -125,7 +128,11 @@ class Connector * Shouldn't be modified directly - is managed by methods `readyToDecode` * and `connectionDecodeResponses`. */ - std::set> m_ReadyToDecode; + std::set *> m_ReadyToDecode; + /** + * Set of active connections owned by connector. + */ + std::set *> m_Connections; }; template @@ -145,7 +152,7 @@ Connector::connect(Connection &conn, { //Make sure that connection is not yet established. assert(conn.get_strm().has_status(SS_DEAD)); - if (m_NetProvider.connect(conn, opts) != 0) { + if (m_NetProvider.connect(conn.getImpl(), opts) != 0) { LOG_ERROR("Failed to connect to ", opts.address, ':', opts.service); return -1; @@ -154,10 +161,11 @@ Connector::connect(Connection &conn, conn.getImpl()->is_auth_required = !opts.user.empty(); if (conn.getImpl()->is_auth_required) { // Encode auth request to reserve space in buffer. - conn.prepare_auth(opts.user, opts.passwd); + conn.getImpl()->prepare_auth(opts.user, opts.passwd); } LOG_DEBUG("Connection to ", opts.address, ':', opts.service, " has been established"); + m_Connections.insert(conn.getImpl()); return 0; } @@ -178,20 +186,24 @@ template void Connector::close(Connection &conn) { - return close(*conn.getImpl()); + return close(conn.getImpl()); } -template +template void -Connector::close(ConnectionImpl &conn) +Connector::close(ConnectionImpl *conn) { - assert(!conn.strm.has_status(SS_DEAD)); - m_NetProvider.close(conn.strm); + if (conn->get_strm().is_open()) { + m_NetProvider.close(conn->get_strm()); + m_ReadyToSend.erase(conn); + m_ReadyToDecode.erase(conn); + m_Connections.erase(conn); + } } -template +template int -Connector::connectionDecodeResponses(Connection &conn, int req_sync, +Connector::connectionDecodeResponses(ConnectionImpl *conn, int req_sync, Response *result) { if (!hasDataToDecode(conn)) @@ -222,6 +234,14 @@ Connector::connectionDecodeResponses(Connection +int +Connector::connectionDecodeResponses(Connection &conn, int req_sync, + Response *result) +{ + return connectionDecodeResponses(conn.getImpl(), req_sync, result); +} + template int Connector::connectionCheckResponsesReadiness(Connection &conn, @@ -342,9 +362,28 @@ template std::optional> Connector::waitAny(int timeout) { + if (m_Connections.empty()) { + LOG_DEBUG("waitAny() called on connector without connections"); + return std::nullopt; + } + for (auto *conn : m_Connections) { + if (conn->getFutureCount() != 0) + return conn; + } Timer timer{timeout}; timer.start(); while (m_ReadyToDecode.empty()) { + bool has_alive_conn = false; + for (auto *conn : m_Connections) { + if (!conn->hasError()) { + has_alive_conn = true; + break; + } + } + if (!has_alive_conn) { + LOG_ERROR("All connections have an error"); + return std::nullopt; + } if (m_NetProvider.wait(timer.timeLeft()) != 0) { LOG_ERROR("Failed to poll connections: ", strerror(errno)); return std::nullopt; @@ -356,7 +395,7 @@ Connector::waitAny(int timeout) LOG_DEBUG("wait() has been timed out! No responses are received"); return std::nullopt; } - Connection conn = *m_ReadyToDecode.begin(); + auto *conn = *m_ReadyToDecode.begin(); assert(hasDataToDecode(conn)); if (connectionDecodeResponses(conn) != 0) return std::nullopt; @@ -399,28 +438,34 @@ Connector::waitCount(Connection &conn, return -1; } -template +template void -Connector::readyToSend(const Connection &conn) +Connector::readyToSend(ConnectionImpl *conn) { - if (conn.getImpl()->is_auth_required && - !conn.getImpl()->is_greeting_received) { + if (conn->is_auth_required && !conn->is_greeting_received) { // Need to receive greeting first. return; } m_ReadyToSend.insert(conn); } -template +template +void +Connector::readyToSend(Connection &conn) +{ + readyToSend(conn.getImpl()); +} + +template void -Connector::readyToDecode(const Connection &conn) +Connector::readyToDecode(ConnectionImpl *conn) { m_ReadyToDecode.insert(conn); } -template +template void -Connector::finishSend(const Connection &conn) +Connector::finishSend(ConnectionImpl *conn) { m_ReadyToSend.erase(conn); } diff --git a/src/Client/EpollNetProvider.hpp b/src/Client/EpollNetProvider.hpp index ccd6aa826..2999678cc 100644 --- a/src/Client/EpollNetProvider.hpp +++ b/src/Client/EpollNetProvider.hpp @@ -53,11 +53,11 @@ class EpollNetProvider { using Buffer_t = BUFFER; using Stream_t = Stream; using NetProvider_t = EpollNetProvider; - using Conn_t = Connection; + using ConnImpl_t = ConnectionImpl; using Connector_t = Connector; EpollNetProvider(Connector_t &connector); ~EpollNetProvider(); - int connect(Conn_t &conn, const ConnectOptions &opts); + int connect(ConnImpl_t *conn, const ConnectOptions &opts); void close(Stream_t &strm); /** Read and write to sockets; polling using epoll. */ int wait(int timeout); @@ -69,11 +69,11 @@ class EpollNetProvider { //return 0 if all data from buffer was processed (sent or read); //return -1 in case of errors; //return 1 in case socket is blocked. - int send(Conn_t &conn); - int recv(Conn_t &conn); + int send(ConnImpl_t *conn); + int recv(ConnImpl_t *conn); - void setPollSetting(Conn_t &conn, int setting); - void registerEpoll(Conn_t &conn); + void setPollSetting(ConnImpl_t *conn, int setting); + void registerEpoll(ConnImpl_t *conn); /** map. Contains both ready to read/send connections */ Connector_t &m_Connector; @@ -98,17 +98,16 @@ EpollNetProvider::~EpollNetProvider() m_EpollFd = -1; } -template +template void -EpollNetProvider::registerEpoll(Conn_t &conn) +EpollNetProvider::registerEpoll(ConnImpl_t *conn) { /* Configure epoll with new socket. */ assert(m_EpollFd >= 0); struct epoll_event event; event.events = EPOLLIN; - event.data.ptr = conn.getImpl(); - if (epoll_ctl(m_EpollFd, EPOLL_CTL_ADD, conn.get_strm().get_fd(), - &event) != 0) { + event.data.ptr = conn; + if (epoll_ctl(m_EpollFd, EPOLL_CTL_ADD, conn->get_strm().get_fd(), &event) != 0) { LOG_ERROR("Failed to add socket to epoll: " "epoll_ctl() returned with errno: ", strerror(errno)); @@ -116,14 +115,14 @@ EpollNetProvider::registerEpoll(Conn_t &conn) } } -template +template void -EpollNetProvider::setPollSetting(Conn_t &conn, int setting) { +EpollNetProvider::setPollSetting(ConnImpl_t *conn, int setting) +{ struct epoll_event event; event.events = setting; - event.data.ptr = conn.getImpl(); - if (epoll_ctl(m_EpollFd, EPOLL_CTL_MOD, conn.get_strm().get_fd(), - &event) != 0) { + event.data.ptr = conn; + if (epoll_ctl(m_EpollFd, EPOLL_CTL_MOD, conn->get_strm().get_fd(), &event) != 0) { LOG_ERROR("Failed to change epoll mode: " "epoll_ctl() returned with errno: ", strerror(errno)); @@ -131,15 +130,13 @@ EpollNetProvider::setPollSetting(Conn_t &conn, int setting) { } } -template +template int -EpollNetProvider::connect(Conn_t &conn, - const ConnectOptions &opts) +EpollNetProvider::connect(ConnImpl_t *conn, const ConnectOptions &opts) { - auto &strm = conn.get_strm(); + auto &strm = conn->get_strm(); if (strm.connect(opts) < 0) { - conn.setError("Failed to establish connection to " + - opts.address); + conn->setError("Failed to establish connection to " + opts.address); return -1; } LOG_DEBUG("Connected to ", opts.address, ", socket is ", strm.get_fd()); @@ -181,69 +178,66 @@ EpollNetProvider::close(Stream_t& strm) epoll_ctl(m_EpollFd, EPOLL_CTL_DEL, was_fd, nullptr); } -template +template int -EpollNetProvider::recv(Conn_t &conn) +EpollNetProvider::recv(ConnImpl_t *conn) { - auto &buf = conn.getInBuf(); + auto &buf = conn->getInBuf(); auto itr = buf.template end(); buf.write({CONN_READAHEAD}); struct iovec iov[IOVEC_MAX_SIZE]; size_t iov_cnt = buf.getIOV(itr, iov, IOVEC_MAX_SIZE); - ssize_t rcvd = conn.get_strm().recv(iov, iov_cnt); + ssize_t rcvd = conn->get_strm().recv(iov, iov_cnt); hasNotRecvBytes(conn, CONN_READAHEAD - (rcvd < 0 ? 0 : rcvd)); if (rcvd < 0) { - conn.setError(std::string("Failed to receive response: ") + - strerror(errno), errno); + conn->setError(std::string("Failed to receive response: ") + strerror(errno), errno); return -1; } if (rcvd == 0) { - assert(conn.get_strm().has_status(SS_NEED_EVENT_FOR_READ)); - if (conn.get_strm().has_status(SS_NEED_WRITE_EVENT_FOR_READ)) + assert(conn->get_strm().has_status(SS_NEED_EVENT_FOR_READ)); + if (conn->get_strm().has_status(SS_NEED_WRITE_EVENT_FOR_READ)) setPollSetting(conn, EPOLLIN | EPOLLOUT); } - if (!conn.getImpl()->is_greeting_received) { + if (!conn->is_greeting_received) { if ((size_t) rcvd < Iproto::GREETING_SIZE) return 0; /* Receive and decode greetings. */ LOG_DEBUG("Greetings are received, read bytes ", rcvd); if (decodeGreeting(conn) != 0) { - conn.setError("Failed to decode greetings"); + conn->setError("Failed to decode greetings"); return -1; } LOG_DEBUG("Greetings are decoded"); rcvd -= Iproto::GREETING_SIZE; - if (conn.getImpl()->is_auth_required) { + if (conn->is_auth_required) { // Finalize auth request in buffer. - conn.commit_auth(conn.get_strm().get_opts().user, - conn.get_strm().get_opts().passwd); + conn->commit_auth(conn->get_strm().get_opts().user, conn->get_strm().get_opts().passwd); } } return 0; } -template +template int -EpollNetProvider::send(Conn_t &conn) +EpollNetProvider::send(ConnImpl_t *conn) { while (hasDataToSend(conn)) { struct iovec iov[IOVEC_MAX_SIZE]; - auto &buf = conn.getOutBuf(); + auto &buf = conn->getOutBuf(); size_t iov_cnt = buf.getIOV(buf.template begin(), iov, IOVEC_MAX_SIZE); - ssize_t sent = conn.get_strm().send(iov, iov_cnt); + ssize_t sent = conn->get_strm().send(iov, iov_cnt); if (sent < 0) { - conn.setError(std::string("Failed to send request: ") + - strerror(errno), errno); + conn->setError(std::string("Failed to send request: ") + strerror(errno), errno); return -1; } else if (sent == 0) { - assert(conn.get_strm().has_status(SS_NEED_EVENT_FOR_WRITE)); - if (conn.get_strm().has_status(SS_NEED_WRITE_EVENT_FOR_WRITE)) + assert(conn->get_strm().has_status(SS_NEED_EVENT_FOR_WRITE)); + if (conn->get_strm().has_status(SS_NEED_WRITE_EVENT_FOR_WRITE)) setPollSetting(conn, EPOLLIN | EPOLLOUT); return 1; } else { @@ -265,8 +259,7 @@ EpollNetProvider::wait(int timeout) /* Send pending requests. */ for (auto conn = m_Connector.m_ReadyToSend.begin(); conn != m_Connector.m_ReadyToSend.end();) { - Conn_t to_be_send(*conn); - (void) send(to_be_send); + (void)send(*conn); conn = m_Connector.m_ReadyToSend.erase(conn); } @@ -280,12 +273,11 @@ EpollNetProvider::wait(int timeout) return -1; } for (int i = 0; i < event_cnt; ++i) { - Conn_t conn((typename Conn_t::Impl_t *)events[i].data.ptr); + auto *conn = reinterpret_cast(events[i].data.ptr); if ((events[i].events & EPOLLIN) != 0) { - LOG_DEBUG("Registered poll event ", i, ": ", - conn.get_strm().get_fd(), + LOG_DEBUG("Registered poll event ", i, ": ", conn->get_strm().get_fd(), " socket is ready to read"); - if (conn.get_strm().has_status(SS_NEED_READ_EVENT_FOR_WRITE)) { + if (conn->get_strm().has_status(SS_NEED_READ_EVENT_FOR_WRITE)) { int rc = send(conn); if (rc < 0) return -1; @@ -302,10 +294,9 @@ EpollNetProvider::wait(int timeout) } if ((events[i].events & EPOLLOUT) != 0) { - LOG_DEBUG("Registered poll event ", i, ": ", - conn.get_strm().get_fd(), + LOG_DEBUG("Registered poll event ", i, ": ", conn->get_strm().get_fd(), " socket is ready to write"); - if (conn.get_strm().has_status(SS_NEED_WRITE_EVENT_FOR_READ)) { + if (conn->get_strm().has_status(SS_NEED_WRITE_EVENT_FOR_READ)) { int rc = recv(conn); if (rc < 0) return -1; diff --git a/src/Client/LibevNetProvider.hpp b/src/Client/LibevNetProvider.hpp index 431d1c126..cbd039051 100644 --- a/src/Client/LibevNetProvider.hpp +++ b/src/Client/LibevNetProvider.hpp @@ -53,21 +53,22 @@ class LibevNetProvider; template struct WaitWatcher { WaitWatcher(Connector> *client, - Connection> conn, - struct ev_timer *t); + ConnectionImpl> *conn, struct ev_timer *t); struct ev_io in; struct ev_io out; Connector> *connector; - Connection> connection; + ConnectionImpl> *connection; struct ev_timer *timer; }; -template +template WaitWatcher::WaitWatcher(Connector> *client, - Connection> conn, - struct ev_timer *t) : connector(client), connection(conn), - timer(t) + ConnectionImpl> *conn, + struct ev_timer *t) + : connector(client) + , connection(conn) + , timer(t) { in.data = this; out.data = this; @@ -86,11 +87,11 @@ class LibevNetProvider { using Buffer_t = BUFFER; using Stream_t = Stream; using NetProvider_t = LibevNetProvider; - using Conn_t = Connection; + using ConnImpl_t = ConnectionImpl; using Connector_t = Connector; LibevNetProvider(Connector_t &connector, struct ev_loop *loop = nullptr); - int connect(Conn_t &conn, const ConnectOptions &opts); + int connect(ConnImpl_t *conn, const ConnectOptions &opts); void close(Stream_t &strm); int wait(int timeout); @@ -99,7 +100,7 @@ class LibevNetProvider { private: static constexpr float MILLISECONDS = 1000.f; - void registerWatchers(Conn_t &conn, int fd); + void registerWatchers(ConnImpl_t *conn, int fd); void stopWatchers(WaitWatcher *watcher); /** Callback for libev timeout. */ static void timeout_cb(EV_P_ ev_timer *w, int revents); @@ -120,39 +121,37 @@ LibevNetProvider::stopWatchers(struct WaitWatcherout); } -template +template static inline int -connectionReceive(Connection> &conn) +connectionReceive(ConnectionImpl> *conn) { - auto &buf = conn.getInBuf(); + auto &buf = conn->inBuf; auto itr = buf.template end(); buf.write({CONN_READAHEAD}); struct iovec iov[IOVEC_MAX_SIZE]; size_t iov_cnt = buf.getIOV(itr, iov, IOVEC_MAX_SIZE); - ssize_t rcvd = conn.get_strm().recv(iov, iov_cnt); + ssize_t rcvd = conn->get_strm().recv(iov, iov_cnt); hasNotRecvBytes(conn, CONN_READAHEAD - (rcvd < 0 ? 0 : rcvd)); if (rcvd < 0) { - conn.setError(std::string("Failed to receive response: ") + - strerror(errno), errno); + conn->setError(std::string("Failed to receive response: ") + strerror(errno), errno); return -1; } - if (!conn.getImpl()->is_greeting_received) { + if (!conn->is_greeting_received) { if ((size_t) rcvd < Iproto::GREETING_SIZE) return 0; /* Receive and decode greetings. */ LOG_DEBUG("Greetings are received, read bytes ", rcvd); if (decodeGreeting(conn) != 0) { - conn.setError("Failed to decode greetings"); + conn->setError("Failed to decode greetings"); return -1; } LOG_DEBUG("Greetings are decoded"); rcvd -= Iproto::GREETING_SIZE; - if (conn.getImpl()->is_auth_required) { + if (conn->is_auth_required) { // Finalize auth request in buffer. - conn.commit_auth(conn.get_strm().get_opts().user, - conn.get_strm().get_opts().passwd); + conn->commit_auth(conn->get_strm().get_opts().user, conn->get_strm().get_opts().passwd); } } @@ -163,17 +162,16 @@ template static void recv_cb(struct ev_loop *loop, struct ev_io *watcher, int /* revents */) { - using NetProvider_t = LibevNetProvider; using Connector_t = Connector>; struct WaitWatcher *waitWatcher = reinterpret_cast *>(watcher->data); assert(&waitWatcher->in == watcher); - Connection conn = waitWatcher->connection; - assert(waitWatcher->in.fd == conn.get_strm().get_fd()); + auto *conn = waitWatcher->connection; + assert(waitWatcher->in.fd == conn->get_strm().get_fd()); - if (conn.get_strm().has_status(SS_NEED_READ_EVENT_FOR_WRITE)) - ev_feed_fd_event(loop, conn.get_strm().get_fd(), EV_WRITE); + if (conn->get_strm().has_status(SS_NEED_READ_EVENT_FOR_WRITE)) + ev_feed_fd_event(loop, conn->get_strm().get_fd(), EV_WRITE); timerDisable(loop, waitWatcher->timer); int rc = connectionReceive(conn); @@ -182,7 +180,7 @@ recv_cb(struct ev_loop *loop, struct ev_io *watcher, int /* revents */) return; if (rc > 0) { /* Recv is not complete, setting the write watcher if needed. */ - if (conn.get_strm().has_status(SS_NEED_WRITE_EVENT_FOR_READ)) + if (conn->get_strm().has_status(SS_NEED_WRITE_EVENT_FOR_READ)) if (!ev_is_active(&waitWatcher->out)) ev_io_start(loop, &waitWatcher->out); return; @@ -191,28 +189,26 @@ recv_cb(struct ev_loop *loop, struct ev_io *watcher, int /* revents */) connector->readyToDecode(conn); } -template +template static inline int -connectionSend(Connection> &conn) +connectionSend(ConnectionImpl> *conn) { - if (conn.getImpl()->is_auth_required && - !conn.getImpl()->is_greeting_received) { + if (conn->is_auth_required && !conn->is_greeting_received) { // Need to receive greeting first. return 0; } while (hasDataToSend(conn)) { struct iovec iov[IOVEC_MAX_SIZE]; - auto &buf = conn.getOutBuf(); + auto &buf = conn->getOutBuf(); size_t iov_cnt = buf.getIOV(buf.template begin(), iov, IOVEC_MAX_SIZE); - ssize_t sent = conn.get_strm().send(iov, iov_cnt); + ssize_t sent = conn->get_strm().send(iov, iov_cnt); if (sent < 0) { - conn.setError(std::string("Failed to send request: ") + - strerror(errno), errno); + conn->setError(std::string("Failed to send request: ") + strerror(errno), errno); return -1; } else if (sent == 0) { - assert(conn.get_strm().has_status(SS_NEED_EVENT_FOR_WRITE)); + assert(conn->get_strm().has_status(SS_NEED_EVENT_FOR_WRITE)); return 1; } else { hasSentBytes(conn, sent); @@ -226,18 +222,17 @@ template static void send_cb(struct ev_loop *loop, struct ev_io *watcher, int /* revents */) { - using NetProvider_t = LibevNetProvider; using Connector_t = Connector>; struct WaitWatcher *waitWatcher = reinterpret_cast *>(watcher->data); assert(&waitWatcher->out == watcher); Connector_t *connector = waitWatcher->connector; - Connection &conn = waitWatcher->connection; - assert(watcher->fd == conn.get_strm().get_fd()); + auto *conn = waitWatcher->connection; + assert(watcher->fd == conn->get_strm().get_fd()); - if (conn.get_strm().has_status(SS_NEED_WRITE_EVENT_FOR_READ)) - ev_feed_fd_event(loop, conn.get_strm().get_fd(), EV_READ); + if (conn->get_strm().has_status(SS_NEED_WRITE_EVENT_FOR_READ)) + ev_feed_fd_event(loop, conn->get_strm().get_fd(), EV_READ); timerDisable(loop, waitWatcher->timer); int rc = connectionSend(conn); @@ -248,7 +243,7 @@ send_cb(struct ev_loop *loop, struct ev_io *watcher, int /* revents */) if (rc > 0) { /* Send is not complete, setting the write watcher. */ LOG_DEBUG("Send is not complete, setting the write watcher"); - if (conn.get_strm().has_status(SS_NEED_WRITE_EVENT_FOR_WRITE)) + if (conn->get_strm().has_status(SS_NEED_WRITE_EVENT_FOR_WRITE)) if (!ev_is_active(&waitWatcher->out)) ev_io_start(loop, &waitWatcher->out); return; @@ -283,7 +278,7 @@ LibevNetProvider::~LibevNetProvider() for (auto w = m_Watchers.begin(); w != m_Watchers.end();) { WaitWatcher *to_delete = w->second; stopWatchers(to_delete); - assert(to_delete->connection.get_strm().get_fd() == w->first); + assert(to_delete->connection->get_strm().get_fd() == w->first); w = m_Watchers.erase(w); delete to_delete; } @@ -294,9 +289,9 @@ LibevNetProvider::~LibevNetProvider() m_Loop = nullptr; } -template +template void -LibevNetProvider::registerWatchers(Conn_t &conn, int fd) +LibevNetProvider::registerWatchers(ConnImpl_t *conn, int fd) { WaitWatcher *watcher = new (std::nothrow) WaitWatcher(&m_Connector, @@ -314,15 +309,13 @@ LibevNetProvider::registerWatchers(Conn_t &conn, int fd) ev_io_start(m_Loop ,&watcher->out); } -template +template int -LibevNetProvider::connect(Conn_t &conn, - const ConnectOptions &opts) +LibevNetProvider::connect(ConnImpl_t *conn, const ConnectOptions &opts) { - auto &strm = conn.get_strm(); + auto &strm = conn->get_strm(); if (strm.connect(opts) < 0) { - conn.setError("Failed to establish connection to " + - opts.address); + conn->setError("Failed to establish connection to " + opts.address); return -1; } LOG_DEBUG("Connected to ", opts.address, ", socket is ", strm.get_fd()); @@ -371,7 +364,7 @@ LibevNetProvider::wait(int timeout) /* Queue pending connections to be send. */ for (auto conn = m_Connector.m_ReadyToSend.begin(); conn != m_Connector.m_ReadyToSend.end();) { - auto w = m_Watchers.find(conn->get_strm().get_fd()); + auto w = m_Watchers.find((*conn)->get_strm().get_fd()); if (w != m_Watchers.end()) { if (!ev_is_active(&w->second->out)) ev_feed_event(m_Loop, &w->second->out, EV_WRITE); diff --git a/src/Client/UnixStream.hpp b/src/Client/UnixStream.hpp index 5e8bc589b..289689c0b 100644 --- a/src/Client/UnixStream.hpp +++ b/src/Client/UnixStream.hpp @@ -66,6 +66,9 @@ class UnixStream : public Stream { /** Get internal file descriptor of the socket. */ int get_fd() const { return fd; } + /** Check whether the socket is open. */ + bool is_open() const { return fd != -1; } + protected: /** Log helpers. */ template @@ -168,6 +171,7 @@ UnixStream::connect(const ConnectOptions &opts_arg) addr_info.last_error()); int socket_errno = 0, connect_errno = 0; for (auto &inf: addr_info) { + socket_errno = connect_errno = 0; fd = ::socket(inf.ai_family, inf.ai_socktype, inf.ai_protocol); if (fd < 0) { socket_errno = errno; @@ -189,10 +193,7 @@ UnixStream::connect(const ConnectOptions &opts_arg) return US_TELL(SS_ESTABLISHED, "Connected", opts); } else if (errno == EINPROGRESS || errno == EAGAIN) { - // TODO remove timeout and #include - //return US_TELL(SS_CONNECT_PENDING, - // "Connect pending", opts); - + errno = 0; set_status(SS_CONNECT_PENDING); struct pollfd fds; fds.fd = fd; @@ -206,7 +207,10 @@ UnixStream::connect(const ConnectOptions &opts_arg) } } while (errno == EINTR); close(); - connect_errno = errno; + if (connect_errno == 0) + connect_errno = errno; + else + assert(connect_errno == ETIMEDOUT && errno == 0); } if (connect_errno != 0) return US_DIE("Failed to connect", strerror(connect_errno)); diff --git a/test/ClientMultithreadTest.cpp b/test/ClientMultithreadTest.cpp index 5c4c068d3..ade6be056 100644 --- a/test/ClientMultithreadTest.cpp +++ b/test/ClientMultithreadTest.cpp @@ -47,6 +47,7 @@ test_connect(Connector &client, Connection &conn, const std::string &addr, unsig .address = addr, .service = service, .transport = transport, + .connect_timeout = 10, .user = user, .passwd = passwd, }); diff --git a/test/ClientTest.cpp b/test/ClientTest.cpp index 41e1424fd..6a8ae828f 100644 --- a/test/ClientTest.cpp +++ b/test/ClientTest.cpp @@ -182,6 +182,8 @@ trivial(Connector &client) TEST_CASE("Connect timeout"); rc = test_connect(client, conn, "8.8.8.8", port); fail_unless(rc != 0); + TEST_CASE("Close of non-established connection (gh-142)"); + client.close(conn); } /** Single connection, separate/sequence pings, no errors */ @@ -235,6 +237,9 @@ single_conn_ping(Connector &client) fail_unless(response->body.error_stack == std::nullopt); } client.close(conn); + + TEST_CASE("Double close of connection (gh-142)"); + client.close(conn); } template @@ -261,6 +266,23 @@ auto_close(Connector &client) std::optional> response = conn.getResponse(f); fail_unless(response != std::nullopt); } + + TEST_CASE("Waiting after connection is automatically closed (gh-140)"); + { + Connection conn(client); + fail_unless(test_connect(client, conn, localhost, port) == 0); + rid_t f = conn.ping(); + fail_unless(!conn.futureIsReady(f)); + } + fail_unless(client.waitAny() == std::nullopt); + { + Connection conn(client); + fail_unless(test_connect(client, conn, localhost, port) == 0); + rid_t f = conn.ping(); + fail_unless(!conn.futureIsReady(f)); + client.wait(conn, f, 0); + } + fail_unless(client.waitAny() == std::nullopt); } /** Several connection, separate/sequence pings, no errors */ @@ -1065,6 +1087,10 @@ test_sigpipe(void) fail_unless(saved_errno == EPIPE); #endif fail_if(conn.futureIsReady(f)); + + TEST_CASE("Close of connection with error (gh-142)"); + fail_unless(conn.hasError()); + client.close(conn); } /** Single connection, wait response from closed connection. */ @@ -1092,11 +1118,9 @@ test_dead_connection_wait(void) fail_if(client.waitCount(conn, 1) == 0); fail_if(conn.futureIsReady(f)); - /* FIXME(gh-51) */ -#if 0 + TEST_CASE("waitAny() correctly handles case when all connections have an error (gh-51"); fail_if(client.waitAny() != std::nullopt); fail_if(conn.futureIsReady(f)); -#endif } /** @@ -1381,13 +1405,10 @@ test_wait(Connector &client) /* FIXME(gh-143): test solely that we check future readiness before waiting. */ fail_unless(client.waitCount(conn, 0) == 0); conn.getResponse(f); - /* FIXME(gh-132): waitAny does not check connections for ready futures. */ -#if 0 f = conn.ping(); fail_unless(client.wait(conn, f, WAIT_TIMEOUT) == 0); - fail_unless(client.waitAny(conn).has_value()); + fail_unless(client.waitAny().has_value()); conn.getResponse(f); -#endif #ifdef __linux__ TEST_CASE("wait methods internal wait failure (gh-121)"); @@ -1425,6 +1446,9 @@ test_wait(Connector &client) #endif /* __linux__ */ client.close(conn); + + TEST_CASE("waitAny() correctly handles case when there are no connections (gh-51"); + fail_if(client.waitAny() != std::nullopt); } int main()