在线聊天系统的面向对象设计

115 阅读6分钟

在线聊天系统面向对象设计完整版

以下是在线聊天系统面向对象设计,包含所有核心组件和服务:

1. 用户与关系管理模块

class User:
    """用户类,管理基本信息和社交关系"""
    def __init__(self, user_id, username, password):
        self.user_id = user_id
        self.username = username
        self.password = password  # 实际应用中应加密
        self.friends = {}  # key: friend_id, value: Friendship
        self.pending_friend_requests = {}  # key: request_id, value: FriendRequest
        self.chat_sessions = {}  # key: session_id, value: ChatSession

    def send_friend_request(self, recipient_id, friend_request_service):
        return friend_request_service.create_request(self.user_id, recipient_id)

    def accept_friend_request(self, request_id, friend_request_service):
        return friend_request_service.accept_request(request_id, self.user_id)

    def reject_friend_request(self, request_id, friend_request_service):
        return friend_request_service.reject_request(request_id, self.user_id)

    def remove_friend(self, friend_id, friendship_service):
        return friendship_service.remove_friendship(self.user_id, friend_id)

    def create_group_chat(self, name, friend_ids, chat_service):
        return chat_service.create_group_chat(self.user_id, name, friend_ids)

    def send_message(self, session_id, content, message_service):
        return message_service.send_message(self.user_id, session_id, content)


class Friendship:
    """好友关系类,维护双向好友关系"""
    def __init__(self, user_id1, user_id2, timestamp):
        self.user_id1 = user_id1
        self.user_id2 = user_id2
        self.timestamp = timestamp
        self.status = "active"  # active, blocked

    def get_friend_id(self, user_id):
        return self.user_id2 if user_id == self.user_id1 else self.user_id1


class FriendRequest:
    """好友请求类,管理请求生命周期"""
    PENDING = "pending"
    ACCEPTED = "accepted"
    REJECTED = "rejected"
    EXPIRED = "expired"

    def __init__(self, request_id, sender_id, receiver_id, timestamp):
        self.request_id = request_id
        self.sender_id = sender_id
        self.receiver_id = receiver_id
        self.status = self.PENDING
        self.timestamp = timestamp

    def accept(self):
        self.status = self.ACCEPTED

    def reject(self):
        self.status = self.REJECTED

    def expire(self):
        self.status = self.EXPIRED

2. 聊天会话模块

from abc import ABC, abstractmethod

class ChatSession(ABC):
    """聊天会话抽象基类"""
    def __init__(self, session_id, creator_id):
        self.session_id = session_id
        self.creator_id = creator_id
        self.participants = [creator_id]
        self.messages = []
        self.created_at = datetime.now()

    @abstractmethod
    def add_participant(self, user_id):
        pass

    @abstractmethod
    def remove_participant(self, user_id):
        pass

    def send_message(self, sender_id, content):
        if sender_id not in self.participants:
            raise PermissionError("User not in chat session")
        message = Message(
            message_id=uuid.uuid4().hex,
            sender_id=sender_id,
            content=content,
            timestamp=datetime.now()
        )
        self.messages.append(message)
        return message

    def get_messages(self, user_id, limit=50, offset=0):
        if user_id not in self.participants:
            raise PermissionError("User not in chat session")
        return self.messages[offset:offset+limit]


class PrivateChat(ChatSession):
    """一对一私聊会话"""
    def __init__(self, session_id, user_id1, user_id2):
        super().__init__(session_id, user_id1)
        self.add_participant(user_id2)

    def add_participant(self, user_id):
        if len(self.participants) >= 2:
            raise ValueError("Private chat can only have two participants")
        if user_id not in self.participants:
            self.participants.append(user_id)

    def remove_participant(self, user_id):
        if user_id != self.creator_id:
            self.participants.remove(user_id)


class GroupChat(ChatSession):
    """群组聊天会话"""
    def __init__(self, session_id, creator_id, name):
        super().__init__(session_id, creator_id)
        self.name = name
        self.administrators = [creator_id]

    def add_participant(self, user_id):
        if user_id not in self.participants:
            self.participants.append(user_id)

    def remove_participant(self, user_id, remover_id=None):
        if remover_id and remover_id != user_id:
            if remover_id not in self.administrators:
                raise PermissionError("Only admins can remove others")
        if user_id in self.participants:
            self.participants.remove(user_id)
            if user_id in self.administrators:
                self.administrators.remove(user_id)

    def add_administrator(self, user_id, adder_id):
        if adder_id not in self.administrators:
            raise PermissionError("Only admins can add new admins")
        if user_id in self.participants and user_id not in self.administrators:
            self.administrators.append(user_id)

3. 消息模块

class Message:
    """消息类,封装消息内容和元数据"""
    def __init__(self, message_id, sender_id, content, timestamp):
        self.message_id = message_id
        self.sender_id = sender_id
        self.content = content
        self.timestamp = timestamp
        self.status = "sent"  # sent, delivered, read

    def mark_as_delivered(self):
        self.status = "delivered"

    def mark_as_read(self):
        self.status = "read"


class MessageFactory:
    """消息工厂,用于创建不同类型的消息"""
    @staticmethod
    def create_text_message(sender_id, content):
        return Message(
            message_id=uuid.uuid4().hex,
            sender_id=sender_id,
            content=content,
            timestamp=datetime.now()
        )

4. 服务层

class UserService:
    """用户服务,处理用户相关业务逻辑"""
    def __init__(self, user_repository):
        self.user_repository = user_repository

    def create_user(self, username, password):
        user_id = uuid.uuid4().hex
        user = User(user_id, username, password)
        self.user_repository.save(user)
        return user

    def get_user(self, user_id):
        return self.user_repository.get(user_id)

    def update_user(self, user_id, **kwargs):
        user = self.get_user(user_id)
        if user:
            for attr, value in kwargs.items():
                if hasattr(user, attr):
                    setattr(user, attr, value)
            self.user_repository.save(user)
            return True
        return False


class FriendRequestService:
    """好友请求服务,处理好友请求业务逻辑"""
    def __init__(self, request_repository, user_repository, friendship_repository):
        self.request_repository = request_repository
        self.user_repository = user_repository
        self.friendship_repository = friendship_repository

    def create_request(self, sender_id, receiver_id):
        if sender_id == receiver_id:
            raise ValueError("Cannot send request to self")
        
        sender = self.user_repository.get(sender_id)
        receiver = self.user_repository.get(receiver_id)
        
        if not sender or not receiver:
            raise ValueError("User not found")
        
        if self.friendship_repository.check_friendship(sender_id, receiver_id):
            raise ValueError("Already friends")
        
        request_id = uuid.uuid4().hex
        request = FriendRequest(request_id, sender_id, receiver_id, datetime.now())
        self.request_repository.save(request)
        
        receiver.pending_friend_requests[request_id] = request
        self.user_repository.save(receiver)
        
        return request

    def accept_request(self, request_id, receiver_id):
        request = self.request_repository.get(request_id)
        if not request or request.receiver_id != receiver_id:
            raise ValueError("Invalid request")
        
        if request.status != FriendRequest.PENDING:
            raise ValueError("Request not pending")
        
        request.accept()
        self.request_repository.save(request)
        
        friendship = Friendship(
            request.sender_id,
            request.receiver_id,
            datetime.now()
        )
        self.friendship_repository.save(friendship)
        
        sender = self.user_repository.get(request.sender_id)
        receiver = self.user_repository.get(request.receiver_id)
        
        sender.friends[request.receiver_id] = friendship
        receiver.friends[request.sender_id] = friendship
        
        self.user_repository.save(sender)
        self.user_repository.save(receiver)
        
        if request_id in receiver.pending_friend_requests:
            del receiver.pending_friend_requests[request_id]
        
        return True


class ChatService:
    """聊天服务,处理聊天会话业务逻辑"""
    def __init__(self, chat_repository, user_repository):
        self.chat_repository = chat_repository
        self.user_repository = user_repository

    def create_private_chat(self, user_id1, user_id2):
        user1 = self.user_repository.get(user_id1)
        user2 = self.user_repository.get(user_id2)
        
        if not user1 or not user2:
            raise ValueError("User not found")
        
        if not self._are_friends(user_id1, user_id2):
            raise ValueError("Users must be friends")
        
        session_id = uuid.uuid4().hex
        chat = PrivateChat(session_id, user_id1, user_id2)
        self.chat_repository.save(chat)
        
        user1.chat_sessions[session_id] = chat
        user2.chat_sessions[session_id] = chat
        
        self.user_repository.save(user1)
        self.user_repository.save(user2)
        
        return chat

    def create_group_chat(self, creator_id, name, friend_ids):
        creator = self.user_repository.get(creator_id)
        if not creator:
            raise ValueError("Creator not found")
        
        valid_friend_ids = []
        for friend_id in friend_ids:
            friend = self.user_repository.get(friend_id)
            if friend and self._are_friends(creator_id, friend_id):
                valid_friend_ids.append(friend_id)
        
        if not valid_friend_ids:
            raise ValueError("No valid friends provided")
        
        session_id = uuid.uuid4().hex
        chat = GroupChat(session_id, creator_id, name)
        
        for friend_id in valid_friend_ids:
            chat.add_participant(friend_id)
        
        self.chat_repository.save(chat)
        
        creator.chat_sessions[session_id] = chat
        self.user_repository.save(creator)
        
        for friend_id in valid_friend_ids:
            friend = self.user_repository.get(friend_id)
            friend.chat_sessions[session_id] = chat
            self.user_repository.save(friend)
        
        return chat

    def _are_friends(self, user_id1, user_id2):
        user1 = self.user_repository.get(user_id1)
        return user_id2 in user1.friends


class MessageService:
    """消息服务,处理消息相关业务逻辑"""
    def __init__(self, message_repository, chat_repository, user_repository, notification_service):
        self.message_repository = message_repository
        self.chat_repository = chat_repository
        self.user_repository = user_repository
        self.notification_service = notification_service

    def send_message(self, sender_id, session_id, content):
        sender = self.user_repository.get(sender_id)
        if not sender:
            raise ValueError("Sender not found")
        
        chat_session = self.chat_repository.get(session_id)
        if not chat_session:
            raise ValueError("Chat session not found")
        
        if sender_id not in chat_session.participants:
            raise PermissionError("User not in chat session")
        
        message = MessageFactory.create_text_message(sender_id, content)
        self.message_repository.save(message)
        
        chat_message = chat_session.send_message(sender_id, content)
        self.chat_repository.save(chat_session)
        
        self._notify_participants(chat_session, message)
        
        return message

    def _notify_participants(self, chat_session, message):
        sender = self.user_repository.get(message.sender_id)
        
        for participant_id in chat_session.participants:
            if participant_id != message.sender_id:
                recipient = self.user_repository.get(participant_id)
                self.notification_service.notify_new_message(
                    recipient, chat_session, sender, message
                )

    def get_messages(self, user_id, session_id, limit=50, offset=0):
        user = self.user_repository.get(user_id)
        if not user:
            raise ValueError("User not found")
        
        chat_session = self.chat_repository.get(session_id)
        if not chat_session:
            raise ValueError("Chat session not found")
        
        if user_id not in chat_session.participants:
            raise PermissionError("User not in chat session")
        
        return chat_session.get_messages(user_id, limit, offset)

    def mark_message_as_read(self, user_id, message_id):
        user = self.user_repository.get(user_id)
        if not user:
            raise ValueError("User not found")
        
        message = self.message_repository.get(message_id)
        if not message:
            raise ValueError("Message not found")
        
        message.mark_as_read()
        self.message_repository.save(message)
        
        return True


class NotificationService:
    """通知服务,处理消息推送和通知"""
    def __init__(self):
        pass

    def notify_new_message(self, recipient, chat_session, sender, message):
        print(f"通知 {recipient.username}: 收到来自 {sender.username} 的新消息")
        # 实际实现可能通过WebSocket、推送通知等

5. 仓储接口(数据访问层)

class Repository(ABC):
    """仓储抽象基类"""
    @abstractmethod
    def save(self, obj):
        pass

    @abstractmethod
    def get(self, obj_id):
        pass

    @abstractmethod
    def delete(self, obj_id):
        pass


class InMemoryRepository(Repository):
    """内存存储实现(仅用于示例)"""
    def __init__(self):
        self.data = {}

    def save(self, obj):
        obj_id = getattr(obj, 'user_id', 
                        getattr(obj, 'request_id', 
                                getattr(obj, 'session_id', 
                                        getattr(obj, 'message_id', None))))
        if obj_id:
            self.data[obj_id] = obj
            return True
        return False

    def get(self, obj_id):
        return self.data.get(obj_id)

    def delete(self, obj_id):
        if obj_id in self.data:
            del self.data[obj_id]
            return True
        return False

6. 用户操作流程示例

def main():
    # 初始化存储
    user_repo = InMemoryRepository()
    request_repo = InMemoryRepository()
    friendship_repo = InMemoryRepository()
    chat_repo = InMemoryRepository()
    message_repo = InMemoryRepository()
    
    # 初始化服务
    notification_service = NotificationService()
    user_service = UserService(user_repo)
    friend_request_service = FriendRequestService(request_repo, user_repo, friendship_repo)
    chat_service = ChatService(chat_repo, user_repo)
    message_service = MessageService(message_repo, chat_repo, user_repo, notification_service)
    
    # 创建用户
    user_a = user_service.create_user("Alice", "password123")
    user_b = user_service.create_user("Bob", "securepass")
    user_c = user_service.create_user("Charlie", "letmein")
    
    # 发送好友请求
    request = user_a.send_friend_request(user_b.user_id, friend_request_service)
    user_b.accept_friend_request(request.request_id, friend_request_service)
    
    # 创建私聊
    private_chat = chat_service.create_private_chat(user_a.user_id, user_b.user_id)
    
    # 发送消息
    message_service.send_message(user_a.user_id, private_chat.session_id, "Hi Bob!")
    message_service.send_message(user_b.user_id, private_chat.session_id, "Hello Alice!")
    
    # 创建群组
    group_chat = user_a.create_group_chat("Team Chat", [user_b.user_id, user_c.user_id], chat_service)
    
    # 群组消息
    message_service.send_message(user_a.user_id, group_chat.session_id, "Hello everyone!")
    message_service.send_message(user_b.user_id, group_chat.session_id, "Hi team!")
    
    # 查看消息历史
    messages = message_service.get_messages(user_b.user_id, private_chat.session_id)
    for msg in messages:
        sender = user_repo.get(msg.sender_id).username
        print(f"{sender}: {msg.content}")
    
    # 标记消息为已读
    message_service.mark_message_as_read(user_c.user_id, messages[0].message_id)

设计特点

  1. 分层架构:清晰分离领域模型、服务层和数据访问层

  2. 面向接口编程:通过抽象基类定义契约

  3. 单一职责原则:每个类专注于单一功能

  4. 可扩展性:易于添加新功能(如多媒体消息、撤回功能)

  5. 松耦合:通过依赖注入实现组件间解耦

这个设计完整覆盖了用户管理、好友关系、聊天会话和消息处理的核心功能,同时保持了良好的可维护性和扩展性。