@@ -36,6 +36,7 @@ using json = nlohmann::json;
3636struct server_params
3737{
3838 std::string hostname = " 127.0.0.1" ;
39+ std::string api_key;
3940 std::string public_path = " examples/server/public" ;
4041 int32_t port = 8080 ;
4142 int32_t read_timeout = 600 ;
@@ -1953,6 +1954,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
19531954 printf (" --host ip address to listen (default (default: %s)\n " , sparams.hostname .c_str ());
19541955 printf (" --port PORT port to listen (default (default: %d)\n " , sparams.port );
19551956 printf (" --path PUBLIC_PATH path from which to serve static files (default %s)\n " , sparams.public_path .c_str ());
1957+ printf (" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n " );
19561958 printf (" -to N, --timeout N server read/write timeout in seconds (default: %d)\n " , sparams.read_timeout );
19571959 printf (" --embedding enable embedding vector output (default: %s)\n " , params.embedding ? " enabled" : " disabled" );
19581960 printf (" -np N, --parallel N number of slots for process requests (default: %d)\n " , params.n_parallel );
@@ -2002,6 +2004,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
20022004 }
20032005 sparams.public_path = argv[i];
20042006 }
2007+ else if (arg == " --api-key" )
2008+ {
2009+ if (++i >= argc)
2010+ {
2011+ invalid_param = true ;
2012+ break ;
2013+ }
2014+ sparams.api_key = argv[i];
2015+ }
20052016 else if (arg == " --timeout" || arg == " -to" )
20062017 {
20072018 if (++i >= argc)
@@ -2669,6 +2680,32 @@ int main(int argc, char **argv)
26692680
26702681 httplib::Server svr;
26712682
2683+ // Middleware for API key validation
2684+ auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
2685+ // If API key is not set, skip validation
2686+ if (sparams.api_key .empty ()) {
2687+ return true ;
2688+ }
2689+
2690+ // Check for API key in the header
2691+ auto auth_header = req.get_header_value (" Authorization" );
2692+ std::string prefix = " Bearer " ;
2693+ if (auth_header.substr (0 , prefix.size ()) == prefix) {
2694+ std::string received_api_key = auth_header.substr (prefix.size ());
2695+ if (received_api_key == sparams.api_key ) {
2696+ return true ; // API key is valid
2697+ }
2698+ }
2699+
2700+ // API key is invalid or not provided
2701+ res.set_content (" Unauthorized: Invalid API Key" , " text/plain" );
2702+ res.status = 401 ; // Unauthorized
2703+
2704+ LOG_WARNING (" Unauthorized: Invalid API Key" , {});
2705+
2706+ return false ;
2707+ };
2708+
26722709 svr.set_default_headers ({{" Server" , " llama.cpp" },
26732710 {" Access-Control-Allow-Origin" , " *" },
26742711 {" Access-Control-Allow-Headers" , " content-type" }});
@@ -2711,8 +2748,11 @@ int main(int argc, char **argv)
27112748 res.set_content (data.dump (), " application/json" );
27122749 });
27132750
2714- svr.Post (" /completion" , [&llama](const httplib::Request &req, httplib::Response &res)
2751+ svr.Post (" /completion" , [&llama, &validate_api_key ](const httplib::Request &req, httplib::Response &res)
27152752 {
2753+ if (!validate_api_key (req, res)) {
2754+ return ;
2755+ }
27162756 json data = json::parse (req.body );
27172757 const int task_id = llama.request_completion (data, false , false , -1 );
27182758 if (!json_value (data, " stream" , false )) {
@@ -2799,8 +2839,11 @@ int main(int argc, char **argv)
27992839 });
28002840
28012841 // TODO: add mount point without "/v1" prefix -- how?
2802- svr.Post (" /v1/chat/completions" , [&llama](const httplib::Request &req, httplib::Response &res)
2842+ svr.Post (" /v1/chat/completions" , [&llama, &validate_api_key ](const httplib::Request &req, httplib::Response &res)
28032843 {
2844+ if (!validate_api_key (req, res)) {
2845+ return ;
2846+ }
28042847 json data = oaicompat_completion_params_parse (json::parse (req.body ));
28052848
28062849 const int task_id = llama.request_completion (data, false , false , -1 );
@@ -2869,8 +2912,11 @@ int main(int argc, char **argv)
28692912 }
28702913 });
28712914
2872- svr.Post (" /infill" , [&llama](const httplib::Request &req, httplib::Response &res)
2915+ svr.Post (" /infill" , [&llama, &validate_api_key ](const httplib::Request &req, httplib::Response &res)
28732916 {
2917+ if (!validate_api_key (req, res)) {
2918+ return ;
2919+ }
28742920 json data = json::parse (req.body );
28752921 const int task_id = llama.request_completion (data, true , false , -1 );
28762922 if (!json_value (data, " stream" , false )) {
@@ -3005,11 +3051,15 @@ int main(int argc, char **argv)
30053051
30063052 svr.set_error_handler ([](const httplib::Request &, httplib::Response &res)
30073053 {
3054+ if (res.status == 401 )
3055+ {
3056+ res.set_content (" Unauthorized" , " text/plain" );
3057+ }
30083058 if (res.status == 400 )
30093059 {
30103060 res.set_content (" Invalid request" , " text/plain" );
30113061 }
3012- else if (res.status != 500 )
3062+ else if (res.status == 404 )
30133063 {
30143064 res.set_content (" File Not Found" , " text/plain" );
30153065 res.status = 404 ;
@@ -3032,11 +3082,15 @@ int main(int argc, char **argv)
30323082 // to make it ctrl+clickable:
30333083 LOG_TEE (" \n llama server listening at http://%s:%d\n\n " , sparams.hostname .c_str (), sparams.port );
30343084
3035- LOG_INFO (" HTTP server listening" , {
3036- {" hostname" , sparams.hostname },
3037- {" port" , sparams.port },
3038- });
3085+ std::unordered_map<std::string, std::string> log_data;
3086+ log_data[" hostname" ] = sparams.hostname ;
3087+ log_data[" port" ] = std::to_string (sparams.port );
3088+
3089+ if (!sparams.api_key .empty ()) {
3090+ log_data[" api_key" ] = " api_key: ****" + sparams.api_key .substr (sparams.api_key .length () - 4 );
3091+ }
30393092
3093+ LOG_INFO (" HTTP server listening" , log_data);
30403094 // run the HTTP server in a thread - see comment below
30413095 std::thread t ([&]()
30423096 {
0 commit comments