Skip to content

Commit 6e0b01f

Browse files
committed
Allow Codec#product to behave correctly with parameter count
1 parent cb9e34b commit 6e0b01f

File tree

3 files changed

+100
-61
lines changed

3 files changed

+100
-61
lines changed

core/shared/src/main/scala/porcupine/codec.scala

Lines changed: 62 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -19,35 +19,48 @@ package porcupine
1919
import cats.Applicative
2020
import cats.ContravariantMonoidal
2121
import cats.InvariantMonoidal
22-
import cats.data.StateT
22+
import cats.data.{State, StateT}
2323
import cats.syntax.all.*
2424
import scodec.bits.ByteVector
25-
2625
import scala.deriving.Mirror
2726

2827
trait Encoder[A]:
2928
outer =>
3029

30+
def parameters: Int
31+
3132
def encode(a: A): List[LiteValue]
3233

33-
def either[B](right: Encoder[B]): Encoder[Either[A, B]] = new:
34-
def encode(aorb: Either[A, B]) = aorb match
35-
case Left(a) => outer.encode(a)
36-
case Right(b) => right.encode(b)
34+
// def either[B](right: Encoder[B]): Encoder[Either[A, B]] = new:
35+
// TODO figure out if this is reasonably implementable
36+
// def parameters: Int = ???
37+
//
38+
// def encode(aorb: Either[A, B]) = aorb match
39+
// case Left(a) => outer.encode(a)
40+
// case Right(b) => right.encode(b)
3741

3842
def opt: Encoder[Option[A]] =
39-
either(Codec.`null`).contramap(_.toLeft(None))
43+
// either(Codec.`null`).contramap(_.toLeft(None))
44+
new:
45+
def parameters = outer.parameters
46+
def encode(aopt: Option[A]) = aopt match
47+
case None => Codec.`null`.encode(None)
48+
case Some(a) => outer.encode(a)
4049

4150
object Encoder:
4251
given ContravariantMonoidal[Encoder] = new:
4352
def unit = Codec.unit
4453

4554
def product[A, B](fa: Encoder[A], fb: Encoder[B]) = new:
55+
def parameters =
56+
fa.parameters + fb.parameters
57+
4658
def encode(ab: (A, B)) =
4759
val (a, b) = ab
4860
fa.encode(a) ::: fb.encode(b)
4961

5062
def contramap[A, B](fa: Encoder[A])(f: B => A) = new:
63+
def parameters = fa.parameters
5164
def encode(b: B) = fa.encode(f(b))
5265

5366
trait Decoder[A]:
@@ -92,52 +105,54 @@ trait Codec[A] extends Encoder[A], Decoder[A]:
92105
def asEncoder: Encoder[A] = this
93106
def asDecoder: Decoder[A] = this
94107

95-
def either[B](right: Codec[B]): Codec[Either[A, B]] = new:
96-
def encode(aorb: Either[A, B]) =
97-
outer.asEncoder.either(right).encode(aorb)
108+
// def either[B](right: Codec[B]): Codec[Either[A, B]] = new:
109+
// def parameters: State[Int, String] =
110+
// outer.asEncoder.either(right).parameters
111+
//
112+
// def encode(aorb: Either[A, B]) =
113+
// outer.asEncoder.either(right).encode(aorb)
114+
//
115+
// def decode = outer.asDecoder.either(right).decode
98116

99-
def decode = outer.asDecoder.either(right).decode
100-
101-
override def opt: Codec[Option[A]] =
102-
either(Codec.`null`).imap(_.left.toOption)(_.toLeft(None))
117+
override def opt: Codec[Option[A]] = new:
118+
def parameters = outer.parameters
119+
def encode(aopt: Option[A]) = outer.asEncoder.opt.encode(aopt)
120+
def decode = outer.asDecoder.opt.decode
103121

104122
object Codec:
105-
val integer: Codec[Long] = new:
106-
def encode(l: Long) = LiteValue.Integer(l) :: Nil
107-
def decode = StateT {
108-
case LiteValue.Integer(l) :: tail => Right((tail, l))
109-
case other => Left(new RuntimeException(s"Expected integer, got ${other.headOption}"))
123+
extension [H](head: Codec[H])
124+
def *:[T <: Tuple](tail: Codec[T]): Codec[H *: T] = (head, tail).imapN(_ *: _) { case h *: t => (h, t) }
125+
126+
private final class Simple[T](
127+
name: String,
128+
apply: T => LiteValue,
129+
unapply: PartialFunction[LiteValue, T]
130+
) extends Codec[T] {
131+
override def parameters: Int = 1
132+
override def encode(a: T): List[LiteValue] = apply(a) :: Nil
133+
override def decode: StateT[Either[Throwable, *], List[LiteValue], T] = StateT {
134+
case unapply(l) :: tail => Right((tail, l))
135+
case other => Left(new RuntimeException(s"Expected $name, got ${other.headOption}"))
110136
}
137+
}
111138

112-
val real: Codec[Double] = new:
113-
def encode(d: Double) = LiteValue.Real(d) :: Nil
114-
def decode = StateT {
115-
case LiteValue.Real(d) :: tail => Right((tail, d))
116-
case other => Left(new RuntimeException(s"Expected real, got ${other.headOption}"))
117-
}
139+
val integer: Codec[Long] =
140+
new Simple("integer", LiteValue.Integer.apply, { case LiteValue.Integer(i) => i })
118141

119-
val text: Codec[String] = new:
120-
def encode(s: String) = LiteValue.Text(s) :: Nil
121-
def decode = StateT {
122-
case LiteValue.Text(s) :: tail => Right((tail, s))
123-
case other => Left(new RuntimeException(s"Expected text, got ${other.headOption}"))
124-
}
142+
val real: Codec[Double] =
143+
new Simple("real", LiteValue.Real.apply, { case LiteValue.Real(r) => r })
125144

126-
val blob: Codec[ByteVector] = new:
127-
def encode(b: ByteVector) = LiteValue.Blob(b) :: Nil
128-
def decode = StateT {
129-
case LiteValue.Blob(b) :: tail => Right((tail, b))
130-
case other => Left(new RuntimeException(s"Expected blob, got ${other.headOption}"))
131-
}
145+
val text: Codec[String] =
146+
new Simple("text", LiteValue.Text.apply, { case LiteValue.Text(t) => t })
132147

133-
val `null`: Codec[None.type] = new:
134-
def encode(n: None.type) = LiteValue.Null :: Nil
135-
def decode = StateT {
136-
case LiteValue.Null :: tail => Right((tail, None))
137-
case other => Left(new RuntimeException(s"Expected NULL, got ${other.headOption}"))
138-
}
148+
val blob: Codec[ByteVector] =
149+
new Simple("blob", LiteValue.Blob.apply, { case LiteValue.Blob(b) => b })
150+
151+
val `null`: Codec[None.type] =
152+
new Simple("NULL", _ => LiteValue.Null, { case LiteValue.Null => None })
139153

140154
def unit: Codec[Unit] = new:
155+
def parameters: Int = 0
141156
def encode(u: Unit) = Nil
142157
def decode = StateT.pure(())
143158

@@ -147,12 +162,16 @@ object Codec:
147162
def unit = Codec.unit
148163

149164
def product[A, B](fa: Codec[A], fb: Codec[B]) = new:
165+
def parameters =
166+
fa.parameters + fb.parameters
167+
150168
def encode(ab: (A, B)) =
151169
val (a, b) = ab
152170
fa.encode(a) ::: fb.encode(b)
153171

154172
def decode = fa.decode.product(fb.decode)
155173

156174
def imap[A, B](fa: Codec[A])(f: A => B)(g: B => A) = new:
175+
def parameters = fa.parameters
157176
def encode(b: B) = fa.encode(g(b))
158177
def decode = fa.decode.map(f)

core/shared/src/main/scala/porcupine/sql.scala

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ package porcupine
1919
import cats.ContravariantMonoidal
2020
import cats.Monoid
2121
import cats.arrow.Profunctor
22+
import cats.data.State
2223
import cats.syntax.all.*
23-
2424
import scala.quoted.Expr
2525
import scala.quoted.Exprs
2626
import scala.quoted.Quotes
@@ -34,24 +34,40 @@ object Query:
3434
def dimap[A, B, C, D](fab: Query[A, B])(f: C => A)(g: B => D) =
3535
Query(fab.sql, fab.encoder.contramap(f), fab.decoder.map(g))
3636

37-
final class Fragment[A](val fragment: String, val encoder: Encoder[A]):
38-
def command: Query[A, Unit] = Query(fragment, encoder, Codec.unit)
37+
final class Fragment[A](
38+
val parts: List[Either[String, Int]],
39+
val encoder: Encoder[A]
40+
):
41+
def sql: String = parts.foldMap {
42+
case Left(s) => s
43+
case Right(i) => ("?, " * (i - 1)) ++ "?"
44+
}
45+
46+
def command: Query[A, Unit] = Query(sql, encoder, Codec.unit)
47+
48+
def query[B](decoder: Decoder[B]): Query[A, B] = Query(sql, encoder, decoder)
3949

40-
def query[B](decoder: Decoder[B]): Query[A, B] = Query(fragment, encoder, decoder)
50+
def apply(a: A): Fragment[Unit] = Fragment(parts, encoder.contramap(_ => a))
4151

42-
def apply(a: A): Fragment[Unit] = Fragment(fragment, encoder.contramap(_ => a))
52+
def stripMargin: Fragment[A] = stripMargin('|')
4353

44-
def stripMargin: Fragment[A] = Fragment(fragment.stripMargin, encoder)
4554
def stripMargin(marginChar: Char): Fragment[A] =
46-
Fragment(fragment.stripMargin(marginChar), encoder)
55+
val head = parts.headOption
56+
val tail = parts.tail
57+
val ps = head.map {
58+
_.leftMap(_.stripMargin(marginChar))
59+
}.toList ++ tail.map {
60+
_.leftMap(str => str.takeWhile(_ != '\n') + str.dropWhile(_ != '\n').stripMargin(marginChar))
61+
}
62+
Fragment(ps, encoder)
4763

4864
object Fragment:
4965
given ContravariantMonoidal[Fragment] = new:
50-
val unit = Fragment("", Codec.unit)
66+
val unit = Fragment(List.empty, Codec.unit)
5167
def product[A, B](fa: Fragment[A], fb: Fragment[B]) =
52-
Fragment(fa.fragment + fb.fragment, (fa.encoder, fb.encoder).tupled)
68+
Fragment(fa.parts ++ fb.parts, (fa.encoder, fb.encoder).tupled)
5369
def contramap[A, B](fa: Fragment[A])(f: B => A) =
54-
Fragment(fa.fragment, fa.encoder.contramap(f))
70+
Fragment(fa.parts, fa.encoder.contramap(f))
5571

5672
given Monoid[Fragment[Unit]] = new:
5773
def empty = ContravariantMonoidal[Fragment].unit
@@ -73,11 +89,14 @@ private def sqlImpl(
7389

7490
val args = Varargs.unapply(argsExpr).toList.flatMap(_.toList)
7591

76-
val fragment = parts.zipAll(args, '{ "" }, '{ "" }).foldLeft('{ "" }) {
77-
case ('{ $acc: String }, ('{ $p: String }, '{ $s: String })) => '{ $acc + $p + $s }
78-
case ('{ $acc: String }, ('{ $p: String }, '{ $e: Encoder[t] })) => '{ $acc + $p + "?" }
79-
case ('{ $acc: String }, ('{ $p: String }, '{ $f: Fragment[t] })) =>
80-
'{ $acc + $p + $f.fragment }
92+
// TODO appending to `List` is slow
93+
val fragment = parts.zipAll(args, '{ "" }, '{ "" }).foldLeft('{ List.empty[Either[String, Int]] }) {
94+
case ('{ $acc: List[Either[String, Int]] }, ('{ $p: String }, '{ $s: String })) =>
95+
'{ $acc :+ Left($p) :+ Left($s) }
96+
case ('{ $acc: List[Either[String, Int]] }, ('{ $p: String }, '{ $e: Encoder[t] })) =>
97+
'{ $acc :+ Left($p) :+ Right($e.parameters) }
98+
case ('{ $acc: List[Either[String, Int]] }, ('{ $p: String }, '{ $f: Fragment[t] })) =>
99+
'{ $acc :+ Left($p) :++ $f.parts }
81100
}
82101

83102
val encoder = args.collect {
@@ -103,5 +122,5 @@ private def sqlImpl(
103122
}
104123

105124
(fragment, encoder) match
106-
case ('{ $s: String }, '{ $e: Encoder[a] }) => '{ Fragment[a]($s, $e) }
125+
case ('{ $s: List[Either[String, Int]] }, '{ $e: Encoder[a] }) => '{ Fragment[a]($s, $e) }
107126
case _ => sys.error("porcupine pricked itself")

core/shared/src/test/scala/porcupine/PorcupineTest.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,16 @@ import cats.effect.IOApp
2020
import cats.effect.IO
2121
import cats.syntax.all.*
2222
import scodec.bits.ByteVector
23-
2423
import Codec.*
2524

2625
object PorcupineTest extends IOApp.Simple:
2726

2827
def run = Database.open[IO](":memory:").use { db =>
28+
// TODO figure out why this is broken inside interpolation
29+
val q = `null` *: integer *: real *: text *: blob *: nil
2930
db.execute(sql"create table porcupine (n, i, r, t, b);".command) *>
3031
db.execute(
31-
sql"insert into porcupine values(${`null`}, $integer, $real, $text, $blob);".command,
32+
sql"insert into porcupine values(${ q });".command,
3233
(None, 42L, 3.14, "quill-pig", ByteVector(0, 1, 2, 3)),
3334
) *>
3435
db.unique(

0 commit comments

Comments
 (0)