工作单元模式深度解析与实践

1 阅读14分钟

引言

在企业级应用开发中,如何优雅地处理复杂的数据持久化操作一直是开发者面临的重要挑战。工作单元模式(Unit of Work Pattern)作为 Martin Fowler 在《企业应用架构模式》中提出的经典设计模式,为解决这一问题提供了强有力的解决方案。本文将深入解析工作单元模式的核心原理、实现细节以及在现代 Java 生态中的应用实践。

1. 工作单元模式概述

1.1 核心定义

工作单元模式的核心思想是:维护一个受业务事务影响的对象列表,并协调变更的写出和并发问题的解决

它具有以下关键特征:

  • 事务边界管理:将多个数据库操作组织在统一的事务边界内
  • 延迟写入策略:收集所有变更操作,在事务提交时统一执行
  • 自动变更跟踪:智能监控对象状态变化,自动识别增删改操作
  • 依赖关系管理:按照正确的顺序执行操作,维护引用完整性

1.2 解决的核心问题

传统的立即执行模式存在诸多问题:

// ❌ 传统方式的问题示例
public class TraditionalApproach {
    
    public void transferMoney(String fromAccount, String toAccount, BigDecimal amount) {
        try {
            // 问题1: 每次操作都是独立事务,无法保证原子性
            accountDao.debit(fromAccount, amount);   // 事务1
            
            // 如果这里发生异常,第一步已经提交无法回滚
            validateBusinessRules(fromAccount, toAccount, amount);
            
            accountDao.credit(toAccount, amount);    // 事务2
            
            // 问题2: 频繁的数据库连接开关,性能低下
            auditDao.logTransaction(fromAccount, toAccount, amount); // 事务3
            
        } catch (Exception e) {
            // 问题3: 无法实现完整的回滚机制
            log.error("Transfer failed, but partial operations may have been committed", e);
            throw e;
        }
    }
}

1.3 架构图

┌─────────────────────────────────────────────────────────────┐
│                    工作单元模式架构                          │
├─────────────────────────────────────────────────────────────┤
│  Business Logic Layer                                       │
│  ┌─────────────────┐    ┌─────────────────┐                │
│  │   Service 1     │    │   Service 2     │                │
│  │                 │    │                 │                │
│  └─────────┬───────┘    └─────────┬───────┘                │
│            │                      │                        │
│            ▼                      ▼                        │
│  ┌─────────────────────────────────────────────────────────┐│
│  │              Unit of Work                               ││
│  │  ┌─────────────┐ ┌─────────────┐ ┌─────────────┐       ││
│  │  │ New Objects │ │Dirty Objects│ │Deleted Objs │       ││
│  │  └─────────────┘ └─────────────┘ └─────────────┘       ││
│  └─────────────────────────────────────────────────────────┘│
│                              │                              │
│                              ▼ commit()                     │
│  ┌─────────────────────────────────────────────────────────┐│
│  │              Data Mapper Layer                          ││
│  │  ┌─────────────┐ ┌─────────────┐ ┌─────────────┐       ││
│  │  │   INSERT    │ │   UPDATE    │ │   DELETE    │       ││
│  │  └─────────────┘ └─────────────┘ └─────────────┘       ││
│  └─────────────────────────────────────────────────────────┘│
│                              │                              │
│                              ▼                              │
│  ┌─────────────────────────────────────────────────────────┐│
│  │                   Database                              ││
│  └─────────────────────────────────────────────────────────┘│
└─────────────────────────────────────────────────────────────┘

2. 核心架构设计

2.1 基础架构组件

工作单元模式的核心架构包含以下关键组件:

public class UnitOfWork {
    
    // 核心状态管理容器
    private final Map<Class<?>, List<Entity>> newObjects = new LinkedHashMap<>();
    private final Map<Class<?>, List<Entity>> dirtyObjects = new LinkedHashMap<>();
    private final Map<Class<?>, List<Entity>> removedObjects = new LinkedHashMap<>();
    
    // 事务管理
    private final DataSource dataSource;
    private Connection connection;
    private boolean isCommitted = false;
    
    // 线程本地存储,确保线程安全
    private static final ThreadLocal<UnitOfWork> current = new ThreadLocal<>();
    
    public static UnitOfWork getCurrent() {
        UnitOfWork uow = current.get();
        if (uow == null) {
            uow = new UnitOfWork();
            current.set(uow);
        }
        return uow;
    }
    
    public UnitOfWork(DataSource dataSource) {
        this.dataSource = dataSource;
    }
    
    private Connection getConnection() throws SQLException {
        if (connection == null) {
            connection = dataSource.getConnection();
            connection.setAutoCommit(false); // 关键:启用手动事务管理
        }
        return connection;
    }
}

2.2 对象注册机制

工作单元通过精确的对象注册机制来跟踪实体状态:

public class UnitOfWork {
    
    /**
     * 注册新对象 - 等待插入的实体
     */
    public void registerNew(Entity entity) {
        Assert.notNull(entity, "Entity cannot be null");
        Assert.isTrue(!dirtyObjects.containsKey(entity.getClass()) || 
                     !dirtyObjects.get(entity.getClass()).contains(entity), 
                     "Object already marked as dirty");
        Assert.isTrue(!removedObjects.containsKey(entity.getClass()) || 
                     !removedObjects.get(entity.getClass()).contains(entity), 
                     "Object already marked for removal");
        
        newObjects.computeIfAbsent(entity.getClass(), k -> new ArrayList<>()).add(entity);
        log.debug("Registered new entity: {} with id: {}", entity.getClass().getSimpleName(), entity.getId());
    }
    
    /**
     * 注册脏对象 - 需要更新的实体
     */
    public void registerDirty(Entity entity) {
        Assert.notNull(entity, "Entity cannot be null");
        Class<?> entityClass = entity.getClass();
        
        // 避免重复注册
        List<Entity> newList = newObjects.get(entityClass);
        if (newList != null && newList.contains(entity)) {
            return; // 新对象无需标记为脏对象
        }
        
        List<Entity> removedList = removedObjects.get(entityClass);
        Assert.isTrue(removedList == null || !removedList.contains(entity), 
                     "Cannot mark removed object as dirty");
        
        dirtyObjects.computeIfAbsent(entityClass, k -> new ArrayList<>()).add(entity);
        log.debug("Registered dirty entity: {} with id: {}", entity.getClass().getSimpleName(), entity.getId());
    }
    
    /**
     * 注册删除对象 - 等待删除的实体
     */
    public void registerRemoved(Entity entity) {
        Assert.notNull(entity, "Entity cannot be null");
        Class<?> entityClass = entity.getClass();
        
        // 如果是新对象,直接从新对象列表移除
        List<Entity> newList = newObjects.get(entityClass);
        if (newList != null && newList.remove(entity)) {
            log.debug("Removed new entity from registration: {}", entity.getClass().getSimpleName());
            return;
        }
        
        // 从脏对象列表移除
        List<Entity> dirtyList = dirtyObjects.get(entityClass);
        if (dirtyList != null) {
            dirtyList.remove(entity);
        }
        
        // 加入删除列表
        removedObjects.computeIfAbsent(entityClass, k -> new ArrayList<>()).add(entity);
        log.debug("Registered entity for removal: {} with id: {}", entity.getClass().getSimpleName(), entity.getId());
    }
}

3. 核心执行引擎

3.1 统一提交机制

工作单元的核心价值在于其统一的提交机制:

public class UnitOfWork {
    
    /**
     * 提交所有变更 - 核心方法
     */
    public void commit() throws SQLException {
        if (isCommitted) {
            throw new IllegalStateException("Unit of work already committed");
        }
        
        long startTime = System.currentTimeMillis();
        log.info("Starting unit of work commit with {} new, {} dirty, {} removed objects", 
                 getTotalCount(newObjects), getTotalCount(dirtyObjects), getTotalCount(removedObjects));
        
        try {
            getConnection(); // 确保连接已建立
            
            // 按依赖顺序执行操作
            performInserts();
            performUpdates();
            performDeletes();
            
            // 提交事务
            connection.commit();
            isCommitted = true;
            
            long duration = System.currentTimeMillis() - startTime;
            log.info("Unit of work committed successfully in {}ms", duration);
            
        } catch (SQLException e) {
            log.error("Unit of work commit failed, rolling back", e);
            if (connection != null) {
                try {
                    connection.rollback();
                } catch (SQLException rollbackEx) {
                    log.error("Rollback failed", rollbackEx);
                    e.addSuppressed(rollbackEx);
                }
            }
            throw e;
        } finally {
            cleanup();
        }
    }
    
    private void performInserts() throws SQLException {
        log.debug("Performing insert operations");
        
        // 按实体依赖顺序处理插入
        List<Class<?>> insertionOrder = getInsertionOrder();
        
        for (Class<?> entityClass : insertionOrder) {
            List<Entity> entities = newObjects.get(entityClass);
            if (entities != null && !entities.isEmpty()) {
                EntityMapper<?> mapper = getMapper(entityClass);
                mapper.batchInsert(entities, getConnection());
                log.debug("Inserted {} entities of type {}", entities.size(), entityClass.getSimpleName());
            }
        }
    }
    
    private void performUpdates() throws SQLException {
        log.debug("Performing update operations");
        
        for (Map.Entry<Class<?>, List<Entity>> entry : dirtyObjects.entrySet()) {
            Class<?> entityClass = entry.getKey();
            List<Entity> entities = entry.getValue();
            
            if (!entities.isEmpty()) {
                EntityMapper<?> mapper = getMapper(entityClass);
                mapper.batchUpdate(entities, getConnection());
                log.debug("Updated {} entities of type {}", entities.size(), entityClass.getSimpleName());
            }
        }
    }
    
    private void performDeletes() throws SQLException {
        log.debug("Performing delete operations");
        
        // 删除操作按相反的依赖顺序执行
        List<Class<?>> deletionOrder = getDeletionOrder();
        
        for (Class<?> entityClass : deletionOrder) {
            List<Entity> entities = removedObjects.get(entityClass);
            if (entities != null && !entities.isEmpty()) {
                EntityMapper<?> mapper = getMapper(entityClass);
                mapper.batchDelete(entities, getConnection());
                log.debug("Deleted {} entities of type {}", entities.size(), entityClass.getSimpleName());
            }
        }
    }
}

3.2 实体映射器抽象

为了支持不同类型实体的批量操作,我们需要一个灵活的映射器系统:

public abstract class EntityMapper<T extends Entity> {
    
    protected final Class<T> entityClass;
    
    public EntityMapper(Class<T> entityClass) {
        this.entityClass = entityClass;
    }
    
    /**
     * 批量插入实体
     */
    @SuppressWarnings("unchecked")
    public void batchInsert(List<Entity> entities, Connection connection) throws SQLException {
        String sql = getInsertSQL();
        
        try (PreparedStatement ps = connection.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS)) {
            
            for (Entity entity : entities) {
                setInsertParameters(ps, (T) entity);
                ps.addBatch();
            }
            
            int[] results = ps.executeBatch();
            
            // 处理生成的主键
            processGeneratedKeys(entities, ps);
            
            log.debug("Batch insert completed: {} records affected", Arrays.stream(results).sum());
        }
    }
    
    /**
     * 批量更新实体
     */
    @SuppressWarnings("unchecked")
    public void batchUpdate(List<Entity> entities, Connection connection) throws SQLException {
        String sql = getUpdateSQL();
        
        try (PreparedStatement ps = connection.prepareStatement(sql)) {
            for (Entity entity : entities) {
                setUpdateParameters(ps, (T) entity);
                ps.addBatch();
            }
            
            int[] results = ps.executeBatch();
            
            // 检查乐观锁
            validateUpdateResults(entities, results);
            
            log.debug("Batch update completed: {} records affected", Arrays.stream(results).sum());
        }
    }
    
    /**
     * 批量删除实体
     */
    @SuppressWarnings("unchecked")
    public void batchDelete(List<Entity> entities, Connection connection) throws SQLException {
        String sql = getDeleteSQL();
        
        try (PreparedStatement ps = connection.prepareStatement(sql)) {
            for (Entity entity : entities) {
                setDeleteParameters(ps, (T) entity);
                ps.addBatch();
            }
            
            int[] results = ps.executeBatch();
            
            // 验证删除结果
            validateDeleteResults(entities, results);
            
            log.debug("Batch delete completed: {} records affected", Arrays.stream(results).sum());
        }
    }
    
    private void processGeneratedKeys(List<Entity> entities, PreparedStatement ps) throws SQLException {
        try (ResultSet rs = ps.getGeneratedKeys()) {
            int index = 0;
            while (rs.next() && index < entities.size()) {
                Entity entity = entities.get(index);
                Long generatedId = rs.getLong(1);
                entity.setId(generatedId);
                index++;
            }
        }
    }
    
    private void validateUpdateResults(List<Entity> entities, int[] results) {
        for (int i = 0; i < results.length; i++) {
            if (results[i] == 0) {
                Entity entity = entities.get(i);
                throw new OptimisticLockException(
                    String.format("Entity %s with id %s was modified by another transaction", 
                                entity.getClass().getSimpleName(), entity.getId()));
            }
        }
    }
    
    private void validateDeleteResults(List<Entity> entities, int[] results) {
        for (int i = 0; i < results.length; i++) {
            if (results[i] == 0) {
                Entity entity = entities.get(i);
                log.warn("Entity {} with id {} was not found for deletion", 
                        entity.getClass().getSimpleName(), entity.getId());
            }
        }
    }
    
    // 抽象方法,由具体实现类提供
    protected abstract String getInsertSQL();
    protected abstract String getUpdateSQL();
    protected abstract String getDeleteSQL();
    protected abstract void setInsertParameters(PreparedStatement ps, T entity) throws SQLException;
    protected abstract void setUpdateParameters(PreparedStatement ps, T entity) throws SQLException;
    protected abstract void setDeleteParameters(PreparedStatement ps, T entity) throws SQLException;
}

4. 高级特性实现

4.1 自动脏检查机制

自动脏检查是工作单元模式的高级特性,可以自动识别对象的变更:

public class DirtyCheckingUnitOfWork extends UnitOfWork {
    
    // 对象状态快照存储
    private final Map<Entity, EntitySnapshot> entitySnapshots = new ConcurrentHashMap<>();
    
    /**
     * 注册干净对象(刚从数据库加载)
     */
    public void registerClean(Entity entity) {
        if (!entitySnapshots.containsKey(entity)) {
            EntitySnapshot snapshot = createSnapshot(entity);
            entitySnapshots.put(entity, snapshot);
            log.debug("Registered clean entity: {} with id: {}", 
                     entity.getClass().getSimpleName(), entity.getId());
        }
    }
    
    @Override
    public void commit() throws SQLException {
        // 在提交前自动检测脏对象
        detectDirtyObjects();
        
        // 调用父类提交逻辑
        super.commit();
    }
    
    /**
     * 自动检测脏对象
     */
    private void detectDirtyObjects() {
        log.debug("Starting automatic dirty object detection");
        int dirtyCount = 0;
        
        for (Map.Entry<Entity, EntitySnapshot> entry : entitySnapshots.entrySet()) {
            Entity entity = entry.getKey();
            EntitySnapshot originalSnapshot = entry.getValue();
            
            // 创建当前状态快照
            EntitySnapshot currentSnapshot = createSnapshot(entity);
            
            // 比较状态
            if (!originalSnapshot.equals(currentSnapshot)) {
                registerDirty(entity);
                dirtyCount++;
                log.debug("Detected dirty entity: {} with id: {}", 
                         entity.getClass().getSimpleName(), entity.getId());
            }
        }
        
        log.info("Dirty object detection completed: {} dirty objects found", dirtyCount);
    }
    
    /**
     * 创建实体状态快照
     */
    private EntitySnapshot createSnapshot(Entity entity) {
        try {
            // 使用反射获取所有字段值
            Map<String, Object> fieldValues = new HashMap<>();
            Class<?> clazz = entity.getClass();
            
            while (clazz != null && clazz != Object.class) {
                Field[] fields = clazz.getDeclaredFields();
                
                for (Field field : fields) {
                    if (!Modifier.isStatic(field.getModifiers()) && 
                        !Modifier.isTransient(field.getModifiers())) {
                        
                        field.setAccessible(true);
                        Object value = field.get(entity);
                        
                        // 深拷贝复杂对象
                        if (value != null && !isPrimitiveOrWrapper(field.getType())) {
                            value = deepCopy(value);
                        }
                        
                        fieldValues.put(field.getName(), value);
                    }
                }
                clazz = clazz.getSuperclass();
            }
            
            return new EntitySnapshot(entity.getClass(), entity.getId(), fieldValues);
            
        } catch (IllegalAccessException e) {
            throw new RuntimeException("Failed to create entity snapshot", e);
        }
    }
    
    /**
     * 实体状态快照
     */
    private static class EntitySnapshot {
        private final Class<?> entityClass;
        private final Object entityId;
        private final Map<String, Object> fieldValues;
        
        public EntitySnapshot(Class<?> entityClass, Object entityId, Map<String, Object> fieldValues) {
            this.entityClass = entityClass;
            this.entityId = entityId;
            this.fieldValues = new HashMap<>(fieldValues);
        }
        
        @Override
        public boolean equals(Object obj) {
            if (this == obj) return true;
            if (obj == null || getClass() != obj.getClass()) return false;
            
            EntitySnapshot that = (EntitySnapshot) obj;
            return Objects.equals(entityClass, that.entityClass) &&
                   Objects.equals(entityId, that.entityId) &&
                   Objects.equals(fieldValues, that.fieldValues);
        }
        
        @Override
        public int hashCode() {
            return Objects.hash(entityClass, entityId, fieldValues);
        }
    }
}

4.2 依赖关系管理

处理实体间的依赖关系是工作单元模式的重要职责:

public class DependencyManager {
    
    private final Map<Class<?>, Set<Class<?>>> dependencyGraph = new HashMap<>();
    
    /**
     * 注册实体依赖关系
     */
    public void registerDependency(Class<?> dependent, Class<?> dependency) {
        dependencyGraph.computeIfAbsent(dependent, k -> new HashSet<>()).add(dependency);
        log.debug("Registered dependency: {} depends on {}", 
                 dependent.getSimpleName(), dependency.getSimpleName());
    }
    
    /**
     * 获取插入顺序(依赖关系正序)
     */
    public List<Class<?>> getInsertionOrder() {
        return topologicalSort(false);
    }
    
    /**
     * 获取删除顺序(依赖关系逆序)
     */
    public List<Class<?>> getDeletionOrder() {
        return topologicalSort(true);
    }
    
    /**
     * 拓扑排序算法
     */
    private List<Class<?>> topologicalSort(boolean reverse) {
        Map<Class<?>, Integer> inDegree = calculateInDegree();
        Queue<Class<?>> queue = new LinkedList<>();
        List<Class<?>> result = new ArrayList<>();
        
        // 找到所有入度为0的节点
        for (Map.Entry<Class<?>, Integer> entry : inDegree.entrySet()) {
            if (entry.getValue() == 0) {
                queue.offer(entry.getKey());
            }
        }
        
        while (!queue.isEmpty()) {
            Class<?> current = queue.poll();
            result.add(current);
            
            // 获取当前节点的邻接节点
            Set<Class<?>> neighbors = reverse ? 
                getDependents(current) : dependencyGraph.getOrDefault(current, Collections.emptySet());
            
            for (Class<?> neighbor : neighbors) {
                inDegree.put(neighbor, inDegree.get(neighbor) - 1);
                if (inDegree.get(neighbor) == 0) {
                    queue.offer(neighbor);
                }
            }
        }
        
        if (result.size() != inDegree.size()) {
            throw new IllegalStateException("Circular dependency detected in entity relationships");
        }
        
        if (reverse) {
            Collections.reverse(result);
        }
        
        return result;
    }
    
    private Map<Class<?>, Integer> calculateInDegree() {
        Map<Class<?>, Integer> inDegree = new HashMap<>();
        
        // 初始化所有节点的入度为0
        for (Class<?> clazz : getAllEntityClasses()) {
            inDegree.put(clazz, 0);
        }
        
        // 计算每个节点的入度
        for (Map.Entry<Class<?>, Set<Class<?>>> entry : dependencyGraph.entrySet()) {
            for (Class<?> dependency : entry.getValue()) {
                inDegree.put(dependency, inDegree.getOrDefault(dependency, 0) + 1);
            }
        }
        
        return inDegree;
    }
    
    private Set<Class<?>> getDependents(Class<?> clazz) {
        Set<Class<?>> dependents = new HashSet<>();
        for (Map.Entry<Class<?>, Set<Class<?>>> entry : dependencyGraph.entrySet()) {
            if (entry.getValue().contains(clazz)) {
                dependents.add(entry.getKey());
            }
        }
        return dependents;
    }
    
    private Set<Class<?>> getAllEntityClasses() {
        Set<Class<?>> allClasses = new HashSet<>();
        for (Map.Entry<Class<?>, Set<Class<?>>> entry : dependencyGraph.entrySet()) {
            allClasses.add(entry.getKey());
            allClasses.addAll(entry.getValue());
        }
        return allClasses;
    }
}

5. 业务层应用实践

5.1 服务层集成

@Service
@Transactional
public class OrderProcessingService {
    
    private final UnitOfWorkFactory unitOfWorkFactory;
    private final OrderRepository orderRepository;
    private final InventoryRepository inventoryRepository;
    private final PaymentService paymentService;
    private final NotificationService notificationService;
    
    /**
     * 复杂订单处理业务场景
     */
    public OrderProcessingResult processOrder(OrderRequest request) {
        return unitOfWorkFactory.executeInUnitOfWork(uow -> {
            
            // 1. 创建订单
            Order order = createOrder(request, uow);
            
            // 2. 处理库存
            List<InventoryAdjustment> adjustments = processInventory(order, uow);
            
            // 3. 处理支付
            Payment payment = processPayment(order, uow);
            
            // 4. 创建发货信息
            Shipment shipment = createShipment(order, uow);
            
            // 5. 发送通知
            scheduleNotifications(order, uow);
            
            // 6. 统一提交所有变更
            // 注意:实际的数据库操作在这里统一执行
            return new OrderProcessingResult(order, payment, shipment, adjustments);
        });
    }
    
    private Order createOrder(OrderRequest request, UnitOfWork uow) {
        Order order = new Order();
        order.setCustomerId(request.getCustomerId());
        order.setStatus(OrderStatus.PENDING);
        order.setOrderDate(LocalDateTime.now());
        order.setTotalAmount(request.calculateTotal());
        
        // 创建订单项
        for (OrderItemRequest itemRequest : request.getItems()) {
            OrderItem item = new OrderItem();
            item.setOrder(order);
            item.setProductId(itemRequest.getProductId());
            item.setQuantity(itemRequest.getQuantity());
            item.setUnitPrice(itemRequest.getUnitPrice());
            
            order.addItem(item);
            uow.registerNew(item); // 注册新的订单项
        }
        
        uow.registerNew(order); // 注册新订单
        log.info("Order created for customer: {}", request.getCustomerId());
        return order;
    }
    
    private List<InventoryAdjustment> processInventory(Order order, UnitOfWork uow) {
        List<InventoryAdjustment> adjustments = new ArrayList<>();
        
        for (OrderItem item : order.getItems()) {
            // 查询库存
            Inventory inventory = inventoryRepository.findByProductId(item.getProductId(), uow);
            
            if (inventory.getAvailableQuantity() < item.getQuantity()) {
                throw new InsufficientInventoryException(
                    String.format("Product %s has insufficient inventory: available=%d, required=%d",
                                item.getProductId(), inventory.getAvailableQuantity(), item.getQuantity()));
            }
            
            // 调整库存
            inventory.decreaseQuantity(item.getQuantity());
            uow.registerDirty(inventory); // 标记库存为脏对象
            
            // 创建库存调整记录
            InventoryAdjustment adjustment = new InventoryAdjustment();
            adjustment.setProductId(item.getProductId());
            adjustment.setQuantityChange(-item.getQuantity());
            adjustment.setReason("ORDER_FULFILLMENT");
            adjustment.setOrderId(order.getId());
            adjustment.setTimestamp(LocalDateTime.now());
            
            adjustments.add(adjustment);
            uow.registerNew(adjustment);
            
            // 检查是否需要补货
            if (inventory.shouldReorder()) {
                RestockRequest restockRequest = new RestockRequest();
                restockRequest.setProductId(item.getProductId());
                restockRequest.setRequestedQuantity(inventory.getReorderQuantity());
                restockRequest.setPriority(inventory.calculatePriority());
                
                uow.registerNew(restockRequest);
                log.info("Restock request created for product: {}", item.getProductId());
            }
        }
        
        return adjustments;
    }
    
    private Payment processPayment(Order order, UnitOfWork uow) {
        Payment payment = new Payment();
        payment.setOrderId(order.getId());
        payment.setAmount(order.getTotalAmount());
        payment.setPaymentMethod(order.getPaymentMethod());
        payment.setStatus(PaymentStatus.PENDING);
        payment.setCreatedAt(LocalDateTime.now());
        
        try {
            // 调用外部支付服务
            PaymentResult result = paymentService.processPayment(payment);
            
            payment.setTransactionId(result.getTransactionId());
            payment.setStatus(result.getStatus());
            payment.setProcessedAt(LocalDateTime.now());
            
            if (result.isSuccessful()) {
                order.setStatus(OrderStatus.PAID);
                uow.registerDirty(order); // 更新订单状态
            } else {
                throw new PaymentProcessingException("Payment failed: " + result.getErrorMessage());
            }
            
        } catch (Exception e) {
            payment.setStatus(PaymentStatus.FAILED);
            payment.setErrorMessage(e.getMessage());
            order.setStatus(OrderStatus.PAYMENT_FAILED);
            uow.registerDirty(order);
            throw e;
        }
        
        uow.registerNew(payment);
        return payment;
    }
}

5.2 Spring Framework 集成

@Configuration
@EnableTransactionManagement
public class UnitOfWorkConfiguration {
    
    @Bean
    public UnitOfWorkFactory unitOfWorkFactory(DataSource dataSource, 
                                             PlatformTransactionManager transactionManager) {
        return new SpringUnitOfWorkFactory(dataSource, transactionManager);
    }
}

@Component
public class SpringUnitOfWorkFactory implements UnitOfWorkFactory {
    
    private final DataSource dataSource;
    private final PlatformTransactionManager transactionManager;
    
    public SpringUnitOfWorkFactory(DataSource dataSource, PlatformTransactionManager transactionManager) {
        this.dataSource = dataSource;
        this.transactionManager = transactionManager;
    }
    
    @Override
    public <T> T executeInUnitOfWork(UnitOfWorkCallback<T> callback) {
        TransactionTemplate transactionTemplate = new TransactionTemplate(transactionManager);
        
        return transactionTemplate.execute(status -> {
            UnitOfWork uow = new UnitOfWork(dataSource);
            UnitOfWorkContext.setCurrent(uow);
            
            try {
                T result = callback.execute(uow);
                uow.commit();
                return result;
                
            } catch (Exception e) {
                status.setRollbackOnly();
                log.error("Unit of work execution failed, transaction will be rolled back", e);
                throw new UnitOfWorkException("Unit of work execution failed", e);
                
            } finally {
                UnitOfWorkContext.clear();
            }
        });
    }
}

public class UnitOfWorkContext {
    private static final ThreadLocal<UnitOfWork> current = new ThreadLocal<>();
    
    public static void setCurrent(UnitOfWork unitOfWork) {
        current.set(unitOfWork);
    }
    
    public static UnitOfWork getCurrent() {
        UnitOfWork uow = current.get();
        if (uow == null) {
            throw new IllegalStateException("No active unit of work in current thread");
        }
        return uow;
    }
    
    public static void clear() {
        current.remove();
    }
}

6. 性能优化策略

6.1 批量操作优化

public class OptimizedEntityMapper<T extends Entity> extends EntityMapper<T> {
    
    private static final int DEFAULT_BATCH_SIZE = 1000;
    private final int batchSize;
    
    public OptimizedEntityMapper(Class<T> entityClass, int batchSize) {
        super(entityClass);
        this.batchSize = batchSize > 0 ? batchSize : DEFAULT_BATCH_SIZE;
    }
    
    @Override
    public void batchInsert(List<Entity> entities, Connection connection) throws SQLException {
        if (entities.size() <= batchSize) {
            // 小批量直接执行
            super.batchInsert(entities, connection);
        } else {
            // 大批量分片执行
            executeLargeBatch(entities, connection, this::executeBatchInsert);
        }
    }
    
    private void executeLargeBatch(List<Entity> entities, Connection connection, 
                                 BatchOperation operation) throws SQLException {
        int totalSize = entities.size();
        int processedCount = 0;
        
        log.info("Processing large batch: {} entities, batch size: {}", totalSize, batchSize);
        
        for (int i = 0; i < totalSize; i += batchSize) {
            int endIndex = Math.min(i + batchSize, totalSize);
            List<Entity> batch = entities.subList(i, endIndex);
            
            operation.execute(batch, connection);
            processedCount += batch.size();
            
            // 定期刷新以避免内存压力
            if (processedCount % (batchSize * 10) == 0) {
                log.debug("Processed {} of {} entities", processedCount, totalSize);
            }
        }
        
        log.info("Large batch processing completed: {} entities processed", processedCount);
    }
    
    private void executeBatchInsert(List<Entity> batch, Connection connection) throws SQLException {
        super.batchInsert(batch, connection);
    }
    
    @FunctionalInterface
    private interface BatchOperation {
        void execute(List<Entity> batch, Connection connection) throws SQLException;
    }
}

6.2 内存管理优化

public class MemoryOptimizedUnitOfWork extends UnitOfWork {
    
    private static final int MEMORY_THRESHOLD = 10000; // 对象数量阈值
    private final MemoryMonitor memoryMonitor;
    
    public MemoryOptimizedUnitOfWork(DataSource dataSource) {
        super(dataSource);
        this.memoryMonitor = new MemoryMonitor();
    }
    
    @Override
    public void registerNew(Entity entity) {
        super.registerNew(entity);
        checkMemoryPressure();
    }
    
    @Override
    public void registerDirty(Entity entity) {
        super.registerDirty(entity);
        checkMemoryPressure();
    }
    
    private void checkMemoryPressure() {
        int totalObjects = getTotalObjectCount();
        
        if (totalObjects > MEMORY_THRESHOLD) {
            log.warn("Memory pressure detected: {} objects in unit of work", totalObjects);
            
            if (memoryMonitor.isMemoryPressureHigh()) {
                // 建议部分提交以释放内存
                log.warn("High memory pressure, consider partial commit");
                // 可以实现智能的部分提交逻辑
                performPartialCommit();
            }
        }
    }
    
    private void performPartialCommit() throws SQLException {
        // 智能部分提交:优先提交独立的实体
        // 这是一个高级特性,需要仔细设计以维护数据一致性
        log.info("Performing partial commit to reduce memory pressure");
        
        try {
            // 先提交没有依赖关系的新对象
            commitIndependentNewObjects();
            
            // 然后提交更新操作
            commitUpdateOperations();
            
        } catch (SQLException e) {
            log.error("Partial commit failed", e);
            throw e;
        }
    }
    
    private static class MemoryMonitor {
        private final MemoryMXBean memoryBean = ManagementFactory.getMemoryMXBean();
        
        public boolean isMemoryPressureHigh() {
            MemoryUsage heapUsage = memoryBean.getHeapMemoryUsage();
            double usageRatio = (double) heapUsage.getUsed() / heapUsage.getMax();
            
            return usageRatio > 0.8; // 内存使用超过80%
        }
    }
}

7. 框架中的实现

7.1 Hibernate 实现

// Hibernate 的 Session 就是工作单元的实现
@Service
@Transactional
public class HibernateExampleService {
    
    @PersistenceContext
    private EntityManager entityManager; // 工作单元
    
    public void businessOperation() {
        // 1. 查询实体(自动纳入工作单元管理)
        User user = entityManager.find(User.class, 1L);
        
        // 2. 修改实体(自动脏检查)
        user.setEmail("new@email.com");
        
        // 3. 创建新实体
        Order order = new Order(user, "New Order");
        entityManager.persist(order); // 注册新对象
        
        // 4. 删除实体
        Product product = entityManager.find(Product.class, 100L);
        entityManager.remove(product); // 注册删除对象
        
        // 5. 事务提交时自动调用 flush()
        // - 执行脏检查
        // - 按依赖顺序执行 SQL
        // - 处理级联操作
    }
}

7.2 Spring Data JPA 实现

@Service
@Transactional
public class SpringDataExampleService {
    
    @Autowired
    private UserRepository userRepository;
    
    @Autowired
    private OrderRepository orderRepository;
    
    public void businessOperation() {
        // Spring Data JPA 底层使用 Hibernate 的工作单元
        
        // 1. 查询和修改
        User user = userRepository.findById(1L).orElseThrow();
        user.setEmail("updated@email.com"); // 自动脏检查
        
        // 2. 批量操作
        List<Order> orders = createOrders(user);
        orderRepository.saveAll(orders); // 批量插入优化
        
        // 3. 删除操作
        orderRepository.deleteByUserId(user.getId());
        
        // 事务结束时统一提交
    }
}

8. 对比分析:工作单元模式 vs. 现代ORM

8.1 与 Hibernate/JPA 对比

特性自定义工作单元Hibernate/JPA
控制粒度完全控制框架控制
性能优化手动优化自动优化
学习曲线较陡峭相对平缓
透明度完全透明部分黑盒
灵活性极高受限于规范
// 自定义工作单元的精确控制
public void customUnitOfWorkExample() {
    UnitOfWork uow = new UnitOfWork(dataSource);
    
    // 精确控制何时注册对象
    Order order = new Order();
    uow.registerNew(order);
    
    // 明确的脏检查时机
    order.setStatus(OrderStatus.PROCESSING);
    uow.registerDirty(order);
    
    // 完全掌控提交时机和批量大小
    uow.commit(); // 明确知道这里会执行哪些SQL
}

// JPA的自动管理
@Service
@Transactional
public void jpaExample() {
    Order order = new Order();
    entityManager.persist(order); // 框架控制
    
    order.setStatus(OrderStatus.PROCESSING); // 自动脏检查
    
    // 事务结束时自动提交,但不能精确控制时机
}

8.2 适用场景分析

选择自定义工作单元的场景:

  • 对性能有极高要求的系统
  • 需要精确控制SQL执行的应用
  • 复杂的遗留系统集成
  • 特殊的业务规则需求

选择成熟ORM框架的场景:

  • 标准的CRUD操作
  • 快速原型开发
  • 团队ORM经验丰富
  • 标准企业应用

9. 最佳实践与注意事项

9.1 设计原则

public class BestPracticeUnitOfWork extends UnitOfWork {
    
    /**
     * 原则1:及早验证,快速失败
     */
    @Override
    public void registerNew(Entity entity) {
        validateEntity(entity);
        super.registerNew(entity);
    }
    
    private void validateEntity(Entity entity) {
        if (entity == null) {
            throw new IllegalArgumentException("Entity cannot be null");
        }
        
        if (entity.getId() != null) {
            throw new IllegalArgumentException("New entity should not have an ID");
        }
        
        // 业务规则验证
        entity.validate();
    }
    
    /**
     * 原则2:提供清晰的错误信息
     */
    @Override
    public void commit() throws SQLException {
        try {
            validateCommitPreconditions();
            super.commit();
            
        } catch (SQLException e) {
            String errorMessage = buildDetailedErrorMessage(e);
            log.error(errorMessage, e);
            throw new UnitOfWorkException(errorMessage, e);
        }
    }
    
    private void validateCommitPreconditions() {
        if (getTotalObjectCount() == 0) {
            log.warn("Attempting to commit empty unit of work");
        }
        
        if (getTotalObjectCount() > 50000) {
            log.warn("Large unit of work detected: {} objects. Consider breaking into smaller transactions.", 
                     getTotalObjectCount());
        }
    }
    
    /**
     * 原则3:提供丰富的监控信息
     */
    private String buildDetailedErrorMessage(SQLException e) {
        StringBuilder sb = new StringBuilder();
        sb.append("Unit of work commit failed. ");
        sb.append("Total objects: ").append(getTotalObjectCount()).append(". ");
        sb.append("New: ").append(getTotalCount(newObjects)).append(", ");
        sb.append("Dirty: ").append(getTotalCount(dirtyObjects)).append(", ");
        sb.append("Removed: ").append(getTotalCount(removedObjects)).append(". ");
        sb.append("SQL Error: ").append(e.getMessage());
        
        return sb.toString();
    }
}

9.2 常见陷阱与解决方案

public class UnitOfWorkAntiPatterns {
    
    /**
     * ❌ 反模式1:长时间持有工作单元
     */
    public void badLongRunningUnitOfWork() {
        UnitOfWork uow = new UnitOfWork(dataSource);
        
        // 长时间运行的操作
        for (int i = 0; i < 100000; i++) {
            Entity entity = createComplexEntity(i);
            uow.registerNew(entity); // 内存持续增长
        }
        
        uow.commit(); // 可能导致内存溢出
    }
    
    /**
     * ✅ 正确做法:分批处理
     */
    public void goodBatchProcessing() {
        int batchSize = 1000;
        
        for (int i = 0; i < 100000; i += batchSize) {
            UnitOfWork uow = new UnitOfWork(dataSource);
            
            for (int j = i; j < Math.min(i + batchSize, 100000); j++) {
                Entity entity = createComplexEntity(j);
                uow.registerNew(entity);
            }
            
            uow.commit(); // 定期提交,控制内存使用
        }
    }
    
    /**
     * ❌ 反模式2:忽略异常处理
     */
    public void badExceptionHandling() {
        UnitOfWork uow = new UnitOfWork(dataSource);
        
        try {
            // 业务操作
            performBusinessOperations(uow);
            uow.commit();
        } catch (Exception e) {
            // 没有清理资源
            throw e;
        }
    }
    
    /**
     * ✅ 正确做法:完整的异常处理
     */
    public void goodExceptionHandling() {
        UnitOfWork uow = new UnitOfWork(dataSource);
        
        try {
            performBusinessOperations(uow);
            uow.commit();
            
        } catch (Exception e) {
            log.error("Business operation failed", e);
            
            try {
                uow.rollback();
            } catch (SQLException rollbackEx) {
                log.error("Rollback failed", rollbackEx);
                e.addSuppressed(rollbackEx);
            }
            
            throw new BusinessException("Operation failed", e);
            
        } finally {
            try {
                uow.close();
            } catch (Exception closeEx) {
                log.warn("Failed to close unit of work", closeEx);
            }
        }
    }
}

10. 总结

工作单元模式作为企业级应用架构的重要组成部分,为复杂事务处理提供了强大而灵活的解决方案。通过本文的深入分析,我们可以看到:

核心价值

  1. 事务一致性保证:确保相关操作的原子性
  2. 性能优化机制:通过批量操作和延迟执行提升效率
  3. 复杂性封装:隐藏底层数据访问的复杂性
  4. 灵活性提供:支持复杂业务场景的定制需求

技术挑战

  1. 实现复杂度:需要深入理解事务、并发和数据库特性
  2. 内存管理:大量对象跟踪可能导致内存压力
  3. 调试困难:延迟执行使问题定位更具挑战性
  4. 学习成本:团队需要深入理解模式的运行机制

应用建议

  • 适度使用:根据实际业务复杂度选择合适的实现程度
  • 渐进演进:从简单实现开始,逐步增加高级特性
  • 充分测试:重点关注事务边界和异常场景的测试
  • 监控完善:建立完整的性能监控和错误跟踪机制

工作单元模式的掌握需要扎实的技术功底和丰富的实践经验。在现代Java生态中,虽然Hibernate、MyBatis等成熟框架已经提供了优秀的实现,但深入理解其背后的设计原理对于构建高质量的企业级应用仍然具有重要意义。

通过合理运用工作单元模式,我们可以构建出既保证数据一致性又具备良好性能的应用系统,为企业的数字化转型提供强有力的技术支撑。