Skip to content

Commit e2c844a

Browse files
authored
Merge pull request #1886 from fermyon/postgres-resources
Use resources for Postgres API
2 parents 88373dd + 4573dce commit e2c844a

File tree

12 files changed

+189
-86
lines changed

12 files changed

+189
-86
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/outbound-pg/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ native-tls = "0.2.11"
1313
postgres-native-tls = "0.5.0"
1414
spin-core = { path = "../core" }
1515
spin-world = { path = "../world" }
16-
tokio = { version = "1", features = [ "rt-multi-thread" ] }
16+
table = { path = "../table" }
17+
tokio = { version = "1", features = ["rt-multi-thread"] }
1718
tokio-postgres = { version = "0.7.7" }
1819
tracing = { workspace = true }

crates/outbound-pg/src/lib.rs

Lines changed: 94 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
use anyhow::{anyhow, Result};
22
use native_tls::TlsConnector;
33
use postgres_native_tls::MakeTlsConnector;
4-
use spin_core::{async_trait, HostComponent};
4+
use spin_core::{async_trait, wasmtime::component::Resource, HostComponent};
55
use spin_world::v1::{
6-
postgres::{self, PgError},
6+
postgres as v1,
77
rdbms_types::{Column, DbDataType, DbValue, ParameterValue, RowSet},
88
};
9-
use std::collections::HashMap;
9+
use spin_world::v2::postgres::{self as v2, Connection};
1010
use tokio_postgres::{
1111
config::SslMode,
1212
types::{ToSql, Type},
@@ -16,7 +16,15 @@ use tokio_postgres::{
1616
/// A simple implementation to support outbound pg connection
1717
#[derive(Default)]
1818
pub struct OutboundPg {
19-
pub connections: HashMap<String, Client>,
19+
pub connections: table::Table<Client>,
20+
}
21+
22+
impl OutboundPg {
23+
async fn get_client(&mut self, connection: Resource<Connection>) -> Result<&Client, v2::Error> {
24+
self.connections
25+
.get(connection.rep())
26+
.ok_or_else(|| v2::Error::ConnectionFailed("no connection found".into()))
27+
}
2028
}
2129

2230
impl HostComponent for OutboundPg {
@@ -26,7 +34,8 @@ impl HostComponent for OutboundPg {
2634
linker: &mut spin_core::Linker<T>,
2735
get: impl Fn(&mut spin_core::Data<T>) -> &mut Self::Data + Send + Sync + Copy + 'static,
2836
) -> anyhow::Result<()> {
29-
postgres::add_to_linker(linker, get)
37+
v1::add_to_linker(linker, get)?;
38+
v2::add_to_linker(linker, get)
3039
}
3140

3241
fn build_data(&self) -> Self::Data {
@@ -35,27 +44,43 @@ impl HostComponent for OutboundPg {
3544
}
3645

3746
#[async_trait]
38-
impl postgres::Host for OutboundPg {
47+
impl v2::Host for OutboundPg {}
48+
49+
#[async_trait]
50+
impl v2::HostConnection for OutboundPg {
51+
async fn open(&mut self, address: String) -> Result<Result<Resource<Connection>, v2::Error>> {
52+
Ok(async {
53+
self.connections
54+
.push(
55+
build_client(&address)
56+
.await
57+
.map_err(|e| v2::Error::ConnectionFailed(format!("{e:?}")))?,
58+
)
59+
.map_err(|_| v2::Error::Other("too many connections".into()))
60+
.map(Resource::new_own)
61+
}
62+
.await)
63+
}
64+
3965
async fn execute(
4066
&mut self,
41-
address: String,
67+
connection: Resource<Connection>,
4268
statement: String,
4369
params: Vec<ParameterValue>,
44-
) -> Result<Result<u64, PgError>> {
70+
) -> Result<Result<u64, v2::Error>> {
4571
Ok(async {
4672
let params: Vec<&(dyn ToSql + Sync)> = params
4773
.iter()
4874
.map(to_sql_parameter)
4975
.collect::<anyhow::Result<Vec<_>>>()
50-
.map_err(|e| PgError::ValueConversionFailed(format!("{:?}", e)))?;
76+
.map_err(|e| v2::Error::ValueConversionFailed(format!("{:?}", e)))?;
5177

5278
let nrow = self
53-
.get_client(&address)
54-
.await
55-
.map_err(|e| PgError::ConnectionFailed(format!("{:?}", e)))?
79+
.get_client(connection)
80+
.await?
5681
.execute(&statement, params.as_slice())
5782
.await
58-
.map_err(|e| PgError::QueryFailed(format!("{:?}", e)))?;
83+
.map_err(|e| v2::Error::QueryFailed(format!("{:?}", e)))?;
5984

6085
Ok(nrow)
6186
}
@@ -64,24 +89,23 @@ impl postgres::Host for OutboundPg {
6489

6590
async fn query(
6691
&mut self,
67-
address: String,
92+
connection: Resource<Connection>,
6893
statement: String,
6994
params: Vec<ParameterValue>,
70-
) -> Result<Result<RowSet, PgError>> {
95+
) -> Result<Result<RowSet, v2::Error>> {
7196
Ok(async {
7297
let params: Vec<&(dyn ToSql + Sync)> = params
7398
.iter()
7499
.map(to_sql_parameter)
75100
.collect::<anyhow::Result<Vec<_>>>()
76-
.map_err(|e| PgError::BadParameter(format!("{:?}", e)))?;
101+
.map_err(|e| v2::Error::BadParameter(format!("{:?}", e)))?;
77102

78103
let results = self
79-
.get_client(&address)
80-
.await
81-
.map_err(|e| PgError::ConnectionFailed(format!("{:?}", e)))?
104+
.get_client(connection)
105+
.await?
82106
.query(&statement, params.as_slice())
83107
.await
84-
.map_err(|e| PgError::QueryFailed(format!("{:?}", e)))?;
108+
.map_err(|e| v2::Error::QueryFailed(format!("{:?}", e)))?;
85109

86110
if results.is_empty() {
87111
return Ok(RowSet {
@@ -95,12 +119,17 @@ impl postgres::Host for OutboundPg {
95119
.iter()
96120
.map(convert_row)
97121
.collect::<Result<Vec<_>, _>>()
98-
.map_err(|e| PgError::QueryFailed(format!("{:?}", e)))?;
122+
.map_err(|e| v2::Error::QueryFailed(format!("{:?}", e)))?;
99123

100124
Ok(RowSet { columns, rows })
101125
}
102126
.await)
103127
}
128+
129+
fn drop(&mut self, connection: Resource<Connection>) -> anyhow::Result<()> {
130+
self.connections.remove(connection.rep());
131+
Ok(())
132+
}
104133
}
105134

106135
fn to_sql_parameter(value: &ParameterValue) -> anyhow::Result<&(dyn ToSql + Sync)> {
@@ -233,16 +262,6 @@ fn convert_entry(row: &Row, index: usize) -> Result<DbValue, tokio_postgres::Err
233262
Ok(value)
234263
}
235264

236-
impl OutboundPg {
237-
async fn get_client(&mut self, address: &str) -> anyhow::Result<&Client> {
238-
let client = match self.connections.entry(address.to_owned()) {
239-
std::collections::hash_map::Entry::Occupied(o) => o.into_mut(),
240-
std::collections::hash_map::Entry::Vacant(v) => v.insert(build_client(address).await?),
241-
};
242-
Ok(client)
243-
}
244-
}
245-
246265
async fn build_client(address: &str) -> anyhow::Result<Client> {
247266
let config = address.parse::<tokio_postgres::Config>()?;
248267

@@ -325,3 +344,47 @@ impl std::fmt::Debug for PgNull {
325344
f.debug_struct("NULL").finish()
326345
}
327346
}
347+
348+
/// Delegate a function call to the v2::HostConnection implementation
349+
macro_rules! delegate {
350+
($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{
351+
let connection = match <Self as v2::HostConnection>::open($self, $address).await? {
352+
Ok(c) => c,
353+
Err(e) => return Ok(Err(to_legacy_error(e))),
354+
};
355+
Ok(<Self as v2::HostConnection>::$name($self, connection, $($arg),*)
356+
.await?
357+
.map_err(|e| to_legacy_error(e)))
358+
}};
359+
}
360+
361+
#[async_trait]
362+
impl v1::Host for OutboundPg {
363+
async fn execute(
364+
&mut self,
365+
address: String,
366+
statement: String,
367+
params: Vec<ParameterValue>,
368+
) -> Result<Result<u64, v1::PgError>> {
369+
delegate!(self.execute(address, statement, params))
370+
}
371+
372+
async fn query(
373+
&mut self,
374+
address: String,
375+
statement: String,
376+
params: Vec<ParameterValue>,
377+
) -> Result<Result<RowSet, v1::PgError>> {
378+
delegate!(self.query(address, statement, params))
379+
}
380+
}
381+
382+
fn to_legacy_error(error: v2::Error) -> v1::PgError {
383+
match error {
384+
v2::Error::ConnectionFailed(e) => v1::PgError::ConnectionFailed(e),
385+
v2::Error::BadParameter(e) => v1::PgError::BadParameter(e),
386+
v2::Error::QueryFailed(e) => v1::PgError::QueryFailed(e),
387+
v2::Error::ValueConversionFailed(e) => v1::PgError::ValueConversionFailed(e),
388+
v2::Error::Other(e) => v1::PgError::OtherError(e),
389+
}
390+
}

crates/table/Cargo.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,4 @@ version.workspace = true
44
authors.workspace = true
55
edition.workspace = true
66

7-
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
8-
97
[dependencies]

crates/table/src/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ pub struct Table<V> {
1313
tuples: HashMap<u32, V>,
1414
}
1515

16+
impl<V> Default for Table<V> {
17+
fn default() -> Self {
18+
Self::new(1024)
19+
}
20+
}
21+
1622
impl<V> Table<V> {
1723
/// Create a new, empty table with the specified capacity.
1824
pub fn new(capacity: u32) -> Self {

examples/rust-outbound-pg/src/lib.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,10 @@ fn process(req: Request) -> Result<Response> {
5353

5454
fn read(_req: Request) -> Result<Response> {
5555
let address = std::env::var(DB_URL_ENV)?;
56+
let conn = pg::Connection::open(&address)?;
5657

5758
let sql = "SELECT id, title, content, authorname, coauthor FROM articletest";
58-
let rowset = pg::query(&address, sql, &[])?;
59+
let rowset = conn.query(sql, &[])?;
5960

6061
let column_summary = rowset
6162
.columns
@@ -89,14 +90,15 @@ fn read(_req: Request) -> Result<Response> {
8990

9091
fn write(_req: Request) -> Result<Response> {
9192
let address = std::env::var(DB_URL_ENV)?;
93+
let conn = pg::Connection::open(&address)?;
9294

9395
let sql = "INSERT INTO articletest (title, content, authorname) VALUES ('aaa', 'bbb', 'ccc')";
94-
let nrow_executed = pg::execute(&address, sql, &[])?;
96+
let nrow_executed = conn.execute(sql, &[])?;
9597

9698
println!("nrow_executed: {}", nrow_executed);
9799

98100
let sql = "SELECT COUNT(id) FROM articletest";
99-
let rowset = pg::query(&address, sql, &[])?;
101+
let rowset = conn.query(sql, &[])?;
100102
let row = &rowset.rows[0];
101103
let count = i64::decode(&row[0])?;
102104
let response = format!("Count: {}\n", count);
@@ -108,10 +110,11 @@ fn write(_req: Request) -> Result<Response> {
108110

109111
fn pg_backend_pid(_req: Request) -> Result<Response> {
110112
let address = std::env::var(DB_URL_ENV)?;
113+
let conn = pg::Connection::open(&address)?;
111114
let sql = "SELECT pg_backend_pid()";
112115

113116
let get_pid = || {
114-
let rowset = pg::query(&address, sql, &[])?;
117+
let rowset = conn.query(sql, &[])?;
115118
let row = &rowset.rows[0];
116119

117120
i32::decode(&row[0])

examples/spin-timer/Cargo.lock

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sdk/rust/src/pg.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
//! | `String` | str(string) | VARCHAR, CHAR(N), TEXT |
1414
//! | `Vec<u8>` | binary(list\<u8\>) | BYTEA |
1515
16-
pub use super::wit::v1::postgres::{execute, query, PgError};
16+
#[doc(inline)]
1717
pub use super::wit::v1::rdbms_types::*;
18+
#[doc(inline)]
19+
pub use super::wit::v2::postgres::{Connection, Error as PgError};
1820

1921
/// A pg error
2022
#[derive(Debug, thiserror::Error)]
@@ -23,7 +25,7 @@ pub enum Error {
2325
#[error("error value decoding: {0}")]
2426
Decode(String),
2527
/// Pg query failed with an error
26-
#[error("{0}")]
28+
#[error(transparent)]
2729
PgError(#[from] PgError),
2830
}
2931

0 commit comments

Comments
 (0)