Skip to content

Commit e77e1f9

Browse files
authored
Introduce query builder DSL (#4812)
* Introduce query builder DSL * Add haddocks * Add CHANGELOG entry * Typo
1 parent ca27f28 commit e77e1f9

File tree

11 files changed

+421
-212
lines changed

11 files changed

+421
-212
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add postgres dynamic query builder
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module Wire.PaginationState (PaginationState (..)) where
2+
3+
import Data.Time.Clock
4+
import Imports
5+
6+
data PaginationState id
7+
= PaginationSortByName (Maybe (Text, id))
8+
| PaginationSortByCreatedAt (Maybe (UTCTime, id))

libs/wire-subsystems/src/Wire/Postgres.hs

Lines changed: 209 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,55 @@
1-
module Wire.Postgres where
1+
module Wire.Postgres
2+
( -- | This module provides a composable DSL for constructing postgres
3+
-- statements. Queries are assembled from smaller 'QueryFragment's that
4+
-- carry both their SQL text and parameter encoders.
5+
--
6+
-- Typical usage involves combining fragments with monoidal operators and
7+
-- building a final 'Statement' using 'buildStatement'.
8+
--
9+
-- Example:
10+
--
11+
-- > let q =
12+
-- > literal "select * from users"
13+
-- > <> where_ [like "name" "alice"]
14+
-- > <> orderBy [("created_at", Desc)]
15+
-- > <> limit (10 :: Int)
16+
-- > in buildStatement q userDecoder
17+
--
18+
-- Note that the encoders are specialised to the specific values passed when
19+
-- constructing the fragments, so they don't require further values. The
20+
-- resulting statement can be run with something like @runStatement ()@.
221

3-
import Hasql.Pipeline
22+
-- * Runners
23+
runStatement,
24+
runTransaction,
25+
runPipeline,
26+
27+
-- * Query builder
28+
QueryFragment,
29+
literal,
30+
where_,
31+
like,
32+
Clause,
33+
mkClause,
34+
clause,
35+
clause1,
36+
orderBy,
37+
limit,
38+
buildStatement,
39+
)
40+
where
41+
42+
import Control.Monad.Trans.State
43+
import Data.Functor.Contravariant
44+
import Data.Id
45+
import Data.Text qualified as T
46+
import Data.Text.Encoding qualified as T
47+
import Data.Time.Clock
48+
import Hasql.Decoders qualified as Dec
49+
import Hasql.Encoders qualified as Enc
50+
import Hasql.Pipeline (Pipeline)
451
import Hasql.Pool
5-
import Hasql.Session qualified as Session
52+
import Hasql.Session
653
import Hasql.Statement
754
import Hasql.Transaction (Transaction)
855
import Hasql.Transaction.Sessions
@@ -11,6 +58,7 @@ import Imports
1158
import Polysemy
1259
import Polysemy.Error (Error, throw)
1360
import Polysemy.Input
61+
import Wire.API.Pagination
1462

1563
runStatement ::
1664
( Member (Input Pool) r,
@@ -22,7 +70,7 @@ runStatement ::
2270
Sem r b
2371
runStatement a stmt = do
2472
pool <- input
25-
liftIO (use pool (Session.statement a stmt)) >>= either throw pure
73+
liftIO (use pool (statement a stmt)) >>= either throw pure
2674

2775
runTransaction ::
2876
(Member (Input Pool) r, Member (Embed IO) r, Member (Error UsageError) r) =>
@@ -43,4 +91,160 @@ runPipeline ::
4391
Sem r a
4492
runPipeline p = do
4593
pool <- input
46-
liftIO (use pool $ Session.pipeline p) >>= either throw pure
94+
liftIO (use pool $ pipeline p) >>= either throw pure
95+
96+
class PostgresValue a where
97+
postgresType :: Text
98+
postgresValue :: a -> Enc.Value ()
99+
100+
instance PostgresValue (Id a) where
101+
postgresType = "uuid"
102+
postgresValue u = const (toUUID u) >$< Enc.uuid
103+
104+
instance PostgresValue Text where
105+
postgresType = "text"
106+
postgresValue x = const x >$< Enc.text
107+
108+
instance PostgresValue UTCTime where
109+
postgresType = "timestamptz"
110+
postgresValue t = const t >$< Enc.timestamptz
111+
112+
instance PostgresValue Int32 where
113+
postgresType = "int"
114+
postgresValue n = const n >$< Enc.int4
115+
116+
--------------------------------------------------------------------------------
117+
-- Query builder DSL
118+
119+
data QueryFragment = QueryFragment
120+
{ query :: State Int Text,
121+
encoder :: Enc.Params ()
122+
}
123+
124+
joinFragments :: Text -> QueryFragment -> QueryFragment -> QueryFragment
125+
joinFragments sep f1 f2 =
126+
QueryFragment
127+
{ query = separate <$> f1.query <*> f2.query,
128+
encoder = f1.encoder <> f2.encoder
129+
}
130+
where
131+
separate "" q = q
132+
separate q "" = q
133+
separate q1 q2 = q1 <> sep <> q2
134+
135+
instance Semigroup QueryFragment where
136+
(<>) = joinFragments " "
137+
138+
instance Monoid QueryFragment where
139+
mempty =
140+
QueryFragment
141+
{ query = pure "",
142+
encoder = mempty
143+
}
144+
145+
literal :: Text -> QueryFragment
146+
literal q =
147+
QueryFragment
148+
{ query = pure q,
149+
encoder = mempty
150+
}
151+
152+
-- | Construct a WHERE clause from a list of fragments.
153+
where_ :: [QueryFragment] -> QueryFragment
154+
where_ frags = literal "where" <> foldr (joinFragments " and ") mempty frags
155+
156+
like :: Text -> Text -> QueryFragment
157+
like field pat =
158+
QueryFragment
159+
{ query = do
160+
i <- nextIndex
161+
pure $ field <> " ilike ($" <> T.pack (show i) <> " :: text)",
162+
encoder = const (fuzzy pat) >$< Enc.param (Enc.nonNullable Enc.text)
163+
}
164+
165+
-- | A portion of a WHERE clause with multiple values. The monoidal operation
166+
-- of this type can be used to combine values into one clause. For example:
167+
--
168+
-- > clause "=" (mkClause "foo" 3 <> mkClause "bar" 4)
169+
--
170+
-- generates a pattern that will end up being expanded as @"(foo, bar) = (3, 4)"@.
171+
data Clause = Clause
172+
{ fields :: [Text],
173+
types :: [Text],
174+
encoder :: Enc.Params ()
175+
}
176+
177+
instance Semigroup Clause where
178+
cl1 <> cl2 =
179+
Clause
180+
{ fields = cl1.fields <> cl2.fields,
181+
types = cl1.types <> cl2.types,
182+
encoder = cl1.encoder <> cl2.encoder
183+
}
184+
185+
instance Monoid Clause where
186+
mempty =
187+
Clause
188+
{ fields = mempty,
189+
types = mempty,
190+
encoder = mempty
191+
}
192+
193+
mkClause :: forall a. (PostgresValue a) => Text -> a -> Clause
194+
mkClause field value =
195+
Clause
196+
{ fields = [field],
197+
types = [postgresType @a],
198+
encoder = Enc.param (Enc.nonNullable (postgresValue value))
199+
}
200+
201+
-- | Convert a 'Clause' to a 'QueryFragment'.
202+
clause :: Text -> Clause -> QueryFragment
203+
clause op cl =
204+
QueryFragment
205+
{ query = do
206+
types <-
207+
fmap wrap $
208+
for cl.types $ \ty -> do
209+
i <- nextIndex
210+
pure $ "$" <> T.pack (show i) <> " :: " <> ty <> ""
211+
let fields = wrap cl.fields
212+
pure $ fields <> " " <> op <> " " <> types,
213+
encoder = cl.encoder
214+
}
215+
where
216+
wrap :: [Text] -> Text
217+
wrap xs = "(" <> T.intercalate ", " xs <> ")"
218+
219+
-- | Fragment for a clause with a single value.
220+
clause1 :: forall a. (PostgresValue a) => Text -> Text -> a -> QueryFragment
221+
clause1 field op value = clause op (mkClause field value)
222+
223+
orderBy :: [(Text, SortOrder)] -> QueryFragment
224+
orderBy os =
225+
literal $
226+
"order by "
227+
<> T.intercalate ", " (map (\(field, o) -> field <> " " <> sortOrderClause o) os)
228+
229+
limit :: forall a. (PostgresValue a) => a -> QueryFragment
230+
limit n =
231+
QueryFragment
232+
{ query = do
233+
i <- nextIndex
234+
pure $ "limit ($" <> T.pack (show i) <> " :: " <> postgresType @a <> ")",
235+
encoder = Enc.param (Enc.nonNullable (postgresValue n))
236+
}
237+
238+
buildStatement :: QueryFragment -> Dec.Result b -> Statement () b
239+
buildStatement frag dec =
240+
Statement
241+
(T.encodeUtf8 (evalState frag.query 1))
242+
frag.encoder
243+
dec
244+
True
245+
246+
nextIndex :: State Int Int
247+
nextIndex = get <* modify succ
248+
249+
fuzzy :: Text -> Text
250+
fuzzy x = "%" <> x <> "%"

libs/wire-subsystems/src/Wire/UserGroupStore.hs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,30 @@ module Wire.UserGroupStore where
44

55
import Data.Id
66
import Data.Json.Util
7+
import Data.Time.Clock
78
import Data.Vector
89
import Imports
910
import Polysemy
1011
import Wire.API.Pagination
1112
import Wire.API.User.Profile
1213
import Wire.API.UserGroup
1314
import Wire.API.UserGroup.Pagination
15+
import Wire.PaginationState
1416

1517
data UserGroupPageRequest = UserGroupPageRequest
1618
{ team :: TeamId,
1719
searchString :: Maybe Text,
18-
paginationState :: PaginationState,
20+
paginationState :: PaginationState UserGroupId,
1921
sortOrder :: SortOrder,
2022
pageSize :: PageSize,
2123
includeMemberCount :: Bool,
2224
includeChannels :: Bool
2325
}
2426

25-
data PaginationState = PaginationSortByName (Maybe (UserGroupName, UserGroupId)) | PaginationSortByCreatedAt (Maybe (UTCTimeMillis, UserGroupId))
27+
userGroupCreatedAtPaginationState :: UserGroup_ f -> (UTCTime, UserGroupId)
28+
userGroupCreatedAtPaginationState ug = (fromUTCTimeMillis ug.createdAt, ug.id_)
2629

27-
userGroupCreatedAtPaginationState :: UserGroup_ f -> (UTCTimeMillis, UserGroupId)
28-
userGroupCreatedAtPaginationState ug = (ug.createdAt, ug.id_)
29-
30-
toSortBy :: PaginationState -> SortBy
30+
toSortBy :: PaginationState UserGroupId -> SortBy
3131
toSortBy = \case
3232
PaginationSortByName _ -> SortByName
3333
PaginationSortByCreatedAt _ -> SortByCreatedAt

0 commit comments

Comments
 (0)