这应该是一个小众的需求,写出来仅供参考。
需求背景
- 每次和大模型的交互都是无状态的。即每次交互都是一个独立会话。
- 为了提高性能使用的模型实例有多个,本文假设有3个吧。
实现要求
- 3个模型动态切换
- 实现负载均衡
实现思路
根据需求描述,首先要做的事情就是得先创建3个模型实例,以供使用。然后呢,实际调用模型之前进行横切,由算法选择一个模型实例再去调用。
这里有两个点没有清晰的解决方案:
- 以什么样的方式去调用模型选择算法,我的选择是使用中间件。当然,这是事后诸葛亮了,一开始为此费了不少脑细胞。究其原因还是对切面、装饰器模式这类概念理解不深入,以至于不能第一时间想到。
- 如何在调用前,替换掉默认的模型。这一步虽然看起来很难,当确定使用中间件后,解决方案是相对好找到的。本来官方文档是有的,还是看得太少了。等我解决这个问题的两天后,居然发现就是这么写的。哎,看得太少了。
初始化3个模型
from langchain.chat_models import init_chat_model
model1 = init_chat_model(
"model-1"
)
model2 = init_chat_model(
"model-2"
)
model3 = init_chat_model(
"model-3"
)
LOADBALANCE_MODELS = [model1, model2, model3]
创建中间件
这里要确定使用什么钩子。根据需求,这里选择wrap_model_call。
下面示例代码使用的是基于类的中间件实现,使用轮询算法来保证3个模型的均衡调度:
class RoundRobinModelLBMiddleware(AgentMiddleware):
def __init__(self, model_list: list):
self.model_list = model_list
self.index = 0
self.lock = threading.Lock()
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse]
) -> ModelResponse:
with self.lock:
target_model = self.model_list[self.index]
self.index = (self.index + 1) % len(self.model_list)
request = request.overwrite(model=target_model)
return handler(request)
这里最主要的就是16行,通过request对象的overwrite方法返回一个新的request对象。这个新对象就使用了替换后的model对象。
使用中间件
创建agent时,model参数是必须指定的。这里就瞎选一个,反正又不用它。
lb_middleware = RoundRobinModelLBMiddleware(model_list=LOADBALANCE_MODELS)
agent = create_agent(
model=LOADBALANCE_MODELS[0],
middleware=[
lb_middleware,
]
)
到这里就完成了。这3个模型的轮询负载均衡就能正常使用了。
扩展
当然,本文的条件比较理想化。
假如,是有状态的模型交互呢?比如一个会话有多次交互,如果换了个模型实例,就无法实现会话的连续性。又该怎么办呢?
我想可以从算法上做文章,使用粘性会话的思路。通过会话id和模型进行绑定。这样做也会有很多的问题,比如:会话是无限的,如何防止内存泄漏。
总结
- 没有好好的阅读官方文档,走了点弯路。向那些能把文档看5遍的大神学习。
- 要思考一下成熟框架对通用问题的解决方案。