❓️问题
当前在BERTopic中,由于其是BERT特征,BERT本身要求文本长度不超过512,否则就会截断,对于这个问题,BERTopic里面是直接进行了截断,然而这种方法并不很合适,对长文本不太友好,其中一种优化方法是:对文本的每512个字符提取BERT特征,然后求均值作为文本特征,具体如何做?修改哪里的代码?
🔗已有资料
目前在互联网上能找到的资料如下
zhuanlan.zhihu.com/p/588248281 blog.csdn.net/qq_30565883…
似乎都出自同一人之手,然而直接按照帖子的改写方法,始终无法获得正确的输出,all_embedding和句子集的长度不匹配,调试了半天。也许是因为之前版本的某些方法有写法变化?
⛑️2025.1能work的方法
搞了很久大概按照思路摸索出可以work的方法,最后对all_embeddings直接降维有一些trick,可能还需要进一步调整:
文件:SentenceTransformer.py
具体调整:
# 关注encode函数
def encode():
...
all_embeddings = []
length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences])
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
maxworklength = 512 # 每次最多提取maxlength个字的特征
for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):
sentences_batch = sentences_sorted[start_index : start_index + batch_size]
#features = self.tokenize(sentences_batch) #这里就不统一Token化了,逐步处理
batch_embeddings = []
for sentence in sentences_batch:
if len(sentence) > maxworklength:
# If the sentence is too long, split it into smaller chunks
chunks = [
sentence[i : i + maxworklength]
for i in range(0, len(sentence), maxworklength)
]
chunk_embeddings = []
for chunk in chunks:
chunk_features = self.tokenize([chunk])
chunk_features = batch_to_device(chunk_features, device)
with torch.no_grad():
out_features = self.forward(chunk_features, **kwargs)
chunk_embedding = out_features["sentence_embedding"]
if normalize_embeddings:
chunk_embedding = torch.nn.functional.normalize(chunk_embedding, p=2, dim=1)
if convert_to_numpy:
chunk_embedding = chunk_embedding.cpu()
chunk_embeddings.append(chunk_embedding)
# Average the embeddings of all chunks
chunk_embeddings = torch.stack(chunk_embeddings).mean(dim=0)
batch_embeddings.append(chunk_embeddings)
else:
tokenized_sentence = self.tokenize([sentence])
features = batch_to_device(tokenized_sentence, device)
with torch.no_grad():
out_features = self.forward(features, **kwargs)
sentence_embedding = out_features["sentence_embedding"]
if normalize_embeddings:
sentence_embedding = torch.nn.functional.normalize(sentence_embedding, p=2, dim=1)
if convert_to_numpy:
sentence_embedding = sentence_embedding.cpu()
batch_embeddings.append(sentence_embedding)
all_embeddings.extend(batch_embeddings)
...
if input_was_string:
all_embeddings = all_embeddings[0]
all_embeddings = np.mean(all_embeddings, axis=1)
return all_embeddings