about summary refs log tree commit diff
path: root/streaming
diff options
context:
space:
mode:
Diffstat (limited to 'streaming')
-rw-r--r--streaming/index.js151
1 files changed, 111 insertions, 40 deletions
diff --git a/streaming/index.js b/streaming/index.js
index e5a2778f8..16dda5c1e 100644
--- a/streaming/index.js
+++ b/streaming/index.js
@@ -1,8 +1,11 @@
 import dotenv from 'dotenv'
 import express from 'express'
+import http from 'http'
 import redis from 'redis'
 import pg from 'pg'
 import log from 'npmlog'
+import url from 'url'
+import WebSocket from 'ws'
 
 const env = process.env.NODE_ENV || 'development'
 
@@ -27,8 +30,10 @@ const pgConfigs = {
   }
 }
 
-const app = express()
+const app    = express()
 const pgPool = new pg.Pool(pgConfigs[env])
+const server = http.createServer(app)
+const wss    = new WebSocket.Server({ server })
 
 const allowCrossDomain = (req, res, next) => {
   res.header('Access-Control-Allow-Origin', '*')
@@ -38,22 +43,7 @@ const allowCrossDomain = (req, res, next) => {
   next()
 }
 
-const authenticationMiddleware = (req, res, next) => {
-  if (req.method === 'OPTIONS') {
-    return next()
-  }
-
-  const authorization = req.get('Authorization')
-
-  if (!authorization) {
-    const err = new Error('Missing access token')
-    err.statusCode = 401
-
-    return next(err)
-  }
-
-  const token = authorization.replace(/^Bearer /, '')
-
+const accountFromToken = (token, req, next) => {
   pgPool.connect((err, client, done) => {
     if (err) {
       return next(err)
@@ -80,26 +70,36 @@ const authenticationMiddleware = (req, res, next) => {
   })
 }
 
+const authenticationMiddleware = (req, res, next) => {
+  if (req.method === 'OPTIONS') {
+    return next()
+  }
+
+  const authorization = req.get('Authorization')
+
+  if (!authorization) {
+    const err = new Error('Missing access token')
+    err.statusCode = 401
+
+    return next(err)
+  }
+
+  const token = authorization.replace(/^Bearer /, '')
+
+  accountFromToken(token, req, next)
+}
+
 const errorMiddleware = (err, req, res, next) => {
   log.error(err)
   res.writeHead(err.statusCode || 500, { 'Content-Type': 'application/json' })
-  res.end(JSON.stringify({ error: err.statusCode ? `${err}` : 'An unexpected error occured' }))
+  res.end(JSON.stringify({ error: err.statusCode ? `${err}` : 'An unexpected error occurred' }))
 }
 
 const placeholders = (arr, shift = 0) => arr.map((_, i) => `$${i + 1 + shift}`).join(', ');
 
-const streamFrom = (id, req, res, needsFiltering = false) => {
+const streamFrom = (redisClient, id, req, output, needsFiltering = false) => {
   log.verbose(`Starting stream from ${id} for ${req.accountId}`)
 
-  res.setHeader('Content-Type', 'text/event-stream')
-  res.setHeader('Transfer-Encoding', 'chunked')
-
-  const redisClient = redis.createClient({
-    host:     process.env.REDIS_HOST     || '127.0.0.1',
-    port:     process.env.REDIS_PORT     || 6379,
-    password: process.env.REDIS_PASSWORD
-  })
-
   redisClient.on('message', (channel, message) => {
     const { event, payload } = JSON.parse(message)
 
@@ -127,36 +127,107 @@ const streamFrom = (id, req, res, needsFiltering = false) => {
             return
           }
 
-          res.write(`event: ${event}\n`)
-          res.write(`data: ${payload}\n\n`)
+          log.silly(`Transmitting for ${req.accountId}: ${event} ${payload}`)
+          output(event, payload)
         })
       })
     } else {
-      res.write(`event: ${event}\n`)
-      res.write(`data: ${payload}\n\n`)
+      log.silly(`Transmitting for ${req.accountId}: ${event} ${payload}`)
+      output(event, payload)
     }
   })
 
+  redisClient.subscribe(id)
+}
+
+// Setup stream output to HTTP
+const streamToHttp = (req, res, redisClient) => {
+  res.setHeader('Content-Type', 'text/event-stream')
+  res.setHeader('Transfer-Encoding', 'chunked')
+
   const heartbeat = setInterval(() => res.write(':thump\n'), 15000)
 
   req.on('close', () => {
-    log.verbose(`Ending stream from ${id} for ${req.accountId}`)
+    log.verbose(`Ending stream for ${req.accountId}`)
     clearInterval(heartbeat)
     redisClient.quit()
   })
 
-  redisClient.subscribe(id)
+  return (event, payload) => {
+    res.write(`event: ${event}\n`)
+    res.write(`data: ${payload}\n\n`)
+  }
+}
+
+// Setup stream output to WebSockets
+const streamToWs = (req, ws, redisClient) => {
+  ws.on('close', () => {
+    log.verbose(`Ending stream for ${req.accountId}`)
+    redisClient.quit()
+  })
+
+  return (event, payload) => {
+    ws.send(JSON.stringify({ event, payload }))
+  }
 }
 
+// Get new redis connection
+const getRedisClient = () => redis.createClient({
+  host:     process.env.REDIS_HOST     || '127.0.0.1',
+  port:     process.env.REDIS_PORT     || 6379,
+  password: process.env.REDIS_PASSWORD
+})
+
 app.use(allowCrossDomain)
 app.use(authenticationMiddleware)
 app.use(errorMiddleware)
 
-app.get('/api/v1/streaming/user',    (req, res) => streamFrom(`timeline:${req.accountId}`, req, res))
-app.get('/api/v1/streaming/public',  (req, res) => streamFrom('timeline:public', req, res, true))
-app.get('/api/v1/streaming/hashtag', (req, res) => streamFrom(`timeline:hashtag:${req.params.tag}`, req, res, true))
+app.get('/api/v1/streaming/user', (req, res) => {
+  const redisClient = getRedisClient()
+  streamFrom(redisClient, `timeline:${req.accountId}`, req, streamToHttp(req, res, redisClient))
+})
+
+app.get('/api/v1/streaming/public', (req, res) => {
+  const redisClient = getRedisClient()
+  streamFrom(redisClient, 'timeline:public', req, streamToHttp(req, res, redisClient), true)
+})
+
+app.get('/api/v1/streaming/hashtag', (req, res) => {
+  const redisClient = getRedisClient()
+  streamFrom(redisClient, `timeline:hashtag:${req.params.tag}`, req, streamToHttp(req, res, redisClient), true)
+})
 
-log.level = 'verbose'
-log.info(`Starting HTTP server on port ${process.env.PORT || 4000}`)
+wss.on('connection', ws => {
+  const location = url.parse(ws.upgradeReq.url, true)
+  const token    = location.query.access_token
+  const req      = {}
 
-app.listen(process.env.PORT || 4000)
+  accountFromToken(token, req, err => {
+    if (err) {
+      log.error(err)
+      ws.close()
+      return
+    }
+
+    const redisClient = getRedisClient()
+
+    switch(location.query.stream) {
+    case 'user':
+      streamFrom(redisClient, `timeline:${req.accountId}`, req, streamToWs(req, ws, redisClient))
+      break;
+    case 'public':
+      streamFrom(redisClient, 'timeline:public', req, streamToWs(req, ws, redisClient), true)
+      break;
+    case 'hashtag':
+      streamFrom(redisClient, `timeline:hashtag:${location.query.tag}`, req, streamToWs(req, ws, redisClient), true)
+      break;
+    default:
+      ws.close()
+    }
+  })
+})
+
+server.listen(process.env.PORT || 4000, () => {
+  log.level = process.env.LOG_LEVEL || 'verbose'
+  log.info(`Starting streaming API server on port ${server.address().port}`)
+})