diff --git a/httplib.h b/httplib.h index 25e1b2f1ab..224689acf9 100644 --- a/httplib.h +++ b/httplib.h @@ -1546,7 +1546,7 @@ class ClientImpl { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT void set_ca_cert_path(const std::string &ca_cert_file_path, const std::string &ca_cert_dir_path = std::string()); - void set_ca_cert_store(X509_STORE *ca_cert_store); + virtual void set_ca_cert_store(X509_STORE *ca_cert_store); X509_STORE *create_ca_cert_store(const char *ca_cert, std::size_t size) const; #endif @@ -8984,7 +8984,9 @@ inline bool ClientImpl::create_redirect_client( } // Handle CA certificate store and paths if available - if (ca_cert_store_) { redirect_client.set_ca_cert_store(ca_cert_store_); } + if (ca_cert_store_ && X509_STORE_up_ref(ca_cert_store_)) { + redirect_client.set_ca_cert_store(ca_cert_store_); + } if (!ca_cert_file_path_.empty()) { redirect_client.set_ca_cert_path(ca_cert_file_path_, ca_cert_dir_path_); } @@ -10871,6 +10873,7 @@ inline void SSLClient::set_ca_cert_store(X509_STORE *ca_cert_store) { if (SSL_CTX_get_cert_store(ctx_) != ca_cert_store) { // Free memory allocated for old cert and use new store `ca_cert_store` SSL_CTX_set_cert_store(ctx_, ca_cert_store); + ca_cert_store_ = ca_cert_store; } } else { X509_STORE_free(ca_cert_store); @@ -11857,7 +11860,7 @@ inline void Client::set_ca_cert_path(const std::string &ca_cert_file_path, inline void Client::set_ca_cert_store(X509_STORE *ca_cert_store) { if (is_ssl_) { - static_cast(*cli_).set_ca_cert_store(ca_cert_store); + dynamic_cast(*cli_).set_ca_cert_store(ca_cert_store); } else { cli_->set_ca_cert_store(ca_cert_store); } @@ -11869,13 +11872,13 @@ inline void Client::load_ca_cert_store(const char *ca_cert, std::size_t size) { inline long Client::get_openssl_verify_result() const { if (is_ssl_) { - return static_cast(*cli_).get_openssl_verify_result(); + return dynamic_cast(*cli_).get_openssl_verify_result(); } return -1; // NOTE: -1 doesn't match any of X509_V_ERR_??? } inline SSL_CTX *Client::ssl_context() const { - if (is_ssl_) { return static_cast(*cli_).ssl_context(); } + if (is_ssl_) { return dynamic_cast(*cli_).ssl_context(); } return nullptr; } #endif diff --git a/test/test.cc b/test/test.cc index 0e4fd42d58..3a2bd4791c 100644 --- a/test/test.cc +++ b/test/test.cc @@ -8963,6 +8963,45 @@ TEST(HttpToHttpsRedirectTest, CertFile) { ASSERT_EQ(StatusCode::OK_200, res->status); } +TEST(SSLClientRedirectTest, CertFile) { + SSLServer ssl_svr1(SERVER_CERT2_FILE, SERVER_PRIVATE_KEY_FILE); + ASSERT_TRUE(ssl_svr1.is_valid()); + ssl_svr1.Get("/index", [&](const Request &, Response &res) { + res.set_redirect("https://127.0.0.1:1235/index"); + ssl_svr1.stop(); + }); + + SSLServer ssl_svr2(SERVER_CERT2_FILE, SERVER_PRIVATE_KEY_FILE); + ASSERT_TRUE(ssl_svr2.is_valid()); + ssl_svr2.Get("/index", [&](const Request &, Response &res) { + res.set_content("test", "text/plain"); + ssl_svr2.stop(); + }); + + thread t = thread([&]() { ASSERT_TRUE(ssl_svr1.listen("127.0.0.1", PORT)); }); + thread t2 = thread([&]() { ASSERT_TRUE(ssl_svr2.listen("127.0.0.1", 1235)); }); + auto se = detail::scope_exit([&] { + t2.join(); + t.join(); + ASSERT_FALSE(ssl_svr1.is_running()); + }); + + ssl_svr1.wait_until_ready(); + ssl_svr2.wait_until_ready(); + + SSLClient cli("127.0.0.1", PORT); + std::string cert; + read_file(SERVER_CERT2_FILE, cert); + cli.load_ca_cert_store(cert.c_str(), cert.size()); + cli.enable_server_certificate_verification(true); + cli.set_follow_location(true); + cli.set_connection_timeout(30); + + auto res = cli.Get("/index"); + ASSERT_TRUE(res); + ASSERT_EQ(StatusCode::OK_200, res->status); +} + TEST(MultipartFormDataTest, LargeData) { SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE);