about summary refs log tree commit diff
path: root/lib/mastodon/snowflake.rb
diff options
context:
space:
mode:
authorEugen Rochko <eugen@zeonfederated.com>2017-10-08 17:34:34 +0200
committerGitHub <noreply@github.com>2017-10-08 17:34:34 +0200
commit0717d9b3e6904a4dcd5d2dc9e680cc5b21c50e51 (patch)
treefc95b8a715b8035231a6aa009bc82b3662ab236c /lib/mastodon/snowflake.rb
parent6e4046fc3f3973ba0b6994930a8b58726e507003 (diff)
Set snowflake IDs for backdated statuses (#5260)
- Rename Mastodon::TimestampIds into Mastodon::Snowflake for clarity
- Skip for statuses coming from inbox, aka delivered in real-time
- Skip for statuses that claim to be from the future
Diffstat (limited to 'lib/mastodon/snowflake.rb')
-rw-r--r--lib/mastodon/snowflake.rb162
1 files changed, 162 insertions, 0 deletions
diff --git a/lib/mastodon/snowflake.rb b/lib/mastodon/snowflake.rb
new file mode 100644
index 000000000..219e323d4
--- /dev/null
+++ b/lib/mastodon/snowflake.rb
@@ -0,0 +1,162 @@
+# frozen_string_literal: true
+
+module Mastodon::Snowflake
+  DEFAULT_REGEX = /timestamp_id\('(?<seq_prefix>\w+)'/
+
+  class Callbacks
+    def self.around_create(record)
+      now = Time.now.utc
+
+      if record.created_at.nil? || record.created_at >= now || record.created_at == record.updated_at
+        yield
+      else
+        record.id = Mastodon::Snowflake.id_at(record.created_at)
+        tries     = 0
+
+        begin
+          yield
+        rescue ActiveRecord::RecordNotUnique
+          raise if tries > 100
+
+          tries     += 1
+          record.id += rand(100)
+
+          retry
+        end
+      end
+    end
+  end
+
+  class << self
+    # Our ID will be composed of the following:
+    # 6 bytes (48 bits) of millisecond-level timestamp
+    # 2 bytes (16 bits) of sequence data
+    #
+    # The 'sequence data' is intended to be unique within a
+    # given millisecond, yet obscure the 'serial number' of
+    # this row.
+    #
+    # To do this, we hash the following data:
+    # * Table name (if provided, skipped if not)
+    # * Secret salt (should not be guessable)
+    # * Timestamp (again, millisecond-level granularity)
+    #
+    # We then take the first two bytes of that value, and add
+    # the lowest two bytes of the table ID sequence number
+    # (`table_name`_id_seq). This means that even if we insert
+    # two rows at the same millisecond, they will have
+    # distinct 'sequence data' portions.
+    #
+    # If this happens, and an attacker can see both such IDs,
+    # they can determine which of the two entries was inserted
+    # first, but not the total number of entries in the table
+    # (even mod 2**16).
+    #
+    # The table name is included in the hash to ensure that
+    # different tables derive separate sequence bases so rows
+    # inserted in the same millisecond in different tables do
+    # not reveal the table ID sequence number for one another.
+    #
+    # The secret salt is included in the hash to ensure that
+    # external users cannot derive the sequence base given the
+    # timestamp and table name, which would allow them to
+    # compute the table ID sequence number.
+    def define_timestamp_id
+      return if already_defined?
+
+      connection.execute(<<~SQL)
+        CREATE OR REPLACE FUNCTION timestamp_id(table_name text)
+        RETURNS bigint AS
+        $$
+          DECLARE
+            time_part bigint;
+            sequence_base bigint;
+            tail bigint;
+          BEGIN
+            time_part := (
+              -- Get the time in milliseconds
+              ((date_part('epoch', now()) * 1000))::bigint
+              -- And shift it over two bytes
+              << 16);
+
+            sequence_base := (
+              'x' ||
+              -- Take the first two bytes (four hex characters)
+              substr(
+                -- Of the MD5 hash of the data we documented
+                md5(table_name ||
+                  '#{SecureRandom.hex(16)}' ||
+                  time_part::text
+                ),
+                1, 4
+              )
+            -- And turn it into a bigint
+            )::bit(16)::bigint;
+
+            -- Finally, add our sequence number to our base, and chop
+            -- it to the last two bytes
+            tail := (
+              (sequence_base + nextval(table_name || '_id_seq'))
+              & 65535);
+
+            -- Return the time part and the sequence part. OR appears
+            -- faster here than addition, but they're equivalent:
+            -- time_part has no trailing two bytes, and tail is only
+            -- the last two bytes.
+            RETURN time_part | tail;
+          END
+        $$ LANGUAGE plpgsql VOLATILE;
+      SQL
+    end
+
+    def ensure_id_sequences_exist
+      # Find tables using timestamp IDs.
+      connection.tables.each do |table|
+        # We're only concerned with "id" columns.
+        next unless (id_col = connection.columns(table).find { |col| col.name == 'id' })
+
+        # And only those that are using timestamp_id.
+        next unless (data = DEFAULT_REGEX.match(id_col.default_function))
+
+        seq_name = data[:seq_prefix] + '_id_seq'
+
+        # If we were on Postgres 9.5+, we could do CREATE SEQUENCE IF
+        # NOT EXISTS, but we can't depend on that. Instead, catch the
+        # possible exception and ignore it.
+        # Note that seq_name isn't a column name, but it's a
+        # relation, like a column, and follows the same quoting rules
+        # in Postgres.
+        connection.execute(<<~SQL)
+          DO $$
+            BEGIN
+              CREATE SEQUENCE #{connection.quote_column_name(seq_name)};
+            EXCEPTION WHEN duplicate_table THEN
+              -- Do nothing, we have the sequence already.
+            END
+          $$ LANGUAGE plpgsql;
+        SQL
+      end
+    end
+
+    def id_at(timestamp)
+      id  = timestamp.to_i * 1000 + rand(1000)
+      id  = id << 16
+      id += rand(2**16)
+      id
+    end
+
+    private
+
+    def already_defined?
+      connection.execute(<<~SQL).values.first.first
+        SELECT EXISTS(
+          SELECT * FROM pg_proc WHERE proname = 'timestamp_id'
+        );
+      SQL
+    end
+
+    def connection
+      ActiveRecord::Base.connection
+    end
+  end
+end