Skip to content

Commit 8d6cc18

Browse files
Allow HTTP Server to bind to a specific address, as opposed to listening on all addresses.
PiperOrigin-RevId: 710784609
1 parent 16cfb2b commit 8d6cc18

File tree

3 files changed

+54
-14
lines changed

3 files changed

+54
-14
lines changed

tensorflow_serving/util/net_http/server/internal/evhttp_server.cc

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -223,14 +223,25 @@ bool EvHTTPServer::StartAcceptingRequests() {
223223

224224
const int port = server_options_->ports().front();
225225

226-
// "::" => in6addr_any
227-
ev_uint16_t ev_port = static_cast<ev_uint16_t>(port);
228-
ev_listener_ = evhttp_bind_socket_with_handle(ev_http_, "::", ev_port);
229-
if (ev_listener_ == nullptr) {
230-
// in case ipv6 is not supported, fallback to inaddr_any
231-
ev_listener_ = evhttp_bind_socket_with_handle(ev_http_, nullptr, ev_port);
226+
if (server_options_->ip_addresses().empty()) {
227+
// "::" => in6addr_any
228+
ev_uint16_t ev_port = static_cast<ev_uint16_t>(port);
229+
ev_listener_ = evhttp_bind_socket_with_handle(ev_http_, "::", ev_port);
232230
if (ev_listener_ == nullptr) {
233-
NET_LOG(ERROR, "Couldn't bind to port %d", port);
231+
// in case ipv6 is not supported, fallback to inaddr_any
232+
ev_listener_ = evhttp_bind_socket_with_handle(ev_http_, nullptr, ev_port);
233+
if (ev_listener_ == nullptr) {
234+
NET_LOG(ERROR, "Couldn't bind to port %d", port);
235+
return false;
236+
}
237+
}
238+
} else {
239+
const std::string& ip_address = server_options_->ip_addresses().front();
240+
ev_listener_ =
241+
evhttp_bind_socket_with_handle(ev_http_, ip_address.c_str(), port);
242+
if (ev_listener_ == nullptr) {
243+
NET_LOG(ERROR, "Couldn't bind address %s to port %d", ip_address.c_str(),
244+
port);
234245
return false;
235246
}
236247
}

tensorflow_serving/util/net_http/server/internal/evhttp_server_test.cc

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,18 @@ class EvHTTPServerTest : public ::testing::Test {
5757
}
5858

5959
protected:
60-
std::unique_ptr<HTTPServerInterface> server;
61-
62-
private:
63-
void InitServer() {
60+
virtual std::unique_ptr<ServerOptions> GetOptions() {
6461
auto options = absl::make_unique<ServerOptions>();
6562
options->AddPort(0);
6663
options->SetExecutor(absl::make_unique<MyExecutor>(4));
64+
return options;
65+
}
6766

68-
server = CreateEvHTTPServer(std::move(options));
67+
std::unique_ptr<HTTPServerInterface> server;
6968

69+
private:
70+
void InitServer() {
71+
server = CreateEvHTTPServer(GetOptions());
7072
ASSERT_TRUE(server != nullptr);
7173
}
7274
};
@@ -339,7 +341,27 @@ TEST_F(EvHTTPServerTest, ActiveRequestCountInShutdown) {
339341
// response.status etc are undefined as the server is terminated
340342
}
341343

344+
class EvHTTPServerTestWithAddress : public EvHTTPServerTest {
345+
protected:
346+
std::unique_ptr<ServerOptions> GetOptions() override {
347+
auto options = EvHTTPServerTest::GetOptions();
348+
options->AddIPAddress("::1");
349+
return options;
350+
}
351+
};
352+
353+
TEST_F(EvHTTPServerTestWithAddress, BasicListen) {
354+
server->StartAcceptingRequests();
355+
356+
EXPECT_TRUE(server->is_accepting_requests());
357+
EXPECT_NE(server->listen_port(), 0);
358+
359+
server->Terminate();
360+
server->WaitForTermination();
361+
}
362+
342363
} // namespace
364+
343365
} // namespace net_http
344366
} // namespace serving
345367
} // namespace tensorflow

tensorflow_serving/util/net_http/server/public/httpserver_interface.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@ limitations under the License.
1919
#define TENSORFLOW_SERVING_UTIL_NET_HTTP_SERVER_PUBLIC_HTTPSERVER_INTERFACE_H_
2020

2121
#include <cassert>
22-
2322
#include <functional>
2423
#include <memory>
24+
#include <string>
2525
#include <vector>
2626

2727
#include "absl/strings/string_view.h"
2828
#include "absl/time/time.h"
29-
3029
#include "tensorflow_serving/util/net_http/server/public/server_request_interface.h"
3130

3231
namespace tensorflow {
@@ -62,18 +61,26 @@ class ServerOptions {
6261
ports_.emplace_back(port);
6362
}
6463

64+
// Add an IP address to listen on. If not specified, the server will listen
65+
// on all addresses.
66+
void AddIPAddress(absl::string_view ip_address) {
67+
ip_addresses_.emplace_back(ip_address);
68+
}
69+
6570
// The default executor for running I/O event polling.
6671
// This is a mandatory option.
6772
void SetExecutor(std::unique_ptr<EventExecutor> executor) {
6873
executor_ = std::move(executor);
6974
}
7075

76+
const std::vector<std::string>& ip_addresses() const { return ip_addresses_; }
7177
const std::vector<int>& ports() const { return ports_; }
7278

7379
EventExecutor* executor() const { return executor_.get(); }
7480

7581
private:
7682
std::vector<int> ports_;
83+
std::vector<std::string> ip_addresses_;
7784
std::unique_ptr<EventExecutor> executor_;
7885
};
7986

0 commit comments

Comments
 (0)