about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--streaming/index.js87
1 files changed, 48 insertions, 39 deletions
diff --git a/streaming/index.js b/streaming/index.js
index fe39cf21d..0411ae8ef 100644
--- a/streaming/index.js
+++ b/streaming/index.js
@@ -95,7 +95,6 @@ const startWorker = (workerId) => {
   const app    = express();
   const pgPool = new pg.Pool(Object.assign(pgConfigs[env], dbUrlToConfig(process.env.DATABASE_URL)));
   const server = http.createServer(app);
-  const wss    = new WebSocket.Server({ server });
   const redisNamespace = process.env.REDIS_NAMESPACE || null;
 
   const redisParams = {
@@ -186,14 +185,10 @@ const startWorker = (workerId) => {
     });
   };
 
-  const authenticationMiddleware = (req, res, next) => {
-    if (req.method === 'OPTIONS') {
-      next();
-      return;
-    }
-
-    const authorization = req.get('Authorization');
-    const accessToken = req.query.access_token;
+  const accountFromRequest = (req, next) => {
+    const authorization = req.headers.authorization;
+    const location = url.parse(req.url, true);
+    const accessToken = location.query.access_token;
 
     if (!authorization && !accessToken) {
       const err = new Error('Missing access token');
@@ -208,6 +203,26 @@ const startWorker = (workerId) => {
     accountFromToken(token, req, next);
   };
 
+  const wsVerifyClient = (info, cb) => {
+    accountFromRequest(info.req, err => {
+      if (!err) {
+        cb(true, undefined, undefined);
+      } else {
+        log.error(info.req.requestId, err.toString());
+        cb(false, 401, 'Unauthorized');
+      }
+    });
+  };
+
+  const authenticationMiddleware = (req, res, next) => {
+    if (req.method === 'OPTIONS') {
+      next();
+      return;
+    }
+
+    accountFromRequest(req, next);
+  };
+
   const errorMiddleware = (err, req, res, next) => {
     log.error(req.requestId, err.toString());
     res.writeHead(err.statusCode || 500, { 'Content-Type': 'application/json' });
@@ -352,10 +367,12 @@ const startWorker = (workerId) => {
     streamFrom(`timeline:hashtag:${req.query.tag}:local`, req, streamToHttp(req, res), streamHttpEnd(req), true);
   });
 
+  const wss    = new WebSocket.Server({ server, verifyClient: wsVerifyClient });
+
   wss.on('connection', ws => {
-    const location = url.parse(ws.upgradeReq.url, true);
-    const token    = location.query.access_token;
-    const req      = { requestId: uuid.v4() };
+    const req      = ws.upgradeReq;
+    const location = url.parse(req.url, true);
+    req.requestId  = uuid.v4();
 
     ws.isAlive = true;
 
@@ -363,33 +380,25 @@ const startWorker = (workerId) => {
       ws.isAlive = true;
     });
 
-    accountFromToken(token, req, err => {
-      if (err) {
-        log.error(req.requestId, err);
-        ws.close();
-        return;
-      }
-
-      switch(location.query.stream) {
-      case 'user':
-        streamFrom(`timeline:${req.accountId}`, req, streamToWs(req, ws), streamWsEnd(req, ws));
-        break;
-      case 'public':
-        streamFrom('timeline:public', req, streamToWs(req, ws), streamWsEnd(req, ws), true);
-        break;
-      case 'public:local':
-        streamFrom('timeline:public:local', req, streamToWs(req, ws), streamWsEnd(req, ws), true);
-        break;
-      case 'hashtag':
-        streamFrom(`timeline:hashtag:${location.query.tag}`, req, streamToWs(req, ws), streamWsEnd(req, ws), true);
-        break;
-      case 'hashtag:local':
-        streamFrom(`timeline:hashtag:${location.query.tag}:local`, req, streamToWs(req, ws), streamWsEnd(req, ws), true);
-        break;
-      default:
-        ws.close();
-      }
-    });
+    switch(location.query.stream) {
+    case 'user':
+      streamFrom(`timeline:${req.accountId}`, req, streamToWs(req, ws), streamWsEnd(req, ws));
+      break;
+    case 'public':
+      streamFrom('timeline:public', req, streamToWs(req, ws), streamWsEnd(req, ws), true);
+      break;
+    case 'public:local':
+      streamFrom('timeline:public:local', req, streamToWs(req, ws), streamWsEnd(req, ws), true);
+      break;
+    case 'hashtag':
+      streamFrom(`timeline:hashtag:${location.query.tag}`, req, streamToWs(req, ws), streamWsEnd(req, ws), true);
+      break;
+    case 'hashtag:local':
+      streamFrom(`timeline:hashtag:${location.query.tag}:local`, req, streamToWs(req, ws), streamWsEnd(req, ws), true);
+      break;
+    default:
+      ws.close();
+    }
   });
 
   const wsInterval = setInterval(() => {