about summary refs log tree commit diff
path: root/streaming
diff options
context:
space:
mode:
authorThibG <thib@sitedethib.com>2019-05-24 15:21:42 +0200
committermultiple creatures <dev@multiple-creature.party>2019-11-19 13:43:02 -0600
commit64a68bf2ea2d750894e55ab1b0f8024466e6f8ad (patch)
tree74a8e15bcf9a5717ed0bd189cd4749671039f2f5 /streaming
parent7952281bbb3b79884a73a6dd5f141469f5c42153 (diff)
Improve streaming server security (#10818)
* Check OAuth token scopes in the streaming API

* Use Sec-WebSocket-Protocol instead of query string to pass WebSocket token

Inspired by https://github.com/kubevirt/kubevirt/issues/1242
Diffstat (limited to 'streaming')
-rw-r--r--streaming/index.js70
1 files changed, 55 insertions, 15 deletions
diff --git a/streaming/index.js b/streaming/index.js
index bd99e4db8..6b9dd33cf 100644
--- a/streaming/index.js
+++ b/streaming/index.js
@@ -195,14 +195,14 @@ const startWorker = (workerId) => {
     next();
   };
 
-  const accountFromToken = (token, req, next) => {
+  const accountFromToken = (token, allowedScopes, req, next) => {
     pgPool.connect((err, client, done) => {
       if (err) {
         next(err);
         return;
       }
 
-      client.query('SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages FROM oauth_access_tokens INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL LIMIT 1', [token], (err, result) => {
+      client.query('SELECT oauth_access_tokens.resource_owner_id, users.account_id, users.chosen_languages, oauth_access_tokens.scopes FROM oauth_access_tokens INNER JOIN users ON oauth_access_tokens.resource_owner_id = users.id WHERE oauth_access_tokens.token = $1 AND oauth_access_tokens.revoked_at IS NULL LIMIT 1', [token], (err, result) => {
         done();
 
         if (err) {
@@ -218,18 +218,29 @@ const startWorker = (workerId) => {
           return;
         }
 
+        const scopes = result.rows[0].scopes.split(' ');
+
+        if (allowedScopes.size > 0 && !scopes.some(scope => allowedScopes.includes(scope))) {
+          err = new Error('Access token does not cover required scopes');
+          err.statusCode = 401;
+
+          next(err);
+          return;
+        }
+
         req.accountId = result.rows[0].account_id;
         req.chosenLanguages = result.rows[0].chosen_languages;
+        req.allowNotifications = scopes.some(scope => ['read', 'read:notifications'].includes(scope));
 
         next();
       });
     });
   };
 
-  const accountFromRequest = (req, next, required = true) => {
+  const accountFromRequest = (req, next, required = true, allowedScopes = ['read']) => {
     const authorization = req.headers.authorization;
     const location = url.parse(req.url, true);
-    const accessToken = location.query.access_token;
+    const accessToken = location.query.access_token || req.headers['sec-websocket-protocol'];
 
     if (!authorization && !accessToken) {
       if (required) {
@@ -246,10 +257,12 @@ const startWorker = (workerId) => {
 
     const token = authorization ? authorization.replace(/^Bearer /, '') : accessToken;
 
-    accountFromToken(token, req, next);
+    accountFromToken(token, allowedScopes, req, next);
   };
 
   const PUBLIC_STREAMS = [
+    'public',
+    'public:media',
     'hashtag',
     'hashtag:local',
   ];
@@ -257,6 +270,16 @@ const startWorker = (workerId) => {
   const wsVerifyClient = (info, cb) => {
     const location = url.parse(info.req.url, true);
     const authRequired = !PUBLIC_STREAMS.some(stream => stream === location.query.stream);
+    const allowedScopes = [];
+
+    if (authRequired) {
+      allowedScopes.push('read');
+      if (location.query.stream === 'user:notification') {
+        allowedScopes.push('read:notifications');
+      } else {
+        allowedScopes.push('read:statuses');
+      }
+    }
 
     accountFromRequest(info.req, err => {
       if (!err) {
@@ -265,10 +288,11 @@ const startWorker = (workerId) => {
         log.error(info.req.requestId, err.toString());
         cb(false, 401, 'Unauthorized');
       }
-    }, authRequired);
+    }, authRequired, allowedScopes);
   };
 
   const PUBLIC_ENDPOINTS = [
+    '/api/v1/streaming/public',
     '/api/v1/streaming/hashtag',
   ];
 
@@ -279,7 +303,18 @@ const startWorker = (workerId) => {
     }
 
     const authRequired = !PUBLIC_ENDPOINTS.some(endpoint => endpoint === req.path);
-    accountFromRequest(req, next, authRequired);
+    const allowedScopes = [];
+
+    if (authRequired) {
+      allowedScopes.push('read');
+      if (req.path === '/api/v1/streaming/user/notification') {
+        allowedScopes.push('read:notifications');
+      } else {
+        allowedScopes.push('read:statuses');
+      }
+    }
+
+    accountFromRequest(req, next, authRequired, allowedScopes);
   };
 
   const errorMiddleware = (err, req, res, {}) => {
@@ -328,12 +363,11 @@ const startWorker = (workerId) => {
         output(event, encodedPayload);
       };
 
-      if (!req.accountId) {
-        log.error(req.requestId, `Unauthorized: ${accountId} is not logged in.`)
+      if (notificationOnly && event !== 'notification') {
         return;
       }
 
-      if (notificationOnly && event !== 'notification') {
+      if (event === 'notification' && !req.allowNotifications) {
         return;
       }
 
@@ -359,17 +393,23 @@ const startWorker = (workerId) => {
       const targetAccountIds = [unpackedPayload.account.id].concat(unpackedPayload.mentions.map(item => item.id));
       const accountDomain    = unpackedPayload.account.acct.split('@')[1];
 
+      if (Array.isArray(req.chosenLanguages) && unpackedPayload.language !== null && req.chosenLanguages.indexOf(unpackedPayload.language) === -1) {
+        log.silly(req.requestId, `Message ${unpackedPayload.id} filtered by language (${unpackedPayload.language})`);
+        return;
+      }
+
+      // When the account is not logged in, it is not necessary to confirm the block or mute
+      if (!req.accountId) {
+        transmit();
+        return;
+      }
+
       // Don't filter user's own events.
       if (req.accountId === unpackedPayload.account.id) {
         transmit();
         return
       }
 
-      if (Array.isArray(req.chosenLanguages) && unpackedPayload.language !== null && req.chosenLanguages.indexOf(unpackedPayload.language) === -1) {
-        log.silly(req.requestId, `Message ${unpackedPayload.id} filtered by language (${unpackedPayload.language})`);
-        return;
-      }
-
       pgPool.connect((err, client, done) => {
         if (err) {
           log.error(err);