about summary refs log tree commit diff
path: root/streaming
diff options
context:
space:
mode:
authorEugen Rochko <eugen@zeonfederated.com>2020-11-12 23:05:24 +0100
committerGitHub <noreply@github.com>2020-11-12 23:05:24 +0100
commitaa10200e58ccb340b6384532ccdba6b7fbac037a (patch)
tree079694664dd9fc7b98aef2704cb4b79c69fa7831 /streaming
parent8532429af749339a3ff6af4130de3743cd8d1c68 (diff)
Fix streaming API allowing connections to persist after access token invalidation (#15111)
Fix #14816
Diffstat (limited to 'streaming')
-rw-r--r--streaming/index.js82
1 files changed, 81 insertions, 1 deletions
diff --git a/streaming/index.js b/streaming/index.js
index 791a26941..877f45d19 100644
--- a/streaming/index.js
+++ b/streaming/index.js
@@ -294,7 +294,7 @@ const startWorker = (workerId) => {
         return;
       }
 
-      client.query('SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes, devices.device_id FROM oauth_access_tokens INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id LEFT OUTER JOIN devices ON oauth_access_tokens.id = devices.access_token_id WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL LIMIT 1', [token], (err, result) => {
+      client.query('SELECT oauth_access_tokens.id, oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes, devices.device_id FROM oauth_access_tokens INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id LEFT OUTER JOIN devices ON oauth_access_tokens.id = devices.access_token_id WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL LIMIT 1', [token], (err, result) => {
         done();
 
         if (err) {
@@ -310,6 +310,7 @@ const startWorker = (workerId) => {
           return;
         }
 
+        req.accessTokenId = result.rows[0].id;
         req.scopes = result.rows[0].scopes.split(' ');
         req.accountId = result.rows[0].account_id;
         req.chosenLanguages = result.rows[0].chosen_languages;
@@ -451,6 +452,55 @@ const startWorker = (workerId) => {
   };
 
   /**
+   * @typedef SystemMessageHandlers
+   * @property {function(): void} onKill
+   */
+
+  /**
+   * @param {any} req
+   * @param {SystemMessageHandlers} eventHandlers
+   * @return {function(string): void}
+   */
+  const createSystemMessageListener = (req, eventHandlers) => {
+    return message => {
+      const json = parseJSON(message);
+
+      if (!json) return;
+
+      const { event } = json;
+
+      log.silly(req.requestId, `System message for ${req.accountId}: ${event}`);
+
+      if (event === 'kill') {
+        log.verbose(req.requestId, `Closing connection for ${req.accountId} due to expired access token`);
+        eventHandlers.onKill();
+      }
+    }
+  };
+
+  /**
+   * @param {any} req
+   * @param {any} res
+   */
+  const subscribeHttpToSystemChannel = (req, res) => {
+    const systemChannelId = `timeline:access_token:${req.accessTokenId}`;
+
+    const listener = createSystemMessageListener(req, {
+
+      onKill () {
+        res.end();
+      },
+
+    });
+
+    res.on('close', () => {
+      unsubscribe(`${redisPrefix}${systemChannelId}`, listener);
+    });
+
+    subscribe(`${redisPrefix}${systemChannelId}`, listener);
+  };
+
+  /**
    * @param {any} req
    * @param {any} res
    * @param {function(Error=): void} next
@@ -462,6 +512,8 @@ const startWorker = (workerId) => {
     }
 
     accountFromRequest(req, alwaysRequireAuth).then(() => checkScopes(req, channelNameFromPath(req))).then(() => {
+      subscribeHttpToSystemChannel(req, res);
+    }).then(() => {
       next();
     }).catch(err => {
       next(err);
@@ -536,7 +588,9 @@ const startWorker = (workerId) => {
 
     const listener = message => {
       const json = parseJSON(message);
+
       if (!json) return;
+
       const { event, payload, queued_at } = json;
 
       const transmit = () => {
@@ -903,6 +957,28 @@ const startWorker = (workerId) => {
     });
 
   /**
+   * @param {WebSocketSession} session
+   */
+  const subscribeWebsocketToSystemChannel = ({ socket, request, subscriptions }) => {
+    const systemChannelId = `timeline:access_token:${request.accessTokenId}`;
+
+    const listener = createSystemMessageListener(request, {
+
+      onKill () {
+        socket.close();
+      },
+
+    });
+
+    subscribe(`${redisPrefix}${systemChannelId}`, listener);
+
+    subscriptions[systemChannelId] = {
+      listener,
+      stopHeartbeat: () => {},
+    };
+  };
+
+  /**
    * @param {string|string[]} arrayOrString
    * @return {string}
    */
@@ -948,7 +1024,9 @@ const startWorker = (workerId) => {
 
     ws.on('message', data => {
       const json = parseJSON(data);
+
       if (!json) return;
+
       const { type, stream, ...params } = json;
 
       if (type === 'subscribe') {
@@ -960,6 +1038,8 @@ const startWorker = (workerId) => {
       }
     });
 
+    subscribeWebsocketToSystemChannel(session);
+
     if (location.query.stream) {
       subscribeWebsocketToChannel(session, firstParam(location.query.stream), location.query);
     }