@@ -120,6 +120,15 @@ def scope_association_reflection(association)
120
120
end
121
121
end
122
122
123
+ # A wrapper to distinguish CTE joins from other nodes.
124
+ class CTEJoin # :nodoc:
125
+ attr_reader :name
126
+
127
+ def initialize ( name )
128
+ @name = name
129
+ end
130
+ end
131
+
123
132
FROZEN_EMPTY_ARRAY = [ ] . freeze
124
133
FROZEN_EMPTY_HASH = { } . freeze
125
134
@@ -1535,13 +1544,13 @@ def each_join_dependencies(join_dependencies = build_join_dependencies, &block)
1535
1544
end
1536
1545
1537
1546
def build_join_dependencies
1538
- associations = joins_values | left_outer_joins_values
1539
- associations |= eager_load_values unless eager_load_values . empty?
1540
- associations |= includes_values unless includes_values . empty?
1547
+ joins = joins_values | left_outer_joins_values
1548
+ joins |= eager_load_values unless eager_load_values . empty?
1549
+ joins |= includes_values unless includes_values . empty?
1541
1550
1542
1551
join_dependencies = [ ]
1543
1552
join_dependencies . unshift construct_join_dependency (
1544
- select_association_list ( associations , join_dependencies ) , nil
1553
+ select_named_joins ( joins , join_dependencies ) , nil
1545
1554
)
1546
1555
end
1547
1556
@@ -1598,6 +1607,18 @@ def build_from
1598
1607
end
1599
1608
end
1600
1609
1610
+ def select_named_joins ( join_names , stashed_joins = nil , &block )
1611
+ cte_joins , associations = join_names . partition do |join_name |
1612
+ Symbol === join_name && with_values . any? { _1 . key? ( join_name ) }
1613
+ end
1614
+
1615
+ cte_joins . each do |cte_name |
1616
+ block &.call ( CTEJoin . new ( cte_name ) )
1617
+ end
1618
+
1619
+ select_association_list ( associations , stashed_joins , &block )
1620
+ end
1621
+
1601
1622
def select_association_list ( associations , stashed_joins = nil )
1602
1623
result = [ ]
1603
1624
associations . each do |association |
@@ -1618,12 +1639,16 @@ def build_join_buckets
1618
1639
1619
1640
unless left_outer_joins_values . empty?
1620
1641
stashed_left_joins = [ ]
1621
- left_joins = select_association_list ( left_outer_joins_values , stashed_left_joins ) do
1622
- raise ArgumentError , "only Hash, Symbol and Array are allowed"
1642
+ left_joins = select_named_joins ( left_outer_joins_values , stashed_left_joins ) do |left_join |
1643
+ if left_join . is_a? ( CTEJoin )
1644
+ buckets [ :join_node ] << build_with_join_node ( left_join . name , Arel ::Nodes ::OuterJoin )
1645
+ else
1646
+ raise ArgumentError , "only Hash, Symbol and Array are allowed"
1647
+ end
1623
1648
end
1624
1649
1625
1650
if joins_values . empty?
1626
- buckets [ :association_join ] = left_joins
1651
+ buckets [ :named_join ] = left_joins
1627
1652
buckets [ :stashed_join ] = stashed_left_joins
1628
1653
return buckets , Arel ::Nodes ::OuterJoin
1629
1654
else
@@ -1649,9 +1674,11 @@ def build_join_buckets
1649
1674
end
1650
1675
end
1651
1676
1652
- buckets [ :association_join ] = select_association_list ( joins , buckets [ :stashed_join ] ) do |join |
1677
+ buckets [ :named_join ] = select_named_joins ( joins , buckets [ :stashed_join ] ) do |join |
1653
1678
if join . is_a? ( Arel ::Nodes ::Join )
1654
1679
buckets [ :join_node ] << join
1680
+ elsif join . is_a? ( CTEJoin )
1681
+ buckets [ :join_node ] << build_with_join_node ( join . name )
1655
1682
else
1656
1683
raise "unknown class: %s" % join . class . name
1657
1684
end
@@ -1668,16 +1695,16 @@ def build_joins(join_sources, aliases = nil)
1668
1695
1669
1696
buckets , join_type = build_join_buckets
1670
1697
1671
- association_joins = buckets [ :association_join ]
1672
- stashed_joins = buckets [ :stashed_join ]
1673
- leading_joins = buckets [ :leading_join ]
1674
- join_nodes = buckets [ :join_node ]
1698
+ named_joins = buckets [ :named_join ]
1699
+ stashed_joins = buckets [ :stashed_join ]
1700
+ leading_joins = buckets [ :leading_join ]
1701
+ join_nodes = buckets [ :join_node ]
1675
1702
1676
1703
join_sources . concat ( leading_joins ) unless leading_joins . empty?
1677
1704
1678
- unless association_joins . empty? && stashed_joins . empty?
1705
+ unless named_joins . empty? && stashed_joins . empty?
1679
1706
alias_tracker = alias_tracker ( leading_joins + join_nodes , aliases )
1680
- join_dependency = construct_join_dependency ( association_joins , join_type )
1707
+ join_dependency = construct_join_dependency ( named_joins , join_type )
1681
1708
join_sources . concat ( join_dependency . join_constraints ( stashed_joins , alias_tracker , references_values ) )
1682
1709
end
1683
1710
@@ -1721,6 +1748,14 @@ def build_with_value_from_hash(hash)
1721
1748
end
1722
1749
end
1723
1750
1751
+ def build_with_join_node ( name , kind = Arel ::Nodes ::InnerJoin )
1752
+ with_table = Arel ::Table . new ( name )
1753
+
1754
+ table . join ( with_table , kind ) . on (
1755
+ with_table [ klass . model_name . to_s . foreign_key ] . eq ( table [ klass . primary_key ] )
1756
+ ) . join_sources . first
1757
+ end
1758
+
1724
1759
def arel_columns ( columns )
1725
1760
columns . flat_map do |field |
1726
1761
case field
0 commit comments