使用golang对文本Embedding

681 阅读1分钟

Embedding

构建Rag应用通常需要对文档内容存储在向量数据库里面,再根据用户查询从数据库中找出相匹配的内容返回,文本内容和用户查询需要通过Embedding后转化成向量存储。

langchaingo使用方式

ollama示例

func TestLocalLlmEmbedding(T *testing.T) {
   //示例化一个llm
   llm, err := ollama.New(ollama.WithModel("llama2"))
   if err != nil {
      log.Fatal(err)
   }
   input := []string{"hello world!", "hello jhonroxton!"}
   ctx := context.Background()
   //调用CreateEmbedding方法,返回编码后的向量切片
   embedding, err := llm.CreateEmbedding(ctx, input)
   if err != nil {
      log.Fatal(err)
   }
   fmt.Printf("embedding vector is: %v", embedding)
}

实现方式

首先是EmbedderClient这个接口,langchaingo中的LLM struct都需要实现这个接口

// EmbedderClient is the interface LLM clients implement for embeddings.
type EmbedderClient interface {
   CreateEmbedding(ctx context.Context, texts []string) ([][]float32, error)
}

以下是ollama中的实现方式,github.com/tmc/langchaingo/llms/ollama/ollamallm.go

func (o *LLM) CreateEmbedding(ctx context.Context, inputTexts []string) ([][]float32, error) {
   embeddings := [][]float32{}

   for _, input := range inputTexts {
      req := &ollamaclient.EmbeddingRequest{
         Prompt: input,
         Model:  o.options.model,
      }
      if o.options.keepAlive != "" {
         req.KeepAlive = o.options.keepAlive
      }

      embedding, err := o.client.CreateEmbedding(ctx, req)
      if err != nil {
         return nil, err
      }

      if len(embedding.Embedding) == 0 {
         return nil, ErrEmptyResponse
      }

      embeddings = append(embeddings, embedding.Embedding)
   }

   if len(inputTexts) != len(embeddings) {
      return embeddings, ErrIncompleteEmbedding
   }

   return embeddings, nil
}

CreateEmbedding这个方法中, 调用了o.client.CreateEmbedding(ctx, req), Client包含一个基础的url和一个httpClient, 并且Client定义了CreateEmbedding的具体实现

type Client struct {
   base       *url.URL
   httpClient *http.Client
}

实现方法

func (c *Client) CreateEmbedding(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
   resp := &EmbeddingResponse{}
   if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {
      return resp, err
   }
   return resp, nil
}

实际就是通过httpClient调用了ollama的embedding接口,具体可见https://github.com/ollama/ollama/blob/main/docs/api.md#/generate-embeddings