Skip to content
2 changes: 2 additions & 0 deletions lib/graphql/dataloader.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
require "graphql/dataloader/request"
require "graphql/dataloader/request_all"
require "graphql/dataloader/source"
require "graphql/dataloader/active_record_association_source"
require "graphql/dataloader/active_record_source"

module GraphQL
# This plugin supports Fiber-based concurrency, along with {GraphQL::Dataloader::Source}.
Expand Down
64 changes: 64 additions & 0 deletions lib/graphql/dataloader/active_record_association_source.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# frozen_string_literal: true
require "graphql/dataloader/source"
require "graphql/dataloader/active_record_source"

module GraphQL
class Dataloader
class ActiveRecordAssociationSource < GraphQL::Dataloader::Source
RECORD_SOURCE_CLASS = ActiveRecordSource

def initialize(association, scope: nil)
@association = association
@scope = nil
end

def load(record)
if (assoc = record.association(@association))&.loaded?
assoc.target
else
super
end
end

def fetch(records)
record_classes = Set.new.compare_by_identity
associated_classes = Set.new.compare_by_identity
records.each do |record|
if record_classes.add?(record.class)
reflection = record.class.reflect_on_association(@association)
if !reflection.polymorphic? && reflection.klass
associated_classes.add(reflection.klass)
end
end
end

available_records = []
associated_classes.each do |assoc_class|
already_loaded_records = dataloader.with(RECORD_SOURCE_CLASS, assoc_class).results.values
available_records.concat(already_loaded_records)
end

::ActiveRecord::Associations::Preloader.new(records: records, associations: @association, available_records: available_records, scope: @scope).call

loaded_associated_records = records.map { |r| r.public_send(@association) }
records_by_model = {}
loaded_associated_records.each do |record|
if record
updates = records_by_model[record.class] ||= {}
updates[record.id] = record
end
end

if @scope.nil?
# Don't cache records loaded via scope because they might have reduced `SELECT`s
# Could check .select_values here?
records_by_model.each do |model_class, updates|
dataloader.with(RECORD_SOURCE_CLASS, model_class).merge(updates)
end
end

loaded_associated_records
end
end
end
end
26 changes: 26 additions & 0 deletions lib/graphql/dataloader/active_record_source.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# frozen_string_literal: true
require "graphql/dataloader/source"

module GraphQL
class Dataloader
class ActiveRecordSource < GraphQL::Dataloader::Source
def initialize(model_class, find_by: model_class.primary_key)
@model_class = model_class
@find_by = find_by
@type_for_column = @model_class.type_for_attribute(@find_by)
end

def load(requested_key)
casted_key = @type_for_column.cast(requested_key)
super(casted_key)
end

def fetch(record_ids)
records = @model_class.where(@find_by => record_ids)
record_lookup = {}
records.each { |r| record_lookup[r.public_send(@find_by)] = r }
record_ids.map { |id| record_lookup[id] }
end
end
end
end
1 change: 1 addition & 0 deletions lib/graphql/schema/member.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
require 'graphql/schema/member/base_dsl_methods'
require 'graphql/schema/member/graphql_type_names'
require 'graphql/schema/member/has_ast_node'
require 'graphql/schema/member/has_dataloader'
require 'graphql/schema/member/has_directives'
require 'graphql/schema/member/has_deprecation_reason'
require 'graphql/schema/member/has_interfaces'
Expand Down
47 changes: 47 additions & 0 deletions lib/graphql/schema/member/has_dataloader.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# frozen_string_literal: true

module GraphQL
class Schema
class Member
module HasDataloader
# @return [GraphQL::Dataloader] The dataloader for the currently-running query
def dataloader
context.dataloader
end

# Find an object with ActiveRecord via {Dataloader::ActiveRecordSource}.
# @param model [Class<ActiveRecord::Base>]
# @param find_by_value [Object] Usually an `id`, might be another value if `find_by:` is also provided
# @param find_by [Symbol, String] A column name to look the record up by. (Defaults to the model's primary key.)
# @return [ActiveRecord::Base, nil]
def dataload_record(model, find_by_value, find_by: nil)
source = if find_by
dataloader.with(Dataloader::ActiveRecordSource, model, find_by: find_by)
else
dataloader.with(Dataloader::ActiveRecordSource, model)
end

source.load(find_by_value)
end

# Look up an associated record using a Rails association.
# @param association_name [Symbol] A `belongs_to` or `has_one` association. (If a `has_many` association is named here, it will be selected without pagination.)
# @param record [ActiveRecord::Base] The object that the association belongs to.
# @param scope [ActiveRecord::Relation] A scope to look up the associated record in
# @return [ActiveRecord::Base, nil] The associated record, if there is one
# @example Looking up a belongs_to on the current object
# dataload_association(:parent) # Equivalent to `object.parent`, but dataloaded
# @example Looking up an associated record on some other object
# dataload_association(:post, comment) # Equivalent to `comment.post`, but dataloaded
def dataload_association(association_name, record = object, scope: nil)
source = if scope
dataloader.with(Dataloader::ActiveRecordAssociationSource, association_name, scope: scope)
else
dataloader.with(Dataloader::ActiveRecordAssociationSource, association_name)
end
source.load(record)
end
end
end
end
end
6 changes: 1 addition & 5 deletions lib/graphql/schema/resolver.rb
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class Resolver
include Schema::Member::HasPath
extend Schema::Member::HasPath
extend Schema::Member::HasDirectives
include Schema::Member::HasDataloader

# @param object [Object] The application object that this field is being resolved on
# @param context [GraphQL::Query::Context]
Expand All @@ -50,11 +51,6 @@ def initialize(object:, context:, field:)
# @return [GraphQL::Query::Context]
attr_reader :context

# @return [GraphQL::Dataloader]
def dataloader
context.dataloader
end

# @return [GraphQL::Schema::Field]
attr_reader :field

Expand Down
74 changes: 74 additions & 0 deletions spec/graphql/dataloader/active_record_association_source_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# frozen_string_literal: true
require "spec_helper"

describe GraphQL::Dataloader::ActiveRecordAssociationSource do
if testing_rails?
it_dataloads "queries for associated records when the association isn't already loaded" do |d|
my_first_car = ::Album.find(2)
homey = ::Album.find(4)
log = with_active_record_log(colorize: false) do
vulfpeck, chon = d.with(GraphQL::Dataloader::ActiveRecordAssociationSource, :band).load_all([my_first_car, homey])
assert_equal "Vulfpeck", vulfpeck.name
assert_equal "Chon", chon.name
end

assert_includes log, '[["id", 1], ["id", 3]]'

toms_story = ::Album.find(3)
log = with_active_record_log(colorize: false) do
vulfpeck, chon, toms_story_band = d.with(GraphQL::Dataloader::ActiveRecordAssociationSource, :band).load_all([my_first_car, homey, toms_story])
assert_equal "Vulfpeck", vulfpeck.name
assert_equal "Chon", chon.name
assert_equal "Tom's Story", toms_story_band.name
end

assert_includes log, '[["id", 2]]'
end

it_dataloads "doesn't load records that are already cached by ActiveRecordSource" do |d|
d.with(GraphQL::Dataloader::ActiveRecordSource, Band).load_all([1,2,3])

my_first_car = ::Album.find(2)
homey = ::Album.find(4)
toms_story = ::Album.find(3)

log = with_active_record_log(colorize: false) do
vulfpeck, chon, toms_story_band = d.with(GraphQL::Dataloader::ActiveRecordAssociationSource, :band).load_all([my_first_car, homey, toms_story])
assert_equal "Vulfpeck", vulfpeck.name
assert_equal "Chon", chon.name
assert_equal "Tom's Story", toms_story_band.name
end

assert_equal "", log
end

it_dataloads "warms the cache for ActiveRecordSource" do |d|
my_first_car = ::Album.find(2)
homey = ::Album.find(4)
toms_story = ::Album.find(3)
d.with(GraphQL::Dataloader::ActiveRecordAssociationSource, :band).load_all([my_first_car, homey, toms_story])

log = with_active_record_log(colorize: false) do
d.with(GraphQL::Dataloader::ActiveRecordSource, Band).load_all([1,2,3])
end

assert_equal "", log
end

it_dataloads "doesn't pause when the association is already loaded" do |d|
source = d.with(GraphQL::Dataloader::ActiveRecordAssociationSource, :band)
assert_equal 0, source.results.size
assert_equal 0, source.pending.size

my_first_car = ::Album.find(2)
vulfpeck = my_first_car.band

vulfpeck2 = source.load(my_first_car)

assert_equal vulfpeck, vulfpeck2

assert_equal 0, source.results.size
assert_equal 0, source.pending.size
end
end
end
98 changes: 98 additions & 0 deletions spec/graphql/dataloader/active_record_source_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# frozen_string_literal: true
require "spec_helper"

describe GraphQL::Dataloader::ActiveRecordSource do
if testing_rails?
describe "finding by ID" do
it_dataloads "loads once, then returns from a cache when available" do |d|
log = with_active_record_log(colorize: false) do
r1 = d.with(GraphQL::Dataloader::ActiveRecordSource, Band).load(1)
assert_equal "Vulfpeck", r1.name
end

assert_includes log, 'SELECT "bands".* FROM "bands" WHERE "bands"."id" = ? [["id", 1]]'

log = with_active_record_log(colorize: false) do
r1 = d.with(GraphQL::Dataloader::ActiveRecordSource, Band).load(1)
assert_equal "Vulfpeck", r1.name
end

assert_equal "", log

log = with_active_record_log(colorize: false) do
records = d.with(GraphQL::Dataloader::ActiveRecordSource, Band).load_all([1, 99, 2, 3])
assert_equal ["Vulfpeck", nil, "Tom's Story", "Chon"], records.map { |r| r&.name }
end

assert_includes log, '[["id", 99], ["id", 2], ["id", 3]]'
end

it_dataloads "casts load values to the column type" do |d|
log = with_active_record_log(colorize: false) do
r1 = d.with(GraphQL::Dataloader::ActiveRecordSource, Band).load("1")
assert_equal "Vulfpeck", r1.name
end

assert_includes log, 'SELECT "bands".* FROM "bands" WHERE "bands"."id" = ? [["id", 1]]'

log = with_active_record_log(colorize: false) do
d.with(GraphQL::Dataloader::ActiveRecordSource, Band).load(1)
end

assert_equal "", log

log = with_active_record_log(colorize: false) do
d.with(GraphQL::Dataloader::ActiveRecordSource, Band).load("1")
end

assert_equal "", log
end
end

describe "finding by other columns" do
it_dataloads "uses the alternative primary key" do |d|
log = with_active_record_log(colorize: false) do
r1 = d.with(GraphQL::Dataloader::ActiveRecordSource, AlternativeBand).load("Vulfpeck")
assert_equal "Vulfpeck", r1.name
if Rails::VERSION::STRING > "8"
assert_equal 1, r1["id"]
else
assert_equal 1, r1._read_attribute("id")
end
end

assert_includes log, 'SELECT "bands".* FROM "bands" WHERE "bands"."name" = ? [["name", "Vulfpeck"]]'
end

it_dataloads "uses specified find_by columns" do |d|
log = with_active_record_log(colorize: false) do
r1 = d.with(GraphQL::Dataloader::ActiveRecordSource, Band, find_by: :name).load("Chon")
assert_equal "Chon", r1.name
assert_equal 3, r1.id
end

assert_includes log, 'SELECT "bands".* FROM "bands" WHERE "bands"."name" = ? [["name", "Chon"]]'
end
end

describe "warming the cache" do
it_dataloads "can receive passed-in objects with a class" do |d|
d.with(GraphQL::Dataloader::ActiveRecordSource, Band).merge({ 100 => Band.find(3) })
log = with_active_record_log(colorize: false) do
band3 = d.with(GraphQL::Dataloader::ActiveRecordSource, Band).load(100)
assert_equal "Chon", band3.name
assert_equal 3, band3.id
end

assert_equal "", log
end
it "can infer class of passed-in objects"
end

describe "in queries" do
it "loads records with dataload_record"

it "accepts custom find-by with dataload_record"
end
end
end
12 changes: 12 additions & 0 deletions spec/spec_helper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,15 @@ def assert_warns(warning, printing = "")
assert_equal stdout, printing, "It produced the expected stdout"
return_val
end

module Minitest
class Test
def self.it_dataloads(message, &block)
it(message) do
GraphQL::Dataloader.with_dataloading do |d|
self.instance_exec(d, &block)
end
end
end
end
end
Loading
Loading