Skip to content

Commit fb48812

Browse files
committed
Merge PR rails#46843
2 parents 2fa766d + 15a7b1a commit fb48812

File tree

3 files changed

+83
-14
lines changed

3 files changed

+83
-14
lines changed

activerecord/CHANGELOG.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
* Make `.joins` / `.left_outer_joins` work with CTEs.
2+
3+
For example:
4+
5+
```ruby
6+
Post
7+
.with(commented_posts: Comment.select(:post_id).distinct)
8+
.joins(:commented_posts)
9+
#=> WITH (...) SELECT ... INNER JOIN commented_posts on posts.id = commented_posts.post_id
10+
```
11+
12+
*Vladimir Dementyev*
13+
114
* Add a load hook for `ActiveRecord::ConnectionAdapters::Mysql2Adapter`
215
(named `active_record_mysql2adapter`) to allow for overriding aspects of the
316
`ActiveRecord::ConnectionAdapters::Mysql2Adapter` class. This makes `Mysql2Adapter`

activerecord/lib/active_record/relation/query_methods.rb

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,15 @@ def scope_association_reflection(association)
120120
end
121121
end
122122

123+
# A wrapper to distinguish CTE joins from other nodes.
124+
class CTEJoin # :nodoc:
125+
attr_reader :name
126+
127+
def initialize(name)
128+
@name = name
129+
end
130+
end
131+
123132
FROZEN_EMPTY_ARRAY = [].freeze
124133
FROZEN_EMPTY_HASH = {}.freeze
125134

@@ -1535,13 +1544,13 @@ def each_join_dependencies(join_dependencies = build_join_dependencies, &block)
15351544
end
15361545

15371546
def build_join_dependencies
1538-
associations = joins_values | left_outer_joins_values
1539-
associations |= eager_load_values unless eager_load_values.empty?
1540-
associations |= includes_values unless includes_values.empty?
1547+
joins = joins_values | left_outer_joins_values
1548+
joins |= eager_load_values unless eager_load_values.empty?
1549+
joins |= includes_values unless includes_values.empty?
15411550

15421551
join_dependencies = []
15431552
join_dependencies.unshift construct_join_dependency(
1544-
select_association_list(associations, join_dependencies), nil
1553+
select_named_joins(joins, join_dependencies), nil
15451554
)
15461555
end
15471556

@@ -1598,6 +1607,18 @@ def build_from
15981607
end
15991608
end
16001609

1610+
def select_named_joins(join_names, stashed_joins = nil, &block)
1611+
cte_joins, associations = join_names.partition do |join_name|
1612+
Symbol === join_name && with_values.any? { _1.key?(join_name) }
1613+
end
1614+
1615+
cte_joins.each do |cte_name|
1616+
block&.call(CTEJoin.new(cte_name))
1617+
end
1618+
1619+
select_association_list(associations, stashed_joins, &block)
1620+
end
1621+
16011622
def select_association_list(associations, stashed_joins = nil)
16021623
result = []
16031624
associations.each do |association|
@@ -1618,12 +1639,16 @@ def build_join_buckets
16181639

16191640
unless left_outer_joins_values.empty?
16201641
stashed_left_joins = []
1621-
left_joins = select_association_list(left_outer_joins_values, stashed_left_joins) do
1622-
raise ArgumentError, "only Hash, Symbol and Array are allowed"
1642+
left_joins = select_named_joins(left_outer_joins_values, stashed_left_joins) do |left_join|
1643+
if left_join.is_a?(CTEJoin)
1644+
buckets[:join_node] << build_with_join_node(left_join.name, Arel::Nodes::OuterJoin)
1645+
else
1646+
raise ArgumentError, "only Hash, Symbol and Array are allowed"
1647+
end
16231648
end
16241649

16251650
if joins_values.empty?
1626-
buckets[:association_join] = left_joins
1651+
buckets[:named_join] = left_joins
16271652
buckets[:stashed_join] = stashed_left_joins
16281653
return buckets, Arel::Nodes::OuterJoin
16291654
else
@@ -1649,9 +1674,11 @@ def build_join_buckets
16491674
end
16501675
end
16511676

1652-
buckets[:association_join] = select_association_list(joins, buckets[:stashed_join]) do |join|
1677+
buckets[:named_join] = select_named_joins(joins, buckets[:stashed_join]) do |join|
16531678
if join.is_a?(Arel::Nodes::Join)
16541679
buckets[:join_node] << join
1680+
elsif join.is_a?(CTEJoin)
1681+
buckets[:join_node] << build_with_join_node(join.name)
16551682
else
16561683
raise "unknown class: %s" % join.class.name
16571684
end
@@ -1668,16 +1695,16 @@ def build_joins(join_sources, aliases = nil)
16681695

16691696
buckets, join_type = build_join_buckets
16701697

1671-
association_joins = buckets[:association_join]
1672-
stashed_joins = buckets[:stashed_join]
1673-
leading_joins = buckets[:leading_join]
1674-
join_nodes = buckets[:join_node]
1698+
named_joins = buckets[:named_join]
1699+
stashed_joins = buckets[:stashed_join]
1700+
leading_joins = buckets[:leading_join]
1701+
join_nodes = buckets[:join_node]
16751702

16761703
join_sources.concat(leading_joins) unless leading_joins.empty?
16771704

1678-
unless association_joins.empty? && stashed_joins.empty?
1705+
unless named_joins.empty? && stashed_joins.empty?
16791706
alias_tracker = alias_tracker(leading_joins + join_nodes, aliases)
1680-
join_dependency = construct_join_dependency(association_joins, join_type)
1707+
join_dependency = construct_join_dependency(named_joins, join_type)
16811708
join_sources.concat(join_dependency.join_constraints(stashed_joins, alias_tracker, references_values))
16821709
end
16831710

@@ -1721,6 +1748,14 @@ def build_with_value_from_hash(hash)
17211748
end
17221749
end
17231750

1751+
def build_with_join_node(name, kind = Arel::Nodes::InnerJoin)
1752+
with_table = Arel::Table.new(name)
1753+
1754+
table.join(with_table, kind).on(
1755+
with_table[klass.model_name.to_s.foreign_key].eq(table[klass.primary_key])
1756+
).join_sources.first
1757+
end
1758+
17241759
def arel_columns(columns)
17251760
columns.flat_map do |field|
17261761
case field

activerecord/test/cases/relation/with_test.rb

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,27 @@ def test_with_when_invalid_params_are_passed
6161
assert_raise(ArgumentError) { Post.with(posts_with_tags: nil).load }
6262
assert_raise(ArgumentError) { Post.with(posts_with_tags: [Post.where("tags_count > 0")]).load }
6363
end
64+
65+
def test_with_joins
66+
relation = Post
67+
.with(commented_posts: Comment.select(:post_id).distinct)
68+
.joins(:commented_posts)
69+
70+
assert_equal POSTS_WITH_COMMENTS, relation.order(:id).pluck(:id)
71+
end
72+
73+
def test_with_left_joins
74+
relation = Post
75+
.with(commented_posts: Comment.select(:post_id).distinct)
76+
.left_outer_joins(:commented_posts)
77+
.select("posts.*, commented_posts.post_id as has_comments")
78+
79+
records = relation.order(:id).to_a
80+
81+
# Make sure we load all records (thus, left outer join is used)
82+
assert_equal Post.count, records.size
83+
assert_equal POSTS_WITH_COMMENTS, records.filter_map { _1.id if _1.has_comments }
84+
end
6485
else
6586
def test_common_table_expressions_are_unsupported
6687
assert_raises ActiveRecord::StatementInvalid do

0 commit comments

Comments
 (0)