Skip to content

Commit 67e8d05

Browse files
committed
feat: use db orchestrator in gateway + remove mutable reference to self
1 parent 4ab27f1 commit 67e8d05

File tree

8 files changed

+134
-108
lines changed

8 files changed

+134
-108
lines changed

aggregation_mode/Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

aggregation_mode/db/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ version = "0.1.0"
44
edition = "2021"
55

66
[dependencies]
7+
serde = { workspace = true }
78
tokio = { version = "1"}
89
sqlx = { version = "0.8", features = [ "runtime-tokio", "postgres", "migrate" ] }
910
backon = "1.2.0"

aggregation_mode/db/src/orchestrator.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,27 +93,23 @@ impl DbOrchestartor {
9393
})
9494
}
9595

96-
pub async fn write<T, Q, Fut>(&mut self, query: Q) -> Result<T, sqlx::Error>
96+
pub async fn write<T, Q, Fut>(&self, query: Q) -> Result<T, sqlx::Error>
9797
where
9898
Q: Fn(Pool<Postgres>) -> Fut,
9999
Fut: Future<Output = Result<T, sqlx::Error>>,
100100
{
101101
self.query::<T, Q, Fut>(query, Operation::Write).await
102102
}
103103

104-
pub async fn read<T, Q, Fut>(&mut self, query: Q) -> Result<T, sqlx::Error>
104+
pub async fn read<T, Q, Fut>(&self, query: Q) -> Result<T, sqlx::Error>
105105
where
106106
Q: Fn(Pool<Postgres>) -> Fut,
107107
Fut: Future<Output = Result<T, sqlx::Error>>,
108108
{
109109
self.query::<T, Q, Fut>(query, Operation::Read).await
110110
}
111111

112-
async fn query<T, Q, Fut>(
113-
&mut self,
114-
query_fn: Q,
115-
operation: Operation,
116-
) -> Result<T, sqlx::Error>
112+
async fn query<T, Q, Fut>(&self, query_fn: Q, operation: Operation) -> Result<T, sqlx::Error>
117113
where
118114
Q: Fn(Pool<Postgres>) -> Fut,
119115
Fut: Future<Output = Result<T, sqlx::Error>>,
@@ -151,7 +147,7 @@ impl DbOrchestartor {
151147
}
152148

153149
async fn execute_once<T, Q, Fut>(
154-
&mut self,
150+
&self,
155151
query_fn: &Q,
156152
operation: Operation,
157153
) -> Result<T, RetryError<sqlx::Error>>

aggregation_mode/db/src/types.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use sqlx::{
44
Type,
55
};
66

7-
#[derive(Debug, Clone, Copy, PartialEq, Eq, Type)]
7+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Type, serde::Serialize)]
88
#[sqlx(type_name = "task_status", rename_all = "lowercase")]
99
pub enum TaskStatus {
1010
Pending,

aggregation_mode/gateway/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ serde_yaml = { workspace = true }
1010
agg_mode_sdk = { path = "../sdk"}
1111
aligned-sdk = { workspace = true }
1212
sp1-sdk = { workspace = true }
13+
db = { workspace = true }
1314
tracing = { version = "0.1", features = ["log"] }
1415
tracing-subscriber = { version = "0.3.0", features = ["env-filter"] }
1516
bincode = "1.3.3"

aggregation_mode/gateway/src/db.rs

Lines changed: 115 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,111 +1,121 @@
1-
use sqlx::{
2-
postgres::PgPoolOptions,
3-
types::{BigDecimal, Uuid},
4-
Pool, Postgres,
5-
};
1+
use db::{orchestrator::DbOrchestartor, retry::RetryConfig};
2+
use sqlx::types::{BigDecimal, Uuid};
3+
4+
use crate::types::Receipt;
65

76
#[derive(Clone, Debug)]
87
pub struct Db {
9-
pool: Pool<Postgres>,
8+
orchestrator: DbOrchestartor,
109
}
1110

1211
#[derive(Debug, Clone)]
1312
pub enum DbError {
1413
ConnectError(String),
1514
}
1615

17-
#[derive(Debug, Clone, sqlx::Type, serde::Serialize)]
18-
#[sqlx(type_name = "task_status")]
19-
#[sqlx(rename_all = "lowercase")]
20-
pub enum TaskStatus {
21-
Pending,
22-
Processing,
23-
Verified,
24-
}
25-
26-
#[derive(Debug, Clone, sqlx::FromRow, sqlx::Type, serde::Serialize)]
27-
pub struct Receipt {
28-
pub status: TaskStatus,
29-
pub merkle_path: Option<Vec<u8>>,
30-
pub nonce: i64,
31-
pub address: String,
32-
}
33-
3416
impl Db {
35-
pub async fn try_new(connection_url: &str) -> Result<Self, DbError> {
36-
let pool = PgPoolOptions::new()
37-
.max_connections(5)
38-
.connect(connection_url)
39-
.await
40-
.map_err(|e| DbError::ConnectError(e.to_string()))?;
17+
pub async fn try_new(connection_urls: &[&str]) -> Result<Self, DbError> {
18+
let orchestrator = DbOrchestartor::try_new(
19+
connection_urls,
20+
RetryConfig {
21+
factor: 0.0,
22+
max_delay_seconds: 0,
23+
max_times: 0,
24+
min_delay_millis: 0,
25+
},
26+
)
27+
.map_err(|e| DbError::ConnectError(e.to_string()))?;
4128

42-
Ok(Self { pool })
29+
Ok(Self { orchestrator })
4330
}
4431

4532
pub async fn count_tasks_by_address(&self, address: &str) -> Result<i64, sqlx::Error> {
46-
let (count,) = sqlx::query_as::<_, (i64,)>("SELECT COUNT(*) FROM tasks WHERE address = $1")
47-
.bind(address.to_lowercase())
48-
.fetch_one(&self.pool)
49-
.await?;
33+
self.orchestrator
34+
.read(async |pool| {
35+
let (count,) =
36+
sqlx::query_as::<_, (i64,)>("SELECT COUNT(*) FROM tasks WHERE address = $1")
37+
.bind(address.to_lowercase())
38+
.fetch_one(&pool)
39+
.await?;
5040

51-
Ok(count)
41+
Ok(count)
42+
})
43+
.await
5244
}
5345

5446
pub async fn get_merkle_path_by_task_id(
5547
&self,
5648
task_id: Uuid,
5749
) -> Result<Option<Vec<u8>>, sqlx::Error> {
58-
sqlx::query_scalar::<_, Option<Vec<u8>>>("SELECT merkle_path FROM tasks WHERE task_id = $1")
59-
.bind(task_id)
60-
.fetch_optional(&self.pool)
50+
self.orchestrator
51+
.read(async |pool| {
52+
sqlx::query_scalar::<_, Option<Vec<u8>>>(
53+
"SELECT merkle_path FROM tasks WHERE task_id = $1",
54+
)
55+
.bind(task_id)
56+
.fetch_optional(&pool)
57+
.await
58+
.map(|res| res.flatten())
59+
})
6160
.await
62-
.map(|res| res.flatten())
6361
}
6462

6563
pub async fn get_tasks_by_address_and_nonce(
6664
&self,
6765
address: &str,
6866
nonce: i64,
6967
) -> Result<Vec<Receipt>, sqlx::Error> {
70-
sqlx::query_as::<_, Receipt>(
71-
"SELECT status,merkle_path,nonce,address FROM tasks
72-
WHERE address = $1
73-
AND nonce = $2
74-
ORDER BY nonce DESC",
75-
)
76-
.bind(address.to_lowercase())
77-
.bind(nonce)
78-
.fetch_all(&self.pool)
79-
.await
68+
self.orchestrator
69+
.read(async |pool| {
70+
sqlx::query_as::<_, Receipt>(
71+
"SELECT status,merkle_path,nonce,address FROM tasks
72+
WHERE address = $1
73+
AND nonce = $2
74+
ORDER BY nonce DESC",
75+
)
76+
.bind(address.to_lowercase())
77+
.bind(nonce)
78+
.fetch_all(&pool)
79+
.await
80+
})
81+
.await
8082
}
8183

8284
pub async fn get_tasks_by_address_with_limit(
8385
&self,
8486
address: &str,
8587
limit: i64,
8688
) -> Result<Vec<Receipt>, sqlx::Error> {
87-
sqlx::query_as::<_, Receipt>(
88-
"SELECT status,merkle_path,nonce,address FROM tasks
89-
WHERE address = $1
90-
ORDER BY nonce DESC
91-
LIMIT $2",
92-
)
93-
.bind(address.to_lowercase())
94-
.bind(limit)
95-
.fetch_all(&self.pool)
96-
.await
89+
self.orchestrator
90+
.read(async |pool| {
91+
sqlx::query_as::<_, Receipt>(
92+
"SELECT status,merkle_path,nonce,address FROM tasks
93+
WHERE address = $1
94+
ORDER BY nonce DESC
95+
LIMIT $2",
96+
)
97+
.bind(address.to_lowercase())
98+
.bind(limit)
99+
.fetch_all(&pool)
100+
.await
101+
})
102+
.await
97103
}
98104

99105
pub async fn get_daily_tasks_by_address(&self, address: &str) -> Result<i64, sqlx::Error> {
100-
sqlx::query_scalar::<_, i64>(
101-
"SELECT COUNT(*)
102-
FROM tasks
103-
WHERE address = $1
104-
AND inserted_at::date = CURRENT_DATE",
105-
)
106-
.bind(address.to_lowercase())
107-
.fetch_one(&self.pool)
108-
.await
106+
self.orchestrator
107+
.read(async |pool| {
108+
sqlx::query_scalar::<_, i64>(
109+
"SELECT COUNT(*)
110+
FROM tasks
111+
WHERE address = $1
112+
AND inserted_at::date = CURRENT_DATE",
113+
)
114+
.bind(address.to_lowercase())
115+
.fetch_one(&pool)
116+
.await
117+
})
118+
.await
109119
}
110120

111121
pub async fn insert_task(
@@ -117,41 +127,49 @@ impl Db {
117127
merkle_path: Option<&[u8]>,
118128
nonce: i64,
119129
) -> Result<Uuid, sqlx::Error> {
120-
sqlx::query_scalar::<_, Uuid>(
121-
"INSERT INTO tasks (
122-
address,
123-
proving_system_id,
124-
proof,
125-
program_commitment,
126-
merkle_path,
127-
nonce
128-
) VALUES ($1, $2, $3, $4, $5, $6)
129-
RETURNING task_id",
130-
)
131-
.bind(address.to_lowercase())
132-
.bind(proving_system_id)
133-
.bind(proof)
134-
.bind(program_commitment)
135-
.bind(merkle_path)
136-
.bind(nonce)
137-
.fetch_one(&self.pool)
138-
.await
130+
self.orchestrator
131+
.write(async |pool| {
132+
sqlx::query_scalar::<_, Uuid>(
133+
"INSERT INTO tasks (
134+
address,
135+
proving_system_id,
136+
proof,
137+
program_commitment,
138+
merkle_path,
139+
nonce
140+
) VALUES ($1, $2, $3, $4, $5, $6)
141+
RETURNING task_id",
142+
)
143+
.bind(address.to_lowercase())
144+
.bind(proving_system_id)
145+
.bind(proof)
146+
.bind(program_commitment)
147+
.bind(merkle_path)
148+
.bind(nonce)
149+
.fetch_one(&pool)
150+
.await
151+
})
152+
.await
139153
}
140154

141155
pub async fn has_active_payment_event(
142156
&self,
143157
address: &str,
144158
epoch: BigDecimal,
145159
) -> Result<bool, sqlx::Error> {
146-
sqlx::query_scalar::<_, bool>(
147-
"SELECT EXISTS (
148-
SELECT 1 FROM payment_events
149-
WHERE address = $1 AND started_at < $2 AND $2 < valid_until
150-
)",
151-
)
152-
.bind(address.to_lowercase())
153-
.bind(epoch)
154-
.fetch_one(&self.pool)
155-
.await
160+
self.orchestrator
161+
.read(async |pool| {
162+
sqlx::query_scalar::<_, bool>(
163+
"SELECT EXISTS (
164+
SELECT 1 FROM payment_events
165+
WHERE address = $1 AND started_at < $2 AND $2 < valid_until
166+
)",
167+
)
168+
.bind(address.to_lowercase())
169+
.bind(&epoch)
170+
.fetch_one(&pool)
171+
.await
172+
})
173+
.await
156174
}
157175
}

aggregation_mode/gateway/src/http.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ impl GatewayServer {
8484
.json(AppResponse::new_unsucessfull("Internal server error", 500));
8585
};
8686

87+
// TODO: how to fix the mutable thing
8788
let state = state.get_ref();
8889
match state.db.count_tasks_by_address(&address).await {
8990
Ok(count) => HttpResponse::Ok().json(AppResponse::new_sucessfull(serde_json::json!(

aggregation_mode/gateway/src/types.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
use actix_multipart::form::{tempfile::TempFile, text::Text, MultipartForm};
2+
use db::types::TaskStatus;
23
use serde::{Deserialize, Serialize};
34
use serde_json::Value;
45

5-
use crate::db::TaskStatus;
6-
76
#[derive(Serialize, Deserialize)]
87
pub(super) struct AppResponse {
98
status: u16,
@@ -62,3 +61,11 @@ pub struct GetReceiptsResponse {
6261
pub nonce: i64,
6362
pub address: String,
6463
}
64+
65+
#[derive(Debug, Clone, sqlx::FromRow, sqlx::Type, serde::Serialize)]
66+
pub struct Receipt {
67+
pub status: TaskStatus,
68+
pub merkle_path: Option<Vec<u8>>,
69+
pub nonce: i64,
70+
pub address: String,
71+
}

0 commit comments

Comments
 (0)