diff --git a/CHANGELOG.md b/CHANGELOG.md index 404d8dae..f6733d9c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,9 @@ when generating position for cursor based on `:id` column (Rails 7.1 and above, where composite primary models are now supported). This ensures we grab the value of the id column, rather than a potentially composite primary key value. +- [456](https://github.com/Shopify/job-iteration/pull/431) - Use Arel to generate SQL that's type compatible for the + cursor pagination conditionals in ActiveRecord cursor. Previously, the cursor would coerce numeric ids to a string value + (e.g.: `... AND id > '1'`) ## v1.4.1 (Sep 5, 2023) diff --git a/lib/job-iteration/active_record_cursor.rb b/lib/job-iteration/active_record_cursor.rb index 10a8f4ef..d60c8976 100644 --- a/lib/job-iteration/active_record_cursor.rb +++ b/lib/job-iteration/active_record_cursor.rb @@ -18,12 +18,8 @@ def initialize end end - def initialize(relation, columns = nil, position = nil) - @columns = if columns - Array(columns) - else - Array(relation.primary_key).map { |pk| "#{relation.table_name}.#{pk}" } - end + def initialize(relation, columns, position = nil) + @columns = columns self.position = Array.wrap(position) raise ArgumentError, "Must specify at least one column" if columns.empty? if relation.joins_values.present? && !@columns.all? { |column| column.to_s.include?(".") } @@ -34,7 +30,7 @@ def initialize(relation, columns = nil, position = nil) raise ConditionNotSupportedError end - @base_relation = relation.reorder(@columns.join(",")) + @base_relation = relation.reorder(*@columns) @reached_end = false end @@ -54,12 +50,10 @@ def position=(position) def update_from_record(record) self.position = @columns.map do |column| - method = column.to_s.split(".").last - - if ActiveRecord.version >= Gem::Version.new("7.1.0.alpha") && method == "id" + if ActiveRecord.version >= Gem::Version.new("7.1.0.alpha") && column.name == "id" record.id_value else - record.send(method.to_sym) + record.send(column.name) end end end @@ -89,14 +83,14 @@ def conditions i = @position.size - 1 column = @columns[i] conditions = if @columns.size == @position.size - "#{column} > ?" + column.gt(@position[i]) else - "#{column} >= ?" + column.gteq(@position[i]) end while i > 0 i -= 1 column = @columns[i] - conditions = "#{column} > ? OR (#{column} = ? AND (#{conditions}))" + conditions = column.gt(@position[i]).or(column.eq(@position[i]).and(conditions)) end ret = @position.reduce([conditions]) { |params, value| params << value << value } ret.pop diff --git a/lib/job-iteration/active_record_enumerator.rb b/lib/job-iteration/active_record_enumerator.rb index 363a4ecf..f21fdd91 100644 --- a/lib/job-iteration/active_record_enumerator.rb +++ b/lib/job-iteration/active_record_enumerator.rb @@ -11,9 +11,9 @@ def initialize(relation, columns: nil, batch_size: 100, cursor: nil) @relation = relation @batch_size = batch_size @columns = if columns - Array(columns) + Array(columns).map { |col| relation.arel_table[col.to_sym] } else - Array(relation.primary_key).map { |pk| "#{relation.table_name}.#{pk}" } + Array(relation.primary_key).map { |pk| relation.arel_table[pk.to_sym] } end @cursor = cursor end @@ -45,7 +45,7 @@ def size def cursor_value(record) positions = @columns.map do |column| - attribute_name = column.to_s.split(".").last + attribute_name = column.name.to_sym column_value(record, attribute_name) end return positions.first if positions.size == 1 @@ -58,8 +58,8 @@ def finder_cursor end def column_value(record, attribute) - value = record.read_attribute(attribute.to_sym) - case record.class.columns_hash.fetch(attribute).type + value = record.read_attribute(attribute) + case record.class.columns_hash.fetch(attribute.to_s).type when :datetime value.strftime(SQL_DATETIME_WITH_NSEC) else diff --git a/test/test_helper.rb b/test/test_helper.rb index f05bf268..4641babb 100644 --- a/test/test_helper.rb +++ b/test/test_helper.rb @@ -108,6 +108,40 @@ def assert_logged(message) end end +module ActiveRecordHelpers + def assert_sql(*patterns_to_match, &block) + captured_queries = [] + assert_nothing_raised do + ActiveSupport::Notifications.subscribed( + ->(_name, _start_time, _end_time, _subscriber_id, payload) { captured_queries << payload[:sql] }, + "sql.active_record", + &block + ) + end + + failed_patterns = [] + patterns_to_match.each do |pattern| + failed_check = captured_queries.none? do |sql| + case pattern + when Regexp + sql.match?(pattern) + when String + sql == pattern + else + raise ArgumentError, "#assert_sql encountered an unknown matcher #{pattern.inspect}" + end + end + failed_patterns << pattern if failed_check + end + queries = captured_queries.empty? ? "" : "\nQueries:\n #{captured_queries.join("\n ")}" + assert_predicate( + failed_patterns, + :empty?, + "Query pattern(s) #{failed_patterns.map(&:inspect).join(", ")} not found.#{queries}", + ) + end +end + JobIteration.logger = Logger.new(IO::NULL) ActiveJob::Base.logger = Logger.new(IO::NULL) diff --git a/test/unit/active_record_enumerator_test.rb b/test/unit/active_record_enumerator_test.rb index 724bbaae..f87b17b8 100644 --- a/test/unit/active_record_enumerator_test.rb +++ b/test/unit/active_record_enumerator_test.rb @@ -4,6 +4,8 @@ module JobIteration class ActiveRecordEnumeratorTest < IterationUnitTest + include ActiveRecordHelpers + SQL_TIME_FORMAT = "%Y-%m-%d %H:%M:%S.%N" test "#records yields every record with their cursor position" do enum = build_enumerator.records @@ -133,6 +135,13 @@ class ActiveRecordEnumeratorTest < IterationUnitTest end end + test "enumerator paginates using integer conditionals for primary key when no columns are defined" do + enum = build_enumerator(relation: Product.all, batch_size: 1).records + assert_sql(/`products`\.`id` > 1/) do + enum.take(2) + end + end + private def build_enumerator(relation: Product.all, batch_size: 2, columns: nil, cursor: nil)