前言
在23种设计模式中有11种行为型模式。行为型模式对类和对象如何交互和怎样分配职责进行描述。其中,策略模式描述的是,将算法封装在对象中,进而可以方便的指定和改变一个对象所使用的算法。本文以假定的需求为基础,介绍下策略模式的应用,并给出代码示例。
需求假设
在之前的文章中,我们假定了公司的产品为客户提供了体验版、基础版和高级版三种套餐,套餐的创建使用了工厂模式。现在我们将需求聚焦在客户购买套餐后,支付价格的计算上。由于不同的套餐具有不同的价格,同时为了促进产品售卖,公司还提供了不同的优惠手段,如打折、满减、无门槛优惠卷等。因此,客户最终的支付价格面临着不同计算规则。
不同的优惠手段就对应着不同的价格计算规则(或称计算策略、计算算法),我们需要利用策略模式来完成价格计算的设计。
模式定义
策略模式就是定义一系列算法,分别封装起来,让他们之间可以互相替换,此模式让算法的变化不会影响到使用算法的客户。
- "一系列算法": 从概念上看所有这些算法完成的都是相同的工作,只是实现不同。例如,打折和满减是不同的算法,但各自的目的是一致的,就是计算出最终的价格。
- "分别封装"、"互相替换":有多少种具体策略,就要封装多少个策略类,这些策略类之间可以相互替换,从而改变算法。例如,针对打折、满减、无门槛卷三种优惠方式,我们就需要分别封装为三个不同的策略类。
- "不会影响到使用算法的客户": 通过实现一个策略上下文类(
Context)来封装当前使用的具体策略和策略方法的调用,同时Strategy类为Context定义了一系列可供重用的算法或行为。客户端仅通过Context类来使用策略,保证了以相同的方式调用所有的算法。如此以来,当各策略算法发生变化时也就不会影响到算法的使用。
模式构成
根据UML结构图,策略模式包含:
- 一个抽象策略类或策略接口:用于约定各具体策略应实现的具体算法。也就是规定这些算法应完成的共同工作是啥。例如**价格计算策略(
IPriceCalculationStrategy)**规定了各种促销手段都需要完成 计算价格calculatePrice()这一行为。 - 多个具体策略类:分别实现策略接口约定的方法。具体策略如,保持原价策略(
KeepNormalStrategy)、打折策略(DiscountStrategy)、满减策略(FullReductionStrategy)、无门槛策略(DirectReductionStrategy)。 - 一个上下文类(
Context): 用以配置一个具体策略,维护一个对策略对象的引用。通过Context类来持有当前所使用的具体策略对象。
代码示例
C++
/*
* File: strategy.hpp
* Created Date: 2021-12-12 05:08:43
* Author: ysj
* Description: 策略类
*/
#pragma once
#include <iostream>
using namespace std;
// 价格计算抽象策略类
class PriceCalculationStrategy
{
public:
virtual float calculatePrice(float originPrice) = 0;
};
// 原价策略
class KeepNormalStrategy : public PriceCalculationStrategy
{
public:
virtual float calculatePrice(float originPrice)
{
return originPrice;
}
};
// 打折策略
class DiscountStrategy : public PriceCalculationStrategy
{
private:
float discount;
public:
DiscountStrategy(float discount)
{
this->discount = discount;
}
virtual float calculatePrice(float originPrice)
{
return originPrice * discount;
}
};
// 满减策略
class FullReductionStrategy : public PriceCalculationStrategy
{
private:
float full;
float reduction;
public:
FullReductionStrategy(float full, float reduction)
{
this->full = full;
this->reduction = reduction;
}
virtual float calculatePrice(float originPrice)
{
float totalPrice = originPrice;
if (originPrice >= full)
{
totalPrice = originPrice - reduction;
}
return totalPrice;
}
};
// 无门槛策略
class DirectReductionStrategy : public PriceCalculationStrategy
{
private:
float reduction;
public:
DirectReductionStrategy(float reduction)
{
this->reduction = reduction;
}
virtual float calculatePrice(float originPrice)
{
float totalPrice = originPrice - reduction;
if (totalPrice < 0)
{
return 0;
}
return totalPrice;
}
};
/*
* File: context.hpp
* Created Date: 2021-12-12 05:08:58
* Author: ysj
* Description: 策略上下文类
*/
#pragma once
#include <iostream>
#include "strategy.hpp"
using namespace std;
class PriceCalculationContext
{
private:
PriceCalculationStrategy *strategy;
public:
PriceCalculationContext(PriceCalculationStrategy *strategy)
{
this->strategy = strategy;
}
float calculatePrice(float originPrice)
{
float totalPrice = strategy->calculatePrice(originPrice);
return totalPrice;
}
};
/*
* File: main.cpp
* Created Date: 2021-12-12 05:06:54
* Author: ysj
* Description: cpp策略模式
*/
#include <iostream>
#include "strategy.hpp"
#include "context.hpp"
using namespace std;
int main()
{
// 购买价格、数量
float price = 599;
float quantity = 1;
float originPrice = price * quantity;
cout << "单价:" << price << " 数量:" << quantity << " 原价:" << originPrice << endl;
// 正常计算策略
PriceCalculationStrategy *keepNormal = new KeepNormalStrategy();
PriceCalculationContext *ctx = new PriceCalculationContext(keepNormal);
float totalPrice = ctx->calculatePrice(originPrice);
cout << "正常价格:" << totalPrice << endl;
// 打折计算策略
PriceCalculationStrategy *discount = new DiscountStrategy(0.8);
ctx = new PriceCalculationContext(discount);
totalPrice = ctx->calculatePrice(originPrice);
cout << "八折价格:" << totalPrice << endl;
// 满减策略
PriceCalculationStrategy *fullReduction = new FullReductionStrategy(500, 200);
ctx = new PriceCalculationContext(fullReduction);
totalPrice = ctx->calculatePrice(originPrice);
cout << "满500减200价格:" << totalPrice << endl;
// 无门槛策略
PriceCalculationStrategy *directReduction = new DirectReductionStrategy(100);
ctx = new PriceCalculationContext(directReduction);
totalPrice = ctx->calculatePrice(originPrice);
cout << "直减100价格:" << totalPrice << endl;
return 0;
}
$ g++ main.cpp -I include -o main && ./main
单价:599 数量:1 原价:599
正常价格:599
八折价格:479.2
满500减200价格:399
直减100价格:499
Golang
/*
* File: strategy.go
* Created Date: 2021-12-12 04:25:14
* Author: ysj
* Description: 策略
*/
package main
// 策略接口
type IPriceCalculationStrategy interface {
calculatePrice(originPrice float64) float64
}
// 原价策略
type KeepNormalStrategy struct{}
func NewKeepNormalStrategy() IPriceCalculationStrategy {
return &KeepNormalStrategy{}
}
func (k *KeepNormalStrategy) calculatePrice(originPrice float64) float64 {
return originPrice
}
// 打折策略
type DiscountStrategy struct {
discount float64
}
func NewDiscountStrategy(discount float64) IPriceCalculationStrategy {
return &DiscountStrategy{
discount: discount,
}
}
func (d *DiscountStrategy) calculatePrice(originPrice float64) float64 {
return originPrice * d.discount
}
// 满减策略
type FullReductionStrategy struct {
full float64
reduction float64
}
func NewFullReductionStrategy(full, reduction float64) IPriceCalculationStrategy {
return &FullReductionStrategy{
full: full,
reduction: reduction,
}
}
func (f *FullReductionStrategy) calculatePrice(originPrice float64) float64 {
totalPrice := originPrice
if originPrice >= f.full {
totalPrice = originPrice - f.reduction
}
return totalPrice
}
// 无门槛策略
type DirectReductionStrategy struct {
reduction float64
}
func NewDirectReductionStrategy(reduction float64) IPriceCalculationStrategy {
return &DirectReductionStrategy{
reduction: reduction,
}
}
func (d *DirectReductionStrategy) calculatePrice(originPrice float64) float64 {
totalPrice := originPrice - d.reduction
if totalPrice < 0 {
return 0
}
return totalPrice
}
/*
* File: context.go
* Created Date: 2021-12-12 04:25:39
* Author: ysj
* Description: 策略上下文
*/
package main
type PriceCalculationContext struct {
strategy IPriceCalculationStrategy
}
func NewPriceCalculationContext(strategy IPriceCalculationStrategy) *PriceCalculationContext {
return &PriceCalculationContext{
strategy: strategy,
}
}
func (p *PriceCalculationContext) calculatePrice(originPrice float64) float64 {
return p.strategy.calculatePrice(originPrice)
}
/*
* File: main.go
* Created Date: 2021-12-12 04:26:03
* Author: ysj
* Description: golang策略模式
*/
package main
import "fmt"
func main() {
// 购买价格、数量
price := 599.0
quantity := 1.0
originPrice := price * quantity
fmt.Printf("单价:%.2f 数量:%.2f 原价:%.2f\n", price, quantity, originPrice)
// 正常计算策略
keepNormal := NewKeepNormalStrategy()
ctx := NewPriceCalculationContext(keepNormal)
totalPrice := ctx.calculatePrice(originPrice)
fmt.Printf("正常价格:%.2f\n", totalPrice)
// 打折计算策略
discount := NewDiscountStrategy(0.8)
ctx = NewPriceCalculationContext(discount)
totalPrice = ctx.calculatePrice(originPrice)
fmt.Printf("八折价格:%.2f\n", totalPrice)
// 满减策略
fullReduction := NewFullReductionStrategy(500, 200)
ctx = NewPriceCalculationContext(fullReduction)
totalPrice = ctx.calculatePrice(originPrice)
fmt.Printf("满500减200价格:%.2f\n", totalPrice)
// 无门槛策略
directReduction := NewDirectReductionStrategy(100)
ctx = NewPriceCalculationContext(directReduction)
totalPrice = ctx.calculatePrice(originPrice)
fmt.Printf("直减100价格:%.2f\n", totalPrice)
}
$ go run .
单价:599.00 数量:1.00 原价:599.00
正常价格:599.00
八折价格:479.20
满500减200价格:399.00
直减100价格:499.00
Python
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
###
# File: strategy.py
# Created Date: 2021-12-07 02:21:10
# Author: ysj
# Description: 策略类
###
from abc import ABCMeta, abstractmethod
class PriceCalculationStrategy(metaclass=ABCMeta):
"""抽象基类"""
@abstractmethod
def calculate_price(self, origin_price):
pass
class KeepNormalStrategy(PriceCalculationStrategy):
"""原价策略"""
def __init__(self):
super().__init__()
def calculate_price(self, origin_price):
total_price = round(origin_price, 2)
return total_price
class DiscountStrategy(PriceCalculationStrategy):
"""打折策略"""
def __init__(self, discount):
self.__discount = discount
def calculate_price(self, origin_price):
total_price = origin_price * self.__discount
total_price = round(total_price, 2)
return total_price
class FullReductionStrategy(PriceCalculationStrategy):
"""满减策略"""
def __init__(self, full, reduction):
self.__full = full
self.__reduction = reduction
def calculate_price(self, origin_price):
if origin_price >= self.__full:
origin_price -= self.__reduction
total_price = round(origin_price, 2)
return total_price
class DirectReductionStrategy(PriceCalculationStrategy):
"""无门槛策略"""
def __init__(self, reduction):
self.__reduction = reduction
def calculate_price(self, origin_price):
total_price = max(0, origin_price-self.__reduction)
total_price = round(total_price, 2)
return total_price
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
###
# File: context.py
# Created Date: 2021-12-07 03:03:37
# Author: ysj
# Description: 策略上下文
###
from strategy import PriceCalculationStrategy
class PriceCalculationContext(object):
"""策略上下文-由客户端判断选择策略"""
def __init__(self, strategy: PriceCalculationStrategy):
self.__strategy = strategy
def calculate_price(self, origin_price):
total_price = self.__strategy.calculate_price(origin_price)
return total_price
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
###
# File: main.py
# Created Date: 2021-12-07 03:14:12
# Author: ysj
# Description: python 策略模式
###
from strategy import (
KeepNormalStrategy, DiscountStrategy,
FullReductionStrategy, DirectReductionStrategy,
)
from context1 import PriceCalculationContext
# 购买价格、数量
price = 599
quantity = 1
origin_price = price * quantity
print(f"单价:{price} 数量:{quantity} 原价:{origin_price}")
# 正常计算策略
keep_normal = KeepNormalStrategy()
ctx = PriceCalculationContext(keep_normal)
total_price = ctx.calculate_price(origin_price)
print(f"正常价格:{total_price}")
# 打折计算策略
discount = DiscountStrategy(discount=0.8)
ctx = PriceCalculationContext(discount)
total_price = ctx.calculate_price(origin_price)
print(f"八折价格:{total_price}")
# 满减策略
full_reduction = FullReductionStrategy(full=500, reduction=200)
ctx = PriceCalculationContext(full_reduction)
total_price = ctx.calculate_price(origin_price)
print(f"满500减200价格:{total_price}")
# 无门槛策略
direct_reduction = DirectReductionStrategy(reduction=100)
ctx = PriceCalculationContext(direct_reduction)
total_price = ctx.calculate_price(origin_price)
print(f"直减100价格:{total_price}")
$ python3 main.py
单价:599 数量:1 原价:599
正常价格:599
八折价格:479.2
满500减200价格:399
直减100价格:499
TypeScript
/**
* -------------------------------------------------------
* File: strategy.ts
* Created Date: 2021-12-12 01:46:00
* Author: ysj
* Description: 策略类
* -------------------------------------------------------
*/
/**策略接口 */
export interface IPriceCalculationStrategy {
calculatePrice(originPrice: number): number;
}
/**原价策略 */
export class KeepNormalStrategy implements IPriceCalculationStrategy {
calculatePrice(originPrice: number) {
const totalPrice = Number(originPrice.toFixed(2));
return totalPrice;
}
}
/**打折策略 */
export class DiscountStrategy implements IPriceCalculationStrategy {
private discount: number;
constructor(discount: number) {
this.discount = discount;
}
calculatePrice(originPrice: number) {
let totalPrice = originPrice * this.discount;
totalPrice = Number(totalPrice.toFixed(2));
return totalPrice;
}
}
/**满减策略 */
export class FullReductionStrategy implements IPriceCalculationStrategy {
private full: number;
private reduction: number;
constructor(full: number, reduction: number) {
this.full = full;
this.reduction = reduction;
}
calculatePrice(originPrice: number) {
let totalPrice =
originPrice >= this.full ? originPrice - this.reduction : originPrice;
totalPrice = Number(totalPrice.toFixed(2));
return totalPrice;
}
}
/**无门槛策略 */
export class DirectReductionStrategy implements IPriceCalculationStrategy {
private reduction: number;
constructor(reduction: number) {
this.reduction = reduction;
}
calculatePrice(originPrice: number) {
let totalPrice = Math.max(0, originPrice - this.reduction);
totalPrice = Number(totalPrice.toFixed(2));
return totalPrice;
}
}
/**
* -------------------------------------------------------
* File: context.ts
* Created Date: 2021-12-12 01:46:28
* Author: ysj
* Description: 策略上下文
* -------------------------------------------------------
*/
import { IPriceCalculationStrategy } from './strategy';
export default class PriceCalculationContext {
private strategy: IPriceCalculationStrategy;
constructor(strategy: IPriceCalculationStrategy) {
this.strategy = strategy;
}
calculatePrice(originPrice: number): number {
const totalPrice = this.strategy.calculatePrice(originPrice);
return totalPrice;
}
}
/**
* -------------------------------------------------------
* File: index.ts
* Created Date: 2021-12-12 01:45:20
* Author: ysj
* Description: ts 策略模式
* -------------------------------------------------------
*/
import PriceCalculationContext from './context';
import {
KeepNormalStrategy,
DiscountStrategy,
FullReductionStrategy,
DirectReductionStrategy,
} from './strategy';
// 购买价格、数量
const price = 599;
const quantity = 1;
const originPrice = price * quantity;
console.log(`单价:${price} 数量:${quantity} 原价:${originPrice}`);
// 正常计算策略
const keepNormal = new KeepNormalStrategy();
let ctx = new PriceCalculationContext(keepNormal);
let totalPrice = ctx.calculatePrice(originPrice);
console.log(`正常价格:${totalPrice}`);
// 打折计算策略
const discount = new DiscountStrategy(0.8);
ctx = new PriceCalculationContext(discount);
totalPrice = ctx.calculatePrice(originPrice);
console.log(`八折价格:${totalPrice}`);
// 满减策略
const fullReduction = new FullReductionStrategy(500, 200);
ctx = new PriceCalculationContext(fullReduction);
totalPrice = ctx.calculatePrice(originPrice);
console.log(`满500减200价格:${totalPrice}`);
// 无门槛策略
const directReduction = new DirectReductionStrategy(100);
ctx = new PriceCalculationContext(directReduction);
totalPrice = ctx.calculatePrice(originPrice);
console.log(`直减100价格:${totalPrice}`);
$ tsc -p ./tsconfig.json && node build/index.js
单价:599 数量:1 原价:599
正常价格:599
八折价格:479.2
满500减200价格:399
直减100价格:499
✨ Done in 2.35s.
模式扩展
上述代码是基本策略模式的实现。在基本策略模式中,我们需要在客户端决定使用哪一个具体策略,实例化一个策略对象,让后将其传递给策略模式的 Context 对象,然后调用Context对象的接口方法。这样的话,选择所用具体实现的职责就由客户端对象来承担,这会让客户端面临选择判断的压力。
因此,在基本策略模式的基础上,出现了一些改进,扩展了策略模式的设计。这些扩展主要的目的是将"选择具体实现的职责"或者叫做"选择判断的压力"转移到Context来承担,从而减轻客户端的职责。主要有两种实现方式:
一是策略模式与简单工厂模式的结合。例如如下Python实现:
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
###
# File: context.py
# Created Date: 2021-12-07 03:43:19
# Author: ysj
# Description: 策略上下文 + 简单工厂
###
from strategy import (
KeepNormalStrategy, DiscountStrategy,
FullReductionStrategy, DirectReductionStrategy,
)
class PriceCalculationContext(object):
"""策略上下文-由简单工厂判断创建策略"""
def __init__(self, strategy_type):
if strategy_type == "keep_normal":
self.__strategy = KeepNormalStrategy()
elif strategy_type == "discount":
self.__strategy = DiscountStrategy(discount=0.8)
elif strategy_type == "full_reduction":
self.__strategy = FullReductionStrategy(full=500, reduction=200)
elif strategy_type == "direct_reduction":
self.__strategy = DirectReductionStrategy(reduction=100)
else:
self.__strategy = KeepNormalStrategy()
def calculate_price(self, origin_price):
total_price = self.__strategy.calculate_price(origin_price)
return total_price
这样客户端在决定使用哪种算法时,只需要通过strategy_type的值来控制,策略类实例化的过程封装在了Context类中。但是,这样有一个缺点,不同策略类的初始参数需要提前确定,除非各策略类不需要初始参数。或者,不同的初始参数通过Context实例化时进行传递,但这无疑让Context的初始化变得复杂。
二是策略模式与HashMap的结合。这种方式抛弃了策略类的定义和实现。直接使用HashMap来存储策略算法名及其对应的算法。参考如下Python示例:
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
###
# File: context2.py
# Created Date: 2021-12-13 03:55:57
# Author: ysj
# Description: 策略上下文 + HashMap
###
class PriceCalculationContext(object):
"""策略上下文-由HashMap存储策略算法"""
def __init__(self):
self.__strategyMap = {
"keep_normal": self.keep_normal,
"discount": self.discount,
"full_reduction": self.full_reduction,
"direct_reduction": self.direct_reduction,
}
@staticmethod
def keep_normal(origin_price):
return round(origin_price, 2)
@staticmethod
def discount(origin_price, discount):
return round(origin_price * discount, 2)
@staticmethod
def full_reduction(origin_price, full, reduction):
if origin_price < full:
return round(origin_price, 2)
return round(origin_price-reduction, 2)
@staticmethod
def direct_reduction(origin_price, reduction):
return round(max(0, origin_price-reduction), 2)
def calculate_price(self, strategy_type, origin_price, *args, **kwargs):
calculate_func = self.__strategyMap[strategy_type]
total_price = calculate_func(origin_price, *args, **kwargs)
return total_price
这种方式,虽然代码量降低了不少,看起来是变简单了,但完全不符合开闭原则和单一职责原则。不推荐!
适用场景
策略模式就是用来封装算法的,在实践中可以用来封装几乎任何类型的规则。关于该模式的适用场景,某些教程的说法条条框框,显得有些不知所谓。根据定义,总结一句话,只要出于同一目的需要应用不同的业务规则,就可以考虑使用策略模式来处理这种规则变化。
参考资料
- 程杰.大话设计模式[M].北京:清华大学出版社,2007.12