|
1 | | -use std::{borrow::Cow, collections::HashMap}; |
| 1 | +use std::{borrow::Cow, collections::HashMap, str::FromStr}; |
2 | 2 |
|
3 | 3 | use actix_web::http::StatusCode; |
4 | 4 | use actix_web_httpauth::headers::authorization::Basic; |
| 5 | +use awc::http::Method; |
5 | 6 | use base64::Engine; |
6 | 7 | use mime_guess::{mime::APPLICATION_OCTET_STREAM, Mime}; |
7 | 8 | use sqlparser::ast::FunctionArg; |
@@ -51,6 +52,7 @@ pub(super) enum StmtParam { |
51 | 52 | ReadFileAsText(Box<StmtParam>), |
52 | 53 | ReadFileAsDataUrl(Box<StmtParam>), |
53 | 54 | RunSql(Box<StmtParam>), |
| 55 | + Fetch(Box<StmtParam>), |
54 | 56 | Path, |
55 | 57 | Protocol, |
56 | 58 | } |
@@ -140,6 +142,7 @@ pub(super) fn func_call_to_param(func_name: &str, arguments: &mut [FunctionArg]) |
140 | 142 | extract_variable_argument("read_file_as_data_url", arguments), |
141 | 143 | )), |
142 | 144 | "run_sql" => StmtParam::RunSql(Box::new(extract_variable_argument("run_sql", arguments))), |
| 145 | + "fetch" => StmtParam::Fetch(Box::new(extract_variable_argument("fetch", arguments))), |
143 | 146 | unknown_name => StmtParam::Error(format!( |
144 | 147 | "Unknown function {unknown_name}({})", |
145 | 148 | FormatArguments(arguments) |
@@ -389,6 +392,90 @@ async fn run_sql<'a>( |
389 | 392 | Ok(Some(Cow::Owned(String::from_utf8(json_results_bytes)?))) |
390 | 393 | } |
391 | 394 |
|
| 395 | +type HeaderVec<'a> = Vec<(Cow<'a, str>, Cow<'a, str>)>; |
| 396 | +#[derive(serde::Deserialize)] |
| 397 | +struct Req<'b> { |
| 398 | + #[serde(borrow)] |
| 399 | + url: Cow<'b, str>, |
| 400 | + #[serde(borrow)] |
| 401 | + method: Option<Cow<'b, str>>, |
| 402 | + #[serde(borrow, deserialize_with = "deserialize_map_to_vec_pairs")] |
| 403 | + headers: HeaderVec<'b>, |
| 404 | + #[serde(borrow)] |
| 405 | + body: Option<&'b serde_json::value::RawValue>, |
| 406 | +} |
| 407 | + |
| 408 | +fn deserialize_map_to_vec_pairs<'de, D: serde::Deserializer<'de>>( |
| 409 | + deserializer: D, |
| 410 | +) -> Result<HeaderVec<'de>, D::Error> { |
| 411 | + struct Visitor; |
| 412 | + |
| 413 | + impl<'de> serde::de::Visitor<'de> for Visitor { |
| 414 | + type Value = Vec<(Cow<'de, str>, Cow<'de, str>)>; |
| 415 | + |
| 416 | + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { |
| 417 | + formatter.write_str("a map") |
| 418 | + } |
| 419 | + |
| 420 | + fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error> |
| 421 | + where |
| 422 | + A: serde::de::MapAccess<'de>, |
| 423 | + { |
| 424 | + let mut vec = Vec::new(); |
| 425 | + while let Some((key, value)) = map.next_entry()? { |
| 426 | + vec.push((key, value)); |
| 427 | + } |
| 428 | + Ok(vec) |
| 429 | + } |
| 430 | + } |
| 431 | + |
| 432 | + deserializer.deserialize_map(Visitor) |
| 433 | +} |
| 434 | + |
| 435 | +async fn fetch<'a>( |
| 436 | + param0: &StmtParam, |
| 437 | + request: &'a RequestInfo, |
| 438 | +) -> Result<Option<Cow<'a, str>>, anyhow::Error> { |
| 439 | + let Some(fetch_target) = Box::pin(extract_req_param(param0, request)).await? else { |
| 440 | + log::debug!("fetch: first argument is NULL, returning NULL"); |
| 441 | + return Ok(None); |
| 442 | + }; |
| 443 | + let client = awc::Client::default(); |
| 444 | + let res = if fetch_target.starts_with("http") { |
| 445 | + client.get(fetch_target.as_ref()).send() |
| 446 | + } else { |
| 447 | + let r = serde_json::from_str::<'_, Req<'_>>(&fetch_target) |
| 448 | + .with_context(|| format!("Invalid request: {fetch_target}"))?; |
| 449 | + let method = if let Some(method) = r.method { |
| 450 | + Method::from_str(&method)? |
| 451 | + } else { |
| 452 | + Method::GET |
| 453 | + }; |
| 454 | + let mut req = client.request(method, r.url.as_ref()); |
| 455 | + for (k, v) in r.headers { |
| 456 | + req = req.insert_header((k.as_ref(), v.as_ref())); |
| 457 | + } |
| 458 | + if let Some(body) = r.body { |
| 459 | + let val = body.get(); |
| 460 | + // The body can be either json, or a string representing a raw body |
| 461 | + let body = if val.starts_with('"') { |
| 462 | + serde_json::from_str::<'_, String>(val)? |
| 463 | + } else { |
| 464 | + req = req.content_type("application/json"); |
| 465 | + val.to_owned() |
| 466 | + }; |
| 467 | + req.send_body(body) |
| 468 | + } else { |
| 469 | + req.send() |
| 470 | + } |
| 471 | + }; |
| 472 | + let mut res = res |
| 473 | + .await |
| 474 | + .map_err(|e| anyhow!("Unable to fetch {fetch_target}: {e}"))?; |
| 475 | + let body = res.body().await?.to_vec(); |
| 476 | + Ok(Some(String::from_utf8(body)?.into())) |
| 477 | +} |
| 478 | + |
392 | 479 | fn mime_from_upload<'a>(param0: &StmtParam, request: &'a RequestInfo) -> Option<&'a Mime> { |
393 | 480 | if let StmtParam::UploadedFilePath(name) | StmtParam::UploadedFileMimeType(name) = param0 { |
394 | 481 | request.uploaded_files.get(name)?.content_type.as_ref() |
@@ -429,6 +516,7 @@ pub(super) async fn extract_req_param<'a>( |
429 | 516 | StmtParam::ReadFileAsText(inner) => read_file_as_text(inner, request).await?, |
430 | 517 | StmtParam::ReadFileAsDataUrl(inner) => read_file_as_data_url(inner, request).await?, |
431 | 518 | StmtParam::RunSql(inner) => run_sql(inner, request).await?, |
| 519 | + StmtParam::Fetch(inner) => fetch(inner, request).await?, |
432 | 520 | StmtParam::PersistUploadedFile { |
433 | 521 | field_name, |
434 | 522 | folder, |
|
0 commit comments