Skip to content

Commit 890417e

Browse files
committed
refactor: Consolidate feature combination macros for encoding, decoding, and indexing
This commit refactors the macros used to generate the `AnyEncode`, `AnyDecode`, and `AnyColumnIndex` traits in the sqlx-core library. The new structure enhances maintainability by streamlining the generation of trait implementations based on enabled features, reducing code duplication while ensuring comprehensive support for various database combinations.
1 parent eae840b commit 890417e

File tree

3 files changed

+174
-322
lines changed

3 files changed

+174
-322
lines changed

sqlx-core/src/any/column.rs

Lines changed: 56 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -88,148 +88,66 @@ impl Column for AnyColumn {
8888
}
8989
}
9090

91-
// Macro to generate AnyColumnIndex trait and impl based on enabled features
92-
macro_rules! define_any_column_index {
93-
(
94-
// List all possible feature combinations with their corresponding bounds
95-
$(
96-
#[cfg($($cfg:tt)*)]
97-
[$($bounds:tt)*]
98-
),* $(,)?
99-
) => {
100-
$(
101-
#[cfg($($cfg)*)]
102-
pub trait AnyColumnIndex: $($bounds)* {}
103-
104-
#[cfg($($cfg)*)]
105-
impl<I: ?Sized> AnyColumnIndex for I where I: $($bounds)* {}
106-
)*
91+
// Macro to generate all feature combinations for column index
92+
macro_rules! for_all_feature_combinations {
93+
// Entry point
94+
( $callback:ident ) => {
95+
for_all_feature_combinations!(@parse_databases [
96+
("postgres", PgRow, PgStatement),
97+
("mysql", MySqlRow, MySqlStatement),
98+
("mssql", MssqlRow, MssqlStatement),
99+
("sqlite", SqliteRow, SqliteStatement),
100+
("odbc", OdbcRow, OdbcStatement)
101+
] $callback);
107102
};
108-
}
109-
110-
// Define all combinations in a compact format
111-
define_any_column_index! {
112-
// 5 databases
113-
#[cfg(all(feature = "postgres", feature = "mysql", feature = "mssql", feature = "sqlite", feature = "odbc"))]
114-
[ColumnIndex<PgRow> + for<'q> ColumnIndex<PgStatement<'q>> + ColumnIndex<MySqlRow> + for<'q> ColumnIndex<MySqlStatement<'q>> + ColumnIndex<MssqlRow> + for<'q> ColumnIndex<MssqlStatement<'q>> + ColumnIndex<SqliteRow> + for<'q> ColumnIndex<SqliteStatement<'q>> + ColumnIndex<OdbcRow> + for<'q> ColumnIndex<OdbcStatement<'q>>],
115-
116-
// 4 databases - missing postgres
117-
#[cfg(all(not(feature = "postgres"), feature = "mysql", feature = "mssql", feature = "sqlite", feature = "odbc"))]
118-
[ColumnIndex<MySqlRow> + for<'q> ColumnIndex<MySqlStatement<'q>> + ColumnIndex<MssqlRow> + for<'q> ColumnIndex<MssqlStatement<'q>> + ColumnIndex<SqliteRow> + for<'q> ColumnIndex<SqliteStatement<'q>> + ColumnIndex<OdbcRow> + for<'q> ColumnIndex<OdbcStatement<'q>>],
119-
120-
// 4 databases - missing mysql
121-
#[cfg(all(feature = "postgres", not(feature = "mysql"), feature = "mssql", feature = "sqlite", feature = "odbc"))]
122-
[ColumnIndex<PgRow> + for<'q> ColumnIndex<PgStatement<'q>> + ColumnIndex<MssqlRow> + for<'q> ColumnIndex<MssqlStatement<'q>> + ColumnIndex<SqliteRow> + for<'q> ColumnIndex<SqliteStatement<'q>> + ColumnIndex<OdbcRow> + for<'q> ColumnIndex<OdbcStatement<'q>>],
123-
124-
// 4 databases - missing mssql
125-
#[cfg(all(feature = "postgres", feature = "mysql", not(feature = "mssql"), feature = "sqlite", feature = "odbc"))]
126-
[ColumnIndex<PgRow> + for<'q> ColumnIndex<PgStatement<'q>> + ColumnIndex<MySqlRow> + for<'q> ColumnIndex<MySqlStatement<'q>> + ColumnIndex<SqliteRow> + for<'q> ColumnIndex<SqliteStatement<'q>> + ColumnIndex<OdbcRow> + for<'q> ColumnIndex<OdbcStatement<'q>>],
127-
128-
// 4 databases - missing sqlite
129-
#[cfg(all(feature = "postgres", feature = "mysql", feature = "mssql", not(feature = "sqlite"), feature = "odbc"))]
130-
[ColumnIndex<PgRow> + for<'q> ColumnIndex<PgStatement<'q>> + ColumnIndex<MySqlRow> + for<'q> ColumnIndex<MySqlStatement<'q>> + ColumnIndex<MssqlRow> + for<'q> ColumnIndex<MssqlStatement<'q>> + ColumnIndex<OdbcRow> + for<'q> ColumnIndex<OdbcStatement<'q>>],
131-
132-
// 4 databases - missing odbc
133-
#[cfg(all(feature = "postgres", feature = "mysql", feature = "mssql", feature = "sqlite", not(feature = "odbc")))]
134-
[ColumnIndex<PgRow> + for<'q> ColumnIndex<PgStatement<'q>> + ColumnIndex<MySqlRow> + for<'q> ColumnIndex<MySqlStatement<'q>> + ColumnIndex<MssqlRow> + for<'q> ColumnIndex<MssqlStatement<'q>> + ColumnIndex<SqliteRow> + for<'q> ColumnIndex<SqliteStatement<'q>>],
135-
136-
// 3 databases - postgres, mysql, mssql
137-
#[cfg(all(feature = "postgres", feature = "mysql", feature = "mssql", not(any(feature = "sqlite", feature = "odbc"))))]
138-
[ColumnIndex<PgRow> + for<'q> ColumnIndex<PgStatement<'q>> + ColumnIndex<MySqlRow> + for<'q> ColumnIndex<MySqlStatement<'q>> + ColumnIndex<MssqlRow> + for<'q> ColumnIndex<MssqlStatement<'q>>],
139-
140-
// 3 databases - postgres, mysql, sqlite
141-
#[cfg(all(feature = "postgres", feature = "mysql", feature = "sqlite", not(any(feature = "mssql", feature = "odbc"))))]
142-
[ColumnIndex<PgRow> + for<'q> ColumnIndex<PgStatement<'q>> + ColumnIndex<MySqlRow> + for<'q> ColumnIndex<MySqlStatement<'q>> + ColumnIndex<SqliteRow> + for<'q> ColumnIndex<SqliteStatement<'q>>],
143-
144-
// 3 databases - postgres, mysql, odbc
145-
#[cfg(all(feature = "postgres", feature = "mysql", feature = "odbc", not(any(feature = "mssql", feature = "sqlite"))))]
146-
[ColumnIndex<PgRow> + for<'q> ColumnIndex<PgStatement<'q>> + ColumnIndex<MySqlRow> + for<'q> ColumnIndex<MySqlStatement<'q>> + ColumnIndex<OdbcRow> + for<'q> ColumnIndex<OdbcStatement<'q>>],
147-
148-
// 3 databases - postgres, mssql, sqlite
149-
#[cfg(all(feature = "postgres", feature = "mssql", feature = "sqlite", not(any(feature = "mysql", feature = "odbc"))))]
150-
[ColumnIndex<PgRow> + for<'q> ColumnIndex<PgStatement<'q>> + ColumnIndex<MssqlRow> + for<'q> ColumnIndex<MssqlStatement<'q>> + ColumnIndex<SqliteRow> + for<'q> ColumnIndex<SqliteStatement<'q>>],
151-
152-
// 3 databases - postgres, mssql, odbc
153-
#[cfg(all(feature = "postgres", feature = "mssql", feature = "odbc", not(any(feature = "mysql", feature = "sqlite"))))]
154-
[ColumnIndex<PgRow> + for<'q> ColumnIndex<PgStatement<'q>> + ColumnIndex<MssqlRow> + for<'q> ColumnIndex<MssqlStatement<'q>> + ColumnIndex<OdbcRow> + for<'q> ColumnIndex<OdbcStatement<'q>>],
155-
156-
// 3 databases - postgres, sqlite, odbc
157-
#[cfg(all(feature = "postgres", feature = "sqlite", feature = "odbc", not(any(feature = "mysql", feature = "mssql"))))]
158-
[ColumnIndex<PgRow> + for<'q> ColumnIndex<PgStatement<'q>> + ColumnIndex<SqliteRow> + for<'q> ColumnIndex<SqliteStatement<'q>> + ColumnIndex<OdbcRow> + for<'q> ColumnIndex<OdbcStatement<'q>>],
159-
160-
// 3 databases - mysql, mssql, sqlite
161-
#[cfg(all(feature = "mysql", feature = "mssql", feature = "sqlite", not(any(feature = "postgres", feature = "odbc"))))]
162-
[ColumnIndex<MySqlRow> + for<'q> ColumnIndex<MySqlStatement<'q>> + ColumnIndex<MssqlRow> + for<'q> ColumnIndex<MssqlStatement<'q>> + ColumnIndex<SqliteRow> + for<'q> ColumnIndex<SqliteStatement<'q>>],
163-
164-
// 3 databases - mysql, mssql, odbc
165-
#[cfg(all(feature = "mysql", feature = "mssql", feature = "odbc", not(any(feature = "postgres", feature = "sqlite"))))]
166-
[ColumnIndex<MySqlRow> + for<'q> ColumnIndex<MySqlStatement<'q>> + ColumnIndex<MssqlRow> + for<'q> ColumnIndex<MssqlStatement<'q>> + ColumnIndex<OdbcRow> + for<'q> ColumnIndex<OdbcStatement<'q>>],
167-
168-
// 3 databases - mysql, sqlite, odbc
169-
#[cfg(all(feature = "mysql", feature = "sqlite", feature = "odbc", not(any(feature = "postgres", feature = "mssql"))))]
170-
[ColumnIndex<MySqlRow> + for<'q> ColumnIndex<MySqlStatement<'q>> + ColumnIndex<SqliteRow> + for<'q> ColumnIndex<SqliteStatement<'q>> + ColumnIndex<OdbcRow> + for<'q> ColumnIndex<OdbcStatement<'q>>],
171-
172-
// 3 databases - mssql, sqlite, odbc
173-
#[cfg(all(feature = "mssql", feature = "sqlite", feature = "odbc", not(any(feature = "postgres", feature = "mysql"))))]
174-
[ColumnIndex<MssqlRow> + for<'q> ColumnIndex<MssqlStatement<'q>> + ColumnIndex<SqliteRow> + for<'q> ColumnIndex<SqliteStatement<'q>> + ColumnIndex<OdbcRow> + for<'q> ColumnIndex<OdbcStatement<'q>>],
175103

176-
// 2 databases - postgres, mysql
177-
#[cfg(all(feature = "postgres", feature = "mysql", not(any(feature = "mssql", feature = "sqlite", feature = "odbc"))))]
178-
[ColumnIndex<PgRow> + for<'q> ColumnIndex<PgStatement<'q>> + ColumnIndex<MySqlRow> + for<'q> ColumnIndex<MySqlStatement<'q>>],
179-
180-
// 2 databases - postgres, mssql
181-
#[cfg(all(feature = "postgres", feature = "mssql", not(any(feature = "mysql", feature = "sqlite", feature = "odbc"))))]
182-
[ColumnIndex<PgRow> + for<'q> ColumnIndex<PgStatement<'q>> + ColumnIndex<MssqlRow> + for<'q> ColumnIndex<MssqlStatement<'q>>],
183-
184-
// 2 databases - postgres, sqlite
185-
#[cfg(all(feature = "postgres", feature = "sqlite", not(any(feature = "mysql", feature = "mssql", feature = "odbc"))))]
186-
[ColumnIndex<PgRow> + for<'q> ColumnIndex<PgStatement<'q>> + ColumnIndex<SqliteRow> + for<'q> ColumnIndex<SqliteStatement<'q>>],
187-
188-
// 2 databases - postgres, odbc
189-
#[cfg(all(feature = "postgres", feature = "odbc", not(any(feature = "mysql", feature = "mssql", feature = "sqlite"))))]
190-
[ColumnIndex<PgRow> + for<'q> ColumnIndex<PgStatement<'q>> + ColumnIndex<OdbcRow> + for<'q> ColumnIndex<OdbcStatement<'q>>],
191-
192-
// 2 databases - mysql, mssql
193-
#[cfg(all(feature = "mysql", feature = "mssql", not(any(feature = "postgres", feature = "sqlite", feature = "odbc"))))]
194-
[ColumnIndex<MySqlRow> + for<'q> ColumnIndex<MySqlStatement<'q>> + ColumnIndex<MssqlRow> + for<'q> ColumnIndex<MssqlStatement<'q>>],
195-
196-
// 2 databases - mysql, sqlite
197-
#[cfg(all(feature = "mysql", feature = "sqlite", not(any(feature = "postgres", feature = "mssql", feature = "odbc"))))]
198-
[ColumnIndex<MySqlRow> + for<'q> ColumnIndex<MySqlStatement<'q>> + ColumnIndex<SqliteRow> + for<'q> ColumnIndex<SqliteStatement<'q>>],
199-
200-
// 2 databases - mysql, odbc
201-
#[cfg(all(feature = "mysql", feature = "odbc", not(any(feature = "postgres", feature = "mssql", feature = "sqlite"))))]
202-
[ColumnIndex<MySqlRow> + for<'q> ColumnIndex<MySqlStatement<'q>> + ColumnIndex<OdbcRow> + for<'q> ColumnIndex<OdbcStatement<'q>>],
203-
204-
// 2 databases - mssql, sqlite
205-
#[cfg(all(feature = "mssql", feature = "sqlite", not(any(feature = "postgres", feature = "mysql", feature = "odbc"))))]
206-
[ColumnIndex<MssqlRow> + for<'q> ColumnIndex<MssqlStatement<'q>> + ColumnIndex<SqliteRow> + for<'q> ColumnIndex<SqliteStatement<'q>>],
207-
208-
// 2 databases - mssql, odbc
209-
#[cfg(all(feature = "mssql", feature = "odbc", not(any(feature = "postgres", feature = "mysql", feature = "sqlite"))))]
210-
[ColumnIndex<MssqlRow> + for<'q> ColumnIndex<MssqlStatement<'q>> + ColumnIndex<OdbcRow> + for<'q> ColumnIndex<OdbcStatement<'q>>],
211-
212-
// 2 databases - sqlite, odbc
213-
#[cfg(all(feature = "sqlite", feature = "odbc", not(any(feature = "postgres", feature = "mysql", feature = "mssql"))))]
214-
[ColumnIndex<SqliteRow> + for<'q> ColumnIndex<SqliteStatement<'q>> + ColumnIndex<OdbcRow> + for<'q> ColumnIndex<OdbcStatement<'q>>],
215-
216-
// 1 database - postgres
217-
#[cfg(all(feature = "postgres", not(any(feature = "mysql", feature = "mssql", feature = "sqlite", feature = "odbc"))))]
218-
[ColumnIndex<PgRow> + for<'q> ColumnIndex<PgStatement<'q>>],
104+
// Convert the database list format to tokens suitable for recursion
105+
(@parse_databases [ $(($feat:literal, $row:ident, $stmt:ident)),* ] $callback:ident) => {
106+
for_all_feature_combinations!(@recurse [] [] [$( ($feat, $row, $stmt) )*] $callback);
107+
};
219108

220-
// 1 database - mysql
221-
#[cfg(all(feature = "mysql", not(any(feature = "postgres", feature = "mssql", feature = "sqlite", feature = "odbc"))))]
222-
[ColumnIndex<MySqlRow> + for<'q> ColumnIndex<MySqlStatement<'q>>],
109+
// Recursive case: process each database
110+
(@recurse [$($yes:tt)*] [$($no:tt)*] [($feat:literal, $row:ident, $stmt:ident) $($rest:tt)*] $callback:ident) => {
111+
// Include this database
112+
for_all_feature_combinations!(@recurse
113+
[$($yes)* ($feat, $row, $stmt)]
114+
[$($no)*]
115+
[$($rest)*]
116+
$callback
117+
);
118+
119+
// Exclude this database
120+
for_all_feature_combinations!(@recurse
121+
[$($yes)*]
122+
[$($no)* $feat]
123+
[$($rest)*]
124+
$callback
125+
);
126+
};
223127

224-
// 1 database - mssql
225-
#[cfg(all(feature = "mssql", not(any(feature = "postgres", feature = "mysql", feature = "sqlite", feature = "odbc"))))]
226-
[ColumnIndex<MssqlRow> + for<'q> ColumnIndex<MssqlStatement<'q>>],
128+
// Base case: no more databases, generate the implementation if we have at least one
129+
(@recurse [$(($feat:literal, $row:ident, $stmt:ident))+] [$($no:literal)*] [] $callback:ident) => {
130+
#[cfg(all($(feature = $feat),+ $(, not(feature = $no))*))]
131+
$callback! { $(($row, $stmt)),+ }
132+
};
133+
134+
// Base case: no databases selected, skip
135+
(@recurse [] [$($no:literal)*] [] $callback:ident) => {
136+
// Don't generate anything for zero databases
137+
};
138+
}
227139

228-
// 1 database - sqlite
229-
#[cfg(all(feature = "sqlite", not(any(feature = "postgres", feature = "mysql", feature = "mssql", feature = "odbc"))))]
230-
[ColumnIndex<SqliteRow> + for<'q> ColumnIndex<SqliteStatement<'q>>],
140+
// Callback macro that generates the actual trait and impl
141+
macro_rules! impl_any_column_index_for_databases {
142+
($(($row:ident, $stmt:ident)),+) => {
143+
pub trait AnyColumnIndex: $(ColumnIndex<$row> + for<'q> ColumnIndex<$stmt<'q>> +)+ Sized {}
231144

232-
// 1 database - odbc
233-
#[cfg(all(feature = "odbc", not(any(feature = "postgres", feature = "mysql", feature = "mssql", feature = "sqlite"))))]
234-
[ColumnIndex<OdbcRow> + for<'q> ColumnIndex<OdbcStatement<'q>>],
145+
impl<I: ?Sized> AnyColumnIndex for I
146+
where
147+
I: $(ColumnIndex<$row> + for<'q> ColumnIndex<$stmt<'q>> +)+ Sized
148+
{}
149+
};
235150
}
151+
152+
// Generate all combinations
153+
for_all_feature_combinations!(impl_any_column_index_for_databases);

0 commit comments

Comments
 (0)