socket_util.py 工具类
import json
import socket
import threading
import hashlib
import base64
import struct
import time
from urllib.parse import unquote
# 全局变量存储所有客户端连接
connected_clients = {}
clients_lock = threading.Lock()
heartbeat_interval = 30 # 心跳间隔(秒)
def create_websocket_server(host, port):
"""
创建WebSocket服务器
"""
server_socket = socket.socket()
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_socket.bind((host, int(port)))
server_socket.listen(5)
print(f"WebSocket服务端已启动,地址{host},端口{port}")
print(f"正在等待客户端连接...")
client_id = 0
while True:
client_id += 1
conn, address = server_socket.accept()
print(f"服务端已接受到客户端 {client_id}号 的连接请求,客户端信息:{address}")
# 存储客户端连接
with clients_lock:
connected_clients[client_id] = {
'conn': conn,
'address': address,
'last_heartbeat': time.time()
}
client_handler = threading.Thread(target=handle_websocket_client, args=(conn, address, client_id))
client_handler.daemon = True
client_handler.start()
def handle_websocket_handshake(conn):
"""
处理WebSocket握手
"""
request = conn.recv(1024).decode('utf-8')
headers = {}
lines = request.splitlines()
for line in lines[1:]:
if ': ' in line:
key, value = line.split(': ', 1)
headers[key] = value
if 'Sec-WebSocket-Key' in headers:
key = headers['Sec-WebSocket-Key']
magic_string = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
accept_key = base64.b64encode(
hashlib.sha1((key + magic_string).encode()).digest()
).decode('utf-8')
response = (
"HTTP/1.1 101 Switching Protocols\r\n"
"Upgrade: websocket\r\n"
"Connection: Upgrade\r\n"
f"Sec-WebSocket-Accept: {accept_key}\r\n\r\n"
)
conn.send(response.encode('utf-8'))
return True
return False
def handle_websocket_client(conn, address, client_id):
"""
处理WebSocket客户端连接
"""
# 先进行握手
if not handle_websocket_handshake(conn):
print(f"客户端 {client_id}号 握手失败")
conn.close()
return
# 启动心跳检测线程
heartbeat_thread = threading.Thread(target=heartbeat_checker, args=(client_id,))
heartbeat_thread.daemon = True
heartbeat_thread.start()
while True:
try:
# 接收WebSocket数据帧
data = receive_websocket_frame(conn)
if data is None: # 客户端断开连接
break
message = data.decode('utf-8')
print(f"客户端 {client_id}号:{address}发来的消息是:{message}")
# 更新心跳时间
with clients_lock:
if client_id in connected_clients:
connected_clients[client_id]['last_heartbeat'] = time.time()
# 处理心跳响应
if message == "ping":
send_websocket_frame(conn, "pong".encode('utf-8'), 0x1) # 文本帧
print(f"客户端 {client_id}号 心跳响应")
except Exception as e:
print(f"客户端 {client_id}号 连接异常: {e}")
break
# 客户端断开时清理资源
with clients_lock:
if client_id in connected_clients:
del connected_clients[client_id]
conn.close()
print(f"客户端 {client_id}号 已断开连接")
def receive_websocket_frame(conn):
"""
接收WebSocket数据帧
"""
try:
# 读取帧头
header = conn.recv(2)
if not header:
return None
# 解析帧头
first_byte, second_byte = struct.unpack('!BB', header)
# 检查是否为FIN帧
fin = (first_byte & 0b10000000) >> 7
opcode = first_byte & 0b00001111
# 检查是否为掩码帧
masked = (second_byte & 0b10000000) >> 7
payload_length = second_byte & 0b01111111
# 处理扩展长度
if payload_length == 126:
payload_length = struct.unpack('!H', conn.recv(2))[0]
elif payload_length == 127:
payload_length = struct.unpack('!Q', conn.recv(8))[0]
# 读取掩码键
if masked:
mask_key = conn.recv(4)
# 读取载荷数据
payload = conn.recv(payload_length)
# 如果有掩码,进行解码
if masked:
payload = bytearray(payload)
for i in range(len(payload)):
payload[i] ^= mask_key[i % 4]
# 处理控制帧
if opcode == 0x8: # 连接关闭帧
return None
elif opcode == 0x9: # ping帧
# 发送pong响应
send_websocket_frame(conn, payload, 0xA)
return b''
elif opcode == 0xA: # pong帧
return b''
return payload
except Exception as e:
print(f"接收WebSocket帧失败: {e}")
return None
def send_websocket_frame(conn, payload, opcode=0x1):
"""
发送WebSocket数据帧
"""
if isinstance(payload, str):
payload = payload.encode('utf-8')
# 计算载荷长度
payload_length = len(payload)
# 构造帧头
frame = bytearray()
# 第一个字节:FIN位 + RSV位 + 操作码
frame.append(0b10000000 | opcode)
# 第二个字节:掩码位 + 长度
if payload_length <= 125:
frame.append(payload_length)
elif payload_length <= 65535:
frame.append(126)
frame.extend(struct.pack('!H', payload_length))
else:
frame.append(127)
frame.extend(struct.pack('!Q', payload_length))
# 添加载荷数据
frame.extend(payload)
try:
conn.send(frame)
return True
except Exception as e:
print(f"发送WebSocket帧失败: {e}")
return False
def heartbeat_checker(client_id):
"""
心跳检测器
"""
while True:
time.sleep(heartbeat_interval)
with clients_lock:
if client_id not in connected_clients:
break
client_info = connected_clients[client_id]
last_heartbeat = client_info['last_heartbeat']
# 检查是否超时
if time.time() - last_heartbeat > heartbeat_interval * 3:
print(f"客户端 {client_id}号 心跳超时,断开连接")
try:
client_info['conn'].close()
except:
pass
if client_id in connected_clients:
del connected_clients[client_id]
break
# 发送心跳请求
try:
send_websocket_frame(client_info['conn'], "ping".encode('utf-8'), 0x1)
print(f"向客户端 {client_id}号 发送心跳请求")
except Exception as e:
print(f"向客户端 {client_id}号 发送心跳失败: {e}")
try:
client_info['conn'].close()
except:
pass
if client_id in connected_clients:
del connected_clients[client_id]
break
def broadcast_to_all_clients(message):
"""
向所有已连接的客户端发送消息
Args:
message (str): 要发送的消息内容
"""
disconnected_clients = []
with clients_lock:
# 遍历所有客户端并发送消息
for client_id, client_info in connected_clients.items():
try:
conn = client_info['conn']
if isinstance(message, str):
data = message.encode("UTF-8")
else:
# 将对象序列化为JSON字符串再编码
data = json.dumps(message, ensure_ascii=False).encode("UTF-8")
if send_websocket_frame(conn, data, 0x1): # 文本帧
print(f"消息已发送至客户端 {client_id}号")
else:
raise Exception("发送失败")
except Exception as e:
print(f"向客户端 {client_id}号 发送消息失败: {e}")
disconnected_clients.append(client_id)
# 清理已断开的客户端
for client_id in disconnected_clients:
if client_id in connected_clients:
del connected_clients[client_id]
def send_to_specific_client(client_id, message):
"""
向特定客户端发送消息
Args:
client_id (int): 客户端编号
message (str): 要发送的消息内容
"""
with clients_lock:
if client_id in connected_clients:
try:
conn = connected_clients[client_id]['conn']
if isinstance(message, str):
data = message.encode("UTF-8")
else:
data = json.dumps(message, ensure_ascii=False).encode("UTF-8")
if send_websocket_frame(conn, data, 0x1): # 文本帧
print(f"消息已发送至客户端 {client_id}号")
return True
else:
raise Exception("发送失败")
except Exception as e:
print(f"向客户端 {client_id}号 发送消息失败: {e}")
# 移除失效连接
del connected_clients[client_id]
return False
else:
print(f"客户端 {client_id}号 不存在或已断开连接")
return False
def get_local_ip():
"""获取本机IP地址"""
try:
# 创建一个UDP socket
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# 连接到一个远程地址(不会实际发送数据)
s.connect(("8.8.8.8", 80))
# 获取本地IP地址
local_ip = s.getsockname()[0]
s.close()
return local_ip
except Exception:
# 如果无法获取,则返回localhost
return "127.0.0.1"
# 其他函数调用示例
def some_function_that_sends_data():
"""示例函数:在其他函数被调用时发送数据"""
data = "这是从其他函数发送的数据"
broadcast_to_all_clients(data)
def another_function_sending_to_client_1():
"""示例函数:向特定客户端发送数据"""
data = "这是专门发送给客户端1的消息"
send_to_specific_client(1, data)
if __name__ == '__main__':
# 自动获取本机IP地址
server_host = get_local_ip()
server_port = int(input("请输入服务端port:"))
print(f"使用本机IP地址: {server_host}")
create_websocket_server(server_host, server_port)
其他调用:
if __name__ == "__main__":
logging.log(logging.INFO, "http://127.0.0.1:8009/docs")
server_host = get_local_ip()
# 在单独的线程中启动 socket 服务器
import threading
socket_thread = threading.Thread(
target=create_server_socket,
args=(server_host, settings.mySettings.socket_port),
daemon=True # 设置为守护线程,主程序退出时自动结束
)
socket_thread.start()
# 现在可以正常启动 uvicorn
uvicorn.run(app, host="0.0.0.0", port=int(settings.mySettings.server_port))