about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--app/lib/access_token_extension.rb17
-rw-r--r--app/models/session_activation.rb16
-rw-r--r--config/application.rb1
-rw-r--r--streaming/index.js82
4 files changed, 109 insertions, 7 deletions
diff --git a/app/lib/access_token_extension.rb b/app/lib/access_token_extension.rb
new file mode 100644
index 000000000..3e184e775
--- /dev/null
+++ b/app/lib/access_token_extension.rb
@@ -0,0 +1,17 @@
+# frozen_string_literal: true
+
+module AccessTokenExtension
+  extend ActiveSupport::Concern
+
+  included do
+    after_commit :push_to_streaming_api
+  end
+
+  def revoke(clock = Time)
+    update(revoked_at: clock.now.utc)
+  end
+
+  def push_to_streaming_api
+    Redis.current.publish("timeline:access_token:#{id}", Oj.dump(event: :kill)) if revoked? || destroyed?
+  end
+end
diff --git a/app/models/session_activation.rb b/app/models/session_activation.rb
index 34d25c83d..b0ce9d112 100644
--- a/app/models/session_activation.rb
+++ b/app/models/session_activation.rb
@@ -70,12 +70,16 @@ class SessionActivation < ApplicationRecord
   end
 
   def assign_access_token
-    superapp = Doorkeeper::Application.find_by(superapp: true)
+    self.access_token = Doorkeeper::AccessToken.create!(access_token_attributes)
+  end
 
-    self.access_token = Doorkeeper::AccessToken.create!(application_id: superapp&.id,
-                                                        resource_owner_id: user_id,
-                                                        scopes: 'read write follow',
-                                                        expires_in: Doorkeeper.configuration.access_token_expires_in,
-                                                        use_refresh_token: Doorkeeper.configuration.refresh_token_enabled?)
+  def access_token_attributes
+    {
+      application_id: Doorkeeper::Application.find_by(superapp: true)&.id,
+      resource_owner_id: user_id,
+      scopes: 'read write follow',
+      expires_in: Doorkeeper.configuration.access_token_expires_in,
+      use_refresh_token: Doorkeeper.configuration.refresh_token_enabled?,
+    }
   end
 end
diff --git a/config/application.rb b/config/application.rb
index de2951487..af7735221 100644
--- a/config/application.rb
+++ b/config/application.rb
@@ -140,6 +140,7 @@ module Mastodon
       Doorkeeper::AuthorizationsController.layout 'modal'
       Doorkeeper::AuthorizedApplicationsController.layout 'admin'
       Doorkeeper::Application.send :include, ApplicationExtension
+      Doorkeeper::AccessToken.send :include, AccessTokenExtension
       Devise::FailureApp.send :include, AbstractController::Callbacks
       Devise::FailureApp.send :include, HttpAcceptLanguage::EasyAccess
       Devise::FailureApp.send :include, Localized
diff --git a/streaming/index.js b/streaming/index.js
index 791a26941..877f45d19 100644
--- a/streaming/index.js
+++ b/streaming/index.js
@@ -294,7 +294,7 @@ const startWorker = (workerId) => {
         return;
       }
 
-      client.query('SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes, devices.device_id FROM oauth_access_tokens INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id LEFT OUTER JOIN devices ON oauth_access_tokens.id = devices.access_token_id WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL LIMIT 1', [token], (err, result) => {
+      client.query('SELECT oauth_access_tokens.id, oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes, devices.device_id FROM oauth_access_tokens INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id LEFT OUTER JOIN devices ON oauth_access_tokens.id = devices.access_token_id WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL LIMIT 1', [token], (err, result) => {
         done();
 
         if (err) {
@@ -310,6 +310,7 @@ const startWorker = (workerId) => {
           return;
         }
 
+        req.accessTokenId = result.rows[0].id;
         req.scopes = result.rows[0].scopes.split(' ');
         req.accountId = result.rows[0].account_id;
         req.chosenLanguages = result.rows[0].chosen_languages;
@@ -451,6 +452,55 @@ const startWorker = (workerId) => {
   };
 
   /**
+   * @typedef SystemMessageHandlers
+   * @property {function(): void} onKill
+   */
+
+  /**
+   * @param {any} req
+   * @param {SystemMessageHandlers} eventHandlers
+   * @return {function(string): void}
+   */
+  const createSystemMessageListener = (req, eventHandlers) => {
+    return message => {
+      const json = parseJSON(message);
+
+      if (!json) return;
+
+      const { event } = json;
+
+      log.silly(req.requestId, `System message for ${req.accountId}: ${event}`);
+
+      if (event === 'kill') {
+        log.verbose(req.requestId, `Closing connection for ${req.accountId} due to expired access token`);
+        eventHandlers.onKill();
+      }
+    }
+  };
+
+  /**
+   * @param {any} req
+   * @param {any} res
+   */
+  const subscribeHttpToSystemChannel = (req, res) => {
+    const systemChannelId = `timeline:access_token:${req.accessTokenId}`;
+
+    const listener = createSystemMessageListener(req, {
+
+      onKill () {
+        res.end();
+      },
+
+    });
+
+    res.on('close', () => {
+      unsubscribe(`${redisPrefix}${systemChannelId}`, listener);
+    });
+
+    subscribe(`${redisPrefix}${systemChannelId}`, listener);
+  };
+
+  /**
    * @param {any} req
    * @param {any} res
    * @param {function(Error=): void} next
@@ -462,6 +512,8 @@ const startWorker = (workerId) => {
     }
 
     accountFromRequest(req, alwaysRequireAuth).then(() => checkScopes(req, channelNameFromPath(req))).then(() => {
+      subscribeHttpToSystemChannel(req, res);
+    }).then(() => {
       next();
     }).catch(err => {
       next(err);
@@ -536,7 +588,9 @@ const startWorker = (workerId) => {
 
     const listener = message => {
       const json = parseJSON(message);
+
       if (!json) return;
+
       const { event, payload, queued_at } = json;
 
       const transmit = () => {
@@ -903,6 +957,28 @@ const startWorker = (workerId) => {
     });
 
   /**
+   * @param {WebSocketSession} session
+   */
+  const subscribeWebsocketToSystemChannel = ({ socket, request, subscriptions }) => {
+    const systemChannelId = `timeline:access_token:${request.accessTokenId}`;
+
+    const listener = createSystemMessageListener(request, {
+
+      onKill () {
+        socket.close();
+      },
+
+    });
+
+    subscribe(`${redisPrefix}${systemChannelId}`, listener);
+
+    subscriptions[systemChannelId] = {
+      listener,
+      stopHeartbeat: () => {},
+    };
+  };
+
+  /**
    * @param {string|string[]} arrayOrString
    * @return {string}
    */
@@ -948,7 +1024,9 @@ const startWorker = (workerId) => {
 
     ws.on('message', data => {
       const json = parseJSON(data);
+
       if (!json) return;
+
       const { type, stream, ...params } = json;
 
       if (type === 'subscribe') {
@@ -960,6 +1038,8 @@ const startWorker = (workerId) => {
       }
     });
 
+    subscribeWebsocketToSystemChannel(session);
+
     if (location.query.stream) {
       subscribeWebsocketToChannel(session, firstParam(location.query.stream), location.query);
     }