非常友好的带你了解模板方法模式Python版

0 阅读10分钟

第一章:初识模板方法——就像做菜一样简单

想象一下,你学会了两件事情:泡茶冲咖啡。它们的步骤非常固定:

  • 泡茶:烧水 -> 放入茶叶 -> 倒入开水 -> 等待3分钟 -> 完成
  • 冲咖啡:烧水 -> 放入咖啡粉 -> 倒入开水 -> 搅拌 -> 完成

你发现了吗?烧水倒入开水这两个步骤是完全一样的,只有中间的“放什么”和最后的“怎么处理”不同。

模板方法模式就是把固定的步骤(算法骨架)写在一个“模板”方法里,把那些会变化的步骤交给子类去实现

一个超简单的例子:

from abc import ABC, abstractmethod

# 1. 定义“食谱”模板 (抽象基类)
class BeverageRecipe(ABC):
    """饮料制作模板"""
    
    def make(self):
        """模板方法:定义了固定的制作步骤"""
        self.boil_water()        # 固定步骤1:烧水
        self.add_main_ingredient() # 可变步骤1:加主料(茶/咖啡)
        self.pour_water()        # 固定步骤2:倒水
        self.final_touch()       # 可变步骤2:最后处理(等/搅)
        self.done()              # 固定步骤3:完成
    
    def boil_water(self):
        print("1. 烧开水...")
    
    def pour_water(self):
        print("3. 倒入开水...")
    
    def done(self):
        print("5. 完成!可以享用了。\n")
    
    @abstractmethod
    def add_main_ingredient(self):
        """加主料 - 子类必须实现"""
        pass
    
    @abstractmethod
    def final_touch(self):
        """最后处理 - 子类必须实现"""
        pass

# 2. 实现具体的“茶”食谱
class TeaRecipe(BeverageRecipe):
    def add_main_ingredient(self):
        print("2. 放入绿茶茶叶...")
    
    def final_touch(self):
        print("4. 等待3分钟,让茶叶舒展...")

# 3. 实现具体的“咖啡”食谱
class CoffeeRecipe(BeverageRecipe):
    def add_main_ingredient(self):
        print("2. 加入一勺咖啡粉...")
    
    def final_touch(self):
        print("4. 快速搅拌一下...")

# 使用
print("--- 开始泡茶 ---")
tea = TeaRecipe()
tea.make()  # 调用模板方法,自动按顺序执行

print("--- 开始冲咖啡 ---")
coffee = CoffeeRecipe()
coffee.make()

输出:

--- 开始泡茶 ---
1. 烧开水...
2. 放入绿茶茶叶...
3. 倒入开水...
4. 等待3分钟,让茶叶舒展...
5. 完成!可以享用了。

--- 开始冲咖啡 ---
1. 烧开水...
2. 加入一勺咖啡粉...
3. 倒入开水...
4. 快速搅拌一下...
5. 完成!可以享用了。

核心要点:

  • make()就是“模板方法” :它规定了步骤顺序,像一个不可更改的流水线。
  • ABC@abstractmethod:来自Python的abc模块,用来强制子类必须实现某些方法(这里是add_main_ingredientfinal_touch),否则会报错。
  • 优点:避免了重复代码(boil_water, pour_water, done只写了一遍),并且很容易扩展新的饮料(比如新增一个“热巧克力”类)。

第二章:进阶——更接近编程的例子

现在,我们模拟一个数据处理流程,它也需要“模板化”。

假设我们有两个数据处理任务:

  1. 处理数字:读取 -> 每个数字加10 -> 保存
  2. 处理文本:读取 -> 每行文字变大写 -> 保存
from abc import ABC, abstractmethod

class DataProcessor(ABC):
    """数据处理模板"""
    
    def process(self, input_data):
        """模板方法:处理数据的主流程"""
        data = self.load_data(input_data)  # 1. 加载 (可变)
        processed_data = self.transform(data) # 2. 转换 (可变)
        self.save_data(processed_data)     # 3. 保存 (可变)
        print("数据处理流程结束。\n")
    
    @abstractmethod
    def load_data(self, input_data):
        pass
    
    @abstractmethod
    def transform(self, data):
        pass
    
    @abstractmethod
    def save_data(self, data):
        pass

# 具体类:数字处理器
class NumberProcessor(DataProcessor):
    def load_data(self, input_data):
        print(f"[数字] 加载数据: {input_data}")
        # 假设输入是逗号分隔的数字字符串
        return [int(x) for x in input_data.split(',')]
    
    def transform(self, data):
        print(f"[数字] 转换数据: 每个值+10")
        return [x + 10 for x in data]
    
    def save_data(self, data):
        print(f"[数字] 保存结果到列表: {data}")

# 具体类:文本处理器
class TextProcessor(DataProcessor):
    def load_data(self, input_data):
        print(f"[文本] 加载数据: '{input_data}'")
        # 假设输入是多行文本字符串
        return input_data.split('\n')
    
    def transform(self, data):
        print(f"[文本] 转换数据: 转为大写")
        return [line.upper() for line in data]
    
    def save_data(self, data):
        joined_text = ' | '.join(data)
        print(f"[文本] 保存结果到字符串: '{joined_text}'")

# 使用
num_input = "1,2,3,4"
text_input = "hello\nworld\npython"

print("=== 处理数字 ===")
num_proc = NumberProcessor()
num_proc.process(num_input)  # 调用模板方法

print("=== 处理文本 ===")
text_proc = TextProcessor()
text_proc.process(text_input)

输出:

=== 处理数字 ===
[数字] 加载数据: 1,2,3,4
[数字] 转换数据: 每个值+10
[数字] 保存结果到列表: [11, 12, 13, 14]
数据处理流程结束。

=== 处理文本 ===
[文本] 加载数据: 'hello
world
python'
[文本] 转换数据: 转为大写
[文本] 保存结果到字符串: 'HELLO | WORLD | PYTHON'
数据处理流程结束。

到这一步,你应该已经掌握了模板方法的核心思想:固定流程,变化细节。


第三章:解决你的实际案例

现在,我们回到你的Salesforce数据查询案例。三个函数 get_parent_ids, get_record_type_ids, get_owner_ids的“固定流程”是:

1. 获取唯一值列表
2. 构建IN语句
3. 执行SOQL查询
4. 构建映射字典
5. 检查未找到的值

变化的细节是:

  • 查询哪个对象Account/ RecordType/ User
  • 用哪个字段做查询条件Name/ DeveloperName/ Email
  • 用哪个字段做映射的Key

让我们用模板方法模式重构它。

第一步:定义模板(抽象基类)

from abc import ABC, abstractmethod
import logging

# 配置日志,方便看到警告信息
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
logger = logging.getLogger(__name__)

class BaseIdQueryStrategy(ABC):
    """
    ID查询策略基类(模板)
    定义了通用的5步查询流程。
    """
    
    def query_ids(self, values, sf_instance):
        """
        模板方法 - 这是算法的骨架,定义了固定的5个步骤。
        Args:
            values: 原始的、可能有重复和空值的列表
            sf_instance: Salesforce连接实例
        Returns:
            dict: 映射字典,格式为 {源值: Salesforce Id}
        """
        # 1. 预处理:获取唯一、非空的值
        unique_values = self._get_unique_values(values)
        if not unique_values:
            logger.info("输入值为空,跳过查询。")
            return {}
        
        # 2. 构建SOQL查询语句
        soql_query = self._build_soql_query(unique_values)
        logger.debug(f"执行查询: {soql_query}")
        
        # 3. 执行查询
        query_result = sf_instance.query_all(soql_query)  # 假设sf_instance有这个方法
        logger.info(f"查询完成,返回 {query_result.get('totalSize', 0)} 条记录")
        
        # 4. 从结果构建映射字典
        id_mapping = self._build_id_mapping(query_result.get('records', []))
        
        # 5. 检查是否有没找到的值,并处理
        self._handle_missing_values(id_mapping, unique_values)
        
        return id_mapping
    
    # ----- 以下是可变的步骤,由子类实现 -----
    @abstractmethod
    def _build_soql_query(self, unique_values):
        """步骤2:根据唯一值列表,构建具体的SOQL查询字符串。"""
        pass
    
    @abstractmethod
    def _build_id_mapping(self, records):
        """步骤4:从Salesforce查询结果中提取出 {源值: Id} 的映射。"""
        pass
    
    # ----- 以下是可重写的步骤,子类可以按需覆盖 -----
    def _get_unique_values(self, values):
        """
        步骤1:获取唯一、非空的值列表。
        这是通用逻辑,通常不需要子类修改。
        """
        # 去重,并过滤掉空值(如None, '', 等Falsy值)
        return [v for v in set(values) if v]
    
    def _handle_missing_values(self, id_mapping, unique_values):
        """
        步骤5:处理未找到对应ID的源值。
        子类可以覆盖此方法来实现不同的处理逻辑(如记录警告、赋默认值等)。
        """
        missing = [v for v in unique_values if v not in id_mapping]
        if missing:
            logger.warning(f"以下 {len(missing)} 个值在Salesforce中未找到: {missing}")
            # 可以选择为缺失项设置一个默认值(如空字符串),避免后续KeyError
            for missing_value in missing:
                id_mapping[missing_value] = ''  # 赋空值

第二步:实现具体的策略(子类)

现在,为三种不同的查询创建子类。每个子类只需要关心查询哪个对象、用哪个字段

class ParentAccountIdQuery(BaseIdQueryStrategy):
    """通过 Account.Name 查询 Account.Id """
    
    def _build_soql_query(self, unique_values):
        # 用 Name 字段查询 Account 对象
        in_clause = ", ".join([f"'{v}'" for v in unique_values])
        return f"""
            SELECT Id, Name
            FROM Account
            WHERE Name IN ({in_clause})
        """
    
    def _build_id_mapping(self, records):
        # 映射关系:Name -> Id
        return {record['Name']: record['Id'] for record in records}


class RecordTypeIdQuery(BaseIdQueryStrategy):
    """通过 RecordType.DeveloperName 查询 RecordType.Id """
    
    def _build_soql_query(self, unique_values):
        in_clause = ", ".join([f"'{v}'" for v in unique_values])
        return f"""
            SELECT Id, DeveloperName
            FROM RecordType
            WHERE DeveloperName IN ({in_clause})
        """
    
    def _build_id_mapping(self, records):
        # 映射关系:DeveloperName -> Id
        return {record['DeveloperName']: record['Id'] for record in records}


class OwnerUserIdQuery(BaseIdQueryStrategy):
    """通过 User.Email 查询 User.Id (作为OwnerId) """
    
    def _build_soql_query(self, unique_values):
        in_clause = ", ".join([f"'{v}'" for v in unique_values])
        return f"""
            SELECT Id, Email
            FROM User
            WHERE Email IN ({in_clause})
        """
    
    def _build_id_mapping(self, records):
        # 映射关系:Email -> Id
        return {record['Email']: record['Id'] for record in records}

或者使用通用映射方法

在基类中实现一个通用的 _build_id_mapping方法,让子类只指定关键的配置信息:

from abc import ABC, abstractmethod
import logging

logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
logger = logging.getLogger(__name__)

class BaseIdQueryStrategy(ABC):
    """
    ID查询策略基类(模板)
    定义了通用的5步查询流程。
    """
    
    @property
    @abstractmethod
    def key_field(self):
        """返回作为映射键的字段名"""
        pass
    
    @property
    @abstractmethod
    def sobject_name(self):
        """返回要查询的Salesforce对象名"""
        pass
    
    def query_ids(self, values, sf_instance):
        """
        模板方法 - 这是算法的骨架,定义了固定的5个步骤。
        Args:
            values: 原始的、可能有重复和空值的列表
            sf_instance: Salesforce连接实例
        Returns:
            dict: 映射字典,格式为 {源值: Salesforce Id}
        """
        # 1. 预处理:获取唯一、非空的值
        unique_values = self._get_unique_values(values)
        if not unique_values:
            logger.info("输入值为空,跳过查询。")
            return {}
        
        # 2. 构建SOQL查询语句
        soql_query = self._build_soql_query(unique_values)
        logger.debug(f"执行查询: {soql_query}")
        
        # 3. 执行查询
        query_result = sf_instance.query_all(soql_query)
        logger.info(f"查询完成,返回 {query_result.get('totalSize', 0)} 条记录")
        
        # 4. 从结果构建映射字典
        id_mapping = self._build_id_mapping(query_result.get('records', []))
        
        # 5. 检查是否有没找到的值,并处理
        self._handle_missing_values(id_mapping, unique_values)
        
        return id_mapping
    
    # ----- 可变的步骤 -----
    def _build_soql_query(self, unique_values):
        """步骤2:构建SOQL查询"""
        in_clause = ", ".join([f"'{v}'" for v in unique_values])
        return f"""
            SELECT Id, {self.key_field}
            FROM {self.sobject_name}
            WHERE {self.key_field} IN ({in_clause})
        """
    
    def _build_id_mapping(self, records):
        """步骤4:构建映射字典"""
        return {record[self.key_field]: record['Id'] for record in records}
    
    # ----- 通用步骤 -----
    def _get_unique_values(self, values):
        """步骤1:获取唯一、非空的值列表"""
        return [v for v in set(values) if v]
    
    def _handle_missing_values(self, id_mapping, unique_values):
        """步骤5:处理未找到对应ID的源值"""
        missing = [v for v in unique_values if v not in id_mapping]
        if missing:
            logger.warning(f"以下 {len(missing)} 个值在Salesforce中未找到: {missing}")
            for missing_value in missing:
                id_mapping[missing_value] = ''

# 子类变得非常简洁
class ParentAccountIdQuery(BaseIdQueryStrategy):
    """通过 Account.Name 查询 Account.Id """
    @property
    def key_field(self):
        return "Name"
    
    @property
    def sobject_name(self):
        return "Account"


class RecordTypeIdQuery(BaseIdQueryStrategy):
    """通过 RecordType.DeveloperName 查询 RecordType.Id """
    @property
    def key_field(self):
        return "DeveloperName"
    
    @property
    def sobject_name(self):
        return "RecordType"


class OwnerUserIdQuery(BaseIdQueryStrategy):
    """通过 User.Email 查询 User.Id """
    @property
    def key_field(self):
        return "Email"
    
    @property
    def sobject_name(self):
        return "User"

总结与收获

通过这个从浅入深的学习,我们掌握了模板方法模式:

  1. 核心思想定义一个操作中的算法骨架,将一些步骤延迟到子类中。使得子类可以不改变算法结构的情况下,重新定义某些特定步骤。

  2. Python实现关键

    • 使用 abc.ABC定义抽象基类。
    • 使用 @abstractmethod装饰器声明必须由子类实现的抽象方法(可变步骤)。
    • 在基类中定义一个模板方法(如 query_ids),它按顺序调用其他方法(包括抽象方法)。
    • 将通用逻辑写在基类的具体方法中(如 _get_unique_values)。
  3. 在案例中的优势

    • 消除重复:三个函数近20行几乎一样的代码,被浓缩到一个模板方法里。
    • 结构清晰:查询逻辑被分解为5个明确的步骤,易于理解和维护。
    • 易于扩展:如果未来要增加通过 PhoneContactId 的功能,你只需要创建一个新的 ContactPhoneIdQuery子类,实现两个抽象方法即可。主流程 (query_ids方法) 一行都不用改!
    • 便于维护:如果要修改“检查未找到值”的逻辑(比如改为抛出异常),只需要修改基类中的 _handle_missing_values方法,所有子类都会自动生效。

希望这个循序渐进的讲解能帮你真正理解并应用模板方法模式!