python使用socket

48 阅读5分钟

socket_util.py 工具类


import json
import socket
import threading
import hashlib
import base64
import struct
import time
from urllib.parse import unquote

# 全局变量存储所有客户端连接
connected_clients = {}
clients_lock = threading.Lock()
heartbeat_interval = 30  # 心跳间隔(秒)


def create_websocket_server(host, port):
    """
    创建WebSocket服务器
    """
    server_socket = socket.socket()
    server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    server_socket.bind((host, int(port)))
    server_socket.listen(5)
    print(f"WebSocket服务端已启动,地址{host},端口{port}")
    print(f"正在等待客户端连接...")

    client_id = 0
    while True:
        client_id += 1
        conn, address = server_socket.accept()
        print(f"服务端已接受到客户端 {client_id}号 的连接请求,客户端信息:{address}")

        # 存储客户端连接
        with clients_lock:
            connected_clients[client_id] = {
                'conn': conn,
                'address': address,
                'last_heartbeat': time.time()
            }

        client_handler = threading.Thread(target=handle_websocket_client, args=(conn, address, client_id))
        client_handler.daemon = True
        client_handler.start()


def handle_websocket_handshake(conn):
    """
    处理WebSocket握手
    """
    request = conn.recv(1024).decode('utf-8')
    headers = {}
    lines = request.splitlines()
    for line in lines[1:]:
        if ': ' in line:
            key, value = line.split(': ', 1)
            headers[key] = value

    if 'Sec-WebSocket-Key' in headers:
        key = headers['Sec-WebSocket-Key']
        magic_string = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
        accept_key = base64.b64encode(
            hashlib.sha1((key + magic_string).encode()).digest()
        ).decode('utf-8')

        response = (
            "HTTP/1.1 101 Switching Protocols\r\n"
            "Upgrade: websocket\r\n"
            "Connection: Upgrade\r\n"
            f"Sec-WebSocket-Accept: {accept_key}\r\n\r\n"
        )
        conn.send(response.encode('utf-8'))
        return True
    return False


def handle_websocket_client(conn, address, client_id):
    """
    处理WebSocket客户端连接
    """
    # 先进行握手
    if not handle_websocket_handshake(conn):
        print(f"客户端 {client_id}号 握手失败")
        conn.close()
        return

    # 启动心跳检测线程
    heartbeat_thread = threading.Thread(target=heartbeat_checker, args=(client_id,))
    heartbeat_thread.daemon = True
    heartbeat_thread.start()

    while True:
        try:
            # 接收WebSocket数据帧
            data = receive_websocket_frame(conn)
            if data is None:  # 客户端断开连接
                break

            message = data.decode('utf-8')
            print(f"客户端 {client_id}号:{address}发来的消息是:{message}")

            # 更新心跳时间
            with clients_lock:
                if client_id in connected_clients:
                    connected_clients[client_id]['last_heartbeat'] = time.time()

            # 处理心跳响应
            if message == "ping":
                send_websocket_frame(conn, "pong".encode('utf-8'), 0x1)  # 文本帧
                print(f"客户端 {client_id}号 心跳响应")

        except Exception as e:
            print(f"客户端 {client_id}号 连接异常: {e}")
            break

    # 客户端断开时清理资源
    with clients_lock:
        if client_id in connected_clients:
            del connected_clients[client_id]
    conn.close()
    print(f"客户端 {client_id}号 已断开连接")


def receive_websocket_frame(conn):
    """
    接收WebSocket数据帧
    """
    try:
        # 读取帧头
        header = conn.recv(2)
        if not header:
            return None

        # 解析帧头
        first_byte, second_byte = struct.unpack('!BB', header)

        # 检查是否为FIN帧
        fin = (first_byte & 0b10000000) >> 7
        opcode = first_byte & 0b00001111

        # 检查是否为掩码帧
        masked = (second_byte & 0b10000000) >> 7
        payload_length = second_byte & 0b01111111

        # 处理扩展长度
        if payload_length == 126:
            payload_length = struct.unpack('!H', conn.recv(2))[0]
        elif payload_length == 127:
            payload_length = struct.unpack('!Q', conn.recv(8))[0]

        # 读取掩码键
        if masked:
            mask_key = conn.recv(4)

        # 读取载荷数据
        payload = conn.recv(payload_length)

        # 如果有掩码,进行解码
        if masked:
            payload = bytearray(payload)
            for i in range(len(payload)):
                payload[i] ^= mask_key[i % 4]

        # 处理控制帧
        if opcode == 0x8:  # 连接关闭帧
            return None
        elif opcode == 0x9:  # ping帧
            # 发送pong响应
            send_websocket_frame(conn, payload, 0xA)
            return b''
        elif opcode == 0xA:  # pong帧
            return b''

        return payload
    except Exception as e:
        print(f"接收WebSocket帧失败: {e}")
        return None


def send_websocket_frame(conn, payload, opcode=0x1):
    """
    发送WebSocket数据帧
    """
    if isinstance(payload, str):
        payload = payload.encode('utf-8')

    # 计算载荷长度
    payload_length = len(payload)

    # 构造帧头
    frame = bytearray()

    # 第一个字节:FIN位 + RSV位 + 操作码
    frame.append(0b10000000 | opcode)

    # 第二个字节:掩码位 + 长度
    if payload_length <= 125:
        frame.append(payload_length)
    elif payload_length <= 65535:
        frame.append(126)
        frame.extend(struct.pack('!H', payload_length))
    else:
        frame.append(127)
        frame.extend(struct.pack('!Q', payload_length))

    # 添加载荷数据
    frame.extend(payload)

    try:
        conn.send(frame)
        return True
    except Exception as e:
        print(f"发送WebSocket帧失败: {e}")
        return False


def heartbeat_checker(client_id):
    """
    心跳检测器
    """
    while True:
        time.sleep(heartbeat_interval)
        with clients_lock:
            if client_id not in connected_clients:
                break
            client_info = connected_clients[client_id]
            last_heartbeat = client_info['last_heartbeat']

            # 检查是否超时
            if time.time() - last_heartbeat > heartbeat_interval * 3:
                print(f"客户端 {client_id}号 心跳超时,断开连接")
                try:
                    client_info['conn'].close()
                except:
                    pass
                if client_id in connected_clients:
                    del connected_clients[client_id]
                break

            # 发送心跳请求
            try:
                send_websocket_frame(client_info['conn'], "ping".encode('utf-8'), 0x1)
                print(f"向客户端 {client_id}号 发送心跳请求")
            except Exception as e:
                print(f"向客户端 {client_id}号 发送心跳失败: {e}")
                try:
                    client_info['conn'].close()
                except:
                    pass
                if client_id in connected_clients:
                    del connected_clients[client_id]
                break


def broadcast_to_all_clients(message):
    """
    向所有已连接的客户端发送消息

    Args:
        message (str): 要发送的消息内容
    """
    disconnected_clients = []

    with clients_lock:
        # 遍历所有客户端并发送消息
        for client_id, client_info in connected_clients.items():
            try:
                conn = client_info['conn']

                if isinstance(message, str):
                    data = message.encode("UTF-8")
                else:
                    # 将对象序列化为JSON字符串再编码
                    data = json.dumps(message, ensure_ascii=False).encode("UTF-8")

                if send_websocket_frame(conn, data, 0x1):  # 文本帧
                    print(f"消息已发送至客户端 {client_id}号")
                else:
                    raise Exception("发送失败")

            except Exception as e:
                print(f"向客户端 {client_id}号 发送消息失败: {e}")
                disconnected_clients.append(client_id)

        # 清理已断开的客户端
        for client_id in disconnected_clients:
            if client_id in connected_clients:
                del connected_clients[client_id]


def send_to_specific_client(client_id, message):
    """
    向特定客户端发送消息

    Args:
        client_id (int): 客户端编号
        message (str): 要发送的消息内容
    """
    with clients_lock:
        if client_id in connected_clients:
            try:
                conn = connected_clients[client_id]['conn']
                if isinstance(message, str):
                    data = message.encode("UTF-8")
                else:
                    data = json.dumps(message, ensure_ascii=False).encode("UTF-8")

                if send_websocket_frame(conn, data, 0x1):  # 文本帧
                    print(f"消息已发送至客户端 {client_id}号")
                    return True
                else:
                    raise Exception("发送失败")

            except Exception as e:
                print(f"向客户端 {client_id}号 发送消息失败: {e}")
                # 移除失效连接
                del connected_clients[client_id]
                return False
        else:
            print(f"客户端 {client_id}号 不存在或已断开连接")
            return False


def get_local_ip():
    """获取本机IP地址"""
    try:
        # 创建一个UDP socket
        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        # 连接到一个远程地址(不会实际发送数据)
        s.connect(("8.8.8.8", 80))
        # 获取本地IP地址
        local_ip = s.getsockname()[0]
        s.close()
        return local_ip
    except Exception:
        # 如果无法获取,则返回localhost
        return "127.0.0.1"


# 其他函数调用示例
def some_function_that_sends_data():
    """示例函数:在其他函数被调用时发送数据"""
    data = "这是从其他函数发送的数据"
    broadcast_to_all_clients(data)


def another_function_sending_to_client_1():
    """示例函数:向特定客户端发送数据"""
    data = "这是专门发送给客户端1的消息"
    send_to_specific_client(1, data)


if __name__ == '__main__':
    # 自动获取本机IP地址
    server_host = get_local_ip()
    server_port = int(input("请输入服务端port:"))
    print(f"使用本机IP地址: {server_host}")
    create_websocket_server(server_host, server_port)

其他调用:


if __name__ == "__main__":
    logging.log(logging.INFO, "http://127.0.0.1:8009/docs")
    server_host = get_local_ip()
    
    # 在单独的线程中启动 socket 服务器
    import threading
    socket_thread = threading.Thread(
        target=create_server_socket, 
        args=(server_host, settings.mySettings.socket_port),
        daemon=True  # 设置为守护线程,主程序退出时自动结束
    )
    socket_thread.start()
    
    # 现在可以正常启动 uvicorn
    uvicorn.run(app, host="0.0.0.0", port=int(settings.mySettings.server_port))