diff --git a/lib/closure_tree.rb b/lib/closure_tree.rb index c675ba2e..fb03939c 100644 --- a/lib/closure_tree.rb +++ b/lib/closure_tree.rb @@ -4,6 +4,7 @@ module ClosureTree extend ActiveSupport::Autoload autoload :HasClosureTree + autoload :HasClosureTreeRoot autoload :Support autoload :HierarchyMaintenance autoload :Model @@ -25,4 +26,5 @@ def self.configuration ActiveSupport.on_load :active_record do ActiveRecord::Base.send :extend, ClosureTree::HasClosureTree + ActiveRecord::Base.send :extend, ClosureTree::HasClosureTreeRoot end diff --git a/lib/closure_tree/has_closure_tree_root.rb b/lib/closure_tree/has_closure_tree_root.rb new file mode 100644 index 00000000..8e03bc16 --- /dev/null +++ b/lib/closure_tree/has_closure_tree_root.rb @@ -0,0 +1,88 @@ +module ClosureTree + class MultipleRootError < StandardError; end + + module HasClosureTreeRoot + + def has_closure_tree_root(assoc_name, options = {}) + options.assert_valid_keys( + :class_name, + :foreign_key + ) + + options[:class_name] ||= assoc_name.to_s.sub(/\Aroot_/, "").classify + options[:foreign_key] ||= self.name.underscore << "_id" + + has_one assoc_name, -> { where(parent: nil) }, options + + # Fetches the association, eager loading all children and given associations + define_method("#{assoc_name}_including_tree") do |assoc_map_or_reload = nil, assoc_map = nil| + reload = false + if assoc_map_or_reload.is_a?(::Hash) + assoc_map = assoc_map_or_reload + else + reload = assoc_map_or_reload + end + + unless reload + # Memoize + @closure_tree_roots ||= {} + @closure_tree_roots[assoc_name] ||= {} + if @closure_tree_roots[assoc_name].has_key?(assoc_map) + return @closure_tree_roots[assoc_name][assoc_map] + end + end + + roots = options[:class_name].constantize.where(parent: nil, options[:foreign_key] => id).to_a + + return nil if roots.empty? + + if roots.size > 1 + raise MultipleRootError.new("#{self.class.name}: has_closure_tree_root requires a single root") + end + + temp_root = roots.first + root = nil + id_hash = {} + parent_col_id = temp_root.class._ct.options[:parent_column_name] + + # Lookup inverse belongs_to association reflection on target class. + inverse = temp_root.class.reflections.values.detect do |r| + r.macro == :belongs_to && r.klass == self.class + end + + # Fetch all descendants in constant number of queries. + # This is the last query-triggering statement in the method. + temp_root.self_and_descendants.includes(assoc_map).each do |node| + id_hash[node.id] = node + parent_node = id_hash[node[parent_col_id]] + + # Pre-assign parent association + parent_assoc = node.association(:parent) + parent_assoc.loaded! + parent_assoc.target = parent_node + + # Pre-assign children association as empty for now, + # children will be added in subsequent loop iterations + children_assoc = node.association(:children) + children_assoc.loaded! + + if parent_node + parent_node.association(:children).target << node + else + # Capture the root we're going to use + root = node + end + + # Pre-assign inverse association back to this class, if it exists on target class. + if inverse + inverse_assoc = node.association(inverse.name) + inverse_assoc.loaded! + inverse_assoc.target = self + end + end + + @closure_tree_roots[assoc_name][assoc_map] = root + end + end + end +end diff --git a/spec/db/models.rb b/spec/db/models.rb index 9060ab85..74a0ef4f 100644 --- a/spec/db/models.rb +++ b/spec/db/models.rb @@ -35,13 +35,30 @@ def add_destroyed_tag class DestroyedTag < ActiveRecord::Base end +class Group < ActiveRecord::Base + has_closure_tree_root :root_user +end + +class Grouping < ActiveRecord::Base + has_closure_tree_root :root_person, class_name: "User", foreign_key: :group_id +end + +class UserSet < ActiveRecord::Base + has_closure_tree_root :root_user, class_name: "Useur" +end + +class Team < ActiveRecord::Base + has_closure_tree_root :root_user, class_name: "User", foreign_key: :grp_id +end + class User < ActiveRecord::Base acts_as_tree :parent_column_name => "referrer_id", :name_column => 'email', :hierarchy_class_name => 'ReferralHierarchy', :hierarchy_table_name => 'referral_hierarchies' - has_many :contracts + has_many :contracts, inverse_of: :user + belongs_to :group # Can't use and don't need inverse_of here when using has_closure_tree_root. def indirect_contracts Contract.where(:user_id => descendant_ids) @@ -53,7 +70,12 @@ def to_s end class Contract < ActiveRecord::Base - belongs_to :user + belongs_to :user, inverse_of: :contracts + belongs_to :contract_type, inverse_of: :contracts +end + +class ContractType < ActiveRecord::Base + has_many :contracts, inverse_of: :contract_type end class Label < ActiveRecord::Base diff --git a/spec/db/schema.rb b/spec/db/schema.rb index 96a30ff8..cb53906b 100644 --- a/spec/db/schema.rb +++ b/spec/db/schema.rb @@ -43,9 +43,26 @@ add_index "tag_hierarchies", [:ancestor_id, :descendant_id, :generations], :unique => true, :name => "tag_anc_desc_idx" add_index "tag_hierarchies", [:descendant_id], :name => "tag_desc_idx" + create_table "groups" do |t| + t.string "name", null: false + end + + create_table "groupings" do |t| + t.string "name", null: false + end + + create_table "user_sets" do |t| + t.string "name", null: false + end + + create_table "teams" do |t| + t.string "name", null: false + end + create_table "users" do |t| t.string "email" t.integer "referrer_id" + t.integer "group_id" t.timestamps null: false end @@ -53,6 +70,11 @@ create_table "contracts" do |t| t.integer "user_id", :null => false + t.integer "contract_type_id" + end + + create_table "contract_types" do |t| + t.string "name", :null => false end create_table "referral_hierarchies", :id => false do |t| diff --git a/spec/has_closure_tree_root_spec.rb b/spec/has_closure_tree_root_spec.rb new file mode 100644 index 00000000..858b54e6 --- /dev/null +++ b/spec/has_closure_tree_root_spec.rb @@ -0,0 +1,132 @@ +require "spec_helper" + +describe "has_closure_tree_root" do + let!(:ct1) { ContractType.create!(name: "Type1") } + let!(:ct2) { ContractType.create!(name: "Type2") } + let!(:user1) { User.create!(email: "1@example.com", group_id: group.id) } + let!(:user2) { User.create!(email: "2@example.com", group_id: group.id) } + let!(:user3) { User.create!(email: "3@example.com", group_id: group.id) } + let!(:user4) { User.create!(email: "4@example.com", group_id: group.id) } + let!(:user5) { User.create!(email: "5@example.com", group_id: group.id) } + let!(:user6) { User.create!(email: "6@example.com", group_id: group.id) } + let!(:group_reloaded) { group.class.find(group.id) } # Ensures were starting fresh. + + before do + # The tree (contract types in parens) + # + # U1(1) + # / \ + # U2(1) U3(1&2) + # / / \ + # U4(2) U5(1) U6(2) + + user1.children << user2 + user1.children << user3 + user2.children << user4 + user3.children << user5 + user3.children << user6 + + user1.contracts.create!(contract_type: ct1) + user2.contracts.create!(contract_type: ct1) + user3.contracts.create!(contract_type: ct1) + user3.contracts.create!(contract_type: ct2) + user4.contracts.create!(contract_type: ct2) + user5.contracts.create!(contract_type: ct1) + user6.contracts.create!(contract_type: ct2) + end + + context "with basic config" do + let!(:group) { Group.create!(name: "TheGroup") } + + it "loads all nodes and associations in a constant number of queries" do + expect do + root = group_reloaded.root_user_including_tree(contracts: :contract_type) + expect(root.children[0].email).to eq "2@example.com" + expect(root.children[0].parent.children[1].email).to eq "3@example.com" + expect(root.children[1].contracts.map(&:contract_type).map(&:name)).to eq %w(Type1 Type2) + expect(root.children[1].children[0].contracts[0].contract_type.name).to eq "Type1" + expect(root.children[0].children[0].contracts[0].user. + parent.parent.children[1].children[1].contracts[0].contract_type.name).to eq "Type2" + end.to_not exceed_query_limit(4) # Without this feature, this is 15, and scales with number of nodes. + end + + it "memoizes by assoc_map" do + group_reloaded.root_user_including_tree.email = "x" + expect(group_reloaded.root_user_including_tree.email).to eq "x" + group_reloaded.root_user_including_tree(contracts: :contract_type).email = "y" + expect(group_reloaded.root_user_including_tree(contracts: :contract_type).email).to eq "y" + expect(group_reloaded.root_user_including_tree.email).to eq "x" + end + + it "doesn't memoize if true argument passed" do + group_reloaded.root_user_including_tree.email = "x" + expect(group_reloaded.root_user_including_tree(true).email).to eq "1@example.com" + group_reloaded.root_user_including_tree(contracts: :contract_type).email = "y" + expect(group_reloaded.root_user_including_tree(true, contracts: :contract_type).email). + to eq "1@example.com" + end + + it "eager loads inverse association to group" do + expect do + root = group_reloaded.root_user_including_tree + expect(root.group).to eq group + expect(root.children[0].group).to eq group + end.to_not exceed_query_limit(2) + end + + it "works if eager load association map is not given" do + expect do + root = group_reloaded.root_user_including_tree + expect(root.children[0].email).to eq "2@example.com" + expect(root.children[0].parent.children[1].children[0].email).to eq "5@example.com" + end.to_not exceed_query_limit(2) + end + + context "with no tree root" do + let(:group2) { Group.create!(name: "OtherGroup") } + + it "should return nil" do + expect(group2.root_user_including_tree(contracts: :contract_type)).to be_nil + end + end + + context "with multiple tree roots" do + let!(:other_root) { User.create!(email: "10@example.com", group_id: group.id) } + + it "should error" do + expect do + root = group_reloaded.root_user_including_tree(contracts: :contract_type) + end.to raise_error(ClosureTree::MultipleRootError) + end + end + end + + context "with explicit class_name and foreign_key" do + let(:group) { Grouping.create!(name: "TheGrouping") } + + it "should still work" do + root = group_reloaded.root_person_including_tree(contracts: :contract_type) + expect(root.children[0].email).to eq "2@example.com" + end + end + + context "with bad class_name" do + let(:group) { UserSet.create!(name: "TheUserSet") } + + it "should error" do + expect do + root = group_reloaded.root_user_including_tree(contracts: :contract_type) + end.to raise_error(NameError) + end + end + + context "with bad foreign_key" do + let(:group) { Team.create!(name: "TheTeam") } + + it "should error" do + expect do + root = group_reloaded.root_user_including_tree(contracts: :contract_type) + end.to raise_error(ActiveRecord::StatementInvalid) + end + end +end diff --git a/spec/support/exceed_query_limit.rb b/spec/support/exceed_query_limit.rb new file mode 100644 index 00000000..b167a826 --- /dev/null +++ b/spec/support/exceed_query_limit.rb @@ -0,0 +1,18 @@ +# Derived from http://stackoverflow.com/a/13423584/153896. Updated for RSpec 3. +RSpec::Matchers.define :exceed_query_limit do |expected| + supports_block_expectations + + match do |block| + query_count(&block) > expected + end + + failure_message_when_negated do |actual| + "Expected to run maximum #{expected} queries, got #{@counter.query_count}" + end + + def query_count(&block) + @counter = ActiveRecord::QueryCounter.new + ActiveSupport::Notifications.subscribed(@counter.to_proc, 'sql.active_record', &block) + @counter.query_count + end +end diff --git a/spec/support/query_counter.rb b/spec/support/query_counter.rb new file mode 100644 index 00000000..f36cfc09 --- /dev/null +++ b/spec/support/query_counter.rb @@ -0,0 +1,18 @@ +# From http://stackoverflow.com/a/13423584/153896 +module ActiveRecord + class QueryCounter + attr_reader :query_count + + def initialize + @query_count = 0 + end + + def to_proc + lambda(&method(:callback)) + end + + def callback(name, start, finish, message_id, values) + @query_count += 1 unless %w(CACHE SCHEMA).include?(values[:name]) + end + end +end