万字带你手撕Websocket协议,从零实现一个基于node的Websocket服务器

815 阅读3分钟

我正在参加「掘金·启航计划」

新年的时候闲的无聊,简单过了一遍《WebSocket权威指南》然后来兴趣花了2天时间写好的,算是库存货,最近正好在做开发工具链的搭建,正好要用到websocket,就没用第三方库,就把自己写的demo改吧改吧就直接上了,那么第二篇文章,我们就来聊一聊Websocket协议的实现(借助node的net模块实现)

1 浏览器发起ws请求

浏览器发起一个ws请求的时候

new WebSocket('ws://localhost:3000')
const net = require('net')
// 校验websocket-key要用到
const crypto = require('crypto')

const server = net.createServer((socket) => {
  socket.once('data', buffer => {
    // 接收到HTTP请求头数据
    const str = buffer.toString()
    console.log(str)
  })
})

server.listen(3000)

2 ws请求头分析

服务端接收到的响应报文如下

GET / HTTP/1.1
Host: localhost:3000
Connection: Upgrade
Pragma: no-cache
Cache-Control: no-cache
User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/109.0.0.0 Safari/537.36   
Upgrade: websocket
Origin: http://127.0.0.1:5500
Sec-WebSocket-Version: 13
Accept-Encoding: gzip, deflate, br
Accept-Language: zh-CN,zh;q=0.9
Sec-WebSocket-Key: BNlBqioQ++EwOor3joITDg==
Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits


编写一个函数,用于把这个报文中的请求头转化为对象类型

function parseHeader(str) {
  // 将请求头数据按回车符切割为数组,得到每一行数据
  let arr = str.split('\r\n').filter(item => item)
  // 第一行数据为GET / HTTP/1.1,可以丢弃。
  arr.shift()
  let headers = {}  // 存储最终处理的数据
  arr.forEach((item) => {
    // 需要用":"将数组切割成key和value
    let [ name, value ] = item.split(':')
    // 去除无用的空格,将属性名转为小写
    name = name.replace(/^\s|\s+$/g, '').toLowerCase()
    value = value.replace(/^\s|\s+$/g, '')
    // 获取所有的请求头属性
    headers[name] = value
  })
  return headers
}

转化之后的值如下

image-20230125202815338.png

注意观察头中有两个头

// 告诉服务器,这个请求想要升级为websocket请求
Upgrade: websocket
// 告诉服务器,websocket的版本是13
Sec-WebSocket-Version: 13

那么服务端可以添加如下两个判断,如果请求头中Upgrade不是websocket,或者版本不是13那么就中断请求,代码如下

const server = net.createServer((socket) => {
  socket.once('data', buffer => {
    // 接收到HTTP请求头数据
    const str = buffer.toString()
    // 获取客户端的请求头
    const headers = parseHeader(str)
    // 校验请求头是否合法
    if (
      headers['upgrade'] !== 'websocket' ||
      headers['sec-websocket-version'] !== '13'
    ) {
      // 不合法直接断开
      socket.end()
    }
    // 合法
    else {

    }
    socket.end()
  })
})

3 校验Sec-WebSocket-Key

具体逻辑如下

// 设置一个GUID,这是个定值
const GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
// 创建一个签名算法为sha1的哈希对象
const hash = crypto.createHash('sha1')
// 将sec-websocket-key和GUID连接,并使用sha1加密
hash.update(`${headers['sec-websocket-key']}${GUID}`)

// 生成供前端校验用的请求头
const responseHeader = [
    // 响应头告诉客户端101,升级协议
    'HTTP/1.1 101 Switching Protocols',
    // 升级协议为 websocket
    'Upgrade: websocket',
    'Connection: Upgrade',
    // 把加密的值转换成base64返回给客户端
    `Sec-Websocket-Accept: ${hash.digest('base64')}`,
    // 最后结尾需要有两个空行,join拼接了一个,这里再接一个
    '\r\n'
].join('\r\n')

// 返回给客户端,如果客户端校验成功,那么就会触发 onopen 方法
socket.write(responseHeader)

4 服务端接收到的Websocket帧组成

websocket帧组成结构如下

websocket帧组成.jpg

4.1 FIN

image-20230126141513948.png

此占位用作多帧信息,有两个可选值

  • 0:表示这条消息还没结束,被拆分为多条
  • 1:标识此消息是最后一条或只有一条

4.2 RSV1、2、3

image-20230126141549483.png

除非一个扩展经过协商赋予了非零值以某种含义,否则必须为0 如果没有定义非零值,并且收到了非零的RSV,则websocket链接会失败

4.3 Opcode

image-20230126141636836.png

解释说明 Payload data 的用途/功能 如果收到了未知的Opcode,最后会断开链接,预设好的code如下

  • 1: 表示这是一个文本帧
  • 2: 表示这是一个二进制帧
  • 8: 表示连接断开
  • 9: 表示这是一个ping操作
  • 10: 表示这是一个pong操作

4.4 MASK

image-20230126141701611.png

此位占据了1bit,那么意味着有两种选项

  • 0: 服务端向客户端发送帧的MASK就是这个值
  • 1: 客户端向服务端发送的MASK就是这个值,如果服务端接收到的不是1,那么就需要断开连接

如果Mask是1,那么就会在Payload length之后加上4个字节存储4个Masking-keyPayload Data无法直接使用,Payload Data必须经过Masking-key还原成可用的二进制数据

4.5 Payload length

  • payload data 的长度如果在0~125 bytes范围内,它就是payload length

image-20230126141731113.png

  • 如果是126 bytes, 紧随其后的被表示为16 bits2 bytes无符号整型就是payload length

image-20230126141751047.png

  • 如果是127 bytes, 紧随其后的被表示为64 bits8 bytes无符号整型就是payload length

image-20230126163104347.png

一条websocket消息至多可以接受8 * 8bit,每位全是1长度的二进制数据,也就是

const data = [
  0b11111111,
  0b11111111,
  0b11111111,
  0b11111111,
  0b11111111,
  0b11111111,
  0b11111111,
  0b11111111
]

let payloadLen = 0
let start = 0

for (let i = 7; i >= 0; --i) {
  // 不能用下面的方式,因为js位运算只能处理32位 位运算
  // payloadLen += (data[start++] << (i * 8))
  payloadLen += data[start++] * Math.pow(2, i * 8)
}

console.log(payloadLen)    // 18446744073709552000字节

4.6 Masking-key

Masking-key是由4个8bit组成的,它存在的目的是为了解码payload data,解码流程的伪代码如下

// 首先先把4个Masking-key取出来
const maskingKey = [ maskingKey1, maskingKey2, maskingKey3, maskingKey4 ]

// 取出 payloadData(模拟)
const payloadData = Buffer.from([ 3, 4, 5, 6 ])

// 解码 payloadData
const usablePayloadData = payloadData.map((byte, index) => {
  // 字节 与 maskingKey[0 - 3] 进行异或运算
  return byte ^ maskingKey[index % 4]
  // 如果想提高一点点性能,可以使用位运算,如果超过32位可表示的无符号整数的话,位运算就不适用了,切记
  // return byte ^ maskingKey[index & 0b11]
})

5 Masking-key作用

为什么要引入掩码计算呢,除了增加计算机器的运算量外似乎并没有太多的收益?Masking-key的作用并不是为了加密数据,是为了防止早期版本的协议中存在的代理缓存污染攻击

在正式描述攻击步骤之前,我们假设有如下参与者:

  • 攻击者、攻击者自己控制的服务器(简称“邪恶服务器”)、攻击者伪造的资源(简称“邪恶资源”)
  • 受害者、受害者想要访问的资源(简称“正义资源”)
  • 受害者实际想要访问的服务器(简称“正义服务器”)
  • 中间代理服务器

攻击步骤一:

  1. 攻击者浏览器 向 邪恶服务器 发起WebSocket连接。根据前文,首先是一个协议升级请求。
  2. 协议升级请求 实际到达 代理服务器
  3. 代理服务器 将协议升级请求转发到 邪恶服务器
  4. 邪恶服务器 同意连接,代理服务器 将响应转发给 攻击者

由于 upgrade 的实现上有缺陷,代理服务器 以为之前转发的是普通的HTTP消息。因此,当协议服务器 同意连接,代理服务器 以为本次会话已经结束。

攻击步骤二:

  1. 攻击者 在之前建立的连接上,通过WebSocket的接口向 邪恶服务器 发送数据,且数据是精心构造的HTTP格式的文本。其中包含了 正义资源 的地址,以及一个伪造的host(指向正义服务器)。(见后面报文)
  2. 请求到达 代理服务器 。虽然复用了之前的TCP连接,但 代理服务器 以为是新的HTTP请求。
  3. 代理服务器邪恶服务器 请求 邪恶资源
  4. 邪恶服务器 返回 邪恶资源代理服务器 缓存住 邪恶资源(url是对的,但host是 正义服务器 的地址)。

到这里,受害者可以登场了:

  1. 受害者 通过 代理服务器 访问 正义服务器正义资源
  2. 代理服务器 检查该资源的url、host,发现本地有一份缓存(伪造的)。
  3. 代理服务器邪恶资源 返回给 受害者
  4. 受害者 卒。

附:前面提到的精心构造的“HTTP请求报文”。

Client  Server:
POST /path/of/attackers/choice HTTP/1.1 Host: host-of-attackers-choice.com Sec-WebSocket-Key: <connection-key>
Server  Client:
HTTP/1.1 200 OK
Sec-WebSocket-Accept: <connection-key>

最初的提案是对数据进行加密处理。基于安全、效率的考虑,最终采用了折中的方案:对数据载荷进行掩码处理。

需要注意的是,这里只是限制了浏览器对数据载荷进行掩码处理,但是坏人完全可以实现自己的WebSocket客户端、服务端,不按规则来,攻击可以照常进行。

但是对浏览器加上这个限制后,可以大大增加攻击的难度,以及攻击的影响范围。如果没有这个限制,只需要在网上放个钓鱼网站骗人去访问,一下子就可以在短时间内展开大范围的攻击

6 解析客户端->服务端的ws帧

客户端在接收到服务端的允许升级协议响应后,向服务端发送如下数据

const ws = new WebSocket('ws://localhost:3000')

ws.onopen = function () {
    console.log('连接成功')
    ws.send('你好')
}

那么服务端会收到下面一串Buffer数据

// 十六进制
<Buffer 81 86 0e dd f8 20 ea 60 58 c5 ab 60>

// 十进制
[ 129, 134, 14, 221, 248, 32, 234, 96, 88, 197, 171, 96 ]

// 二进制
[
  0b10000001,
  0b10000110,
  0b00001110,
  0b11011101,
  0b11111000,
  0b00100000,
  0b11101010,
  0b01100000,
  0b01011000,
  0b11000101,
  0b10101011,
  0b01100000
]

6.1 取FIN

image-20230126141513948.png

从图中可知,想要取FIN,那么就是把第一个二进制的0b10000001的第一位取出来,其实很简单

0b10000001 >> 7      // 1

6.2 取Opcode

image-20230126141636836.png

Opcode在后4位,因此想要取Opcode可以借助与运算的取值特性,值为1,说明当前是文本帧

0b10000001 & 0b1111   // 1

6.3 取MASK

image-20230126141701611.png

MASK的值在第二个字节的第一个比特位,取法和FIN一致,这里不说了

6.4 取Payload length

image-20230126141731113.png

取最后7位,“你好”使用utf8存储,占用6个字节正好,既然这个数在 [ 0, 126 ) 之间,那么后面的扩展位就不需要了

0b10000110 & 0b1111111     // 6

6.5 取Masking-Key

3-6字节就是Masking-Key

[
  0b10000001,
  0b10000110,
  // Masking-Key1
  0b00001110,
  // Masking-Key2
  0b11011101,
  // Masking-Key3
  0b11111000,
  // Masking-Key4
  0b00100000,
  0b11101010,
  0b01100000,
  0b01011000,
  0b11000101,
  0b10101011,
  0b01100000
]

6.6 取Payload Data

Masking-Key4之后就是Payload Data

[
  0b10000001,
  0b10000110,
  0b00001110,
  0b11011101,
  0b11111000,
  0b00100000,
  // -----------Payload Data-----------
  0b11101010,
  0b01100000,
  0b01011000,
  0b11000101,
  0b10101011,
  0b01100000
  // -----------Payload Data-----------
]

7 生成服务端->客户端的ws帧

服务端向客户端发送消息,封装的帧就不需要那么麻烦了,不需要加MASKMasking-key,别的和接收帧一致

返回帧.jpg

那么难点只有一个,就是计算payload length,这个在第九章代码实现中有很详细的注释

8 ping/pong

因为网络的不可靠性,有可能在 TCP 保持长连接的过程中, 由于某些突发情况, 例如网线被拔出, 突然掉电等, 会造成服务器和客户端的连接中断,在这些突发情况下, 如果恰好服务器和客户端之间没有交互的话, 那么它们是不能在短时间内发现对方已经掉线的。 websocket是,chrome是实现了ping/pong的,只要服务端发送了ping, 那么会立即收到一个pong

建议30s服务端发一次心跳监测

9 简易版本ws服务器代码

import net from 'net'
import crypto from 'crypto'

// 启动一个tcp服务器
net
  .createServer(socket => new WsSocket(socket))
  .listen(3000)

const enum OPCODES {
  /** 文本帧 */
  TEXT = 1,
  /** 二进制帧 */
  BINARY = 2,
  /** 关闭帧 */
  CLOSE = 8,
  /** PING帧 */
  PING = 9,
  /** PONG帧 */
  PONG = 10
}

type IClass = new (...args: any[]) => any
type IClassPrototype = IClass['prototype']

/** 解码ws frame结果 */
type DecodeFrame = {
  fin: 0 | 1,
  opcode: OPCODES,
  mask: 0 | 1,
  payloadLen: number,
  maskingKey: [number, number, number, number],
  payloadData: Buffer
}

class WsSocket {

  /**
   * 解析websocket的首次http请求头数据
   * @param str 
   * @returns 
   */
  static parseHeader(str: string) {
    // 得到每一行数据
    const arr = str.split('\r\n').filter(item => item)
    // 第一行数据为GET / HTTP/1.1,可以丢弃。
    arr.shift()
    return arr.reduce<Record<string, string>>((headers, item) => {
      // 需要用":"将数组切割成key和value
      let [ name, value ] = item.split(':')
      // 去除无用的空格,将属性名转为小写
      name = name.replace(/^\s|\s+$/g, '').toLowerCase()
      value = value.replace(/^\s|\s+$/g, '')
      // 获取所有的请求头属性
      headers[name] = value
      return headers
    }, {})
  }

  /**
   * 解码ws帧
   */
  static decodeWsFrame(prototype: IClassPrototype, propertyName: string, _: PropertyDescriptor): any {
    const oldFunc = prototype[propertyName]
    return {
      value(data: Buffer) {

        // 定义一个指针,从0开始
        let i = 0

        // 定义初始数据
        const wsFrame: DecodeFrame = {
          fin:         data[i] >> 7 as 0 | 1,
          opcode:      data[i] & 0b1111 as OPCODES,
          mask:        data[++i] >> 7 as 0 | 1,
          payloadLen:  data[i] & 0b1111111,
          maskingKey:  null!,
          payloadData: null!
        }

        // 处理payloadLen
        if (wsFrame.payloadLen === 126) {
          // 重写wsFrame.payloadLen为 后两位
          wsFrame.payloadLen = (data[++i] << 8) + data[++i]
        }
        else if (wsFrame.payloadLen === 127) {
          // 重写wsFrame.payloadLen为后8帧
          let payloadLen = 0
          for (let j = 7; j >= 0; --j) {
            // 不能用下面的方式,因为js位运算只能处理32位 位运算
            // payloadLen += (data[start++] << (i * 8))
            payloadLen += data[++i] * Math.pow(2, j * 8)
          }
          wsFrame.payloadLen = payloadLen
        }

        // 处理payloadData
        if (wsFrame.payloadLen) {
          wsFrame.payloadData = data.slice(i + 5, i + 5 + wsFrame.payloadLen)
          // 开启了屏蔽
          if (wsFrame.mask) {
            const maskingKey = wsFrame.maskingKey = [
              data[++i],
              data[++i],
              data[++i],
              data[++i]
            ]
            wsFrame.payloadData = (
              wsFrame
                .payloadData
                // .map((byte, index) => byte ^ maskingKey[index % 4]) as Buffer
                .map((byte, idx) => byte ^ maskingKey[idx & 0b11]) as Buffer
            )
          }
        }

        // 把处理好的数据给原函数
        oldFunc.call(this, wsFrame)
      }
    }
  }

  /**
   * 对发送数据进行封装
   * @param sendData 
   */
  static encodeWsFrame(sendData: string | Buffer) {
    // 校验数据
    if (typeof sendData === 'string') {
      sendData = Buffer.from(sendData)
    }
    if (!(sendData instanceof Buffer)) {
      throw new Error('发送的数据必须是string或Buffer类型')
    }

    // 生成要发送的数据
    const wsFrame = [
      // byte1  FIN:1 + 3个0 + 0001
      0b10000001
    ]

    // 获取Payload len
    let payloadLen = sendData.length
    // 要发送的内容小于126
    if (payloadLen < 126) {
      wsFrame.push(payloadLen)
    }
    // [126, 65535] 访问内,占用2个字节
    else if (payloadLen < 65536) {
      wsFrame.push(
        126,
        // 比如想把 0b11001 10000111 拆成 00011001 10000111 的话
        // 第一个字节,把二进制右移8位,得 00011001
        payloadLen >> 8,
        // 然后与运算取最后8位
        payloadLen & 0b11111111
      )
    }
    // 占用8个字节
    else {
      wsFrame.push(127)
      // 把长度长度转化成 64 位字符串
      // 比如 (999999).toString(2).padStart(64, '0')
      // 0000000000000000000000000000000000000000000011110100001000111111
      const binaryString = payloadLen.toString(2)
      for (let i = 0; i < 8; i++) {
        wsFrame.push(
          parseInt(binaryString.slice(i * 8, (i + 1) * 8), 2)
        )
      }
    }

    // 填充 帧
    return Buffer.concat([
      Buffer.from(wsFrame),
      sendData
    ])

  }

  // 存储心跳的定时器
  heartbeatTimer: NodeJS.Timer = null!
  // 监测心跳回应
  connectWaiting = false

  constructor(public socket: net.Socket) {
    this.connectionUpgrade()
    this.listen = this.listen.bind(this)
  }

  /** 协议升级 */
  connectionUpgrade() {
    this.socket.once('data', buffer => {
      // 获取接到到的http报文
      const str = buffer.toString()
      // 获取客户端的请求头
      const headers = WsSocket.parseHeader(str)
      // 校验请求头是否合法
      if (
        headers['upgrade'] !== 'websocket' ||
        headers['sec-websocket-version'] !== '13'
      ) {
        // 不合法直接断开
        this.socket.end()
      }
      // 合法
      else {
        // 设置一个GUID,这是个定值
        const GUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
        // 创建一个签名算法为sha1的哈希对象
        const hash = crypto.createHash('sha1')
        // 将sec-websocket-key和GUID连接,并使用sha1加密
        hash.update(`${headers['sec-websocket-key']}${GUID}`)

        // 生成供前端校验用的请求头
        const responseHeader = [
          // 响应头告诉客户端101,升级协议
          'HTTP/1.1 101 Switching Protocols',
          // 升级协议为 websocket
          'Upgrade: websocket',
          'Connection: Upgrade',
          // 把加密的值转换成base64返回给客户端
          `Sec-Websocket-Accept: ${hash.digest('base64')}`,
          // 最后结尾需要有两个空行,join拼接了一个,这里再接一个
          '\r\n'
        ].join('\r\n')

        // 返回给客户端,如果客户端校验成功,那么就会触发 onopen 方法
        this.socket.write(responseHeader)

        // 监听数据
        this.socket.on('data', this.listen)
        // this.socket.on('close', this.close)
        // 30秒发一次心跳
        this.heartbeatTimer = setInterval(this.heartbeat, 30 * 1000)

      }
    })
  }

  @WsSocket.decodeWsFrame
  listen(wsFrame: DecodeFrame) {
    switch (wsFrame.opcode) {
      case OPCODES.TEXT:
      // 暂时二进制和文本一样处理
      case OPCODES.BINARY:
        const decodeMsg = wsFrame.payloadData.toString('utf-8')
        console.log('接收到普通消息:', decodeMsg)
        this.send('服务端接收到消息了:' + decodeMsg)
        break
      case OPCODES.PING:
        console.log('接收到PING')
        this.pongHeartbeat()
        break
      case OPCODES.PONG:
        console.log('接收到PONG')
        // 把ping等待标识改为false
        this.connectWaiting = false
        break
      case OPCODES.CLOSE:
        this.close()
        break
      default:
        console.error('未处理的消息:')
        console.log(wsFrame)
    }
  }

  /** 心跳 */
  heartbeat = () => {
    // 如果再次调用心跳,上次心跳还未上一次心跳标识还是true,那么说明可能tcp连接断开
    if (this.connectWaiting){
      return this.destroySocket()
    }
    console.log('心跳监测')
    this.connectWaiting = true
    this.socket.write(
      Buffer.of(
        // byte1  FIN:1 + 3个0 + 9(心跳,二进制是1001)
        0b10001001,
        // 第二个帧虽然是0,但一定要写,不然chrome没办法返回PONG
        0b00000000
      )
    )
  }

  /** 回应心跳 */
  pongHeartbeat(){
    console.log('回应心跳')
    this.socket.write(
      Buffer.of(
        // byte1  FIN:1 + 3个0 + 10(心跳,二进制是1010)
        0b10001010,
        0b00000000
      )
    )
  }

  /** 发送数据 */
  send = (data: any) => {
    this.socket.write(
      WsSocket.encodeWsFrame(JSON.stringify(data))
    )
  }

  /** 连接关闭 */
  close = () => {
    this
      .socket
      .removeAllListeners('data')
    // .removeAllListeners('close')
    this.destroySocket()
    console.log('连接关闭')
  }

  /** 关闭连接 */
  destroySocket = () => {
    clearInterval(this.heartbeatTimer)
    this.socket.end()
    this.socket.destroy()
  }

}