about summary refs log tree commit diff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/mastodon/timestamp_ids.rb201
-rw-r--r--lib/tasks/db.rake2
2 files changed, 104 insertions, 99 deletions
diff --git a/lib/mastodon/timestamp_ids.rb b/lib/mastodon/timestamp_ids.rb
index d49b5c1b5..3b048a50c 100644
--- a/lib/mastodon/timestamp_ids.rb
+++ b/lib/mastodon/timestamp_ids.rb
@@ -1,120 +1,111 @@
 # frozen_string_literal: true
 
-module Mastodon
-  module TimestampIds
-    def self.define_timestamp_id
-      conn = ActiveRecord::Base.connection
-
-      # Make sure we don't already have a `timestamp_id` function.
-      unless conn.execute(<<~SQL).values.first.first
-        SELECT EXISTS(
-          SELECT * FROM pg_proc WHERE proname = 'timestamp_id'
-        );
+module Mastodon::TimestampIds
+  DEFAULT_REGEX = /timestamp_id\('(?<seq_prefix>\w+)'/
+
+  class << self
+    # Our ID will be composed of the following:
+    # 6 bytes (48 bits) of millisecond-level timestamp
+    # 2 bytes (16 bits) of sequence data
+    #
+    # The 'sequence data' is intended to be unique within a
+    # given millisecond, yet obscure the 'serial number' of
+    # this row.
+    #
+    # To do this, we hash the following data:
+    # * Table name (if provided, skipped if not)
+    # * Secret salt (should not be guessable)
+    # * Timestamp (again, millisecond-level granularity)
+    #
+    # We then take the first two bytes of that value, and add
+    # the lowest two bytes of the table ID sequence number
+    # (`table_name`_id_seq). This means that even if we insert
+    # two rows at the same millisecond, they will have
+    # distinct 'sequence data' portions.
+    #
+    # If this happens, and an attacker can see both such IDs,
+    # they can determine which of the two entries was inserted
+    # first, but not the total number of entries in the table
+    # (even mod 2**16).
+    #
+    # The table name is included in the hash to ensure that
+    # different tables derive separate sequence bases so rows
+    # inserted in the same millisecond in different tables do
+    # not reveal the table ID sequence number for one another.
+    #
+    # The secret salt is included in the hash to ensure that
+    # external users cannot derive the sequence base given the
+    # timestamp and table name, which would allow them to
+    # compute the table ID sequence number.
+    def define_timestamp_id
+      return if already_defined?
+
+      connection.execute(<<~SQL)
+        CREATE OR REPLACE FUNCTION timestamp_id(table_name text)
+        RETURNS bigint AS
+        $$
+          DECLARE
+            time_part bigint;
+            sequence_base bigint;
+            tail bigint;
+          BEGIN
+            time_part := (
+              -- Get the time in milliseconds
+              ((date_part('epoch', now()) * 1000))::bigint
+              -- And shift it over two bytes
+              << 16);
+
+            sequence_base := (
+              'x' ||
+              -- Take the first two bytes (four hex characters)
+              substr(
+                -- Of the MD5 hash of the data we documented
+                md5(table_name ||
+                  '#{SecureRandom.hex(16)}' ||
+                  time_part::text
+                ),
+                1, 4
+              )
+            -- And turn it into a bigint
+            )::bit(16)::bigint;
+
+            -- Finally, add our sequence number to our base, and chop
+            -- it to the last two bytes
+            tail := (
+              (sequence_base + nextval(table_name || '_id_seq'))
+              & 65535);
+
+            -- Return the time part and the sequence part. OR appears
+            -- faster here than addition, but they're equivalent:
+            -- time_part has no trailing two bytes, and tail is only
+            -- the last two bytes.
+            RETURN time_part | tail;
+          END
+        $$ LANGUAGE plpgsql VOLATILE;
       SQL
-        # The function doesn't exist, so we'll define it.
-        conn.execute(<<~SQL)
-          CREATE OR REPLACE FUNCTION timestamp_id(table_name text)
-          RETURNS bigint AS
-          $$
-            DECLARE
-              time_part bigint;
-              sequence_base bigint;
-              tail bigint;
-            BEGIN
-              -- Our ID will be composed of the following:
-              -- 6 bytes (48 bits) of millisecond-level timestamp
-              -- 2 bytes (16 bits) of sequence data
-
-              -- The 'sequence data' is intended to be unique within a
-              -- given millisecond, yet obscure the 'serial number' of
-              -- this row.
-
-              -- To do this, we hash the following data:
-              -- * Table name (if provided, skipped if not)
-              -- * Secret salt (should not be guessable)
-              -- * Timestamp (again, millisecond-level granularity)
-
-              -- We then take the first two bytes of that value, and add
-              -- the lowest two bytes of the table ID sequence number
-              -- (`table_name`_id_seq). This means that even if we insert
-              -- two rows at the same millisecond, they will have
-              -- distinct 'sequence data' portions.
-
-              -- If this happens, and an attacker can see both such IDs,
-              -- they can determine which of the two entries was inserted
-              -- first, but not the total number of entries in the table
-              -- (even mod 2**16).
-
-              -- The table name is included in the hash to ensure that
-              -- different tables derive separate sequence bases so rows
-              -- inserted in the same millisecond in different tables do
-              -- not reveal the table ID sequence number for one another.
-
-              -- The secret salt is included in the hash to ensure that
-              -- external users cannot derive the sequence base given the
-              -- timestamp and table name, which would allow them to
-              -- compute the table ID sequence number.
-
-              time_part := (
-                -- Get the time in milliseconds
-                ((date_part('epoch', now()) * 1000))::bigint
-                -- And shift it over two bytes
-                << 16);
-
-              sequence_base := (
-                'x' ||
-                -- Take the first two bytes (four hex characters)
-                substr(
-                  -- Of the MD5 hash of the data we documented
-                  md5(table_name ||
-                    '#{SecureRandom.hex(16)}' ||
-                    time_part::text
-                  ),
-                  1, 4
-                )
-              -- And turn it into a bigint
-              )::bit(16)::bigint;
-
-              -- Finally, add our sequence number to our base, and chop
-              -- it to the last two bytes
-              tail := (
-                (sequence_base + nextval(table_name || '_id_seq'))
-                & 65535);
-
-              -- Return the time part and the sequence part. OR appears
-              -- faster here than addition, but they're equivalent:
-              -- time_part has no trailing two bytes, and tail is only
-              -- the last two bytes.
-              RETURN time_part | tail;
-            END
-          $$ LANGUAGE plpgsql VOLATILE;
-        SQL
-      end
     end
 
-    def self.ensure_id_sequences_exist
-      conn = ActiveRecord::Base.connection
-
+    def ensure_id_sequences_exist
       # Find tables using timestamp IDs.
-      default_regex = /timestamp_id\('(?<seq_prefix>\w+)'/
-      conn.tables.each do |table|
+      connection.tables.each do |table|
         # We're only concerned with "id" columns.
-        next unless (id_col = conn.columns(table).find { |col| col.name == 'id' })
+        next unless (id_col = connection.columns(table).find { |col| col.name == 'id' })
 
         # And only those that are using timestamp_id.
-        next unless (data = default_regex.match(id_col.default_function))
+        next unless (data = DEFAULT_REGEX.match(id_col.default_function))
 
         seq_name = data[:seq_prefix] + '_id_seq'
+
         # If we were on Postgres 9.5+, we could do CREATE SEQUENCE IF
         # NOT EXISTS, but we can't depend on that. Instead, catch the
         # possible exception and ignore it.
         # Note that seq_name isn't a column name, but it's a
         # relation, like a column, and follows the same quoting rules
         # in Postgres.
-        conn.execute(<<~SQL)
+        connection.execute(<<~SQL)
           DO $$
             BEGIN
-              CREATE SEQUENCE #{conn.quote_column_name(seq_name)};
+              CREATE SEQUENCE #{connection.quote_column_name(seq_name)};
             EXCEPTION WHEN duplicate_table THEN
               -- Do nothing, we have the sequence already.
             END
@@ -122,5 +113,19 @@ module Mastodon
         SQL
       end
     end
+
+    private
+
+    def already_defined?
+      connection.execute(<<~SQL).values.first.first
+        SELECT EXISTS(
+          SELECT * FROM pg_proc WHERE proname = 'timestamp_id'
+        );
+      SQL
+    end
+
+    def connection
+      ActiveRecord::Base.connection
+    end
   end
 end
diff --git a/lib/tasks/db.rake b/lib/tasks/db.rake
index 66468d999..6af6bb6fb 100644
--- a/lib/tasks/db.rake
+++ b/lib/tasks/db.rake
@@ -20,10 +20,10 @@ def each_schema_load_environment
 
   if Rails.env == 'development'
     test_conf = ActiveRecord::Base.configurations['test']
+
     if test_conf['database']&.present?
       ActiveRecord::Base.establish_connection(:test)
       yield
-
       ActiveRecord::Base.establish_connection(Rails.env.to_sym)
     end
   end