|
| 1 | +type Pos = {x: Int with x >= 0} |
| 2 | + |
| 3 | +def safeDiv(x: Pos, y: Pos with y > 1): {res: Pos with res < x} = |
| 4 | + (x / y).runtimeChecked |
| 5 | + |
| 6 | +object SafeSeqs: |
| 7 | + opaque type SafeSeq[T] = Seq[T] |
| 8 | + object SafeSeq: |
| 9 | + def fromSeq[T](seq: Seq[T]): SafeSeq[T] = seq |
| 10 | + def apply[T](elems: T*): SafeSeq[T] = fromSeq(elems) |
| 11 | + extension [T](a: SafeSeq[T]) |
| 12 | + def len: Pos = a.length.runtimeChecked |
| 13 | + def apply(i: Pos with i < a.len): T = a(i) |
| 14 | + def splitAt(i: Pos with i < a.len): (SafeSeq[T], SafeSeq[T]) = a.splitAt(i) |
| 15 | + def ++(that: SafeSeq[T]): SafeSeq[T] = a ++ that |
| 16 | + extension [T](a: SafeSeq[T] with a.len > 0) |
| 17 | + def head: T = a.head |
| 18 | + def tail: SafeSeq[T] = a.tail |
| 19 | + |
| 20 | +import SafeSeqs.* |
| 21 | + |
| 22 | +def merge[T: Ordering as ord](left: SafeSeq[T], right: SafeSeq[T]): SafeSeq[T] = |
| 23 | + (left, right) match |
| 24 | + case (left: SafeSeq[T], right: SafeSeq[T] with right.len > 0) => |
| 25 | + if ord.lt(left.head, right.head) then SafeSeq(left.head) ++ merge(left.tail, right) |
| 26 | + else SafeSeq(right.head) ++ merge(left, right.tail) |
| 27 | + case _ => |
| 28 | + if left.len == 0 then right |
| 29 | + else left |
| 30 | + |
| 31 | +def mergeSort[T: Ordering](list: SafeSeq[T]): SafeSeq[T] = |
| 32 | + val len = list.len |
| 33 | + val middle = safeDiv(len, 2) |
| 34 | + if middle == 0 then |
| 35 | + list |
| 36 | + else |
| 37 | + val (left, right) = list.splitAt(middle) |
| 38 | + merge(mergeSort(left), mergeSort(right)) |
| 39 | + |
| 40 | +@main def Test = |
| 41 | + val nums = SafeSeq(5, 3, 8, 1, 2, 7, 4, 6) |
| 42 | + val sortedNums = mergeSort(nums) |
| 43 | + println(s"Unsorted: $nums") |
| 44 | + println(s"Sorted: $sortedNums") |
| 45 | + |
| 46 | + val nums2 = SafeSeq(7, 4, 5, 3, 2, 6, 1) |
| 47 | + val sortedNums2 = mergeSort(nums2) |
| 48 | + println(s"Unsorted: $nums2") |
| 49 | + println(s"Sorted: $sortedNums2") |
0 commit comments