about summary refs log tree commit diff
path: root/streaming/index.js
diff options
context:
space:
mode:
Diffstat (limited to 'streaming/index.js')
-rw-r--r--streaming/index.js737
1 files changed, 502 insertions, 235 deletions
diff --git a/streaming/index.js b/streaming/index.js
index f69064e33..f5c9b4224 100644
--- a/streaming/index.js
+++ b/streaming/index.js
@@ -1,3 +1,5 @@
+// @ts-check
+
 const os = require('os');
 const throng = require('throng');
 const dotenv = require('dotenv');
@@ -12,7 +14,7 @@ const uuid = require('uuid');
 const fs = require('fs');
 
 const env = process.env.NODE_ENV || 'development';
-const alwaysRequireAuth = process.env.WHITELIST_MODE === 'true' || process.env.AUTHORIZED_FETCH === 'true';
+const alwaysRequireAuth = process.env.LIMITED_FEDERATION_MODE === 'true' || process.env.WHITELIST_MODE === 'true' || process.env.AUTHORIZED_FETCH === 'true';
 
 dotenv.config({
   path: env === 'production' ? '.env.production' : '.env',
@@ -20,6 +22,10 @@ dotenv.config({
 
 log.level = process.env.LOG_LEVEL || 'verbose';
 
+/**
+ * @param {string} dbUrl
+ * @return {Object.<string, any>}
+ */
 const dbUrlToConfig = (dbUrl) => {
   if (!dbUrl) {
     return {};
@@ -53,6 +59,10 @@ const dbUrlToConfig = (dbUrl) => {
   return config;
 };
 
+/**
+ * @param {Object.<string, any>} defaultConfig
+ * @param {string} redisUrl
+ */
 const redisUrlToClient = (defaultConfig, redisUrl) => {
   const config = defaultConfig;
 
@@ -108,6 +118,7 @@ const startWorker = (workerId) => {
   }
 
   const app = express();
+
   app.set('trusted proxy', process.env.TRUSTED_PROXY_IP || 'loopback,uniquelocal');
 
   const pgPool = new pg.Pool(Object.assign(pgConfigs[env], dbUrlToConfig(process.env.DATABASE_URL)));
@@ -130,6 +141,9 @@ const startWorker = (workerId) => {
   const redisSubscribeClient = redisUrlToClient(redisParams, process.env.REDIS_URL);
   const redisClient = redisUrlToClient(redisParams, process.env.REDIS_URL);
 
+  /**
+   * @type {Object.<string, Array.<function(string): void>>}
+   */
   const subs = {};
 
   redisSubscribeClient.on('message', (channel, message) => {
@@ -144,11 +158,11 @@ const startWorker = (workerId) => {
     callbacks.forEach(callback => callback(message));
   });
 
+  /**
+   * @param {string[]} channels
+   * @return {function(): void}
+   */
   const subscriptionHeartbeat = channels => {
-    if (!Array.isArray(channels)) {
-      channels = [channels];
-    }
-
     const interval = 6 * 60;
 
     const tellSubscribed = () => {
@@ -164,25 +178,65 @@ const startWorker = (workerId) => {
     };
   };
 
+  /**
+   * @param {string} channel
+   * @param {function(string): void} callback
+   */
   const subscribe = (channel, callback) => {
     log.silly(`Adding listener for ${channel}`);
     subs[channel] = subs[channel] || [];
+
     if (subs[channel].length === 0) {
       log.verbose(`Subscribe ${channel}`);
       redisSubscribeClient.subscribe(channel);
     }
+
     subs[channel].push(callback);
   };
 
+  /**
+   * @param {string} channel
+   * @param {function(string): void} callback
+   */
   const unsubscribe = (channel, callback) => {
     log.silly(`Removing listener for ${channel}`);
+
+    if (!subs[channel]) {
+      return;
+    }
+
     subs[channel] = subs[channel].filter(item => item !== callback);
+
     if (subs[channel].length === 0) {
       log.verbose(`Unsubscribe ${channel}`);
       redisSubscribeClient.unsubscribe(channel);
     }
   };
 
+  const FALSE_VALUES = [
+    false,
+    0,
+    "0",
+    "f",
+    "F",
+    "false",
+    "FALSE",
+    "off",
+    "OFF"
+  ];
+
+  /**
+   * @param {any} value
+   * @return {boolean}
+   */
+  const isTruthy = value =>
+    value && !FALSE_VALUES.includes(value);
+
+  /**
+   * @param {any} req
+   * @param {any} res
+   * @param {function(Error=): void}
+   */
   const allowCrossDomain = (req, res, next) => {
     res.header('Access-Control-Allow-Origin', '*');
     res.header('Access-Control-Allow-Headers', 'Authorization, Accept, Cache-Control');
@@ -191,6 +245,11 @@ const startWorker = (workerId) => {
     next();
   };
 
+  /**
+   * @param {any} req
+   * @param {any} res
+   * @param {function(Error=): void}
+   */
   const setRequestId = (req, res, next) => {
     req.requestId = uuid.v4();
     res.header('X-Request-Id', req.requestId);
@@ -198,16 +257,26 @@ const startWorker = (workerId) => {
     next();
   };
 
+  /**
+   * @param {any} req
+   * @param {any} res
+   * @param {function(Error=): void}
+   */
   const setRemoteAddress = (req, res, next) => {
     req.remoteAddress = req.connection.remoteAddress;
 
     next();
   };
 
-  const accountFromToken = (token, allowedScopes, req, next) => {
+  /**
+   * @param {string} token
+   * @param {any} req
+   * @return {Promise.<void>}
+   */
+  const accountFromToken = (token, req) => new Promise((resolve, reject) => {
     pgPool.connect((err, client, done) => {
       if (err) {
-        next(err);
+        reject(err);
         return;
       }
 
@@ -215,62 +284,89 @@ const startWorker = (workerId) => {
         done();
 
         if (err) {
-          next(err);
+          reject(err);
           return;
         }
 
         if (result.rows.length === 0) {
           err = new Error('Invalid access token');
-          err.statusCode = 401;
+          err.status = 401;
 
-          next(err);
-          return;
-        }
-
-        const scopes = result.rows[0].scopes.split(' ');
-
-        if (allowedScopes.size > 0 && !scopes.some(scope => allowedScopes.includes(scope))) {
-          err = new Error('Access token does not cover required scopes');
-          err.statusCode = 401;
-
-          next(err);
+          reject(err);
           return;
         }
 
+        req.scopes = result.rows[0].scopes.split(' ');
         req.accountId = result.rows[0].account_id;
         req.chosenLanguages = result.rows[0].chosen_languages;
-        req.allowNotifications = scopes.some(scope => ['read', 'read:notifications'].includes(scope));
+        req.allowNotifications = req.scopes.some(scope => ['read', 'read:notifications'].includes(scope));
         req.deviceId = result.rows[0].device_id;
 
-        next();
+        resolve();
       });
     });
-  };
+  });
 
-  const accountFromRequest = (req, next, required = true, allowedScopes = ['read']) => {
+  /**
+   * @param {any} req
+   * @param {boolean=} required
+   * @return {Promise.<void>}
+   */
+  const accountFromRequest = (req, required = true) => new Promise((resolve, reject) => {
     const authorization = req.headers.authorization;
-    const location = url.parse(req.url, true);
-    const accessToken = location.query.access_token || req.headers['sec-websocket-protocol'];
+    const location      = url.parse(req.url, true);
+    const accessToken   = location.query.access_token || req.headers['sec-websocket-protocol'];
 
     if (!authorization && !accessToken) {
       if (required) {
         const err = new Error('Missing access token');
-        err.statusCode = 401;
+        err.status = 401;
 
-        next(err);
+        reject(err);
         return;
       } else {
-        next();
+        resolve();
         return;
       }
     }
 
     const token = authorization ? authorization.replace(/^Bearer /, '') : accessToken;
 
-    accountFromToken(token, allowedScopes, req, next);
+    resolve(accountFromToken(token, req));
+  });
+
+  /**
+   * @param {any} req
+   * @return {string}
+   */
+  const channelNameFromPath = req => {
+    const { path, query } = req;
+    const onlyMedia = isTruthy(query.only_media);
+    const allowLocalOnly = isTruthy(query.allow_local_only);
+
+    switch(path) {
+    case '/api/v1/streaming/user':
+      return 'user';
+    case '/api/v1/streaming/user/notification':
+      return 'user:notification';
+    case '/api/v1/streaming/public':
+      return onlyMedia ? 'public:media' : 'public';
+    case '/api/v1/streaming/public/local':
+      return onlyMedia ? 'public:local:media' : 'public:local';
+    case '/api/v1/streaming/public/remote':
+      return onlyMedia ? 'public:remote:media' : 'public:remote';
+    case '/api/v1/streaming/hashtag':
+      return 'hashtag';
+    case '/api/v1/streaming/hashtag/local':
+      return 'hashtag:local';
+    case '/api/v1/streaming/direct':
+      return 'direct';
+    case '/api/v1/streaming/list':
+      return 'list';
+    }
   };
 
-  const PUBLIC_STREAMS = [
+  const PUBLIC_CHANNELS = [
     'public',
     'public:media',
     'public:local',
@@ -281,96 +377,149 @@ const startWorker = (workerId) => {
     'hashtag:local',
   ];
 
-  const wsVerifyClient = (info, cb) => {
-    const location = url.parse(info.req.url, true);
-    const authRequired = alwaysRequireAuth || !PUBLIC_STREAMS.some(stream => stream === location.query.stream);
-    const allowedScopes = [];
+  /**
+   * @param {any} req
+   * @param {string} channelName
+   * @return {Promise.<void>}
+   */
+  const checkScopes = (req, channelName) => new Promise((resolve, reject) => {
+    log.silly(req.requestId, `Checking OAuth scopes for ${channelName}`);
+
+    // When accessing public channels, no scopes are needed
+    if (PUBLIC_CHANNELS.includes(channelName)) {
+      resolve();
+      return;
+    }
 
-    if (authRequired) {
-      allowedScopes.push('read');
-      if (location.query.stream === 'user:notification') {
-        allowedScopes.push('read:notifications');
-      } else {
-        allowedScopes.push('read:statuses');
-      }
+    // The `read` scope has the highest priority, if the token has it
+    // then it can access all streams
+    const requiredScopes = ['read'];
+
+    // When accessing specifically the notifications stream,
+    // we need a read:notifications, while in all other cases,
+    // we can allow access with read:statuses. Mind that the
+    // user stream will not contain notifications unless
+    // the token has either read or read:notifications scope
+    // as well, this is handled separately.
+    if (channelName === 'user:notification') {
+      requiredScopes.push('read:notifications');
+    } else {
+      requiredScopes.push('read:statuses');
     }
 
-    accountFromRequest(info.req, err => {
-      if (!err) {
-        cb(true, undefined, undefined);
-      } else {
-        log.error(info.req.requestId, err.toString());
-        cb(false, 401, 'Unauthorized');
-      }
-    }, authRequired, allowedScopes);
-  };
+    if (requiredScopes.some(requiredScope => req.scopes.includes(requiredScope))) {
+      resolve();
+      return;
+    }
 
-  const PUBLIC_ENDPOINTS = [
-    '/api/v1/streaming/public',
-    '/api/v1/streaming/public/allow_local_only',
-    '/api/v1/streaming/public/local',
-    '/api/v1/streaming/public/remote',
-    '/api/v1/streaming/hashtag',
-    '/api/v1/streaming/hashtag/local',
-  ];
+    const err = new Error('Access token does not cover required scopes');
+    err.status = 401;
+
+    reject(err);
+  });
+
+  /**
+   * @param {any} info
+   * @param {function(boolean, number, string): void} callback
+   */
+  const wsVerifyClient = (info, callback) => {
+    // When verifying the websockets connection, we no longer pre-emptively
+    // check OAuth scopes and drop the connection if they're missing. We only
+    // drop the connection if access without token is not allowed by environment
+    // variables. OAuth scope checks are moved to the point of subscription
+    // to a specific stream.
+
+    accountFromRequest(info.req, alwaysRequireAuth).then(() => {
+      callback(true, undefined, undefined);
+    }).catch(err => {
+      log.error(info.req.requestId, err.toString());
+      callback(false, 401, 'Unauthorized');
+    });
+  };
 
+  /**
+   * @param {any} req
+   * @param {any} res
+   * @param {function(Error=): void} next
+   */
   const authenticationMiddleware = (req, res, next) => {
     if (req.method === 'OPTIONS') {
       next();
       return;
     }
 
-    const authRequired = alwaysRequireAuth || !PUBLIC_ENDPOINTS.some(endpoint => endpoint === req.path);
-    const allowedScopes = [];
-
-    if (authRequired) {
-      allowedScopes.push('read');
-      if (req.path === '/api/v1/streaming/user/notification') {
-        allowedScopes.push('read:notifications');
-      } else {
-        allowedScopes.push('read:statuses');
-      }
-    }
-
-    accountFromRequest(req, next, authRequired, allowedScopes);
+    accountFromRequest(req, alwaysRequireAuth).then(() => checkScopes(req, channelNameFromPath(req))).then(() => {
+      next();
+    }).catch(err => {
+      next(err);
+    });
   };
 
-  const errorMiddleware = (err, req, res, {}) => {
+  /**
+   * @param {Error} err
+   * @param {any} req
+   * @param {any} res
+   * @param {function(Error=): void} next
+   */
+  const errorMiddleware = (err, req, res, next) => {
     log.error(req.requestId, err.toString());
-    res.writeHead(err.statusCode || 500, { 'Content-Type': 'application/json' });
-    res.end(JSON.stringify({ error: err.statusCode ? err.toString() : 'An unexpected error occurred' }));
+
+    if (res.headersSent) {
+      return next(err);
+    }
+
+    res.writeHead(err.status || 500, { 'Content-Type': 'application/json' });
+    res.end(JSON.stringify({ error: err.status ? err.toString() : 'An unexpected error occurred' }));
   };
 
+  /**
+   * @param {array}
+   * @param {number=} shift
+   * @return {string}
+   */
   const placeholders = (arr, shift = 0) => arr.map((_, i) => `$${i + 1 + shift}`).join(', ');
 
-  const authorizeListAccess = (id, req, next) => {
+  /**
+   * @param {string} listId
+   * @param {any} req
+   * @return {Promise.<void>}
+   */
+  const authorizeListAccess = (listId, req) => new Promise((resolve, reject) => {
+    const { accountId } = req;
+
     pgPool.connect((err, client, done) => {
       if (err) {
-        next(false);
+        reject();
         return;
       }
 
-      client.query('SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1', [id], (err, result) => {
+      client.query('SELECT id, account_id FROM lists WHERE id = $1 LIMIT 1', [listId], (err, result) => {
         done();
 
-        if (err || result.rows.length === 0 || result.rows[0].account_id !== req.accountId) {
-          next(false);
+        if (err || result.rows.length === 0 || result.rows[0].account_id !== accountId) {
+          reject();
           return;
         }
 
-        next(true);
+        resolve();
       });
     });
-  };
+  });
 
+  /**
+   * @param {string[]} ids
+   * @param {any} req
+   * @param {function(string, string): void} output
+   * @param {function(string[], function(string): void): void} attachCloseHandler
+   * @param {boolean=} needsFiltering
+   * @param {boolean=} notificationOnly
+   * @param {boolean=} allowLocalOnly
+   * @return {function(string): void}
+   */
   const streamFrom = (ids, req, output, attachCloseHandler, needsFiltering = false, notificationOnly = false, allowLocalOnly = false) => {
     const accountId  = req.accountId || req.remoteAddress;
     const streamType = notificationOnly ? ' (notification)' : '';
 
-    if (!Array.isArray(ids)) {
-      ids = [ids];
-    }
-
     log.verbose(req.requestId, `Starting stream from ${ids.join(', ')} for ${accountId}${streamType}`);
 
     const listener = message => {
@@ -454,10 +603,18 @@ const startWorker = (workerId) => {
       subscribe(`${redisPrefix}${id}`, listener);
     });
 
-    attachCloseHandler(ids.map(id => `${redisPrefix}${id}`), listener);
+    if (attachCloseHandler) {
+      attachCloseHandler(ids.map(id => `${redisPrefix}${id}`), listener);
+    }
+
+    return listener;
   };
 
-  // Setup stream output to HTTP
+  /**
+   * @param {any} req
+   * @param {any} res
+   * @return {function(string, string): void}
+   */
   const streamToHttp = (req, res) => {
     const accountId = req.accountId || req.remoteAddress;
 
@@ -480,12 +637,12 @@ const startWorker = (workerId) => {
     };
   };
 
-  // Setup stream end for HTTP
-  const streamHttpEnd = (req, closeHandler = false) => (ids, listener) => {
-    if (!Array.isArray(ids)) {
-      ids = [ids];
-    }
-
+  /**
+   * @param {any} req
+   * @param {function(): void} [closeHandler]
+   * @return {function(string[], function(string): void)}
+   */
+  const streamHttpEnd = (req, closeHandler = undefined) => (ids, listener) => {
     req.on('close', () => {
       ids.forEach(id => {
         unsubscribe(id, listener);
@@ -497,37 +654,24 @@ const startWorker = (workerId) => {
     });
   };
 
-  // Setup stream output to WebSockets
-  const streamToWs = (req, ws) => (event, payload) => {
+  /**
+   * @param {any} req
+   * @param {any} ws
+   * @param {string[]} streamName
+   * @return {function(string, string): void}
+   */
+  const streamToWs = (req, ws, streamName) => (event, payload) => {
     if (ws.readyState !== ws.OPEN) {
       log.error(req.requestId, 'Tried writing to closed socket');
       return;
     }
 
-    ws.send(JSON.stringify({ event, payload }));
-  };
-
-  // Setup stream end for WebSockets
-  const streamWsEnd = (req, ws, closeHandler = false) => (id, listener) => {
-    const accountId = req.accountId || req.remoteAddress;
-
-    ws.on('close', () => {
-      log.verbose(req.requestId, `Ending stream for ${accountId}`);
-      unsubscribe(id, listener);
-      if (closeHandler) {
-        closeHandler();
-      }
-    });
-
-    ws.on('error', () => {
-      log.verbose(req.requestId, `Ending stream for ${accountId}`);
-      unsubscribe(id, listener);
-      if (closeHandler) {
-        closeHandler();
-      }
-    });
+    ws.send(JSON.stringify({ stream: streamName, event, payload }));
   };
 
+  /**
+   * @param {any} res
+   */
   const httpNotFound = res => {
     res.writeHead(404, { 'Content-Type': 'application/json' });
     res.end(JSON.stringify({ error: 'Not found' }));
@@ -545,165 +689,281 @@ const startWorker = (workerId) => {
   app.use(authenticationMiddleware);
   app.use(errorMiddleware);
 
-  app.get('/api/v1/streaming/user', (req, res) => {
-    const channels = [`timeline:${req.accountId}`];
-
-    if (req.deviceId) {
-      channels.push(`timeline:${req.accountId}:${req.deviceId}`);
-    }
-
-    streamFrom(channels, req, streamToHttp(req, res), streamHttpEnd(req, subscriptionHeartbeat(channels)));
-  });
-
-  app.get('/api/v1/streaming/user/notification', (req, res) => {
-    streamFrom(`timeline:${req.accountId}`, req, streamToHttp(req, res), streamHttpEnd(req), false, true);
-  });
-
-  app.get('/api/v1/streaming/public', (req, res) => {
-    const onlyMedia = req.query.only_media === '1' || req.query.only_media === 'true';
-    const channel   = onlyMedia ? 'timeline:public:media' : 'timeline:public';
-
-    const allowLocalOnly = req.query.allow_local_only === '1' || req.query.allow_local_only === 'true';
-
-    streamFrom(channel, req, streamToHttp(req, res), streamHttpEnd(req), true, false, allowLocalOnly);
-  });
-
-  app.get('/api/v1/streaming/public/local', (req, res) => {
-    const onlyMedia = req.query.only_media === '1' || req.query.only_media === 'true';
-    const channel   = onlyMedia ? 'timeline:public:local:media' : 'timeline:public:local';
-
-    streamFrom(channel, req, streamToHttp(req, res), streamHttpEnd(req), true, false, true);
-  });
-
-  app.get('/api/v1/streaming/public/remote', (req, res) => {
-    const onlyMedia = req.query.only_media === '1' || req.query.only_media === 'true';
-    const channel   = onlyMedia ? 'timeline:public:remote:media' : 'timeline:public:remote';
-
-    streamFrom(channel, req, streamToHttp(req, res), streamHttpEnd(req), true);
-  });
-
-  app.get('/api/v1/streaming/direct', (req, res) => {
-    const channel = `timeline:direct:${req.accountId}`;
-    streamFrom(channel, req, streamToHttp(req, res), streamHttpEnd(req, subscriptionHeartbeat(channel)), true);
-  });
-
-  app.get('/api/v1/streaming/hashtag', (req, res) => {
-    const { tag } = req.query;
-
-    if (!tag || tag.length === 0) {
-      httpNotFound(res);
-      return;
-    }
-
-    streamFrom(`timeline:hashtag:${tag.toLowerCase()}`, req, streamToHttp(req, res), streamHttpEnd(req), true);
-  });
-
-  app.get('/api/v1/streaming/hashtag/local', (req, res) => {
-    const { tag } = req.query;
+  app.get('/api/v1/streaming/*', (req, res) => {
+    channelNameToIds(req, channelNameFromPath(req), req.query).then(({ channelIds, options }) => {
+      const onSend = streamToHttp(req, res);
+      const onEnd  = streamHttpEnd(req, subscriptionHeartbeat(channelIds));
 
-    if (!tag || tag.length === 0) {
+      streamFrom(channelIds, req, onSend, onEnd, options.needsFiltering, options.notificationOnly, options.allowLocalOnly);
+    }).catch(err => {
+      log.verbose(req.requestId, 'Subscription error:', err.toString());
       httpNotFound(res);
-      return;
-    }
-
-    streamFrom(`timeline:hashtag:${tag.toLowerCase()}:local`, req, streamToHttp(req, res), streamHttpEnd(req), true);
-  });
-
-  app.get('/api/v1/streaming/list', (req, res) => {
-    const listId = req.query.list;
-
-    authorizeListAccess(listId, req, authorized => {
-      if (!authorized) {
-        httpNotFound(res);
-        return;
-      }
-
-      const channel = `timeline:list:${listId}`;
-      streamFrom(channel, req, streamToHttp(req, res), streamHttpEnd(req, subscriptionHeartbeat(channel)));
     });
   });
 
   const wss = new WebSocketServer({ server, verifyClient: wsVerifyClient });
 
-  wss.on('connection', (ws, req) => {
-    const location = url.parse(req.url, true);
-    req.requestId  = uuid.v4();
-    req.remoteAddress = ws._socket.remoteAddress;
-
-    let channel;
-
-    switch(location.query.stream) {
+  /**
+   * @typedef StreamParams
+   * @property {string} [tag]
+   * @property {string} [list]
+   * @property {string} [only_media]
+   */
+
+  /**
+   * @param {any} req
+   * @param {string} name
+   * @param {StreamParams} params
+   * @return {Promise.<{ channelIds: string[], options: { needsFiltering: boolean, notificationOnly: boolean } }>}
+   */
+  const channelNameToIds = (req, name, params) => new Promise((resolve, reject) => {
+    switch(name) {
     case 'user':
-      channel = [`timeline:${req.accountId}`];
-
-      if (req.deviceId) {
-        channel.push(`timeline:${req.accountId}:${req.deviceId}`);
-      }
+      resolve({
+        channelIds: req.deviceId ? [`timeline:${req.accountId}`, `timeline:${req.accountId}:${req.deviceId}`] : [`timeline:${req.accountId}`],
+        options: { needsFiltering: false, notificationOnly: false, allowLocalOnly: true },
+      });
 
-      streamFrom(channel, req, streamToWs(req, ws), streamWsEnd(req, ws, subscriptionHeartbeat(channel)));
       break;
     case 'user:notification':
-      streamFrom(`timeline:${req.accountId}`, req, streamToWs(req, ws), streamWsEnd(req, ws), false, true);
+      resolve({
+        channelIds: [`timeline:${req.accountId}`],
+        options: { needsFiltering: false, notificationOnly: true, allowLocalOnly: true },
+      });
+
       break;
     case 'public':
-      streamFrom('timeline:public', req, streamToWs(req, ws), streamWsEnd(req, ws), true);
+      resolve({
+        channelIds: ['timeline:public'],
+        options: { needsFiltering: true, notificationOnly: false, allowLocalOnly: isTruthy(params.allow_local_only) },
+      });
+
       break;
     case 'public:allow_local_only':
-      streamFrom('timeline:public', req, streamToWs(req, ws), streamWsEnd(req, ws), true, false, true);
+      resolve({
+        channelIds: ['timeline:public'],
+        options: { needsFiltering: true, notificationOnly: false, allowLocalOnly: true },
+      });
+
       break;
     case 'public:local':
-      streamFrom('timeline:public:local', req, streamToWs(req, ws), streamWsEnd(req, ws), true, false, true);
+      resolve({
+        channelIds: ['timeline:public:local'],
+        options: { needsFiltering: true, notificationOnly: false, allowLocalOnly: true },
+      });
+
       break;
     case 'public:remote':
-      streamFrom('timeline:public:remote', req, streamToWs(req, ws), streamWsEnd(req, ws), true);
+      resolve({
+        channelIds: ['timeline:public:remote'],
+        options: { needsFiltering: true, notificationOnly: false, allowLocalOnly: false },
+      });
+
       break;
     case 'public:media':
-      streamFrom('timeline:public:media', req, streamToWs(req, ws), streamWsEnd(req, ws), true);
+      resolve({
+        channelIds: ['timeline:public:media'],
+        options: { needsFiltering: true, notificationOnly: false, allowLocalOnly: isTruthy(query.allow_local_only) },
+      });
+
       break;
     case 'public:allow_local_only:media':
-      streamFrom('timeline:public:media', req, streamToWs(req, ws), streamWsEnd(req, ws), true, false, true);
+      resolve({
+        channelIds: ['timeline:public:media'],
+        options: { needsFiltering: true, notificationsOnly: false, allowLocalOnly: true },
+      });
+
       break;
     case 'public:local:media':
-      streamFrom('timeline:public:local:media', req, streamToWs(req, ws), streamWsEnd(req, ws), true, false, true);
+      resolve({
+        channelIds: ['timeline:public:local:media'],
+        options: { needsFiltering: true, notificationOnly: false, allowLocalOnly: true },
+      });
+
       break;
     case 'public:remote:media':
-      streamFrom('timeline:public:remote:media', req, streamToWs(req, ws), streamWsEnd(req, ws), true);
+      resolve({
+        channelIds: ['timeline:public:remote:media'],
+        options: { needsFiltering: true, notificationOnly: false, allowLocalOnly: false },
+      });
+
       break;
     case 'direct':
-      channel = `timeline:direct:${req.accountId}`;
-      streamFrom(channel, req, streamToWs(req, ws), streamWsEnd(req, ws, subscriptionHeartbeat(channel)), true);
+      resolve({
+        channelIds: [`timeline:direct:${req.accountId}`],
+        options: { needsFiltering: false, notificationOnly: false, allowLocalOnly: true },
+      });
+
       break;
     case 'hashtag':
-      if (!location.query.tag || location.query.tag.length === 0) {
-        ws.close();
-        return;
+      if (!params.tag || params.tag.length === 0) {
+        reject('No tag for stream provided');
+      } else {
+        resolve({
+          channelIds: [`timeline:hashtag:${params.tag.toLowerCase()}`],
+          options: { needsFiltering: true, notificationOnly: false, allowLocalOnly: true },
+        });
       }
 
-      streamFrom(`timeline:hashtag:${location.query.tag.toLowerCase()}`, req, streamToWs(req, ws), streamWsEnd(req, ws), true);
       break;
     case 'hashtag:local':
-      if (!location.query.tag || location.query.tag.length === 0) {
-        ws.close();
-        return;
+      if (!params.tag || params.tag.length === 0) {
+        reject('No tag for stream provided');
+      } else {
+        resolve({
+          channelIds: [`timeline:hashtag:${params.tag.toLowerCase()}:local`],
+          options: { needsFiltering: true, notificationOnly: false, allowLocalOnly: true },
+        });
       }
 
-      streamFrom(`timeline:hashtag:${location.query.tag.toLowerCase()}:local`, req, streamToWs(req, ws), streamWsEnd(req, ws), true);
       break;
     case 'list':
-      const listId = location.query.list;
-
-      authorizeListAccess(listId, req, authorized => {
-        if (!authorized) {
-          ws.close();
-          return;
-        }
-
-        channel = `timeline:list:${listId}`;
-        streamFrom(channel, req, streamToWs(req, ws), streamWsEnd(req, ws, subscriptionHeartbeat(channel)));
+      authorizeListAccess(params.list, req).then(() => {
+        resolve({
+          channelIds: [`timeline:list:${params.list}`],
+          options: { needsFiltering: false, notificationOnly: false, allowLocalOnly: true },
+        });
+      }).catch(() => {
+        reject('Not authorized to stream this list');
       });
+
       break;
     default:
-      ws.close();
+      reject('Unknown stream type');
+    }
+  });
+
+  /**
+   * @param {string} channelName
+   * @param {StreamParams} params
+   * @return {string[]}
+   */
+  const streamNameFromChannelName = (channelName, params) => {
+    if (channelName === 'list') {
+      return [channelName, params.list];
+    } else if (['hashtag', 'hashtag:local'].includes(channelName)) {
+      return [channelName, params.tag];
+    } else {
+      return [channelName];
+    }
+  };
+
+  /**
+   * @typedef WebSocketSession
+   * @property {any} socket
+   * @property {any} request
+   * @property {Object.<string, { listener: function(string): void, stopHeartbeat: function(): void }>} subscriptions
+   */
+
+  /**
+   * @param {WebSocketSession} session
+   * @param {string} channelName
+   * @param {StreamParams} params
+   */
+  const subscribeWebsocketToChannel = ({ socket, request, subscriptions }, channelName, params) =>
+    checkScopes(request, channelName).then(() => channelNameToIds(request, channelName, params)).then(({ channelIds, options }) => {
+      if (subscriptions[channelIds.join(';')]) {
+        return;
+      }
+
+      const onSend        = streamToWs(request, socket, streamNameFromChannelName(channelName, params));
+      const stopHeartbeat = subscriptionHeartbeat(channelIds);
+      const listener      = streamFrom(channelIds, request, onSend, undefined, options.needsFiltering, options.notificationOnly, options.allowLocalOnly);
+
+      subscriptions[channelIds.join(';')] = {
+        listener,
+        stopHeartbeat,
+      };
+    }).catch(err => {
+      log.verbose(request.requestId, 'Subscription error:', err.toString());
+      socket.send(JSON.stringify({ error: err.toString() }));
+    });
+
+  /**
+   * @param {WebSocketSession} session
+   * @param {string} channelName
+   * @param {StreamParams} params
+   */
+  const unsubscribeWebsocketFromChannel = ({ socket, request, subscriptions }, channelName, params) =>
+    channelNameToIds(request, channelName, params).then(({ channelIds }) => {
+      log.verbose(request.requestId, `Ending stream from ${channelIds.join(', ')} for ${request.accountId}`);
+
+      const { listener, stopHeartbeat } = subscriptions[channelIds.join(';')];
+
+      if (!listener) {
+        return;
+      }
+
+      channelIds.forEach(channelId => {
+        unsubscribe(`${redisPrefix}${channelId}`, listener);
+      });
+
+      stopHeartbeat();
+
+      subscriptions[channelIds.join(';')] = undefined;
+    }).catch(err => {
+      log.verbose(request.requestId, 'Unsubscription error:', err);
+      socket.send(JSON.stringify({ error: err.toString() }));
+    });
+
+  /**
+   * @param {string|string[]} arrayOrString
+   * @return {string}
+   */
+  const firstParam = arrayOrString => {
+    if (Array.isArray(arrayOrString)) {
+      return arrayOrString[0];
+    } else {
+      return arrayOrString;
+    }
+  };
+
+  wss.on('connection', (ws, req) => {
+    const location = url.parse(req.url, true);
+
+    req.requestId     = uuid.v4();
+    req.remoteAddress = ws._socket.remoteAddress;
+
+    /**
+     * @type {WebSocketSession}
+     */
+    const session = {
+      socket: ws,
+      request: req,
+      subscriptions: {},
+    };
+
+    const onEnd = () => {
+      const keys = Object.keys(session.subscriptions);
+
+      keys.forEach(channelIds => {
+        const { listener, stopHeartbeat } = session.subscriptions[channelIds];
+
+        channelIds.split(';').forEach(channelId => {
+          unsubscribe(`${redisPrefix}${channelId}`, listener);
+        });
+
+        stopHeartbeat();
+      });
+    };
+
+    ws.on('close', onEnd);
+    ws.on('error', onEnd);
+
+    ws.on('message', data => {
+      const { type, stream, ...params } = JSON.parse(data);
+
+      if (type === 'subscribe') {
+        subscribeWebsocketToChannel(session, firstParam(stream), params);
+      } else if (type === 'unsubscribe') {
+        unsubscribeWebsocketFromChannel(session, firstParam(stream), params)
+      } else {
+        // Unknown action type
+      }
+    });
+
+    if (location.query.stream) {
+      subscribeWebsocketToChannel(session, firstParam(location.query.stream), location.query);
     }
   });
 
@@ -731,6 +991,10 @@ const startWorker = (workerId) => {
   process.on('uncaughtException', onError);
 };
 
+/**
+ * @param {any} server
+ * @param {function(string): void} [onSuccess]
+ */
 const attachServerWithConfig = (server, onSuccess) => {
   if (process.env.SOCKET || process.env.PORT && isNaN(+process.env.PORT)) {
     server.listen(process.env.SOCKET || process.env.PORT, () => {
@@ -748,6 +1012,9 @@ const attachServerWithConfig = (server, onSuccess) => {
   }
 };
 
+/**
+ * @param {function(Error=): void} onSuccess
+ */
 const onPortAvailable = onSuccess => {
   const testServer = http.createServer();