Skip to content

Commit 2637d28

Browse files
authored
Merge pull request #5299 from sysown/v3.0_pg-cancel-terminate-backend-param-support_5298
Add parameterized PID support for pg_cancel_backend/pg_terminate_backend in extended query protocol
2 parents 3bcb9e0 + 67cbe46 commit 2637d28

File tree

5 files changed

+860
-48
lines changed

5 files changed

+860
-48
lines changed

include/PgSQL_Session.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class PgSQL_Describe_Message;
2020
class PgSQL_Close_Message;
2121
class PgSQL_Bind_Message;
2222
class PgSQL_Execute_Message;
23+
struct PgSQL_Param_Value;
2324

2425
#ifndef PROXYJSON
2526
#define PROXYJSON
@@ -580,6 +581,7 @@ class PgSQL_Session : public Base_Session<PgSQL_Session, PgSQL_Data_Stream, PgSQ
580581
void Memory_Stats();
581582
void create_new_session_and_reset_connection(PgSQL_Data_Stream* _myds) override;
582583
bool handle_command_query_kill(PtrSize_t*);
584+
583585
//void update_expired_conns(const std::vector<std::function<bool(PgSQL_Connection*)>>&);
584586
/**
585587
* @brief Performs the final operations after current query has finished to be executed. It updates the session
@@ -603,6 +605,12 @@ class PgSQL_Session : public Base_Session<PgSQL_Session, PgSQL_Data_Stream, PgSQ
603605
void set_previous_status_mode3(bool allow_execute = true);
604606
char* get_current_query(int max_length = -1);
605607

608+
private:
609+
int32_t extract_pid_from_param(const PgSQL_Param_Value& param, uint16_t format) const;
610+
void send_parameter_error_response(const char* error_message, PGSQL_ERROR_CODES code = PGSQL_ERROR_CODES::ERRCODE_INVALID_TEXT_REPRESENTATION);
611+
bool handle_kill_success(int32_t pid, int tki, const char* digest_text, PgSQL_Connection* mc, PtrSize_t* pkt);
612+
bool handle_literal_kill_query(PtrSize_t* pkt, PgSQL_Connection* mc);
613+
606614
#if defined(__clang__)
607615
template<typename SESS, typename DS, typename BE, typename THD>
608616
friend class Base_Session;

include/gen_utils.h

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,31 @@ inline T overflow_safe_multiply(T val) {
436436
return (val * FACTOR);
437437
}
438438

439+
/**
440+
* @brief Read a 64-bit unsigned integer from a big-endian byte buffer.
441+
*
442+
* Reads 8 bytes from the provided buffer and converts them from
443+
* big-endian (network byte order) into host byte order.
444+
*
445+
* @param pkt Pointer to at least 8 bytes of input data.
446+
* @param dst_p Pointer to the destination uint64_t where the result
447+
* will be stored.
448+
*
449+
* @return true Always returns true.
450+
*/
451+
inline bool get_uint64be(const unsigned char* pkt, uint64_t* dst_p) {
452+
*dst_p =
453+
((uint64_t)pkt[0] << 56) |
454+
((uint64_t)pkt[1] << 48) |
455+
((uint64_t)pkt[2] << 40) |
456+
((uint64_t)pkt[3] << 32) |
457+
((uint64_t)pkt[4] << 24) |
458+
((uint64_t)pkt[5] << 16) |
459+
((uint64_t)pkt[6] << 8) |
460+
((uint64_t)pkt[7]);
461+
return true;
462+
}
463+
439464
/*
440465
* @brief Reads and converts a big endian 32-bit unsigned integer from the provided packet buffer into the destination pointer.
441466
*
@@ -448,9 +473,9 @@ inline T overflow_safe_multiply(T val) {
448473
*/
449474
inline bool get_uint32be(const unsigned char* pkt, uint32_t* dst_p) {
450475
*dst_p = ((uint32_t)pkt[0] << 24) |
451-
((uint32_t)pkt[1] << 16) |
452-
((uint32_t)pkt[2] << 8) |
453-
((uint32_t)pkt[3]);
476+
((uint32_t)pkt[1] << 16) |
477+
((uint32_t)pkt[2] << 8) |
478+
((uint32_t)pkt[3]);
454479
return true;
455480
}
456481

lib/PgSQL_Session.cpp

Lines changed: 249 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4435,11 +4435,10 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___PGSQL_Q
44354435
}
44364436

44374437
// Handle KILL command
4438-
//if (prepared == false) {
44394438
if (handle_command_query_kill(pkt)) {
44404439
return true;
44414440
}
4442-
//
4441+
44434442
// Query cache handling
44444443
if (qpo->cache_ttl > 0 && stmt_type == PGSQL_EXTENDED_QUERY_TYPE_NOT_SET) {
44454444
const std::shared_ptr<PgSQL_QC_entry_t> pgsql_qc_entry = GloPgQC->get(
@@ -5181,55 +5180,249 @@ bool PgSQL_Session::handle_command_query_kill(PtrSize_t* pkt) {
51815180
if (!CurrentQuery.QueryParserArgs.digest_text)
51825181
return false;
51835182

5184-
if (client_myds && client_myds->myconn) {
5185-
PgSQL_Connection* mc = client_myds->myconn;
5186-
if (mc->userinfo && mc->userinfo->username) {
5187-
if (CurrentQuery.PgQueryCmd == PGSQL_QUERY_CANCEL_BACKEND ||
5188-
CurrentQuery.PgQueryCmd == PGSQL_QUERY_TERMINATE_BACKEND) {
5189-
char* qu = pgsql_query_strip_comments((char*)CurrentQuery.QueryPointer, CurrentQuery.QueryLength,
5190-
pgsql_thread___query_digests_lowercase);
5191-
string nq = string(qu, strlen(qu));
5192-
re2::RE2::Options* opt2 = new re2::RE2::Options(RE2::Quiet);
5193-
opt2->set_case_sensitive(false);
5194-
char* pattern = (char*)"^SELECT\\s+(?:pg_catalog\\.)?PG_(TERMINATE|CANCEL)_BACKEND\\s*\\(\\s*(\\d+)\\s*\\)\\s*;?\\s*$";
5195-
re2::RE2* re = new RE2(pattern, *opt2);
5196-
string tk;
5197-
int id = 0;
5198-
RE2::FullMatch(nq, *re, &tk, &id);
5199-
delete re;
5200-
delete opt2;
5201-
proxy_debug(PROXY_DEBUG_MYSQL_QUERY_PROCESSOR, 2, "filtered query= \"%s\"\n", qu);
5202-
free(qu);
5203-
5204-
if (id) {
5205-
int tki = -1;
5206-
// Note: tk will capture "TERMINATE" or "CANCEL" (case insensitive match)
5207-
if (strcasecmp(tk.c_str(), "TERMINATE") == 0) {
5208-
tki = 0; // Connection terminate
5209-
} else if (strcasecmp(tk.c_str(), "CANCEL") == 0) {
5210-
tki = 1; // Query cancel
5211-
}
5212-
if (tki >= 0) {
5213-
proxy_debug(PROXY_DEBUG_MYSQL_QUERY_PROCESSOR, 2, "Killing %s %d\n", (tki == 0 ? "CONNECTION" : "QUERY"), id);
5214-
GloPTH->kill_connection_or_query(id, 0, mc->userinfo->username, (tki == 0 ? false : true));
5215-
client_myds->DSS = STATE_QUERY_SENT_NET;
5216-
5217-
std::unique_ptr<SQLite3_result> resultset = std::make_unique<SQLite3_result>(1);
5218-
resultset->add_column_definition(SQLITE_TEXT, tki == 0 ? "pg_terminate_backend" : "pg_cancel_backend");
5219-
char* pta[1];
5220-
pta[0] = (char*)"t";
5221-
resultset->add_row(pta);
5222-
bool send_ready_packet = is_extended_query_ready_for_query();
5223-
unsigned int nTxn = NumActiveTransactions();
5224-
char txn_state = (nTxn ? 'T' : 'I');
5225-
SQLite3_to_Postgres(client_myds->PSarrayOUT, resultset.get(), nullptr, 0, (const char*)pkt->ptr + 5, send_ready_packet, txn_state);
5183+
if (!client_myds ||
5184+
!client_myds->myconn ||
5185+
!client_myds->myconn->userinfo ||
5186+
!client_myds->myconn->userinfo->username) {
5187+
return false;
5188+
}
52265189

5227-
RequestEnd(NULL, false);
5190+
PgSQL_Connection* mc = client_myds->myconn;
5191+
if (CurrentQuery.PgQueryCmd == PGSQL_QUERY_CANCEL_BACKEND ||
5192+
CurrentQuery.PgQueryCmd == PGSQL_QUERY_TERMINATE_BACKEND) {
5193+
5194+
if (cmd == 'Q') {
5195+
// Simple query protocol - only handle literal values
5196+
// Parameterized queries in simple protocol are invalid and will be handled by PostgreSQL
5197+
return handle_literal_kill_query(pkt, mc);
5198+
} else {
5199+
// cmd == 'E' - Execute phase of extended query protocol
5200+
// Check if this is a parameterized query (contains $1)
5201+
// Note: This simple check might have false positives if $1 appears in comments or string literals
5202+
// but those cases would fail later when checking bind_msg or parameter validation
5203+
const char* digest_text = CurrentQuery.QueryParserArgs.digest_text;
5204+
5205+
// Use protocol facts (Bind)
5206+
const PgSQL_Bind_Message* bind_msg = CurrentQuery.extended_query_info.bind_msg;
5207+
const bool is_parameterized = bind_msg && bind_msg->data().num_param_values > 0;
5208+
if (is_parameterized) {
5209+
// Check that we have exactly one parameter
5210+
if (bind_msg->data().num_param_values != 1) {
5211+
send_parameter_error_response("function requires exactly one parameter");
5212+
l_free(pkt->size, pkt->ptr);
5213+
return true;
5214+
}
5215+
auto param_reader = bind_msg->get_param_value_reader();
5216+
PgSQL_Param_Value param;
5217+
if (param_reader.next(&param)) {
5218+
// Get parameter format (default to text format 0)
5219+
uint16_t param_format = 0;
5220+
if (bind_msg->data().num_param_formats == 1) {
5221+
// Single format applies to all parameters
5222+
auto format_reader = bind_msg->get_param_format_reader();
5223+
format_reader.next(&param_format);
5224+
}
5225+
5226+
// Extract PID from parameter
5227+
int32_t pid = extract_pid_from_param(param, param_format);
5228+
if (pid > 0) {
5229+
// Determine if this is terminate or cancel
5230+
int tki = -1;
5231+
if (CurrentQuery.PgQueryCmd == PGSQL_QUERY_TERMINATE_BACKEND) {
5232+
tki = 0; // Connection terminate
5233+
} else if (CurrentQuery.PgQueryCmd == PGSQL_QUERY_CANCEL_BACKEND) {
5234+
tki = 1; // Query cancel
5235+
}
5236+
5237+
if (tki >= 0) {
5238+
return handle_kill_success(pid, tki, digest_text, mc, pkt);
5239+
}
5240+
} else {
5241+
// Invalid parameter - send appropriate error response
5242+
if (pid == -2) {
5243+
// NULL parameter
5244+
send_parameter_error_response("NULL is not allowed", PGSQL_ERROR_CODES::ERRCODE_NULL_VALUE_NOT_ALLOWED);
5245+
} else if (pid == -1) {
5246+
// Invalid format (not a valid integer)
5247+
send_parameter_error_response("invalid input syntax for integer", PGSQL_ERROR_CODES::ERRCODE_INVALID_PARAMETER_VALUE);
5248+
} else if (pid == 0) {
5249+
// PID <= 0 (non-positive)
5250+
send_parameter_error_response("PID must be a positive integer", PGSQL_ERROR_CODES::ERRCODE_INVALID_PARAMETER_VALUE);
5251+
}
52285252
l_free(pkt->size, pkt->ptr);
52295253
return true;
52305254
}
5255+
} else {
5256+
// No parameter available - this shouldn't happen
5257+
return false;
52315258
}
5259+
} else {
5260+
// Literal query in extended protocol
5261+
return handle_literal_kill_query(pkt, mc);
5262+
}
5263+
}
5264+
}
5265+
5266+
return false;
5267+
}
5268+
5269+
int32_t PgSQL_Session::extract_pid_from_param(const PgSQL_Param_Value& param, uint16_t format) const {
5270+
5271+
if (param.len == -1) {
5272+
// NULL parameter
5273+
return -2; // Special value for NULL
5274+
}
5275+
5276+
/* ---------------- TEXT FORMAT ---------------- */
5277+
if (format == 0) {
5278+
// Text format
5279+
if (param.len == 0) {
5280+
// Empty string
5281+
return -1;
5282+
}
5283+
5284+
// Convert text to integer
5285+
std::string str_val(reinterpret_cast<const char*>(param.value), param.len);
5286+
5287+
// Parse the integer (allow leading +/- and whitespace, then validate semantics)
5288+
char* endptr;
5289+
errno = 0;
5290+
long pid = strtol(str_val.c_str(), &endptr, 10);
5291+
5292+
// Require full consumption (ignoring trailing whitespace)
5293+
while (endptr && *endptr && isspace(static_cast<unsigned char>(*endptr))) endptr++;
5294+
if (endptr == str_val.c_str() || (endptr && *endptr) || errno == ERANGE) {
5295+
return -1;
5296+
}
5297+
5298+
// Check valid range
5299+
if (pid <= 0) {
5300+
return 0; // Special value for non-positive
5301+
}
5302+
if (pid > INT_MAX) {
5303+
return -1; // Out of range
5304+
}
5305+
5306+
return static_cast<int32_t>(pid);
5307+
}
5308+
5309+
/* ---------------- BINARY FORMAT ---------------- */
5310+
// PostgreSQL sends int4 or int8 for integer parameters
5311+
if (format == 1) { // Binary format (format == 1)
5312+
5313+
if (param.len == 4) {
5314+
// uint32 in network byte order
5315+
uint32_t host_u32;
5316+
get_uint32be(reinterpret_cast<const unsigned char*>(param.value), &host_u32);
5317+
if (host_u32 & 0x80000000u) { // negative int4
5318+
return 0;
5319+
}
5320+
int32_t pid = static_cast<int32_t>(host_u32);
5321+
return pid;
5322+
}
5323+
5324+
if (param.len == 8) {
5325+
// int64 in network byte order (PostgreSQL sends int8 for some integer types)
5326+
uint64_t host_u64 = 0;
5327+
get_uint64be(reinterpret_cast<const unsigned char*>(param.value), &host_u64);
5328+
if (host_u64 & 0x8000000000000000ull) { // negative int8
5329+
return 0;
5330+
}
5331+
if (host_u64 > static_cast<uint64_t>(INT32_MAX)) {
5332+
return -1; // out of range for PID
52325333
}
5334+
int64_t pid = static_cast<int64_t>(host_u64);
5335+
return static_cast<int32_t>(pid);
5336+
}
5337+
5338+
// Invalid integer width for Bind
5339+
return -1;
5340+
}
5341+
5342+
char buf[INET6_ADDRSTRLEN];
5343+
switch (client_myds->client_addr->sa_family) {
5344+
case AF_INET: {
5345+
struct sockaddr_in* ipv4 = (struct sockaddr_in*)client_myds->client_addr;
5346+
inet_ntop(client_myds->client_addr->sa_family, &ipv4->sin_addr, buf, INET_ADDRSTRLEN);
5347+
break;
5348+
}
5349+
case AF_INET6: {
5350+
struct sockaddr_in6* ipv6 = (struct sockaddr_in6*)client_myds->client_addr;
5351+
inet_ntop(client_myds->client_addr->sa_family, &ipv6->sin6_addr, buf, INET6_ADDRSTRLEN);
5352+
break;
5353+
}
5354+
default:
5355+
sprintf(buf, "localhost");
5356+
break;
5357+
}
5358+
// Unknown format code
5359+
proxy_error("Unknown parameter format code: %u received from client %s:%d", format, buf, client_myds->addr.port);
5360+
return -1;
5361+
}
5362+
5363+
void PgSQL_Session::send_parameter_error_response(const char* error_message, PGSQL_ERROR_CODES error_code) {
5364+
if (!client_myds) return;
5365+
5366+
// Create proper PostgreSQL error message
5367+
std::string full_error = std::string("invalid input syntax for integer: \"") +
5368+
(error_message ? error_message : "parameter error") + "\"";
5369+
client_myds->setDSS_STATE_QUERY_SENT_NET();
5370+
// Generate and send error packet using PostgreSQL protocol
5371+
client_myds->myprot.generate_error_packet(true, is_extended_query_ready_for_query(),
5372+
full_error.c_str(), error_code, false, true);
5373+
5374+
RequestEnd(NULL, true);
5375+
}
5376+
5377+
bool PgSQL_Session::handle_kill_success(int32_t pid, int tki, const char* digest_text, PgSQL_Connection* mc, PtrSize_t* pkt) {
5378+
5379+
proxy_debug(PROXY_DEBUG_MYSQL_QUERY_PROCESSOR, 2, "Killing %s %d\n",
5380+
(tki == 0 ? "CONNECTION" : "QUERY"), pid);
5381+
GloPTH->kill_connection_or_query(pid, 0, mc->userinfo->username, (tki == 0 ? false : true));
5382+
client_myds->DSS = STATE_QUERY_SENT_NET;
5383+
5384+
std::unique_ptr<SQLite3_result> resultset = std::make_unique<SQLite3_result>(1);
5385+
resultset->add_column_definition(SQLITE_TEXT, tki == 0 ? "pg_terminate_backend" : "pg_cancel_backend");
5386+
char* pta[1];
5387+
pta[0] = (char*)"t";
5388+
resultset->add_row(pta);
5389+
bool send_ready_packet = is_extended_query_ready_for_query();
5390+
unsigned int nTxn = NumActiveTransactions();
5391+
char txn_state = (nTxn ? 'T' : 'I');
5392+
SQLite3_to_Postgres(client_myds->PSarrayOUT, resultset.get(), nullptr, 0, digest_text, send_ready_packet, txn_state);
5393+
5394+
RequestEnd(NULL, false);
5395+
l_free(pkt->size, pkt->ptr);
5396+
return true;
5397+
}
5398+
5399+
bool PgSQL_Session::handle_literal_kill_query(PtrSize_t* pkt, PgSQL_Connection* mc) {
5400+
// Handle literal query (original implementation)
5401+
char* qu = pgsql_query_strip_comments((char*)CurrentQuery.QueryPointer, CurrentQuery.QueryLength,
5402+
pgsql_thread___query_digests_lowercase);
5403+
std::string nq(qu);
5404+
5405+
re2::RE2::Options opt2(RE2::Quiet);
5406+
opt2.set_case_sensitive(false);
5407+
const char* pattern = "^SELECT\\s+(?:pg_catalog\\.)?PG_(TERMINATE|CANCEL)_BACKEND\\s*\\(\\s*(\\d+)\\s*\\)\\s*;?\\s*$";
5408+
re2::RE2 re(pattern, opt2);
5409+
std::string tk;
5410+
uint32_t id = 0;
5411+
RE2::FullMatch(nq, re, &tk, &id);
5412+
5413+
proxy_debug(PROXY_DEBUG_MYSQL_QUERY_PROCESSOR, 2, "filtered query= \"%s\"\n", qu);
5414+
free(qu);
5415+
5416+
if (id > 0) {
5417+
int tki = -1;
5418+
// Note: tk will capture "TERMINATE" or "CANCEL" (case insensitive match)
5419+
if (strcasecmp(tk.c_str(), "TERMINATE") == 0) {
5420+
tki = 0; // Connection terminate
5421+
} else if (strcasecmp(tk.c_str(), "CANCEL") == 0) {
5422+
tki = 1; // Query cancel
5423+
}
5424+
if (tki >= 0) {
5425+
return handle_kill_success(id, tki, CurrentQuery.QueryParserArgs.digest_text, mc, pkt);
52335426
}
52345427
}
52355428
return false;
@@ -6134,6 +6327,17 @@ int PgSQL_Session::handle_post_sync_execute_message(PgSQL_Execute_Message* execu
61346327
// if we are here, it means we have handled the special command
61356328
return 0;
61366329
}
6330+
6331+
PGSQL_QUERY_command pg_query_cmd = extended_query_info.stmt_info->PgQueryCmd;
6332+
if (pg_query_cmd == PGSQL_QUERY_CANCEL_BACKEND ||
6333+
pg_query_cmd == PGSQL_QUERY_TERMINATE_BACKEND) {
6334+
CurrentQuery.PgQueryCmd = pg_query_cmd;
6335+
auto execute_pkt = execute_msg->get_raw_pkt(); // detach the packet from the describe message
6336+
if (handle_command_query_kill(&execute_pkt)) {
6337+
execute_msg->detach(); // detach the packet from the execute message
6338+
return 0;
6339+
}
6340+
}
61376341
}
61386342
current_hostgroup = previous_hostgroup; // reset current hostgroup to previous hostgroup
61396343
proxy_debug(PROXY_DEBUG_MYSQL_COM, 5, "Session=%p client_myds=%p. Using previous hostgroup '%d'\n",

0 commit comments

Comments
 (0)