LangChain调用大模型时,实现动态模型和负载均衡

8 阅读3分钟

这应该是一个小众的需求,写出来仅供参考。

需求背景

  • 每次和大模型的交互都是无状态的。即每次交互都是一个独立会话。
  • 为了提高性能使用的模型实例有多个,本文假设有3个吧。

实现要求

  1. 3个模型动态切换
  2. 实现负载均衡

实现思路

根据需求描述,首先要做的事情就是得先创建3个模型实例,以供使用。然后呢,实际调用模型之前进行横切,由算法选择一个模型实例再去调用。

这里有两个点没有清晰的解决方案:

  1. 以什么样的方式去调用模型选择算法,我的选择是使用中间件。当然,这是事后诸葛亮了,一开始为此费了不少脑细胞。究其原因还是对切面、装饰器模式这类概念理解不深入,以至于不能第一时间想到。
  2. 如何在调用前,替换掉默认的模型。这一步虽然看起来很难,当确定使用中间件后,解决方案是相对好找到的。本来官方文档是有的,还是看得太少了。等我解决这个问题的两天后,居然发现就是这么写的。哎,看得太少了。

初始化3个模型

from langchain.chat_models import init_chat_model

model1 = init_chat_model(
    "model-1"
)

model2 = init_chat_model(
    "model-2"
)

model3 = init_chat_model(
    "model-3"
)

LOADBALANCE_MODELS = [model1, model2, model3]

创建中间件

这里要确定使用什么钩子。根据需求,这里选择wrap_model_call

下面示例代码使用的是基于类的中间件实现,使用轮询算法来保证3个模型的均衡调度:

class RoundRobinModelLBMiddleware(AgentMiddleware):
    def __init__(self, model_list: list):
        self.model_list = model_list
        self.index = 0
        self.lock = threading.Lock()

    def wrap_model_call(
        self,
        request: ModelRequest,
        handler: Callable[[ModelRequest], ModelResponse]
    ) -> ModelResponse:
        
        with self.lock:
            target_model = self.model_list[self.index]
            self.index = (self.index + 1) % len(self.model_list)
            request = request.overwrite(model=target_model)

        return handler(request)

这里最主要的就是16行,通过request对象的overwrite方法返回一个新的request对象。这个新对象就使用了替换后的model对象。

使用中间件

创建agent时,model参数是必须指定的。这里就瞎选一个,反正又不用它。

lb_middleware = RoundRobinModelLBMiddleware(model_list=LOADBALANCE_MODELS)

agent = create_agent(
    model=LOADBALANCE_MODELS[0],
    middleware=[
        lb_middleware,
    ]
)

到这里就完成了。这3个模型的轮询负载均衡就能正常使用了。

扩展

当然,本文的条件比较理想化。

假如,是有状态的模型交互呢?比如一个会话有多次交互,如果换了个模型实例,就无法实现会话的连续性。又该怎么办呢?

我想可以从算法上做文章,使用粘性会话的思路。通过会话id和模型进行绑定。这样做也会有很多的问题,比如:会话是无限的,如何防止内存泄漏。

总结

  1. 没有好好的阅读官方文档,走了点弯路。向那些能把文档看5遍的大神学习。
  2. 要思考一下成熟框架对通用问题的解决方案。

参考

Dynamic model - Docs by LangChain