大模型LLM:合成训练样本的数据分布问题

767 阅读3分钟

近几天在研究大模型LLM数数问题时,使用合成数据集来训练LLM“统计字符串(100个单词以内)中字母的个数”的能力,基于Word进行分词。原始的合成代码在生成随机字符串时,采用如下代码:

# self.words为常见英文单词数组,长度为3432
if random.random() < 0.1:
    ss = random.choices(self.words, k=random.randint(1, 9))
else:
    ss = random.choices(self.words, k=random.randint(1, 99))

合成样本示例如下:

how many letters are there in the following string: "spread high"? 10
how many letters are there in the following string: "european contradictory"? 21
how many letters are there in the following string: "lock over constitution smart boil superior patient teenager graduation drop speaker pronounce contribution boring step carpet realize format surprise disappoint promote track thick rank affect nurse preparation armchair data warn pint construction tale organization tank wear understand vast tremble"? 261

使用单卡训练12个小时左右,测试准确率约为99.937%。

这个准确率看上去很高的,但在人工测试过程中发现,模型对一些简单的case都会预测错误。例如:

how many letters are there in the following string: "a a"?  4, 2(expected)
how many letters are there in the following string: "be be be be"?  0, 8(expected)
how many letters are there in the following string: "dog dog a"?  8, 7(expected)
how many letters are there in the following string: "mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark mark"?  21292220, 396(expected)
how many letters are there in the following string: "world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world world"?  32, 320(expected)
how many letters are there in the following string: "i am fine"?  9, 7(expected)

从直觉上看,出错的case应该是更容易预测的。对此我提出如下猜测:如果测试样本的数据分布跟训练样本的数据分布差异较大,就会导致测试准确率降低。主要表现在: 1、长度为1的单词只有2个,占比为2/3432,但在实际测试中,"a"和"I"是高频单词 2、通过有放回地从词汇表中随机选取k个单词,难以出现类似"mark mark mark mark mark mark mark mark mark"这样的字符串

基于以上猜测,我修改了合成样本的代码:

# 提升短单词在单词表中的比例
self.short_words = []
for w in self.words:
  if len(w) == 1:
    self.short_words += [w] * 50
  elif len(w) == 2:
    self.short_words += [w] * 10
  elif len(w) == 3:
    self.short_words += [w] * 2                
self.words_new = self.words + self.short_words

# 提升同一个单词在字符串中多次出现的概率
if random.random() < 0.05:
  words = random.choices(self.words_new, k=random.randint(1, 5))
else:
  words = self.words_new

if random.random() < 0.1:  
  ss = random.choices(words, k=random.randint(1, 9))
else:
  ss = random.choices(words, k=random.randint(1, 99))

重新训练模型后再进行测试,上述错误的case就全部预测正确了。

总结:在合成训练样本时,应考虑实际使用场景的数据分布。

原文链接 请勿转载