BERTopic 中文 长文本 SentenceTransformer BERT 均值特征向量 整体特征分词Topic

7 阅读2分钟

❓️问题

当前在BERTopic中,由于其是BERT特征,BERT本身要求文本长度不超过512,否则就会截断,对于这个问题,BERTopic里面是直接进行了截断,然而这种方法并不很合适,对长文本不太友好,其中一种优化方法是:对文本的每512个字符提取BERT特征,然后求均值作为文本特征,具体如何做?修改哪里的代码?

🔗已有资料

目前在互联网上能找到的资料如下

zhuanlan.zhihu.com/p/588248281 blog.csdn.net/qq_30565883…

似乎都出自同一人之手,然而直接按照帖子的改写方法,始终无法获得正确的输出,all_embedding和句子集的长度不匹配,调试了半天。也许是因为之前版本的某些方法有写法变化?

⛑️2025.1能work的方法

搞了很久大概按照思路摸索出可以work的方法,最后对all_embeddings直接降维有一些trick,可能还需要进一步调整:

文件:SentenceTransformer.py

具体调整:

# 关注encode函数
def encode():
        ...
        all_embeddings = []
        length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences])
        sentences_sorted = [sentences[idx] for idx in length_sorted_idx]

        maxworklength = 512 # 每次最多提取maxlength个字的特征

        for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):
            sentences_batch = sentences_sorted[start_index : start_index + batch_size]
            #features = self.tokenize(sentences_batch) #这里就不统一Token化了,逐步处理

            batch_embeddings = []
            for sentence in sentences_batch:
                if len(sentence) > maxworklength:
                    # If the sentence is too long, split it into smaller chunks
                    chunks = [
                        sentence[i : i + maxworklength]
                        for i in range(0, len(sentence), maxworklength)
                    ]
                    chunk_embeddings = []

                    for chunk in chunks:
                        chunk_features = self.tokenize([chunk])
                        chunk_features = batch_to_device(chunk_features, device)

                        with torch.no_grad():
                            out_features = self.forward(chunk_features, **kwargs)
                            chunk_embedding = out_features["sentence_embedding"]
                            if normalize_embeddings:
                                chunk_embedding = torch.nn.functional.normalize(chunk_embedding, p=2, dim=1)
                            if convert_to_numpy:
                                chunk_embedding = chunk_embedding.cpu()
                            chunk_embeddings.append(chunk_embedding)

                    # Average the embeddings of all chunks
                    chunk_embeddings = torch.stack(chunk_embeddings).mean(dim=0)
                    batch_embeddings.append(chunk_embeddings)
                else:
                    tokenized_sentence = self.tokenize([sentence])
                    features = batch_to_device(tokenized_sentence, device)
                    with torch.no_grad():
                        out_features = self.forward(features, **kwargs)
                        sentence_embedding = out_features["sentence_embedding"]
                        if normalize_embeddings:
                            sentence_embedding = torch.nn.functional.normalize(sentence_embedding, p=2, dim=1)
                        if convert_to_numpy:
                            sentence_embedding = sentence_embedding.cpu()
                        batch_embeddings.append(sentence_embedding)


            all_embeddings.extend(batch_embeddings)
          ...
        if input_was_string:
            all_embeddings = all_embeddings[0]

        all_embeddings = np.mean(all_embeddings, axis=1)

        return all_embeddings