深入理解设计模式之模板方法模式(Template Method Pattern)
一、引言
在软件开发中,我们经常会遇到这样的场景:多个类需要执行相似的操作流程,但某些步骤的具体实现有所不同。如果每个类都完整实现整个流程,会导致大量代码重复。模板方法模式正是为了解决这类问题而诞生的一种行为型设计模式。
本文将深入探讨模板方法模式的原理、实现方式,并结合实际生产场景和开源框架(如Apache Commons Lang、Guava)中的应用,帮助你全面掌握这一重要的设计模式。
二、什么是模板方法模式
2.1 定义
模板方法模式(Template Method Pattern)是一种行为设计模式,它在父类中定义了一个算法的骨架,允许子类在不改变算法整体结构的情况下重写算法的某些特定步骤。
2.2 核心思想
- 封装不变部分:将算法的主体结构固定在父类中
- 扩展可变部分:将算法中易变的部分延迟到子类实现
- 控制扩展点:通过钩子方法(Hook Method)提供可选的扩展点
2.3 模式结构
┌─────────────────────────────────┐
│ AbstractClass (抽象类) │
├─────────────────────────────────┤
│ + templateMethod() │ <-- 定义算法骨架(final)
│ # primitiveOperation1() │ <-- 抽象方法,子类必须实现
│ # primitiveOperation2() │ <-- 抽象方法,子类必须实现
│ # hook() │ <-- 钩子方法,子类可选实现
└─────────────────────────────────┘
△
│ 继承
│
┌──────┴──────────┐
│ │
┌───┴────────┐ ┌────┴──────────┐
│ConcreteA │ │ ConcreteB │
├────────────┤ ├───────────────┤
│具体实现1 │ │ 具体实现2 │
└────────────┘ └───────────────┘
三、基础示例
3.1 场景:饮料制作流程
让我们通过一个经典的饮料制作流程来理解模板方法模式。制作茶和咖啡都有固定的步骤,但具体实现有所不同。
/**
* 抽象类:饮料制作模板
*/
public abstract class Beverage {
/**
* 模板方法:定义制作饮料的算法骨架
* 使用final防止子类修改流程
*/
public final void prepareBeverage() {
// 步骤1:烧水
boilWater();
// 步骤2:冲泡(由子类实现)
brew();
// 步骤3:倒入杯子
pourInCup();
// 步骤4:添加调料(由子类实现)
addCondiments();
// 步骤5:可选的钩子方法
if (customerWantsCondiments()) {
addExtraCondiments();
}
}
// 通用步骤:烧水
private void boilWater() {
System.out.println("烧开水...");
}
// 通用步骤:倒入杯子
private void pourInCup() {
System.out.println("倒入杯子...");
}
// 抽象方法:冲泡(子类必须实现)
protected abstract void brew();
// 抽象方法:添加调料(子类必须实现)
protected abstract void addCondiments();
// 钩子方法:是否需要调料(子类可选实现)
protected boolean customerWantsCondiments() {
return true;
}
// 钩子方法:添加额外调料(子类可选实现)
protected void addExtraCondiments() {
// 默认不做任何事
}
}
茶的具体实现:
/**
* 具体类:茶
*/
public class Tea extends Beverage {
@Override
protected void brew() {
System.out.println("用热水浸泡茶叶...");
}
@Override
protected void addCondiments() {
System.out.println("添加柠檬...");
}
@Override
protected boolean customerWantsCondiments() {
return getUserInput();
}
@Override
protected void addExtraCondiments() {
System.out.println("添加蜂蜜...");
}
private boolean getUserInput() {
// 简化示例,实际应从用户输入获取
return true;
}
}
咖啡的具体实现:
/**
* 具体类:咖啡
*/
public class Coffee extends Beverage {
@Override
protected void brew() {
System.out.println("用热水冲泡咖啡粉...");
}
@Override
protected void addCondiments() {
System.out.println("添加糖和牛奶...");
}
@Override
protected void addExtraCondiments() {
System.out.println("添加焦糖糖浆...");
}
}
客户端使用:
public class BeverageTest {
public static void main(String[] args) {
System.out.println("=== 制作茶 ===");
Beverage tea = new Tea();
tea.prepareBeverage();
System.out.println("\n=== 制作咖啡 ===");
Beverage coffee = new Coffee();
coffee.prepareBeverage();
}
}
输出结果:
=== 制作茶 ===
烧开水...
用热水浸泡茶叶...
倒入杯子...
添加柠檬...
添加蜂蜜...
=== 制作咖啡 ===
烧开水...
用热水冲泡咖啡粉...
倒入杯子...
添加糖和牛奶...
添加焦糖糖浆...
四、实际生产场景应用
4.1 场景:数据导入流程
在企业级应用中,我们经常需要从不同数据源(CSV、Excel、数据库)导入数据。虽然数据源不同,但导入流程基本一致。
/**
* 抽象类:数据导入模板
*/
public abstract class DataImporter {
/**
* 模板方法:定义数据导入的标准流程
*/
public final ImportResult importData(String source) {
ImportResult result = new ImportResult();
try {
// 1. 验证数据源
System.out.println("步骤1: 验证数据源");
validateSource(source);
// 2. 打开连接
System.out.println("步骤2: 打开数据源连接");
openConnection(source);
// 3. 读取数据
System.out.println("步骤3: 读取数据");
List<String[]> rawData = readData();
// 4. 数据校验
System.out.println("步骤4: 数据校验");
List<String[]> validData = validateData(rawData);
// 5. 数据转换
System.out.println("步骤5: 数据转换");
List<Object> transformedData = transformData(validData);
// 6. 数据持久化
System.out.println("步骤6: 数据持久化");
persistData(transformedData);
// 7. 可选:数据后处理
if (needsPostProcessing()) {
System.out.println("步骤7: 数据后处理");
postProcess(transformedData);
}
result.setSuccess(true);
result.setRecordCount(transformedData.size());
} catch (Exception e) {
result.setSuccess(false);
result.setErrorMessage(e.getMessage());
handleError(e);
} finally {
// 8. 关闭连接
System.out.println("最后步骤: 关闭连接");
closeConnection();
}
return result;
}
// 抽象方法:子类必须实现
protected abstract void validateSource(String source) throws Exception;
protected abstract void openConnection(String source) throws Exception;
protected abstract List<String[]> readData() throws Exception;
protected abstract List<Object> transformData(List<String[]> rawData);
protected abstract void persistData(List<Object> data) throws Exception;
protected abstract void closeConnection();
// 钩子方法:子类可选实现
protected boolean needsPostProcessing() {
return false;
}
protected void postProcess(List<Object> data) {
// 默认不做任何处理
}
// 通用方法:数据校验
protected List<String[]> validateData(List<String[]> rawData) {
// 通用的数据校验逻辑
return rawData.stream()
.filter(row -> row != null && row.length > 0)
.collect(java.util.stream.Collectors.toList());
}
// 通用方法:错误处理
protected void handleError(Exception e) {
System.err.println("导入失败: " + e.getMessage());
e.printStackTrace();
}
}
/**
* 导入结果类
*/
class ImportResult {
private boolean success;
private int recordCount;
private String errorMessage;
// Getters and Setters
public boolean isSuccess() { return success; }
public void setSuccess(boolean success) { this.success = success; }
public int getRecordCount() { return recordCount; }
public void setRecordCount(int recordCount) { this.recordCount = recordCount; }
public String getErrorMessage() { return errorMessage; }
public void setErrorMessage(String errorMessage) { this.errorMessage = errorMessage; }
}
CSV导入实现:
import java.io.*;
import java.util.*;
/**
* 具体类:CSV数据导入
*/
public class CsvDataImporter extends DataImporter {
private BufferedReader reader;
private List<User> users;
@Override
protected void validateSource(String source) throws Exception {
File file = new File(source);
if (!file.exists() || !file.getName().endsWith(".csv")) {
throw new IllegalArgumentException("无效的CSV文件: " + source);
}
}
@Override
protected void openConnection(String source) throws Exception {
reader = new BufferedReader(new FileReader(source));
}
@Override
protected List<String[]> readData() throws Exception {
List<String[]> data = new ArrayList<>();
String line;
// 跳过标题行
reader.readLine();
while ((line = reader.readLine()) != null) {
String[] fields = line.split(",");
data.add(fields);
}
return data;
}
@Override
protected List<Object> transformData(List<String[]> rawData) {
users = new ArrayList<>();
for (String[] row : rawData) {
if (row.length >= 3) {
User user = new User();
user.setName(row[0].trim());
user.setEmail(row[1].trim());
user.setAge(Integer.parseInt(row[2].trim()));
users.add(user);
}
}
return new ArrayList<>(users);
}
@Override
protected void persistData(List<Object> data) throws Exception {
// 模拟保存到数据库
System.out.println("保存 " + data.size() + " 条记录到数据库");
}
@Override
protected void closeConnection() {
try {
if (reader != null) {
reader.close();
}
} catch (IOException e) {
e.printStackTrace();
}
}
@Override
protected boolean needsPostProcessing() {
return true;
}
@Override
protected void postProcess(List<Object> data) {
// 发送导入完成通知
System.out.println("发送邮件通知:成功导入 " + data.size() + " 个用户");
}
}
/**
* 用户实体类
*/
class User {
private String name;
private String email;
private int age;
// Getters and Setters
public String getName() { return name; }
public void setName(String name) { this.name = name; }
public String getEmail() { return email; }
public void setEmail(String email) { this.email = email; }
public int getAge() { return age; }
public void setAge(int age) { this.age = age; }
}
数据库导入实现:
import java.sql.*;
import java.util.*;
/**
* 具体类:数据库数据导入
*/
public class DatabaseImporter extends DataImporter {
private Connection connection;
private String tableName;
public DatabaseImporter(String tableName) {
this.tableName = tableName;
}
@Override
protected void validateSource(String source) throws Exception {
// 验证数据库连接字符串
if (source == null || !source.startsWith("jdbc:")) {
throw new IllegalArgumentException("无效的数据库连接: " + source);
}
}
@Override
protected void openConnection(String source) throws Exception {
connection = DriverManager.getConnection(source);
}
@Override
protected List<String[]> readData() throws Exception {
List<String[]> data = new ArrayList<>();
String sql = "SELECT name, email, age FROM " + tableName;
try (Statement stmt = connection.createStatement();
ResultSet rs = stmt.executeQuery(sql)) {
while (rs.next()) {
String[] row = new String[3];
row[0] = rs.getString("name");
row[1] = rs.getString("email");
row[2] = String.valueOf(rs.getInt("age"));
data.add(row);
}
}
return data;
}
@Override
protected List<Object> transformData(List<String[]> rawData) {
List<Object> users = new ArrayList<>();
for (String[] row : rawData) {
User user = new User();
user.setName(row[0]);
user.setEmail(row[1]);
user.setAge(Integer.parseInt(row[2]));
users.add(user);
}
return users;
}
@Override
protected void persistData(List<Object> data) throws Exception {
// 批量插入到目标表
String sql = "INSERT INTO target_table (name, email, age) VALUES (?, ?, ?)";
try (PreparedStatement pstmt = connection.prepareStatement(sql)) {
for (Object obj : data) {
User user = (User) obj;
pstmt.setString(1, user.getName());
pstmt.setString(2, user.getEmail());
pstmt.setInt(3, user.getAge());
pstmt.addBatch();
}
pstmt.executeBatch();
}
}
@Override
protected void closeConnection() {
try {
if (connection != null && !connection.isClosed()) {
connection.close();
}
} catch (SQLException e) {
e.printStackTrace();
}
}
}
五、开源框架中的应用
5.1 Apache Commons Lang - 抽象枚举比较器
Apache Commons Lang中的ObjectUtils类使用了模板方法模式的思想。让我们看一个典型的例子:
import org.apache.commons.lang3.ObjectUtils;
/**
* 模拟Apache Commons Lang中的toString方法模板
*/
public abstract class AbstractToStringStyle {
/**
* 模板方法:生成对象的字符串表示
*/
public final String toString(Object object) {
if (object == null) {
return nullText();
}
StringBuilder buffer = new StringBuilder();
// 添加前缀
appendPrefix(buffer, object);
// 添加字段内容
appendFields(buffer, object);
// 添加后缀
appendSuffix(buffer, object);
return buffer.toString();
}
// 钩子方法:null值的文本表示
protected String nullText() {
return "null";
}
// 抽象方法:添加前缀
protected abstract void appendPrefix(StringBuilder buffer, Object object);
// 抽象方法:添加字段
protected abstract void appendFields(StringBuilder buffer, Object object);
// 抽象方法:添加后缀
protected abstract void appendSuffix(StringBuilder buffer, Object object);
}
/**
* JSON风格的toString实现
*/
class JsonToStringStyle extends AbstractToStringStyle {
@Override
protected void appendPrefix(StringBuilder buffer, Object object) {
buffer.append("{ \"class\": \"")
.append(object.getClass().getSimpleName())
.append("\", ");
}
@Override
protected void appendFields(StringBuilder buffer, Object object) {
// 简化示例,实际需要使用反射获取字段
buffer.append("\"fields\": {...}");
}
@Override
protected void appendSuffix(StringBuilder buffer, Object object) {
buffer.append(" }");
}
}
/**
* XML风格的toString实现
*/
class XmlToStringStyle extends AbstractToStringStyle {
@Override
protected void appendPrefix(StringBuilder buffer, Object object) {
buffer.append("<")
.append(object.getClass().getSimpleName())
.append(">");
}
@Override
protected void appendFields(StringBuilder buffer, Object object) {
buffer.append("<fields>...</fields>");
}
@Override
protected void appendSuffix(StringBuilder buffer, Object object) {
buffer.append("</")
.append(object.getClass().getSimpleName())
.append(">");
}
}
5.2 JDK中的应用
AbstractList中的迭代器:
import java.util.*;
/**
* 模拟JDK中AbstractList的模板方法实现
*/
public abstract class AbstractCustomList<E> {
/**
* 模板方法:创建迭代器
*/
public Iterator<E> iterator() {
return new Itr();
}
// 抽象方法:子类必须实现
public abstract E get(int index);
public abstract int size();
/**
* 内部迭代器类 - 使用外部类的模板方法
*/
private class Itr implements Iterator<E> {
int cursor = 0;
@Override
public boolean hasNext() {
return cursor != size();
}
@Override
public E next() {
if (cursor >= size()) {
throw new NoSuchElementException();
}
return get(cursor++);
}
}
}
/**
* 具体实现:简单的ArrayList
*/
class SimpleArrayList<E> extends AbstractCustomList<E> {
private Object[] elements;
private int size;
public SimpleArrayList(int capacity) {
elements = new Object[capacity];
size = 0;
}
public void add(E element) {
elements[size++] = element;
}
@Override
@SuppressWarnings("unchecked")
public E get(int index) {
if (index >= size) {
throw new IndexOutOfBoundsException();
}
return (E) elements[index];
}
@Override
public int size() {
return size;
}
}
5.3 Servlet中的应用
import javax.servlet.http.*;
import java.io.IOException;
/**
* HttpServlet本身就是一个典型的模板方法模式应用
* service()方法是模板方法,doGet()、doPost()等是具体方法
*/
public abstract class AbstractHttpHandler extends HttpServlet {
/**
* 模板方法:处理HTTP请求的标准流程
*/
@Override
protected final void service(HttpServletRequest req, HttpServletResponse resp)
throws IOException {
// 1. 请求前置处理
if (!preHandle(req, resp)) {
return;
}
// 2. 根据请求方法分发
String method = req.getMethod();
try {
if ("GET".equals(method)) {
doGet(req, resp);
} else if ("POST".equals(method)) {
doPost(req, resp);
}
// 3. 请求后置处理
postHandle(req, resp);
} catch (Exception e) {
// 4. 异常处理
handleException(req, resp, e);
} finally {
// 5. 完成处理
afterCompletion(req, resp);
}
}
// 钩子方法:前置处理
protected boolean preHandle(HttpServletRequest req, HttpServletResponse resp) {
// 可以用于权限检查、日志记录等
return true;
}
// 钩子方法:后置处理
protected void postHandle(HttpServletRequest req, HttpServletResponse resp) {
// 可以用于添加通用响应头等
}
// 钩子方法:异常处理
protected void handleException(HttpServletRequest req, HttpServletResponse resp, Exception e) {
e.printStackTrace();
}
// 钩子方法:完成处理
protected void afterCompletion(HttpServletRequest req, HttpServletResponse resp) {
// 清理资源
}
}
六、Spring框架中的应用
6.1 JdbcTemplate
Spring的JdbcTemplate是模板方法模式的经典应用:
import javax.sql.DataSource;
import java.sql.*;
import java.util.*;
/**
* 简化版的JdbcTemplate实现
*/
public abstract class SimpleJdbcTemplate {
private DataSource dataSource;
public SimpleJdbcTemplate(DataSource dataSource) {
this.dataSource = dataSource;
}
/**
* 模板方法:执行查询操作
*/
public <T> List<T> query(String sql, Object[] args, RowMapper<T> rowMapper) {
Connection conn = null;
PreparedStatement pstmt = null;
ResultSet rs = null;
try {
// 1. 获取连接
conn = getConnection();
// 2. 创建语句
pstmt = createPreparedStatement(conn, sql);
// 3. 设置参数
setParameters(pstmt, args);
// 4. 执行查询
rs = executeQuery(pstmt);
// 5. 处理结果集(由调用者提供)
return extractData(rs, rowMapper);
} catch (SQLException e) {
// 6. 异常处理
handleException(e);
return Collections.emptyList();
} finally {
// 7. 关闭资源
closeResources(rs, pstmt, conn);
}
}
// 模板方法内部的固定步骤
protected Connection getConnection() throws SQLException {
return dataSource.getConnection();
}
protected PreparedStatement createPreparedStatement(Connection conn, String sql)
throws SQLException {
return conn.prepareStatement(sql);
}
protected void setParameters(PreparedStatement pstmt, Object[] args)
throws SQLException {
if (args != null) {
for (int i = 0; i < args.length; i++) {
pstmt.setObject(i + 1, args[i]);
}
}
}
protected ResultSet executeQuery(PreparedStatement pstmt) throws SQLException {
return pstmt.executeQuery();
}
// 可变部分:由调用者通过RowMapper接口提供
protected <T> List<T> extractData(ResultSet rs, RowMapper<T> rowMapper)
throws SQLException {
List<T> results = new ArrayList<>();
int rowNum = 0;
while (rs.next()) {
results.add(rowMapper.mapRow(rs, rowNum++));
}
return results;
}
protected void handleException(SQLException e) {
System.err.println("数据库操作异常: " + e.getMessage());
e.printStackTrace();
}
protected void closeResources(ResultSet rs, PreparedStatement pstmt, Connection conn) {
try {
if (rs != null) rs.close();
if (pstmt != null) pstmt.close();
if (conn != null) conn.close();
} catch (SQLException e) {
e.printStackTrace();
}
}
}
/**
* 行映射器接口
*/
interface RowMapper<T> {
T mapRow(ResultSet rs, int rowNum) throws SQLException;
}
/**
* 使用示例
*/
class JdbcTemplateExample {
public void queryUsers(SimpleJdbcTemplate jdbcTemplate) {
String sql = "SELECT id, name, email FROM users WHERE age > ?";
Object[] args = {18};
List<User> users = jdbcTemplate.query(sql, args, new RowMapper<User>() {
@Override
public User mapRow(ResultSet rs, int rowNum) throws SQLException {
User user = new User();
user.setName(rs.getString("name"));
user.setEmail(rs.getString("email"));
return user;
}
});
users.forEach(user -> System.out.println(user.getName()));
}
}
七、模板方法模式的优缺点
7.1 优点
1. 代码复用
┌─────────────────────────────────────┐
│ 父类封装公共逻辑 │
│ ┌─────────────────────────────┐ │
│ │ Step 1: Common Logic │ │
│ │ Step 2: Variant Logic ───┐ │ │
│ │ Step 3: Common Logic │ │ │
│ │ Step 4: Variant Logic ───┤ │ │
│ └──────────────────────────┼──┘ │
└─────────────────────────────┼──────┘
│
┌─────────────────┴─────────────────┐
▼ ▼
┌───────────────┐ ┌───────────────┐
│子类A只实现 │ │子类B只实现 │
│变化的部分 │ │变化的部分 │
└───────────────┘ └───────────────┘
2. 符合开闭原则
- 对扩展开放:通过继承扩展新的实现
- 对修改封闭:不需要修改父类代码
3. 控制反转(IoC)
- 父类调用子类的方法,而不是相反
- 这是好莱坞原则的体现:"Don't call us, we'll call you"
4. 行为可控
- 父类控制整体流程
- 子类只能影响特定步骤
7.2 缺点
1. 类的数量增加
- 每个具体实现都需要一个子类
- 可能导致类层次结构复杂
2. 继承的固有缺点
- 子类与父类耦合度高
- 不够灵活,违反组合优于继承的原则
3. 阅读困难
- 算法流程分散在父类和子类中
- 需要同时查看多个类才能理解完整逻辑
八、最佳实践
8.1 使用final修饰模板方法
public abstract class BaseProcessor {
// 使用final防止子类重写模板方法
public final void process() {
step1();
step2();
step3();
}
protected abstract void step2();
private void step1() { /* ... */ }
private void step3() { /* ... */ }
}
8.2 合理使用访问修饰符
public abstract class SecureTemplate {
// public: 模板方法对外公开
public final void execute() {
init();
doExecute();
cleanup();
}
// protected: 抽象方法供子类实现
protected abstract void doExecute();
// private: 内部步骤不暴露给子类
private void init() { /* ... */ }
private void cleanup() { /* ... */ }
}
8.3 提供合理的默认实现
public abstract class FlexibleTemplate {
public final void run() {
preProcess();
process();
postProcess();
}
// 抽象方法:核心逻辑必须实现
protected abstract void process();
// 钩子方法:提供默认实现
protected void preProcess() {
// 默认什么都不做,子类可选择覆盖
}
protected void postProcess() {
// 默认什么都不做,子类可选择覆盖
}
}
8.4 使用策略模式组合优化
当变化点较多时,可以结合策略模式:
/**
* 结合策略模式的模板方法
*/
public abstract class FlexibleDataProcessor {
private ValidationStrategy validationStrategy;
private TransformStrategy transformStrategy;
public FlexibleDataProcessor(ValidationStrategy validationStrategy,
TransformStrategy transformStrategy) {
this.validationStrategy = validationStrategy;
this.transformStrategy = transformStrategy;
}
// 模板方法
public final void processData(List<String> rawData) {
// 使用策略对象处理变化点
List<String> validData = validationStrategy.validate(rawData);
List<Object> transformedData = transformStrategy.transform(validData);
// 固定流程
persistData(transformedData);
}
// 固定步骤
protected abstract void persistData(List<Object> data);
}
interface ValidationStrategy {
List<String> validate(List<String> data);
}
interface TransformStrategy {
List<Object> transform(List<String> data);
}
8.5 文档化模板流程
/**
* 订单处理模板
*
* 处理流程:
* 1. 验证订单 {@link #validateOrder(Order)}
* 2. 计算价格 {@link #calculatePrice(Order)}
* 3. 库存检查 {@link #checkInventory(Order)}
* 4. 创建订单 {@link #createOrder(Order)}
* 5. 发送通知 {@link #sendNotification(Order)} (可选)
*
* 子类必须实现:
* - calculatePrice(): 不同类型订单的价格计算逻辑
* - createOrder(): 不同订单的创建逻辑
*
* 子类可选实现:
* - sendNotification(): 自定义通知方式
*/
public abstract class OrderProcessor {
/**
* 模板方法:处理订单
* @param order 待处理的订单
* @return 处理结果
*/
public final OrderResult processOrder(Order order) {
validateOrder(order);
calculatePrice(order);
checkInventory(order);
createOrder(order);
if (shouldSendNotification()) {
sendNotification(order);
}
return new OrderResult(true, order.getOrderId());
}
// 方法文档...
protected abstract void calculatePrice(Order order);
protected abstract void createOrder(Order order);
// 其他方法...
private void validateOrder(Order order) { /* ... */ }
private void checkInventory(Order order) { /* ... */ }
protected boolean shouldSendNotification() { return true; }
protected void sendNotification(Order order) { /* ... */ }
}
class Order {
private String orderId;
public String getOrderId() { return orderId; }
}
class OrderResult {
private boolean success;
private String orderId;
public OrderResult(boolean success, String orderId) {
this.success = success;
this.orderId = orderId;
}
}
九、模板方法 vs 策略模式
两者都处理算法变化,但侧重点不同:
模板方法模式(继承) 策略模式(组合)
┌─────────────────┐ ┌──────────────┐
│ AbstractClass │ │ Context │
├─────────────────┤ ├──────────────┤
│ templateMethod()│ │ - strategy │◇───┐
│ primitiveOp() │ │ + execute() │ │
└────────△────────┘ └──────────────┘ │
│ │
┌────┴────┐ │
│ │ ┌───────▽────────┐
┌───┴───┐ ┌──┴────┐ │ <<interface>>│
│ConcreteA ConcreteB│ │ Strategy │
└───────┘ └───────┘ ├────────────────┤
│ + algorithm() │
└───────△────────┘
│
┌───────┴────────┐
│ │
┌─────┴────┐ ┌─────┴────┐
│StrategyA │ │StrategyB │
└──────────┘ └──────────┘
使用场景: 使用场景:
- 算法骨架固定 - 算法可完全替换
- 变化点在特定步骤 - 运行时切换算法
- 通过继承扩展 - 通过组合扩展
十、实战演练:构建通用报表生成器
让我们通过一个完整的案例来综合运用模板方法模式:
import java.util.*;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
/**
* 抽象类:报表生成器模板
*/
public abstract class ReportGenerator {
/**
* 模板方法:生成报表的标准流程
*/
public final Report generateReport(ReportRequest request) {
System.out.println("开始生成报表: " + request.getReportName());
// 1. 验证请求
validateRequest(request);
// 2. 查询数据
List<Map<String, Object>> rawData = queryData(request);
// 3. 数据聚合(可选)
if (needsAggregation()) {
rawData = aggregateData(rawData);
}
// 4. 格式化数据
String formattedContent = formatData(rawData, request);
// 5. 生成报表对象
Report report = createReport(request, formattedContent);
// 6. 添加元数据
enrichMetadata(report);
// 7. 后处理(可选)
if (needsPostProcessing()) {
postProcess(report);
}
System.out.println("报表生成完成: " + report.getReportId());
return report;
}
// ===== 抽象方法:子类必须实现 =====
/**
* 查询数据
*/
protected abstract List<Map<String, Object>> queryData(ReportRequest request);
/**
* 格式化数据
*/
protected abstract String formatData(List<Map<String, Object>> data,
ReportRequest request);
// ===== 钩子方法:子类可选实现 =====
protected boolean needsAggregation() {
return false;
}
protected List<Map<String, Object>> aggregateData(List<Map<String, Object>> data) {
return data;
}
protected boolean needsPostProcessing() {
return false;
}
protected void postProcess(Report report) {
// 默认不做处理
}
// ===== 通用方法:所有子类共享 =====
private void validateRequest(ReportRequest request) {
if (request == null || request.getReportName() == null) {
throw new IllegalArgumentException("无效的报表请求");
}
}
private Report createReport(ReportRequest request, String content) {
Report report = new Report();
report.setReportId(generateReportId());
report.setReportName(request.getReportName());
report.setContent(content);
report.setCreateTime(LocalDateTime.now());
return report;
}
private void enrichMetadata(Report report) {
report.addMetadata("generator", this.getClass().getSimpleName());
report.addMetadata("version", "1.0");
}
private String generateReportId() {
return "RPT-" + System.currentTimeMillis();
}
}
/**
* 具体类:CSV格式报表生成器
*/
class CsvReportGenerator extends ReportGenerator {
@Override
protected List<Map<String, Object>> queryData(ReportRequest request) {
// 模拟从数据库查询数据
List<Map<String, Object>> data = new ArrayList<>();
Map<String, Object> row1 = new HashMap<>();
row1.put("name", "张三");
row1.put("department", "技术部");
row1.put("salary", 10000);
data.add(row1);
Map<String, Object> row2 = new HashMap<>();
row2.put("name", "李四");
row2.put("department", "销售部");
row2.put("salary", 8000);
data.add(row2);
return data;
}
@Override
protected String formatData(List<Map<String, Object>> data, ReportRequest request) {
StringBuilder csv = new StringBuilder();
if (!data.isEmpty()) {
// 添加表头
Map<String, Object> firstRow = data.get(0);
csv.append(String.join(",", firstRow.keySet())).append("\n");
// 添加数据行
for (Map<String, Object> row : data) {
List<String> values = new ArrayList<>();
for (Object value : row.values()) {
values.add(String.valueOf(value));
}
csv.append(String.join(",", values)).append("\n");
}
}
return csv.toString();
}
}
/**
* 具体类:HTML格式报表生成器
*/
class HtmlReportGenerator extends ReportGenerator {
@Override
protected List<Map<String, Object>> queryData(ReportRequest request) {
// 复用相同的查询逻辑
List<Map<String, Object>> data = new ArrayList<>();
Map<String, Object> row1 = new HashMap<>();
row1.put("name", "张三");
row1.put("department", "技术部");
row1.put("salary", 10000);
data.add(row1);
Map<String, Object> row2 = new HashMap<>();
row2.put("name", "李四");
row2.put("department", "销售部");
row2.put("salary", 8000);
data.add(row2);
return data;
}
@Override
protected String formatData(List<Map<String, Object>> data, ReportRequest request) {
StringBuilder html = new StringBuilder();
html.append("<!DOCTYPE html>\n<html>\n<head>\n");
html.append("<title>").append(request.getReportName()).append("</title>\n");
html.append("<style>table{border-collapse:collapse;} th,td{border:1px solid black;padding:8px;}</style>\n");
html.append("</head>\n<body>\n");
html.append("<h1>").append(request.getReportName()).append("</h1>\n");
html.append("<table>\n");
if (!data.isEmpty()) {
// 表头
html.append("<tr>");
for (String key : data.get(0).keySet()) {
html.append("<th>").append(key).append("</th>");
}
html.append("</tr>\n");
// 数据行
for (Map<String, Object> row : data) {
html.append("<tr>");
for (Object value : row.values()) {
html.append("<td>").append(value).append("</td>");
}
html.append("</tr>\n");
}
}
html.append("</table>\n</body>\n</html>");
return html.toString();
}
@Override
protected boolean needsPostProcessing() {
return true;
}
@Override
protected void postProcess(Report report) {
// HTML报表添加生成时间戳
String timestamp = LocalDateTime.now()
.format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"));
report.addMetadata("generated_at", timestamp);
}
}
/**
* 具体类:汇总报表生成器
*/
class SummaryReportGenerator extends ReportGenerator {
@Override
protected List<Map<String, Object>> queryData(ReportRequest request) {
List<Map<String, Object>> data = new ArrayList<>();
for (int i = 1; i <= 5; i++) {
Map<String, Object> row = new HashMap<>();
row.put("department", "部门" + i);
row.put("salary", 5000 + i * 1000);
data.add(row);
}
return data;
}
@Override
protected boolean needsAggregation() {
return true;
}
@Override
protected List<Map<String, Object>> aggregateData(List<Map<String, Object>> data) {
// 计算总和
int totalSalary = 0;
for (Map<String, Object> row : data) {
totalSalary += (Integer) row.get("salary");
}
// 添加汇总行
Map<String, Object> summaryRow = new HashMap<>();
summaryRow.put("department", "总计");
summaryRow.put("salary", totalSalary);
data.add(summaryRow);
return data;
}
@Override
protected String formatData(List<Map<String, Object>> data, ReportRequest request) {
StringBuilder sb = new StringBuilder();
sb.append("=== ").append(request.getReportName()).append(" ===\n\n");
for (Map<String, Object> row : data) {
sb.append(row.get("department")).append(": ")
.append(row.get("salary")).append("\n");
}
return sb.toString();
}
}
/**
* 报表请求类
*/
class ReportRequest {
private String reportName;
private Map<String, Object> parameters;
public ReportRequest(String reportName) {
this.reportName = reportName;
this.parameters = new HashMap<>();
}
public String getReportName() { return reportName; }
public Map<String, Object> getParameters() { return parameters; }
}
/**
* 报表类
*/
class Report {
private String reportId;
private String reportName;
private String content;
private LocalDateTime createTime;
private Map<String, Object> metadata = new HashMap<>();
// Getters and Setters
public String getReportId() { return reportId; }
public void setReportId(String reportId) { this.reportId = reportId; }
public String getReportName() { return reportName; }
public void setReportName(String reportName) { this.reportName = reportName; }
public String getContent() { return content; }
public void setContent(String content) { this.content = content; }
public LocalDateTime getCreateTime() { return createTime; }
public void setCreateTime(LocalDateTime createTime) { this.createTime = createTime; }
public void addMetadata(String key, Object value) { metadata.put(key, value); }
}
/**
* 测试类
*/
class ReportGeneratorTest {
public static void main(String[] args) {
// 1. 生成CSV报表
System.out.println("=== 测试CSV报表 ===");
ReportGenerator csvGenerator = new CsvReportGenerator();
Report csvReport = csvGenerator.generateReport(new ReportRequest("员工工资表-CSV"));
System.out.println(csvReport.getContent());
// 2. 生成HTML报表
System.out.println("\n=== 测试HTML报表 ===");
ReportGenerator htmlGenerator = new HtmlReportGenerator();
Report htmlReport = htmlGenerator.generateReport(new ReportRequest("员工工资表-HTML"));
System.out.println(htmlReport.getContent().substring(0, 200) + "...");
// 3. 生成汇总报表
System.out.println("\n=== 测试汇总报表 ===");
ReportGenerator summaryGenerator = new SummaryReportGenerator();
Report summaryReport = summaryGenerator.generateReport(new ReportRequest("部门工资汇总"));
System.out.println(summaryReport.getContent());
}
}
十一、总结
11.1 核心要点
- 模板方法模式的本质:在父类中定义算法骨架,子类实现具体步骤
- 适用场景:多个类有相似的操作流程,但某些步骤实现不同
- 关键技巧:
- 使用
final修饰模板方法 - 合理划分抽象方法和钩子方法
- 提供合适的默认实现
- 使用
11.2 实践建议
选择模板方法模式的检查清单:
✓ 是否有多个类执行相似的流程?
✓ 流程的主体结构是否稳定?
✓ 是否只有部分步骤需要定制?
✓ 是否需要严格控制算法结构?
如果以上都是"是",那么模板方法模式是个好选择!
11.3 与其他模式的配合
- 策略模式:处理算法变化点
- 工厂方法模式:创建模板方法中使用的对象
- 观察者模式:在模板方法的各个步骤发送通知
11.4 现代Java的替代方案
在Java 8+中,可以使用函数式接口替代某些场景的模板方法模式:
public class FunctionalTemplate {
public void execute(Runnable preProcess,
Runnable mainProcess,
Runnable postProcess) {
preProcess.run();
mainProcess.run();
postProcess.run();
}
}
// 使用
FunctionalTemplate template = new FunctionalTemplate();
template.execute(
() -> System.out.println("前置处理"),
() -> System.out.println("主要处理"),
() -> System.out.println("后置处理")
);