å¨AI浪潮汹æ¶çä»å¤©ï¼å¤§åè¯è¨æ¨¡åï¼LLMï¼ä»¥å ¶åè¶çéç¨è½åï¼æ£æ·±å»æ¹åçæä»¬çå·¥ä½ä¸çæ´»ãç¶èï¼ä½ æ¯å¦ä¹æ¾éå°è¿æ ·çåºæ¯ï¼ä¸ä¸ªå¨éç¨é¢å表ç°åºè²çLLMï¼é¢å¯¹ä½ ç¹å®è¡ä¸æç§åæ°æ®æ¶ï¼å´æ¾å¾åä¸ä»å¿ï¼çè³âçéæé®âï¼
è¿å°±åä½ æ¥æäºä¸ä½ç¥è¯æ¸åçâç¾ç§å ¨ä¹¦å¼âå©çï¼ä½å½éè¦å¤çç¹å®é¢åçä¸ä¸é®é¢ï¼å¦è§£è¯»å¤æçæ³å¾æ¡æãåæä¼ä¸å é¨è´¢æ¥æ°æ®ï¼æçæç¬¦åå ¬å¸åçè°æ§çè¥éææ¡æ¶ï¼è¿ä½å©ççåçå¯è½è¿äºæ³æ³ï¼çè³äº§çåå·®ã
çç¹ä»£ç 示ä¾ï¼
让æä»¬æ¥çä¸ä¸ªç®åçPython伪代ç 示ä¾ï¼æ¨¡æä¸ä¸ªæªç»å¾®è°çéç¨LLMå¨ç¹å®ä»»å¡ä¸ç表ç°ãå设æä»¬æä¸ä¸ªåºç¡çLLMï¼æ¨å¨åçå
³äºç¹å®å
¬å¸ï¼å¦TechCorpï¼çä¸ä¸é®é¢ï¼
# åè®¾è¿æ¯ä¸ä¸ªåºç¡LLMçæ¨¡ææ¥å£
class GenericLLM:
def generate(self, prompt: str) -> str:
# 模æLLMçéç¨åçé»è¾
if "TechCorp" in prompt:
return "TechCorp æ¯ä¸å®¶é«ç§æå
¬å¸ï¼è´åäºåæ°åææ¯åå±ãä»ä»¬æå¾å¤äº§ååæå¡ã"
elif "æ³å¾æ¡æ¬¾" in prompt:
return "æ³å¾æ¡æ¬¾é常å
嫿å©ãä¹å¡å责任çå
容ï¼å
·ä½éæ¥é
ç¸å
³æ³å¾æä»¶ã"
else:
return "ææ¯ä¸ä¸ªå¤§åè¯è¨æ¨¡åï¼å¯ä»¥åçå¤ç§é®é¢ã"
# å®ä¾åéç¨LLM
base_llm = GenericLLM()
# ç¨æ·å°è¯è¯¢é®ç¹å®é®é¢
question_1 = "TechCorp 卿æ°çå£åº¦è´¢æ¥ä¸æåäºåªäºæ°çAIæç¥ï¼"
question_2 = "è¯·æ ¹æ®TechCorpçå
é¨è§èï¼æ°åä¸ä»½å
³äºé¡¹ç®å»¶æçéç¥ã"
print(f"**é®é¢1:** {question_1}")
print(f"**éç¨LLMåç:** {base_llm.generate(question_1)}
") # æ³æ³èè°ï¼ç¼ºä¹å
·ä½è´¢æ¥ä¿¡æ¯
print(f"**é®é¢2:** {question_2}")
print(f"**éç¨LLMåç:** {base_llm.generate(question_2)}
") # æ æ³çè§£å
é¨è§èï¼ç»åºéç¨åç
ä¸è¿°ä»£ç çè¾åºæ¸ æ°å°å±ç¤ºäºï¼éç¨LLMè½ç¶è½è¯å«å ³é®è¯ï¼ä½æ æ³æ·±å ¥çè§£ç¹å®é¢åçç»å¾®ä¹å¤ï¼ä¹æ æ³éµå¾ªç¹å®çæ ¼å¼æè¯è°è¦æ±ãè¿æ£æ¯æä»¬è¿è¡LLMå¾®è°ï¼Fine-tuningï¼çæ ¹æ¬åå ï¼
å¾®è°ï¼å°±åæ¯ç»è¿ä½âç¾ç§å ¨ä¹¦å¼âå©çæä¾äºä¸ä»½ä¸å±çâè¡ä¸å¹è®æåâåâå ¬å¸è¡ä¸ºååâï¼è®©ä»è½å¤å¿«ééåºå¹¶ç²¾éä½ çç¹å®éæ±ãå®è½æ¾èæå模åå¨ä¸æ¸¸ä»»å¡ä¸çæ§è½ï¼åæ¶ç¸æ¯ä»é¶å¼å§è®ç»ä¸ä¸ªæ¨¡åï¼ææ¬æ´ä½ãæçæ´é«ãé£ä¹ï¼æä»¬è¯¥å¦ä½é«æã使æ¬å°è¿è¡LLMå¾®è°ï¼è®©å®çæ£æä¸ºæä»¬ä¸å±çæºè½å©æå¢ï¼è®©æä»¬å¸¦çè¿ä¸ªé®é¢ï¼å¼å¯ä»å¤©çæ¢ç´¢ä¹æ ï¼
第ä¸ç« ï¼LLMå¾®è°åºç¡ï¼ä¸ºä»ä¹éè¦å¾®è°ï¼
1.1 ä»ä¹æ¯LLMå¾®è°ï¼é¢è®ç»ä¸å¾®è°ç奥ç§
大åè¯è¨æ¨¡åï¼LLMï¼ççå½å¨æé常å为两个主è¦é¶æ®µï¼é¢è®ç»ï¼Pre-trainingï¼ å å¾®è°ï¼Fine-tuningï¼ãçè§£è¿ä¸¤ä¸ªé¶æ®µæ¯ææ¡LLMå¾®è°çå ³é®ã
é¢è®ç» æ¯æå¨æµ·éçéç¨ææ¬æ°æ®ï¼å¦äºèç½ä¸ç书ç±ãæç« ã代ç ãç»´åºç¾ç§çï¼ä¸ï¼éè¿èªçç£å¦ä¹ ï¼ä¾å¦é¢æµä¸ä¸ä¸ªè¯ãå®å½¢å¡«ç©ºï¼æ¥è®ç»æ¨¡åãè¿ä¸ªé¶æ®µçç®æ æ¯è®©æ¨¡åå¦ä¹ å°éç¨çè¯è¨è§å¾ãä¸çç¥è¯ååºæ¬æ¨çè½åãæä»¬å¯ä»¥å°é¢è®ç»é¶æ®µçæ¨¡åæ³è±¡æä¸ä¸ªæ¥æâéç¨æºè½å¤§èâçå¦é¸ï¼ä»åè§ç¾¤ä¹¦ï¼ç¥è¯å¨å¤æå ¶ä¸°å¯ï¼è½å¤çè§£åçæåç§ç±»åçææ¬ï¼ä½å¯¹äºæä¸ªç¹å®é¢åçæ·±åº¦åºç¨ï¼å¯è½ç¼ºä¹å®æç»éªãè¿ä¸ªé¶æ®µç模ååæ°é巨大ï¼ä¾å¦GPT-3æ1750äº¿åæ°ï¼Llamaç³»å乿æ°å亿å°åäº¿åæ°ã
å¾®è° åæ¯å¨é¢è®ç»æ¨¡åçåºç¡ä¸ï¼ä½¿ç¨ç¹å®ä»»å¡æé¢åçæ°æ®è¿è¡è¿ä¸æ¥è®ç»ãå ¶ç®æ æ¯ä½¿æ¨¡åéåºç¹å®ç䏿¸¸ä»»å¡ï¼æåå ¶å¨è¯¥ä»»å¡ä¸çæ§è½ãå¾®è°é常æ¶åæ´æ°æ¨¡åçæææé¨ååæ°ãè¿å°±åæ¯ç»è¿ä½âéç¨æºè½å¤§èâçå¦é¸å®è£ äºâä¸ä¸æè½å âï¼è®©ä»è½å¤å¿«éææ¡ç¹å®é¢åçç¥è¯ï¼å¹¶å¦ä¼ä»¥ç¹å®çæ¹å¼è§£å³é®é¢ï¼ä»èæä¸ºæä¸ªé¢åçä¸å®¶ãä¾å¦ï¼æä»¬å¯ä»¥å¾®è°ä¸ä¸ªéç¨LLMï¼ä½¿å ¶æ é¿çææ³å¾æä¹¦ãå»å¦æ¥åï¼ææä¸ºä¸ä¸ªè½å¤çè§£åååºç¨æ·æ 绪çå®¢ææºå¨äººãéè¿å¾®è°ï¼æä»¬è½å¤å°éç¨è½å转å为ä¸å±è½åï¼è®©LLMçæ£ä¸ºæä»¬çç¹å®åºç¨æå¡ã
为ä»ä¹è¦å¾®è°ï¼
- é¢åéåºæ§ä¸è¶³ï¼éç¨LLMå¯è½ä¸çè§£ç¹å®è¡ä¸æ¯è¯ãè¡è¯æä¸å¡é»è¾ãå¾®è°è½è®©æ¨¡åâå¦ä¹ âè¿äºä¸ä¸ç¥è¯ï¼ä¾å¦å»çé¢åçä¸ä¸æ¯è¯æéèè¡ä¸çé£é©è¯ä¼°è§åã
- æ§è½ç¶é¢ï¼å¨æäºç¹å®ä»»å¡ï¼å¦æ æåæã代ç çæãæ³å¾é®çï¼ä¸ï¼éç¨LLMçæ§è½å¯è½è¾¾ä¸å°è¦æ±ãå¾®è°è½æ¾èæåè¿äºä»»å¡çåç¡®æ§åç¸å ³æ§ï¼ä½¿å ¶è¾åºæ´ç²¾åãæ´ç¬¦å颿ã
- **è¡ä¸ºæ¨¡å¼å®å¶**ï¼æä»¬å¯è½å¸æLLM以ç¹å®çè¯æ°ã飿 ¼ææ ¼å¼è¿è¡åå¤ãå¾®è°å¯ä»¥å¼å¯¼æ¨¡åäº§çææçè¡ä¸ºï¼ä¾å¦ä»¥å¹½é»çå£å»åå¤ãä¸¥æ ¼éµå¾ªJSONæ ¼å¼è¾åºæ°æ®ï¼æé¿å çæææå 容ã
- ææ¬æçï¼ä»é¶å¼å§è®ç»ä¸ä¸ªä¸GPT-3/4åçè§æ¨¡çLLMæ¯å¤©ææ°å级çæå ¥ï¼æ 论æ¯è®¡ç®èµæºè¿æ¯æ¶é´ææ¬é½æé«ãå¾®è°ä¸ä¸ªç°ææ¨¡åï¼è½å¤ä»¥ç¸å¯¹è¾ä½çææ¬è·å¾é«æ§è½ï¼è¿å¯¹äºå¤§å¤æ°ä¸ªäººå¼åè åä¼ä¸æ¥è¯´ï¼æ¯æ´ç»æµãæ´é«æçéæ©ã
1.2 ä¼ ç»å¾®è°çææï¼ç®åä¸âéå¿çâ
ä¼ ç»çå ¨éå¾®è°ï¼Full Fine-tuningï¼ æçæ¯å¨å¾®è°é¶æ®µæ´æ°æ¨¡åçææåæ°ã对äºå°å模åï¼ä¾å¦åæ°é卿°åä¸å°æ°äº¿ä¹é´ï¼ï¼è¿å¯è½ä¸æ¯é®é¢ï¼å 为å®ä»¬çåæ°éç¸å¯¹è¾å°ï¼æéç计ç®èµæºåæ¶é´é½å¨å¯æ¥åèå´å ãç¶èï¼å¯¹äºæ¥ææ°ç¾äº¿çè³æ°åäº¿åæ°çLLMï¼å ¨éå¾®è°é¢ä¸´çå·¨å¤§çææï¼è¿äºææä¸»è¦ä½ç°å¨ç®å鿱忍¡åç¨³å®æ§ä¸¤æ¹é¢ï¼
- **ç®å鿱巍大**ï¼ä¸ä¸ªæ°ç¾äº¿åæ°çLLMï¼å³ä½¿ä½¿ç¨16使µ®ç¹æ°ï¼FP16ï¼è¡¨ç¤ºï¼å ¶æ¨¡åæéä¹ä¼å ç¨æ°ç¾GBçGPUå åãå¨è®ç»è¿ç¨ä¸ï¼è¿éè¦å卿¢¯åº¦ãä¼åå¨ç¶æçï¼è¿ä½¿å¾å ¨éå¾®è°éè¦å¤åé«ç«¯GPUé群ï¼ä¾å¦NVIDIA A100æH100ï¼è¿å¯¹äºå¤§å¤æ°ä¸ªäººå¼åè åä¸å°åä¼ä¸èè¨æ¯é¾ä»¥æ¿åçææ¬ã髿ç硬件æå ¥åçµåæ¶èï¼æä¸ºäºå ¨éå¾®è°çä¸é鍿§ã
- **è®ç»æ¶é´æ¼«é¿**ï¼å³ä¾¿æè¶³å¤çèµæºï¼æ´æ°å¦æ¤åºå¤§çåæ°éåä¹éè¦æ¼«é¿çæ¶é´ãè®ç»å¨æå¯è½ä»æ°å¤©å°æ°å¨ä¸çï¼è¿å¤§å¤§å»¶é¿äºæ¨¡åè¿ä»£åé¨ç½²ç卿ï¼éä½äºå¼åæçã
- ç¾é¾æ§éå¿ï¼Catastrophic Forgettingï¼ï¼å¨å¾®è°è¿ç¨ä¸ï¼æ¨¡åå¯è½ä¼âéå¿âå¨é¢è®ç»é¶æ®µå¦å°çä¸äºéç¨ç¥è¯ï¼å°¤å ¶æ¯å¨æ°ä»»å¡æ°æ®ä¸é¢è®ç»æ°æ®åå¸å·®å¼è¾å¤§æ¶ãè¿å°±åä¸ä¸ªå¦ç为äºåºå¯¹ä¸é¨æ°èè¯ï¼è¿åº¦ä¸æ³¨äºæ°ç¥è¯èæä»¥åå¦çéç¨ç¥è¯é½å¿äºãè¿ç§ç°è±¡ä¼å¯¼è´æ¨¡åå¨å¾®è°ä»»å¡ä¸è¡¨ç°åºè²ï¼ä½å¨éç¨è½å䏿æéåï¼å½±åå ¶æ³åè½åã
ä¼ ç»å¾®è°çæ°æ®åå¤ä¸ææï¼æ¦å¿µæ§ä»£ç ï¼ï¼
è½ç¶ä¼ ç»å¾®è°çèµæºæ¶è巨大ï¼ä½æ°æ®å夿¯å ¶åºç¡ãä¸é¢ç代ç å±ç¤ºäºå¦ä½å°åå§æ°æ®è½¬æ¢ä¸ºLLM坿¥åçææ¬æ ¼å¼å¹¶è¿è¡tokenizeãä½å³ä½¿æ°æ®åå¤å°±ç»ªï¼å ¨éå¾®è°ä¾ç¶ä¼éå°ä¸è¿°ææã
# åºç¡ç¤ºä¾ä»£ç ï¼åå¤ç¨äºä¼ ç»å¾®è°çæ°æ®é
# éç¨äºPyTorchåHugging Face Transformersåº
from datasets import Dataset
from transformers import AutoTokenizer
# å设æä»¬æä¸ä¸ªå
嫿令åçæ¡çåå§æ°æ®é
# è¿æ¯ä¸ä¸ªé常ç®åç示ä¾ï¼å®é
æ°æ®ä¼æ´å¤æï¼ä¸éè¦ç»ä¸æ ¼å¼
raw_data = [
{"instruction": "请ä»ç»ä¸ä¸LoRAå¾®è°çåçã", "output": "LoRAï¼Low-Rank Adaptationï¼æ¯ä¸ç§åæ°é«æå¾®è°æ¹æ³ï¼éè¿å¼å
¥ä½ç§©ç©éµæ¥æ´æ°æ¨¡åã"},
{"instruction": "è§£éä¸ä¸QLoRAä¸LoRAçåºå«ã", "output": "QLoRAæ¯LoRAçéåçæ¬ï¼éè¿éååºåº§æ¨¡åè¿ä¸æ¥éä½å
åæ¶èã"},
{"instruction": "å¦ä½åå¤LLMå¾®è°æ°æ®ï¼", "output": "LLMå¾®è°æ°æ®é常éè¦éµå¾ªæä»¤å¾®è°æ ¼å¼ï¼å¹¶è¿è¡æ¸
æ´åæ ¼å¼åã"}
]
# å°åå§æ°æ®è½¬æ¢ä¸ºHugging Face Dataset对象
# æä»¬ä¼å°instructionåoutputæ¼æ¥æä¸ä¸ªå®æ´çè¾å
¥ææ¬ï¼è¿æ¯ä¼ ç»çç£å¾®è°ç常è§åæ³
def format_instruction_data(example):
# å®é
åºç¨ä¸ï¼ä½ å¯è½éè¦æ´å¤æçprompt模æ¿ï¼æ¯å¦ChatMLæ ¼å¼
example["text"] = f"### Instruction:
{example['instruction']}
### Output:
{example['output']}"
return example
# åå»ºå¹¶æ ¼å¼åæ°æ®é
dataset = Dataset.from_list(raw_data)
formatted_dataset = dataset.map(format_instruction_data)
print("åå§æ°æ®æ ·ä¾ï¼")
print(raw_data[0])
print("
æ ¼å¼ååçæ°æ®æ ·ä¾ï¼")
print(formatted_dataset[0])
# å è½½é¢è®ç»æ¨¡åçTokenizerï¼è¿é使ç¨google/gemma-2bä½ä¸ºç¤ºä¾
# 注æï¼ä½ éè¦ç¡®ä¿ä½ çç¯å¢å¯ä»¥è®¿é®è¿ä¸ªæ¨¡åï¼æè
ä½¿ç¨æ¬å°æ¨¡åè·¯å¾
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
# 妿tokenizer没æå¡«å
tokenï¼éè¦æå¨æ·»å ï¼ä»¥ä¾¿æ¹å¤ç
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# å¯¹ææ¬è¿è¡tokenize
def tokenize_function(examples):
# truncation=True: æªæè¶
é¿åºå
# max_length: é嶿大é¿åº¦ï¼æ ¹æ®æ¨¡åä¸ä¸æçªå£åæ°æ®å®é
æ
åµè°æ´
# padding="max_length": å¡«å
å°max_lengthï¼å¦æä½¿ç¨DataCollatorï¼ä¹å¯ä»¥è®¾ç½®ä¸ºFalse
return tokenizer(examples["text"], truncation=True, max_length=512, padding="max_length")
tokenized_dataset = formatted_dataset.map(tokenize_function, batched=True, remove_columns=["instruction", "output", "text"])
print("
Tokenizedæ°æ®æ ·ä¾ (åå 个token_ids):")
print(tokenized_dataset[0]["input_ids"][:10])
print("Tokenizedæ°æ®æ ·ä¾ (attention_mask):")
print(tokenized_dataset[0]["attention_mask"][:10])
# å¨å®é
çä¼ ç»å
¨éå¾®è°ä¸ï¼æä»¬è¿éè¦å建ä¸ä¸ªDataCollatoråHugging Face Traineræ¥å¼å§è®ç»ã
# ä½å
¶é«æçèµæºæ¶èï¼ä¿ä½¿æä»¬å¯»æ¾æ´ä¼è§£ââåæ°é«æå¾®è° (PEFT)ã
è¿æ®µä»£ç å±ç¤ºäºå¦ä½å°ç»æåæ°æ®è½¬æ¢ä¸ºLLM坿¥åçææ¬æ ¼å¼å¹¶è¿è¡tokenizeãç¶èï¼å³ä¾¿æ°æ®åå¤å°±ç»ªï¼å ¨éå¾®è°æéçèµæºä¾ç¶æ¯å·¨å¤§çãå æ¤ï¼åæ°é«æå¾®è°ï¼PEFTï¼ æ¹æ³åºè¿èçï¼å®ä»¥âå°æ¹å¨ï¼å¤§æåâççç¥ï¼ææè§£å³äºä¼ ç»å¾®è°çå°å¢ã
第äºç« ï¼é«æå¾®è°æ¹æ³ï¼PEFTï¼çæ ¸å¿åçï¼å°æ¹å¨ï¼å¤§æåï¼
2.1 PEFTæ¦è¿°ï¼ä»¥å°åå¤§çæºæ §
忰髿微è°ï¼Parameter-Efficient Fine-tuning, PEFTï¼ æ¯ä¸ç³»åæ¨å¨æ¾èåå°å¾®è°è¿ç¨ä¸éè¦è®ç»çåæ°æ°éçææ¯ãå®çæ ¸å¿ææ³é常精å¦ï¼LLMçç»å¤§é¨åéç¨ç¥è¯å·²ç»éè¿é¢è®ç»å¦å°äºï¼æä»¬ä¸éè¦âéå¡âæ´ä¸ªæ¨¡åãç¸åï¼æä»¬åªéè¦é对ç¹å®ä»»å¡ï¼å¯¹æ¨¡åçå°æ°å ³é®é¨åè¿è¡å¾®è°ï¼æè å¼å ¥å°éé¢å¤åæ°æ¥âå¼å¯¼â模åï¼å°±è½è¾¾å°æ¥è¿çè³è¶ è¶å ¨éå¾®è°çææï¼åæ¶å¤§å¤§éä½è®¡ç®åå卿æ¬ï¼å¹¶ææç¼è§£ç¾é¾æ§éå¿ãè¿å°±åæ¯å¨ä¸ä¸ªåºå¤§ç夿æºå¨ä¸ï¼æä»¬ä¸éè¦æ¿æ¢æ´ä¸ªå¼æï¼åªéè¦è°æ´å ä¸ªå ³é®çèºä¸æå è£ å 个å°é¨ä»¶ï¼å°±è½è®©æºå¨éåºæ°çå·¥ä½ã
PEFTæ¹æ³çä¼å¿æ¾èæè§ï¼
- æ¾èéä½ç®åéæ±ï¼ç±äºåªè®ç»å°éåæ°ï¼æéçGPUå åå计ç®èµæºå¤§å¹ åå°ï¼ä½¿å¾å¨æ¶è´¹çº§GPUä¸å¾®è°å¤§å模åæä¸ºå¯è½ã
- 缩çè®ç»æ¶é´ï¼åæ°éçåå°ç´æ¥æå³çè®ç»é度çæåï¼å éäºæ¨¡åè¿ä»£å¨æã
- ç¼è§£ç¾é¾æ§éå¿ï¼PEFTæ¹æ³é常ä¼å»ç»å¤§é¨åé¢è®ç»æéï¼åªä¿®æ¹ææ·»å å°éåæ°ï¼ä»èæ´å¥½å°ä¿çäºæ¨¡åå¨é¢è®ç»é¶æ®µå¦å°çéç¨ç¥è¯ã
- **å¤ä»»å¡éåºæ§**ï¼ç±äºæ¯ä¸ªä»»å¡åªå¼å ¥å°éåæ°ï¼æä»¬å¯ä»¥ä¸ºä¸åçä»»å¡è®ç»ä¸åçPEFTéé å¨ï¼å¹¶å¨æ¨çæ¶æ ¹æ®éè¦å è½½ï¼å®ç°æ¨¡åçâæ¨¡ååâåâæä»¶åâã
Hugging Faceçpeftåºæ¯ä¸ä¸ªå¼ºå¤§çå·¥å
·ï¼å®éæäºå¤ç§PEFTæ¹æ³ï¼å¦LoRA, QLoRA, Prefix Tuningçï¼ï¼æå¤§å°ç®åäºå¼åè
çå·¥ä½ï¼è®©æä»¬å¯ä»¥è½»æ¾å°å°è¿äºé«æçå¾®è°çç¥åºç¨å°Transformers模åä¸ãæ¥ä¸æ¥ï¼æä»¬å°æ·±å
¥äºè§£å ç§ææµè¡ä¸é«æçPEFTæ¹æ³ã
2.2 LoRA (Low-Rank Adaptation)ï¼ä½ç§©éé çç²¾é«
åçåæï¼ LoRA çæ ¸å¿ææ³æ¯ï¼å¨å¤§åè¯è¨æ¨¡åçé¢è®ç»æéç©éµæè¾¹ï¼æ³¨å ¥ä¸å¯¹å°çãå¯è®ç»çä½ç§©ç©éµï¼LoRAæéï¼ãå½è¿è¡å¾®è°æ¶ï¼æ¨¡åçåå§æé被å»ç»ï¼åªæè¿äºä½ç§©ç©éµè¢«æ´æ°ã卿¨çæ¶ï¼è¿äºä½ç§©ç©éµçä¹ç§¯ä¸åå§æéç©éµçä¹ç§¯ç¸å ï¼å½¢æä¸ä¸ªâå¾®è°åâçæéç©éµã
å ·ä½æ¥è¯´ï¼å¯¹äºåå§æéç©éµ ï¼LoRA å¼å ¥ä¸¤ä¸ªè¾å°çç©éµ å ï¼å ¶ä¸ æ¯è¿å°äº å çâç§©âï¼rankï¼ãå¾®è°æ¶ï¼æä»¬åªè®ç» å ï¼è ä¿æä¸åãæ´æ°é ãç±äº ï¼éè¦è®ç»çåæ°é è¿å°äº ã䏾便¥è¯´ï¼å¦æ æ¯ä¸ä¸ª çç©éµï¼åæ°é为 1024^2 pprox 10^6ã妿 ï¼é£ä¹LoRAç忰鿝 1024 imes 8 + 8 imes 1024 = 2 imes 8 imes 1024 pprox 1.6 imes 10^4ï¼ä» 为åå§åæ°ç1.6%ï¼è¿ç§å·§å¦ç设计ï¼ä½¿å¾LoRAè½å¤å¨ä¿ææ¨¡åæ§è½çåæ¶ï¼å¤§å¹ ååè®ç»ææ¬ã
ä¼å¿ï¼
* åæ°é大å¹
åå°ï¼LoRAåªæ´æ°é常å°çåæ°ï¼æå¤§å°éä½äºå
åå计ç®éæ±ï¼ä½¿å¾å¨å个GPUä¸å¾®è°å¤§å模åæä¸ºå¯è½ã
* é¿å
ç¾é¾æ§éå¿ï¼ç±äºåå§æ¨¡åæéä¸åï¼éç¨ç¥è¯å¾ä»¥ä¿çï¼æ¨¡åå¨å¾®è°åä¾ç¶è½ä¿æè¯å¥½çæ³åè½åã
* **é¨ç½²çµæ´»**ï¼å¯ä»¥å°LoRAæéä¸åå§æéåå¹¶ï¼âç¤ç®±èåâï¼ï¼æä½ä¸ºç¬ç«æä»¶å è½½ï¼æ¹ä¾¿è¿è¡å¤ä»»å¡ç®¡çå忢ã
代ç 示ä¾ï¼ä½¿ç¨ peft åºé
ç½® LoRA
# è¿é¶å®æä»£ç ï¼ä½¿ç¨peftåºé
ç½®LoRAåæ°
# è¿æ®µä»£ç å±ç¤ºäºå¦ä½ä¸ºLLMç线æ§å±ï¼é常æ¯attention模åï¼é
ç½®LoRA
from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM
import torch
# 1. å è½½ä¸ä¸ªé¢è®ç»çLLM模å
# 使ç¨ä¸ä¸ªè¾å°ç模åè¿è¡æ¼ç¤ºï¼ä¾å¦Googleçgemma-2b
# 注æï¼å®é
åºç¨ä¸ä½ ä¼å è½½æ´å¤§ç模åï¼ä¾å¦Llama-2, Mistralç
print(" æ£å¨å è½½é¢è®ç»æ¨¡å...")
model_name_or_path = "google/gemma-2b" # æè
ä½ æ¬å°ç模åè·¯å¾
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=torch.bfloat16, # æ¨è使ç¨bfloat16è¿è¡é«æè®ç»ï¼èçå
å
device_map="auto" # èªå¨åé
å°å¯ç¨GPU (å¦æåªæä¸ä¸ªGPUï¼é常æ¯cuda:0)
)
print("
æ£å¨é
ç½®LoRA...")
# 2. å®ä¹LoRAé
ç½®
# target_modules: å
³é®åæ°ï¼éå¸¸æ¯æ³¨æåæºå¶ä¸ç线æ§å±ï¼å¦q_proj, k_proj, v_proj, o_projï¼
# 对äºä¸åçæ¨¡åæ¶æï¼å¦Llama, Mixtral, Gemmaï¼ï¼è¿äºæ¨¡ååå¯è½ç¥æä¸åã
# ä½ å¯ä»¥éè¿ `model.print_trainable_parameters()` æ `model.named_modules()` æ¥æ¢ç´¢æ¨¡åç»æã
# r: LoRAçç§©ï¼å³å®äºæ°å¢åæ°çæ°é忍¡åç表达è½åï¼é常å¨8-64ä¹é´ãrè¶å¤§ï¼åæ°è¶å¤ï¼è¡¨è¾¾è½åè¶å¼ºï¼ä½å¯è½å¢å è¿æåé£é©ã
# lora_alpha: LoRAç缩æ¾å åï¼é常æ¯rç两åæä¸rç¸çã宿§å¶äºLoRAæ´æ°å¯¹åå§æ¨¡åçå½±å强度ã
# lora_dropout: Dropoutçï¼ç¨äºæ£ååï¼é²æ¢è¿æåã
# bias: æ¯å¦å¯¹å置项è¿è¡LoRAéé
ï¼é常设置为"none"ã
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, # ä»»å¡ç±»åï¼å æè¯è¨å»ºæ¨¡ (大é¨åLLMéç¨)
inference_mode=False, # è®ç»æ¨¡å¼
r=16, # LoRAçç§©ï¼å½±ååæ°éå表达è½å
lora_alpha=32, # LoRAç缩æ¾å åï¼é常æ¯rç两å
lora_dropout=0.1, # Dropoutçï¼ç¨äºæ£åå
# å
³é®åæ°ï¼æå®åªäºæ¨¡åéè¦åºç¨LoRAã
# 对äºGemmaç³»åï¼å¸¸è§ç线æ§å±æ¨¡ååå¯è½å
å«'q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
bias="none" # ä¸å¯¹å置项è¿è¡LoRAéé
)
# 3. å°LoRAé
ç½®åºç¨å°åå§æ¨¡åä¸
# get_peft_modelä¼è¿åä¸ä¸ªPEFT模åï¼å®åªæ´é²LoRAåæ°ä¸ºå¯è®ç»ç
peft_model = get_peft_model(model, lora_config)
print(f"
LoRA模ååæ°ç»è®¡ï¼")
peft_model.print_trainable_parameters() # æå°å¯è®ç»åæ°æ°éåå æ¯
# æ£æ¥LoRAåºç¨åï¼åªæLoRA屿¯å¯è®ç»ç
print("
æ£æ¥é¨åå¯è®ç»åæ°ï¼")
found_trainable = False
for name, param in peft_model.named_parameters():
if param.requires_grad:
print(f"Trainable parameter after LoRA: {name}")
found_trainable = True
if "lora_A" in name: # 示ä¾ï¼åªæå°å 个LoRAå±ï¼å±ç¤ºå
¶å¼
if len(param.shape) > 1: # é¿å
æå°åç½®é¡¹ææ éåæ°
print(f" Shape: {param.shape}, Values: {param.data[0, :5].tolist()}...")
# 为äºé¿å
è¾åºè¿å¤ï¼åªæå°ä¸é¨å
if sum(1 for p_name, _ in peft_model.named_parameters() if p_name.startswith("base_model.model.model.layers") and "lora_A" in p_name) > 5:
break # æå°å 个LoRAå±å忢ï¼é¿å
è¿å¤è¾åº
if not found_trainable:
print("æªè½æ¾å°LoRAå¯è®ç»åæ°ï¼è¯·æ£æ¥target_modulesé
ç½®ææ¨¡åç»æã")
else:
print("
LoRAé
ç½®æåï¼ç°å¨å¯ä»¥å¼å§è®ç»åªå
å«å°éå¯è®ç»åæ°ç模åäºã")
è¿æ®µä»£ç å±ç¤ºäºå¦ä½ä½¿ç¨peftåºä¸ºHugging Faceæ¨¡åæ·»å LoRAéé
å¨ãéè¿lora_configçtarget_modulesåæ°ï¼æä»¬å¯ä»¥ç²¾ç¡®æ§å¶åªäºæ¨¡åå°è¢«LoRAåãpeft_model.print_trainable_parameters()伿¸
æ°å°æ¾ç¤ºï¼ç¸æ¯äºåå§æ¨¡åçå
¨é¨åæ°ï¼å¯è®ç»åæ°æ°é大å¹
åå°ï¼è¿æ£æ¯LoRAçé
åæå¨ï¼
2.3 QLoRA (Quantized Low-Rank Adaptation)ï¼éåå éä¸å åä¼å
åçåæï¼ QLoRA æ¯ LoRA çä¸ä¸ªé©å½æ§åç§ï¼å®éè¿å°é¢è®ç»æ¨¡åéåå° 4-bit NormalFloat (NF4) æ°æ®ç±»åæ¥è¿ä¸æ¥åå°å åå ç¨ï¼åæ¶ä»ç¶ä¿æ LoRA éé å¨çè®ç»ãè¿æå³çåå§ LLM ç忰以 4 ä½ç²¾åº¦å è½½ï¼æå¤§åç¼©äºæ¨¡åå¨GPUä¸çåå¨ç©ºé´ãå¨ååä¼ ææ¶ï¼éåæéä¼è¢«åéåå° 16 ä½ BFloat16 è¿è¡è®¡ç®ï¼è LoRA éé å¨å以 16 ä½ç²¾åº¦è®ç»ãè¿ç§å·§å¦çç»åï¼ä½¿å¾QLoRAå¨ä¿ææ¨¡åæ§è½çåæ¶ï¼å°LLMå¾®è°æéç GPU å 忍åæè´ã
ä¼å¿ï¼
* å
åå ç¨æä½ï¼è¿æ¯QLoRAæçªåºçä¼å¿ãä¾å¦ï¼ä¸ä¸ª70äº¿åæ°çLlama 2模åï¼å¨å
¨ç²¾åº¦ï¼FP16ï¼ä¸éè¦çº¦14GB GPUå
åï¼èéè¿QLoRAéååï¼å¯è½ä»
é 8-10GB GPUå
åï¼ä½¿å¾å¨åå¼ RTX 3090/4090çæ¶è´¹çº§æ¾å¡ä¸å¾®è°å¤§å模åæä¸ºå¯è½ãå¯¹äºæ´å¤§ç模åï¼å¦300äº¿åæ°çæ¨¡åï¼QLoRAä¹è½å°å
¶å
åéæ±ä»60GB+éä½å°20GBå·¦å³ã
* æ§è½æ¥è¿LoRAï¼å°½ç®¡è¿è¡äºæ·±åº¦éåï¼ä½NF4éåææåºè²ï¼å¯¹æ¨¡åæ§è½å½±åå¾å°ï¼é常è½è¾¾å°ä¸FP16 LoRAå¾®è°ç¸ä¼¼çææã
* **è®ç»é度快**ï¼åå°äºæ¨¡ååæ°çå
åå ç¨ï¼æå³çæ°æ®å¨GPUå
ååè®¡ç®æ ¸å¿ä¹é´çä¼ è¾éåå°ï¼ä»èå éäºè®ç»è¿ç¨ã
代ç 示ä¾ï¼ä½¿ç¨ bitsandbytes å peft é
ç½® QLoRA
# è¿é¶å®æä»£ç ï¼ä½¿ç¨bitsandbytesåpeftåºé
ç½®QLoRA
# è¿æ®µä»£ç å±ç¤ºäºå¦ä½ç»å4-bitéååLoRAè¿è¡å¾®è°
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, TaskType, get_peft_model
import torch
print(" æ£å¨é
ç½®4-bitéå...")
# 1. å®ä¹BitsAndBytesé
ç½®ï¼å¯ç¨4-bitéå
# load_in_4bit=True: å¯ç¨4ä½éå
# bnb_4bit_quant_type="nf4": 使ç¨4-bit NormalFloatéåç±»åï¼æ¨èç¨äºtransformer模åã
# NF4æ¯ä¸ä¸ºæ£æå叿é设计çï¼æ¯å
¶ä»4ä½éåæ¹æ³æææ´å¥½ã
# bnb_4bit_use_double_quant=True: å¯ç¨åµå¥éåï¼å°éåå¸¸æ°æ¬èº«ä¹éåï¼è¿ä¸æ¥èçå
åï¼é常è½é¢å¤èç约0.4 bit/åæ°ã
# bnb_4bit_compute_dtype=torch.bfloat16: å¨é忍¡å䏿§è¡è®¡ç®æ¶ä½¿ç¨çæ°æ®ç±»åã
# é常设置为bfloat16æfloat16ï¼ä»¥ä¿æè®¡ç®ç²¾åº¦ã
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
# 2. å è½½é¢è®ç»æ¨¡åï¼å¹¶åºç¨éåé
ç½®
print(" æ£å¨å è½½é¢è®ç»æ¨¡åå¹¶åºç¨QLoRAéå...")
model_name_or_path = "google/gemma-2b"
quantized_model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
quantization_config=quantization_config,
device_map="auto" # èªå¨å°æ¨¡åå è½½å°å¯ç¨è®¾å¤
)
# 3. é
ç½®LoRAï¼ä¸ä¹åLoRAé
置类似ï¼ä½ç°å¨æ¯ä½ç¨å¨é忍¡åä¸
# QLoRAé常å¯ä»¥æ¯ææ´å¤§çç§© (r)ï¼å 为å®å¯¹å
åçååæ´å°ï¼å¯è½å¸¦æ¥æ´å¥½çæ§è½ã
lora_config_qlora = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=32, # å¯ä»¥éæ©æ´å¤§çç§©ï¼å 为å
åååæ´å°ï¼å¯è½å¸¦æ¥æ´å¥½çæ§è½
lora_alpha=64, # ç¸åºç缩æ¾å å
lora_dropout=0.05,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
bias="none"
)
# 4. å°LoRAé
ç½®åºç¨å°é忍¡åä¸
qlora_model = get_peft_model(quantized_model, lora_config_qlora)
print(f"
QLoRA模ååæ°ç»è®¡ï¼")
qlora_model.print_trainable_parameters()
print("
QLoRAé
ç½®æåï¼æå¤§å°èçäºGPUå
åï¼è®©æ´å¤äººææºä¼å¾®è°å¤§åLLMã")
# 注æï¼å¦æä½ æ³æ£æ¥å®é
çå
åå ç¨ï¼å¯ä»¥ä½¿ç¨nvidia-smiå½ä»¤ã
# QLoRAå¨å è½½æ¨¡åæ¶å°±ä¼æ¾èåå°å
åå ç¨ï¼ä¾å¦ä¸ä¸ª7B模åå¯è½åªéè¦8-10GB GPUå
åã
# è¿æ¯ä¸ä¸ªæå¤§çä¼å¿ï¼è®©å¨åå¼ RTX 3090/4090ä¸å¾®è°å¤§å模åæä¸ºå¯è½ã
QLoRAæ¯å¾®è°å¤§åLLMï¼å¦Llamaç³»åãMistralãGemmaï¼çæä½³éæ©ä¹ä¸ï¼å°¤å ¶æ¯å¨èµæºæéçæ åµä¸ãå®è®©ä¸ªäººå¼åè åå°åå¢éä¹è½åä¸å°å¤§å模åçå¾®è°ä¸æ¥ï¼çæ£å®ç°äºâ人人å¯å¾®è°âçç®æ ã
2.4 Prompt Tuning / P-Tuning / Prefix Tuningï¼ä¸å¨å¦å±±ç模å主ä½
è¿ç»æ¹æ³çæ ¸å¿ææ³æ¯ï¼å»ç»å¤§åé¢è®ç»æ¨¡åçææåæ°ï¼åªè®ç»ä¸å°é¨åè¿ç»çãå¯å¦ä¹ çâ软æç¤ºâï¼soft promptsï¼æâåç¼âï¼prefixesï¼ãè¿äºè½¯æç¤ºä¼è¢«æ·»å å°æ¨¡åçè¾å ¥åµå ¥å±æ Transformer å±çæ¿æ´»ä¸ï¼ä»èå¼å¯¼æ¨¡åçæææçè¾åºãå®çä¼å¿å¨äºåæ°éæå°ï¼ä¸å¯¹æ¨¡å䏻使 ä¾µå ¥ã
- Prompt Tuningï¼è¿æ¯æç®åçä¸ç§å½¢å¼ï¼åªè®ç»è¾å ¥åµå ¥å±åçä¸äºè¿ç»åéãè¿äºåéä¸åå§è¾å ¥æ¼æ¥å¨ä¸èµ·ï¼å ±åè¾å ¥ç»æ¨¡åãæ¨¡åæ¬èº«æ éä¿®æ¹ã
- P-Tuningï¼å¨Prompt Tuningçåºç¡ä¸ï¼P-Tuningå¼å ¥äºä¸ä¸ªå°åç¥ç»ç½ç»ï¼å¦LSTMï¼æ¥çæè¿äºè½¯æç¤ºãè¿æ ·å使å¾è½¯æç¤ºæ´å ·è¡¨è¾¾åï¼å 为å®ä»¬ä¸åæ¯ç®åçåºå®åéï¼èæ¯éè¿ä¸ä¸ªå°å模å卿çæçï¼ä»èè½å¤æ´å¥½å°éåºä¸åçè¾å ¥ã
- Prefix Tuningï¼è¿ç§æ¹æ³æ¯Prompt Tuningæ´è¿ä¸æ¥ï¼å®ä¸ä» å¨è¾å ¥å±æ·»å 软æç¤ºï¼èæ¯å¨Transformerçæ¯ä¸å±é½æ·»å å¯è®ç»çåç¼åéãè¿äºåç¼åé被添å å°èªæ³¨æåæºå¶çé®ï¼Keyï¼åå¼ï¼Valueï¼ç©éµä¸ï¼ä»è卿¨¡åçæ¯ä¸å±é½å¯¹ä¿¡æ¯æµè¿è¡å¼å¯¼ã
ä¼å¿ï¼
* åæ°éæå°ï¼é常åªè®ç»å åå°å åä¸ä¸ªåæ°ï¼è¿å°äºLoRAãè¿ä½¿å¾è®ç»é度æå¿«ï¼ä¸å¯¹è®¡ç®èµæºçè¦æ±æä½ã
* å
åå ç¨æä½ï¼ç±äºæ¨¡å主ä½å®å
¨å»ç»ï¼å
åå ç¨é常å°ï¼çè³æ¯QLoRAæ´ä½ï¼å 为ä¸éè¦åå¨é忍¡ååå
¶åéåæéçæ¥æ¾è¡¨ã
* **çµæ´»æ§**ï¼ç¹å«éåå°æ ·æ¬ï¼Few-shotï¼çè³é¶æ ·æ¬ï¼Zero-shotï¼åºæ¯ï¼å 为å®ä¸ä¿®æ¹æ¨¡åçæ ¸å¿ç¥è¯ï¼èæ¯éè¿å¤é¨å¼å¯¼æ¥éåºä»»å¡ãè¿ä½¿å¾æ¨¡åå¨é¢å¯¹æ°ä»»å¡æ¶è½å¤å¿«ééé
ï¼èæ éè¿è¡å¤§è§æ¨¡çéæ°è®ç»ã
代ç 示ä¾ï¼ä½¿ç¨ peft åºé
ç½® Prefix Tuning
# æ¯è¾ä»£ç 示ä¾ï¼é
ç½®Prefix Tuning
# è¿æ®µä»£ç å±ç¤ºäºä¸LoRA/QLoRAä¸åçåæ°é«æå¾®è°æ¹æ³ï¼åæ°éæå°
from peft import PrefixTuningConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM
import torch
print(" æ£å¨å è½½é¢è®ç»æ¨¡å...")
model_name_or_path = "google/gemma-2b"
model_prefix = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=torch.bfloat16,
device_map="auto"
)
print("
æ£å¨é
ç½®Prefix Tuning...")
# å®ä¹Prefix Tuningé
ç½®
# num_virtual_tokens: èætokençæ°éï¼è¿äºtokençåµå
¥å°è¢«è®ç»ãæ°éè¶å¤ï¼è¡¨è¾¾è½åå¯è½è¶å¼ºï¼ä½åæ°éä¹è¶å¤ã
# encoder_hidden_size: éå¸¸æ¯æ¨¡åéèå±ç大å°ï¼ç¨äºåå§åPrefixç维度ãè¿æ¯å
³é®ï¼å¿
须䏿¨¡åå¹é
ã
# prefix_projection: æ¯å¦æå½±åç¼ï¼ä½¿å
¶æ´å¤æã设置为Trueæ¶ï¼ä¼å
éè¿ä¸ä¸ªå°åMLPè¿è¡æå½±ï¼éå¸¸æææ´å¥½ï¼ä½åæ°éç¥å¢ã
prefix_config = PrefixTuningConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
num_virtual_tokens=30, # å¢å èætokenæ°é以æä¾æ´å¤è¡¨è¾¾åï¼ä¾å¦è®¾ç½®ä¸º20-100
encoder_hidden_size=model_prefix.config.hidden_size, # è·å模åçéèå±å¤§å°
prefix_projection=False # ç®åèµ·è§ï¼ä¸è¿è¡æå½±ãå¦æè®¾ä¸ºTrueï¼éå¸¸æææ´å¥½ï¼ä½åæ°éç¥å¢ã
)
# å°Prefix Tuningé
ç½®åºç¨å°åå§æ¨¡åä¸
peft_model_prefix = get_peft_model(model_prefix, prefix_config)
print(f"
Prefix Tuning模ååæ°ç»è®¡ï¼")
peft_model_prefix.print_trainable_parameters()
print("
Prefix Tuningé
ç½®æåï¼ä»¥æå°çåæ°éè¿è¡å¾®è°ã")
# 对æ¯åæï¼
# LoRAåQLoRAéè¿ä¿®æ¹æ¨¡åå
é¨çæéç©éµæ¥éåºä»»å¡ï¼é叏卿æä¸æ´æ¥è¿å
¨éå¾®è°ï¼
# éç¨äºéè¦æ¨¡å对ç¹å®é¢åç¥è¯ææ´æ·±çè§£æå¤æè¡ä¸ºè°æ´çåºæ¯ã
# Prefix Tuningåéè¿ä¿®æ¹è¾å
¥æç¤ºæ¥å¼å¯¼æ¨¡åè¡ä¸ºï¼å¯¹æ¨¡åå
é¨ç»ææ¹å¨æå°ï¼
# æ´éå对模åè¡ä¸ºè¿è¡è½»éçº§è°æ´ã飿 ¼æ§å¶æå¨Few-shotåºæ¯ä¸å¿«ééé
ï¼ä½å¯è½ä¸å¦LoRAå¨å¤æä»»å¡ä¸è¡¨ç°å¥½ã
Prefix Tuningçåæ°éé常å°ï¼è¿ä½¿å¾å®å¨æäºåºæ¯ä¸æä¸ºæå ¶é«æçéæ©ãçè§£ä¸åPEFTæ¹æ³çåçåéç¨åºæ¯ï¼è½å¸®å©æä»¬æ ¹æ®å ·ä½éæ±ååºæä½³éæ©ã
第ä¸ç« ï¼æ°æ®åå¤ä¸å¤ççèºæ¯ï¼ä¼è´¨æ°æ®æ¯å¾®è°æåçåºç³
âåå¾è¿ï¼åå¾åºâï¼Garbage In, Garbage Outï¼è¿å¥æ ¼è¨å¨LLMå¾®è°é¢å尤为éç¨ãé«è´¨éãæ ¼å¼è§èãæ°ééä¸çæ°æ®æ¯å¾®è°æåçå ³é®ãè¿éæä»¬ä¸»è¦èç¦äºæä»¤å¾®è°ï¼Instruction Tuningï¼ çæ°æ®åå¤ï¼è¿æ¯ç®åæå¸¸è§ä¸é«æçå¾®è°èå¼ä¹ä¸ã
3.1 æä»¤å¾®è°æ°æ®æ ¼å¼ï¼è®©æ¨¡åçè§£ä½ çæå¾
æä»¤å¾®è°æ¯æéè¿ä¸ç³»åâæä»¤-ååºâ对æ¥è®ç»æ¨¡åï¼ä½¿å
¶è½å¤ç解并éµå¾ªäººç±»çæä»¤ãéå¸¸ï¼æ¯ä¸ªæ°æ®ç¹å
å«ä¸ä¸ªç¨æ·æä»¤ï¼instructionï¼ãå¯è½çä¸ä¸æï¼inputæcontextï¼ä»¥å模ååºè¯¥çæçææååºï¼outputï¼ãè¿ç§æ ¼å¼è½å¤æä¼æ¨¡åå¦ä½åä¸ä¸ªå©æä¸æ ·ï¼æ ¹æ®ç»å®çæä»¤åä¸ä¸æäº§çæ°å½çåå¤ã
æ¨èæ ¼å¼ç¤ºä¾ï¼
[
{
"instruction": "请æ»ç»ä»¥ä¸æç« çæ ¸å¿è§ç¹ã",
"input": "大åè¯è¨æ¨¡åï¼LLMï¼çå¾®è°æ¹æ³...",
"output": "æç« 主è¦è®¨è®ºäºLLMå¾®è°çéè¦æ§ãPEFTæ¹æ³ï¼å¦LoRA, QLoRA, Prefix Tuningï¼çåçä¸å®è·µï¼ä»¥åæ°æ®åå¤çå
³é®æå·§ã"
},
{
"instruction": "æ¨èä¸é¨å
³äºAIçç§å¹»çµå½±ã",
"input": "", // æ ä¸ä¸æ
"output": "1. ã2001太空漫游ã 2. ãé¶ç¿¼ææã 3. ãæºæ¢°å§¬ã"
},
{
"instruction": "请å°ä»¥ä¸è±æå¥åç¿»è¯æä¸æï¼å¹¶ä¿æä¸ä¸æ¯è¯ä¸åã",
"input": "Large Language Models (LLMs) are revolutionizing the field of Artificial Intelligence.",
"output": "大åè¯è¨æ¨¡åï¼LLMï¼æ£å¨å½»åºæ¹å人工æºè½é¢åã"
}
]
å¨å°è¿äºæ°æ®è¾å
¥æ¨¡ååï¼æä»¬é常ä¼å°å
¶è½¬æ¢ä¸ºç¹å®çPrompt模æ¿ï¼ä»¥å¼å¯¼æ¨¡åæç
§æä»¤è¿è¡åå¤ãä¸ä¸ªå¸¸è§çæ¨¡æ¿æ¯åºäº Alpaca æ ChatML æ ¼å¼çãè¿äºæ¨¡æ¿éè¿ç¹å®çæ è¯ç¬¦ï¼å¦### Instruction:ï¼### Response:ï¼æ<|im_start|>userï¼<|im_end|>ï¼æ¥åºåæä»¤ãä¸ä¸æåææçè¾åºï¼å¸®å©æ¨¡åæ´å¥½å°çè§£ä¸åé¨åçè¯ä¹ãéæ©ä¸ä¸ªä¸æ¨¡åé¢è®ç»æ¶ä½¿ç¨ç模æ¿ç¸ä¼¼çæ ¼å¼ï¼é常è½è·å¾æ´å¥½çææã
代ç 示ä¾ï¼å è½½ä¸æ ¼å¼åæä»¤å¾®è°æ°æ®é
# åºç¡ç¤ºä¾ä»£ç ï¼ä½¿ç¨Hugging Face `datasets`åºå è½½åæ ¼å¼åæ°æ®
from datasets import Dataset
import pandas as pd
import random
# 模æä¸ä¸ªæä»¤å¾®è°æ°æ®é
# çå®åºæ¯ä¸ï¼ä½ ä¼ä»JSONLãCSVçæä»¶å è½½ä½ çæ°æ®é
instruction_data = [
{"instruction": "请ç¨ä¸å¥è¯æ¦æ¬Pythonè¯è¨çç¹ç¹ã", "input": "", "output": "Pythonæ¯ä¸ç§é«çº§çãè§£éåçãéç¨çç¼ç¨è¯è¨ï¼ä»¥å
¶ç®æ´æäºçè¯æ³å强大çåºçæç³»ç»èé»åã"},
{"instruction": "å¦ä½è®¡ç®å表[1,2,3,4,5]çå¹³åå¼ï¼", "input": "", "output": "å¯ä»¥ä½¿ç¨Pythonå
ç½®ç`sum()`彿°å`len()`彿°ï¼å³`sum(my_list) / len(my_list)`ã"},
{"instruction": "请解éä¸ä¸æ¢¯åº¦ä¸éï¼Gradient Descentï¼ç®æ³ã", "input": "", "output": "梯度ä¸éæ¯ä¸ç§ä¼åç®æ³ï¼ç¨äºå¯»æ¾å½æ°æå°å¼ï¼éè¿æ²¿ç彿°æ¢¯åº¦ï¼æçï¼çåæ¹åè¿ä»£ç§»å¨æ¥å®ç°ã"},
{"instruction": "å°ä»¥ä¸å¥åç¿»è¯ææ³è¯ï¼'Hello, how are you?'", "input": "", "output": "Bonjour, comment allez-vous ?"},
{"instruction": "è¯·æ ¹æ®ä»¥ä¸ä¿¡æ¯ï¼çæä¸ä»½ç®çç产ååå¸éç¥é®ä»¶ã产ååç§°ï¼AI婿V1.0ï¼å叿¥æï¼2023å¹´10æ26æ¥ï¼äº®ç¹ï¼æºè½é®çãå¤è¯è¨æ¯æã", "input": "产ååç§°ï¼AI婿V1.0ï¼å叿¥æï¼2023å¹´10æ26æ¥ï¼äº®ç¹ï¼æºè½é®çãå¤è¯è¨æ¯æã", "output": "主é¢ï¼AI婿V1.0ééåå¸ï¼
亲ç±çç¨æ·ï¼
æä»¬æ¿å¨å°å®£å¸ï¼AI婿V1.0å°äº2023å¹´10æ26æ¥æ£å¼ä¸çº¿ï¼æ°çæ¬å¸¦æ¥äºæºè½é®çãå¤è¯è¨æ¯æç强大åè½ï¼æ¨å¨ä¸ºæ¨æä¾æ´é«æãæ´ä¾¿æ·çæºè½ä½éªã
æè°¢æ¨çæ¯æï¼
AI婿å¢é"}
]
# å°åå§æ°æ®è½¬æ¢ä¸ºHugging Face Dataset对象
dataset = Dataset.from_list(instruction_data)
# å®ä¹ä¸ä¸ªæ ¼å¼å彿°ï¼éµå¾ªAlpaca-likeçPrompt模æ¿
def format_alpaca_prompt(example):
instruction = example["instruction"]
input_text = example["input"]
output_text = example["output"]
if input_text:
# å
å«ä¸ä¸æç模æ¿
full_text = (
f"### Instruction:
{instruction}
"
f"### Input:
{input_text}
"
f"### Response:
{output_text}"
)
else:
# ä¸å«ä¸ä¸æç模æ¿
full_text = (
f"### Instruction:
{instruction}
"
f"### Response:
{output_text}"
)
return {"text": full_text}
# åºç¨æ ¼å¼å彿°
formatted_dataset = dataset.map(format_alpaca_prompt)
print("åå§æ°æ®æ ·ä¾ (éæºä¸ä¸ª):")
print(random.choice(instruction_data))
print("
æ ¼å¼ååçæ°æ®æ ·ä¾ (éæºä¸ä¸ª):")
print(random.choice(formatted_dataset)["text"])
# 好çå®è·µï¼æ°æ®æ¸
æ´ä¸è¿æ»¤
def clean_and_filter_data(example):
# 示ä¾ï¼ç§»é¤è¿çæè¿é¿çæ ·æ¬ï¼æ£æ¥æ¯å¦æç©ºå¼
if not example["text"] or len(example["text"]) < 50 or len(example["text"]) > 2000:
return False # è¿æ»¤æä¸ç¬¦åé¿åº¦è¦æ±çæ ·æ¬
# å¯ä»¥å¨è¿éæ·»å æ´å¤æçé»è¾ï¼ä¾å¦æ£æ¥JSONæ ¼å¼æ¯å¦ææãç§»é¤é夿°æ®ç
return True
# è¿æ»¤æ°æ®
cleaned_dataset = formatted_dataset.filter(clean_and_filter_data)
print(f"
åå§æ ·æ¬æ°: {len(formatted_dataset)}, æ¸
æ´åæ ·æ¬æ°: {len(cleaned_dataset)}")
# ä¸å¥½çå®è·µï¼ç®åç²æ´å°ç§»é¤ææå
å«ç¹å®å
³é®è¯çæ ·æ¬ï¼å¯è½å¯¼è´è¯¯å
# def bad_filter_data(example):
# if "广å" in example["text"] or "ä¿é" in example["text"]:
# return False
# return True
# bad_cleaned_dataset = formatted_dataset.filter(bad_filter_data)
# print(f"ç²æ´è¿æ»¤åæ ·æ¬æ°: {len(bad_cleaned_dataset)}") # å¯è½ä¼è¯¯å ä¸å¹¿åç¸å
³çæ£å¸¸äº¤æµ
# æ¨èåæ³ï¼ä½¿ç¨æ´æºè½çå
³é®è¯å¹é
ææ¨¡åè¿è¡æ°æ®æ 注åè¿æ»¤
# ä¾å¦ï¼å¯ä»¥ä½¿ç¨ä¸ä¸ªå°çææ¬åç±»æ¨¡åæ¥è¯å«ä½è´¨éæä¸ç¸å
³çæ ·æ¬ã
# æè
ä½¿ç¨æ£å表达å¼è¿è¡æ´ç²¾ç¡®çå¹é
ï¼å¹¶äººå·¥å¤æ ¸ã
è¿æ®µä»£ç å±ç¤ºäºå¦ä½å©ç¨datasetsåºå°åå§æ°æ®æ ¼å¼å为模åå好çPrompt模æ¿ï¼å¹¶æä¾äºä¸ä¸ªç®åçæ°æ®æ¸
æ´ç¤ºä¾ãä¼è´¨çæ°æ®æ¯å¾®è°æåçåºç³ï¼æå
¥æ¶é´è¿è¡æ°æ®æ¶éãæ¸
æ´ãæ æ³¨åæ ¼å¼åï¼ç»å¯¹æ¯å¼å¾çã
第åç« ï¼è¿é¶æå·§ä¸å®æï¼ä¼åãé·é±ä¸é¨ç½²
å½æä»¬ææ¡äºPEFTçåºæ¬åçåæ°æ®åå¤åï¼ä¸ä¸æ¥å°±æ¯æ·±å ¥äºè§£å¦ä½è¿ä¸æ¥ä¼åå¾®è°è¿ç¨ãé¿å 常è§é·é±ï¼å¹¶æç»å°å¾®è°åçæ¨¡åæå ¥å®é åºç¨ã
4.1 æ§è½ä¼åçç¥ï¼ä¸ä» ä» æ¯QLoRA
é¤äºQLoRA带æ¥çå ååé度ä¼å¿ï¼è¿æå¤ç§æ¹æ³å¯ä»¥è¿ä¸æ¥ä¼åLLMå¾®è°çæ§è½åæçã
-
梯度累积ï¼Gradient Accumulationï¼ï¼
å½ä½ çGPUå åä¸è¶³ä»¥å®¹çº³ä¸ä¸ªå¤§çbatch_sizeæ¶ï¼æ¢¯åº¦ç´¯ç§¯å è®¸ä½ ä½¿ç¨å°çbatch_sizeè¿è¡å¤æ¬¡åååååä¼ æï¼ç¶å累积è¿äºæ¢¯åº¦ï¼æ¯Næ¥ææ§è¡ä¸æ¬¡åæ°æ´æ°ãè¿æ¨¡æäºä¸ä¸ªæ´å¤§çææbatch_sizeï¼å¯ä»¥å¸®å©æ¨¡åæ¶æå°æ´å¥½çç»æï¼åæ¶é¿å OOMï¼Out Of Memoryï¼ã# æ§è½ä¼å代ç 示ä¾ï¼æ¢¯åº¦ç´¯ç§¯ä¸æ··å精度è®ç» from transformers import TrainingArguments, Trainer import torch # å设æä»¬å·²ç»æäºpeft_model, tokenizer, tokenized_dataset # peft_model = ... # tokenizer = ... # tokenized_dataset = ... # æ¨èåæ³ï¼é ç½®TrainingArgumentsè¿è¡æ¢¯åº¦ç´¯ç§¯åæ··å精度 training_args = TrainingArguments( output_dir="./fine_tuned_model", num_train_epochs=3, per_device_train_batch_size=4, # å®é çbatch size gradient_accumulation_steps=8, # æ¢¯åº¦ç´¯ç§¯æ¥æ°ï¼ææbatch size = 4 * 8 = 32 learning_rate=2e-4, logging_steps=10, save_strategy="epoch", fp16=False, # ç¦ç¨fp16ï¼å 为æä»¬å·²ç»å¨QLoRAä¸ä½¿ç¨äºbfloat16 bf16=True, # å¯ç¨bfloat16æ··å精度è®ç»ï¼å¤§å¹ æåé度åç¨³å®æ§ optim="paged_adamw_8bit", # QLoRAæ¨èçä¼åå¨ï¼è¿ä¸æ¥èçå å report_to="tensorboard", ) # 䏿¨èï¼ä¸ä½¿ç¨æ¢¯åº¦ç´¯ç§¯åæ··å精度ï¼å¯è½å¯¼è´OOMæè®ç»æçä½ä¸ # bad_training_args = TrainingArguments( # output_dir="./bad_model", # per_device_train_batch_size=32, # 妿GPUå åä¸è¶³ï¼ä¼OOM # fp16=False, # ä¸ä½¿ç¨æ··å精度ï¼è®ç»éåº¦æ ¢ # ) # å®ä¾åTrainer # trainer = Trainer( # model=peft_model, # args=training_args, # train_dataset=tokenized_dataset, # tokenizer=tokenizer, # ) # trainer.train() print(" TrainingArgumentsé ç½®å®æï¼å 嫿¢¯åº¦ç´¯ç§¯åbfloat16æ··å精度ã") print(f" ææBatch Size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}") print(f" æ··å精度è®ç»: {'bfloat16' if training_args.bf16 else ('fp16' if training_args.fp16 else 'Disabled')}") print(f" ä¼åå¨: {training_args.optim}") ` `` **å ³é®ç¹è§£æ**ï¼éè¿ `gradient_accumulation_steps=8`ï¼æä»¬å¯ä»¥å¨ä» æå°`batch_size=4`çæ åµä¸ï¼å®ç°ç¸å½äº`batch_size=32`çè®ç»ææã -
æ··å精度è®ç»ï¼Mixed Precision Trainingï¼ï¼
ç°ä»£GPUï¼å¦NVIDIA A100ãH100ï¼ææ¶è´¹çº§çRTX 30ç³»ã40ç³»ï¼æ¯æFP16æBF16ï¼bfloat16ï¼æµ®ç¹æ°æ ¼å¼ãå¨è®ç»æ¶ï¼ä½¿ç¨BF16è¿è¡å¤§é¨å计ç®ï¼åæ¶ä¿æé¨åå ³é®è®¡ç®ï¼å¦æéæ´æ°ï¼ä½¿ç¨FP32ï¼å¯ä»¥å¨ä¸æå¤±å¤ªå¤ç²¾åº¦çæ åµä¸ï¼æ¾èå éè®ç»å¹¶åå°å åå ç¨ãtransformersåºçTraineréè¿è®¾ç½®fp16=Trueæbf16=Trueå³å¯è½»æ¾å¯ç¨ãBF16é常æ¯FP16æ´ç¨³å®ï¼å ä¸ºå®ææ´å®½çææ°èå´ã -
FlashAttention / xFormersï¼
FlashAttentionæ¯ä¸ç§é«æç注æåæºå¶å®ç°ï¼éè¿åå°å åI/Oæä½ï¼æ¾èå éTransformer模åçè®ç»åæ¨çãå®é常ä¸xFormersåºéæï¼å¨Hugging Facetransformers模åä¸ï¼éè¿å®è£xFormers并设置attn_implementation="flash_attention_2"å³å¯å¯ç¨ï¼å¯¹è®¡ç®å¯éå任塿巍大æåã
python # æ§è½ä¼å代ç 示ä¾ï¼å¯ç¨FlashAttention (æ¦å¿µæ§) # ç¡®ä¿ä½ å·²ç»å®è£ äºxformers: pip install xformers # from transformers import AutoModelForCausalLM # model_flash = AutoModelForCausalLM.from_pretrained( # "google/gemma-2b", # torch_dtype=torch.bfloat16, # device_map="auto", # attn_implementation="flash_attention_2" # å ³é®åæ°ï¼å¯ç¨FlashAttention # ) print(" å¯ç¨FlashAttention_2 (éå®è£ xformers)ï¼å¯è¿ä¸æ¥æå注æå计ç®é度ã")
å ³é®ç¹è§£æï¼FlashAttention 2 å¯ä»¥å¸¦æ¥ 2-4 åçé度æåï¼å¹¶åå°å åå ç¨é«è¾¾ 2 åï¼å°¤å ¶å¨é¿åºåå¤çæ¶æææ¾èã
4.2 常è§é·é±ä¸è§£å³æ¹æ¡ï¼é¿å¼å¾®è°ä¹è·¯çâåâ
å¾®è°LLMå¹¶éä¸å¸é£é¡ºï¼ä»¥ä¸æ¯ä¸äºå¸¸è§çé·é±å对åºçè§£å³æ¹æ¡ï¼
-
ç¾é¾æ§éå¿ï¼Catastrophic Forgettingï¼ï¼
-
é·é±ï¼å¨å¾®è°ç¹å®ä»»å¡æ¶ï¼æ¨¡åå¯è½å¿è®°å ¶å¨é¢è®ç»é¶æ®µå¦å°çéç¨ç¥è¯ï¼å¯¼è´æ³åè½åä¸éã
-
è§£å³æ¹æ¡ï¼
- 使ç¨PEFTæ¹æ³ï¼å¦LoRAï¼å®å»ç»äºå¤§é¨ååå§æéï¼åªæ´æ°å°éåæ°ï¼ææä¿çäºéç¨ç¥è¯ã
- æ°æ®æ··åï¼Data Blendingï¼ï¼å¨å¾®è°æ°æ®ä¸æ··åå°ééç¨é¢åæ°æ®æå¤ä»»å¡æ°æ®ï¼ä»¥æé模åä¸è¦å¿è®°éç¨ç¥è¯ã
- ç¥è¯è¸é¦ï¼Knowledge Distillationï¼ï¼ä½¿ç¨æªå¾®è°çåå§æ¨¡åä½ä¸ºâæå¸æ¨¡åâï¼éè¿è¸é¦çæ¹å¼å°éç¨ç¥è¯è¿ç§»å°å¾®è°åçâå¦ç模åâä¸ã
-
-
è¿æåï¼Overfittingï¼ä¸æ¬ æåï¼Underfittingï¼ï¼
-
é·é±ï¼
- è¿æåï¼æ¨¡åå¨è®ç»æ°æ®ä¸è¡¨ç°æå¥½ï¼ä½å¨æªè§è¿çæ°æ®ä¸è¡¨ç°å·®ãè¿é常æ¯ç±äºè®ç»æ°æ®éä¸è¶³ãå¦ä¹ çè¿é«ãè®ç»è½®æ¬¡è¿å¤ææ¨¡åå¤æåº¦è¿é«ï¼å¦LoRAç
rå¼è¿å¤§ï¼å¯¼è´ã - æ¬ æåï¼æ¨¡åå¨è®ç»æ°æ®åæµè¯æ°æ®ä¸é½è¡¨ç°ä¸ä½³ãè¿å¯è½æ¯å 为è®ç»æ°æ®éè¿å°ã模åå¤æåº¦ä¸è¶³ï¼å¦LoRAç
rå¼è¿å°ï¼ãå¦ä¹ çè¿ä½æè®ç»è½®æ¬¡ä¸è¶³ã
- è¿æåï¼æ¨¡åå¨è®ç»æ°æ®ä¸è¡¨ç°æå¥½ï¼ä½å¨æªè§è¿çæ°æ®ä¸è¡¨ç°å·®ãè¿é常æ¯ç±äºè®ç»æ°æ®éä¸è¶³ãå¦ä¹ çè¿é«ãè®ç»è½®æ¬¡è¿å¤ææ¨¡åå¤æåº¦è¿é«ï¼å¦LoRAç
-
è§£å³æ¹æ¡ï¼
- æ°æ®å¢å¼ºï¼æ©å è®ç»æ°æ®éï¼æé«æ°æ®ç夿 ·æ§ã
- æ£ååï¼å¨LoRAé
ç½®ä¸å¢å
lora_dropoutï¼æä½¿ç¨æéè¡°åï¼weight_decayï¼ã - æ©åï¼Early Stoppingï¼ï¼çæ§éªè¯éä¸çæ§è½ï¼å½æ§è½ä¸åæåæ¶åæ¢è®ç»ã
- è¶
åæ°è°ä¼ï¼ä»ç»è°æ´å¦ä¹ çãLoRAç
rålora_alphaçåæ°ã
常è§é·é±ä»£ç 示ä¾ï¼æ©åçç¥ (æ¦å¿µæ§)
from transformers import EarlyStoppingCallback
training_args_with_early_stopping = TrainingArguments(
output_dir="./fine_tuned_model_early_stop",
evaluation_strategy="steps", # æ¯éä¸å®æ¥æ°è¯ä¼°ä¸æ¬¡
eval_steps=50, # è¯ä¼°é´é
load_best_model_at_end=True, # è®ç»ç»æåå è½½éªè¯éä¸è¡¨ç°æå¥½ç模å
metric_for_best_model="eval_loss", # çæ§ææ
greater_is_better=False, # æå¤±è¶å°è¶å¥½
# ... å ¶ä»åæ° ...
)
trainer_with_early_stopping = Trainer(
model=peft_model,
args=training_args_with_early_stopping,
train_dataset=tokenized_dataset,
eval_dataset=tokenized_validation_dataset, # éè¦æä¾éªè¯é
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)], # è¿ç»3次éªè¯æå¤±ä¸ä¸éå忢
tokenizer=tokenizer,
)
trainer_with_early_stopping.train()
print("
æ©åçç¥æ¯é²æ¢è¿æåçææææ®µï¼éè¿çæ§éªè¯éæ§è½æ¥å³å®ä½æ¶åæ¢è®ç»ã")
-
-
æ°æ®è´¨éé®é¢ï¼
-
é·é±ï¼è®ç»æ°æ®ä¸å å«åªå£°ãé误ãä¸ä¸è´çæ ¼å¼ãåè§æä¸ç¸å ³çä¿¡æ¯ã
-
è§£å³æ¹æ¡ï¼
- ä¸¥æ ¼çæ°æ®æ¸ æ´ï¼å»é¤éå¤ãæªæãæ ¼å¼éè¯¯çæ°æ®ã
- äººå·¥å®¡æ ¸ä¸æ 注ï¼å¯¹äºå ³é®æ°æ®ï¼è¿è¡é«è´¨éçäººå·¥å®¡æ ¸åæ æ³¨ã
- æ°æ®å»åï¼è¯å«å¹¶å°è¯åå°æ°æ®ä¸çåè§ï¼ç¡®ä¿å¤æ ·æ§å代表æ§ã
- 代ç 示ä¾ï¼æ°æ®æ¸
æ´ï¼ï¼
æ¨èåæ³ï¼é²æ£çæ°æ®æ¸ æ´å½æ°
def robust_data_cleaner(example):
text = example["text"]1. ç§»é¤å¤ä½ç空ç½ç¬¦
text = ' '.join(text.split())
2. ç»ä¸æ ç¹ç¬¦å· (ä¾å¦ï¼å°å ¨è§éå·è½¬æ¢ä¸ºåè§)
text = text.replace('ï¼', ',').replace('ã', '.')
3. æ£æ¥æ¯å¦æææ¾çä¹±ç æHTMLæ ç¾
if "" in text or "<html" in text:
return None # è¿æ»¤æç½é¡µå 容4. æ£æ¥é¿åº¦æææ§ (é¿å æç«¯é¿çå¥)
if len(text) < 30 or len(text) > 3000:
return None5. å¯éï¼ä½¿ç¨æ£å表达å¼å»é¤ç¹å®æ¨¡å¼ï¼å¦URLãé®ç®±çï¼
import re
text = re.sub(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*(),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', '', text)
example["text"] = text
return examplecleaned_dataset_robust = formatted_dataset.map(robust_data_cleaner).filter(lambda x: x is not None)
print(f"
鲿£æ¸ æ´åçæ ·æ¬æ°: {len(cleaned_dataset_robust)}")
䏿¨èï¼ä¸è¿è¡ä»»ä½æ¸ æ´ï¼ç´æ¥ä½¿ç¨åå§æ°æ®ï¼å¯è½å¼å ¥å¤§éåªå£°ã
bad_data = ["è¿æ¯ä¸ä¸ªæµè¯ã", "è¿æ¯ä¸ä¸ªæµè¯ã", "å徿°æ®", "å¾çã"]
raw_dataset = Dataset.from_dict({"text": bad_data})
print(f"
æªæ¸ æ´æ°æ®æ ·ä¾: {raw_dataset['text']}")
**å ³é®ç¹è§£æ**ï¼æ°æ®æ¸ æ´æ¯ä¸ä¸ªè¿ä»£çè¿ç¨ï¼éè¦ç»åé¢åç¥è¯åèªå¨åå·¥å ·ã
-
-
è¶ åæ°è°ä¼ï¼Hyperparameter Tuningï¼ï¼
-
é·é±ï¼ç²ç®éæ©å¦ä¹ çãLoRAç§©ï¼
rï¼ãlora_alphaçï¼å¯¼è´è®ç»ä¸ç¨³å®ææ§è½ä¸ä½³ã -
è§£å³æ¹æ¡ï¼
- ç½æ ¼æç´¢/éæºæç´¢ï¼å¯¹å ³é®è¶ åæ°è¿è¡ç³»ç»æ§æ¢ç´¢ã
- è´å¶æ¯ä¼åï¼æ´é«æçè¶
åæ°æç´¢æ¹æ³ï¼å¦ä½¿ç¨
OptunaæRay Tuneï¼ã - å¦ä¹ çè°åº¦å¨ï¼Learning Rate Schedulerï¼ï¼ä½¿ç¨ä½å¼¦éç«ï¼Cosine Annealingï¼çè°åº¦å¨ï¼å¨æè°æ´å¦ä¹ çï¼å¸®å©æ¨¡åæ´å¥½å°æ¶æã
-
4.3 å¾®è°åçé¨ç½²ä¸åºç¨ï¼è®©æ¨¡åè½å°
å¾®è°å®æåï¼å¦ä½å°æ¨¡åé¨ç½²å°ç产ç¯å¢æ¯å ³é®ä¸æ¥ã
-
åå¹¶LoRAæéï¼
LoRAéé å¨å¨è®ç»å®æåï¼å¯ä»¥ä¸åå§é¢è®ç»æ¨¡åçæéåå¹¶ãè¿ä½¿å¾é¨ç½²è¿ç¨æ´ç®åï¼å ä¸ºä½ åªéè¦å è½½ä¸ä¸ªå®æ´ç模åï¼èæ éåå«å è½½åºåº§æ¨¡ååéé å¨ã# 宿´é¡¹ç®ä»£ç (ç®å)ï¼LoRA模åä¿åä¸åå¹¶ # from peft import PeftModel # from transformers import AutoModelForCausalLM, AutoTokenizer # peft_model = ... # åè®¾è¿æ¯æä»¬è®ç»å¥½çPEFT模å # tokenizer = ... # 对åºçtokenizer # ä¿åLoRAéé å¨ # peft_model.save_pretrained("./my_lora_adapter") # tokenizer.save_pretrained("./my_lora_adapter") print(" LoRAéé å¨å·²ä¿åå° `./my_lora_adapter`ã") # å è½½åºåº§æ¨¡å # base_model = AutoModelForCausalLM.from_pretrained( # "google/gemma-2b", # torch_dtype=torch.bfloat16, # ) # # # å°LoRAéé å¨åå¹¶å°åºåº§æ¨¡å # merged_model = PeftModel.from_pretrained(base_model, "./my_lora_adapter") # merged_model = merged_model.merge_and_unload() # åå¹¶å¹¶å¸è½½PEFTç»æ # # # ä¿ååå¹¶åçæ¨¡å # merged_model.save_pretrained("./my_merged_model") # tokenizer.save_pretrained("./my_merged_model") print(" LoRAéé å¨å¯ä»¥ä¸åºåº§æ¨¡ååå¹¶ï¼å½¢æä¸ä¸ªå®æ´çå¾®è°æ¨¡åï¼ä¾¿äºé¨ç½²ã") -
模åæå¡ï¼Model Servingï¼ï¼
对äºç产ç¯å¢ï¼ç´æ¥ä½¿ç¨Hugging Faceçpipelineè¿è¡æ¨çå¯è½æçä¸é«ãå¯ä»¥èè使ç¨ä¸é¨çLLMæå¡æ¡æ¶ï¼- **
vLLM**ï¼ä¸ä¸ªé«æ§è½çLLMæ¨çåæå¡å¼æï¼æ¯æè¿ç»æ¹å¤çï¼continuous batchingï¼åPagedAttentionç®æ³ï¼æ¾èæé«ååéåéä½å»¶è¿ã Text Generation Inference (TGI)ï¼Hugging Faceå¼åçç产级æ¨çè§£å³æ¹æ¡ï¼æ¯æå¤ç§ä¼åï¼å¦FlashAttentionãéåãè¿ç»æ¹å¤çï¼ã
è¿äºå·¥å ·è½å¤å¸®å©ä½ 卿éçç¡¬ä»¶èµæºä¸ï¼ä¸ºå¤ä¸ªç¨æ·æä¾é«æã稳å®çLLMæå¡ã
- **
第äºç« ï¼æ»ç»ä¸å±æï¼å¼å¯ä½ çä¸å±LLM乿
æåä½ ï¼éè¿æ¬æçæ¢ç´¢ï¼æä»¬å·²ç»æ·±å ¥äºè§£äºLLMå¾®è°ç奥ç§ãé«æå¾®è°æ¹æ³ï¼PEFTï¼çæ ¸å¿åçãæ°æ®åå¤çèºæ¯ï¼ä»¥åå¨å®æä¸ä¼åæ§è½åé¿å é·é±çæå·§ãç°å¨ï¼ä½ å·²ç»ææ¡äºè®©éç¨LLMå为ä¸å±æºè½å©æçå ³é®è½åã
5.1 æ ¸å¿ç¥è¯ç¹å顾
- **LLMå¾®è°çå¿ è¦æ§**ï¼éç¨LLMå¨ç¹å®é¢åæä»»å¡ä¸è¡¨ç°ä¸è¶³ï¼å¾®è°æ¯æåå ¶ä¸ä¸æ§åè¡ä¸ºæ¨¡å¼çå ³é®ã
- ä¼ ç»å¾®è°çææï¼é«ç®åãé¿æ¶é´åç¾é¾æ§éå¿æ¯å ¶ä¸»è¦ç¶é¢ã
- PEFTï¼åæ°é«æå¾®è°ï¼ï¼ä»¥å°å大ççç¥ï¼éè¿å¼å ¥å°éå¯è®ç»åæ°æä¿®æ¹è¾å ¥æç¤ºï¼å®ç°é«æå¾®è°ã
- LoRAï¼éè¿æ³¨å ¥ä½ç§©ç©éµæ´æ°æ¨¡åï¼å¤§å¹ åå°å¯è®ç»åæ°ï¼å¹³è¡¡æ§è½ä¸æçã
- QLoRAï¼å¨LoRAåºç¡ä¸ç»å4-bitéåï¼å°å åå ç¨æ¨åæè´ï¼è®©æ¶è´¹çº§GPUä¹è½å¾®è°å¤§åLLMã
- Prompt Tuning / Prefix Tuningï¼å»ç»æ¨¡å主ä½ï¼åªè®ç»è½¯æç¤ºæåç¼ï¼ä»¥æå°åæ°å¼å¯¼æ¨¡åè¡ä¸ºã
- æ°æ®ä¸ºçï¼é«è´¨éãæ ¼å¼è§èçæä»¤å¾®è°æ°æ®æ¯æåçåºç³ï¼æ°æ®æ¸ æ´å模æ¿åè³å ³éè¦ã
- æ§è½ä¼åï¼æ¢¯åº¦ç´¯ç§¯ãæ··å精度è®ç»ãFlashAttentionçææ¯å¯è¿ä¸æ¥æåè®ç»æçã
- 常è§é·é±ï¼ç¾é¾æ§éå¿ãè¿æåãæ°æ®è´¨éé®é¢åè¶ åæ°è°ä¼æ¯å¾®è°è¿ç¨ä¸çä¸»è¦ææã
- 模åé¨ç½²ï¼è®ç»åçLoRAæéå¯åå¹¶ï¼å¹¶éè¿
vLLMãTGIçå·¥å ·è¿è¡é«ææå¡ã
5.2 å®æå»ºè®®
-
ä»å°å¤çæï¼ä»ä¸ä¸ªè¾å°ç模åï¼å¦Gemma-2BãLlama-3-8Bï¼åå°éæ°æ®å¼å§ï¼å¿«éè¿ä»£ï¼éªè¯ææã
-
éæ©åéçPEFTæ¹æ³ï¼
- èµæºæéä¸è¿½æ±æè´å åä¼åï¼QLoRAæ¯é¦éã
- **叿卿§è½åèµæºé´åå¾å¹³è¡¡**ï¼LoRAæ¯éç¨ä¸å¼ºå¤§çéæ©ã
- åæ°é¢ç®æä½ï¼æä» éè¿è¡è½»é级è¡ä¸ºå¼å¯¼ï¼Prompt Tuning/Prefix Tuningå¯è½æ´åéã
-
æå ¥æ°æ®åå¤ï¼è±è¶³å¤çæ¶é´æ¶éãæ¸ æ´åæ ¼å¼åæ°æ®ãæ°æ®è´¨éæ¯æ°æ®æ°éæ´éè¦ã
-
ç»è´è°ä¼è¶ åæ°ï¼å¦ä¹ çãLoRAç
rålora_alphaå¯¹æ¨¡åæ§è½å½±å巨大ï¼å¤åå®éªã -
çæ§ä¸è¯ä¼°ï¼å¨è®ç»è¿ç¨ä¸æç»çæ§è®ç»æå¤±åéªè¯éææ ï¼åæ¶åç°è¿æåææ¬ æåã
-
å©ç¨ç¤¾åºèµæºï¼Hugging Faceçæç³»ç»ï¼
transformersãpeftãdatasetsï¼æä¾äºä¸°å¯çå·¥å ·åé¢è®ç»æ¨¡åï¼æ¯ä½ çå®èã
5.3 è¿é¶æ¹å䏿ªæ¥è¶å¿
LLMå¾®è°é¢åä»å¨å¿«éåå±ï¼ä»¥ä¸æ¯ä¸äºå¼å¾å ³æ³¨çè¿é¶æ¹åï¼
- å好对é½ï¼Preference Alignmentï¼ï¼å¦DPOï¼Direct Preference Optimizationï¼ å RLHFï¼Reinforcement Learning from Human Feedbackï¼ï¼éè¿äººç±»åé¦è¿ä¸æ¥ä¼å模åè¡ä¸ºï¼ä½¿å ¶æ´å¥½å°ç¬¦å人类价å¼è§åå好ã
- æ£ç´¢å¢å¼ºçæï¼RAG, Retrieval Augmented Generationï¼ï¼å°LLMä¸å¤é¨ç¥è¯åºç»åï¼ä½¿å ¶è½å¤æ£ç´¢ææ°ãæåç¡®çä¿¡æ¯è¿è¡çæï¼è§£å³LLMçâå¹»è§âé®é¢ã
- 夿¨¡æLLMå¾®è°ï¼å°LLMæ©å±å°å¾åãé³é¢çå¤ç§æ¨¡æï¼ä¾å¦å¾®è°è§è§-è¯è¨æ¨¡åï¼VLMï¼ä»¥çè§£åçæå¾åæè¿°ã
- Agentic Workflowsï¼å¾®è°LLMä½¿å ¶è½å¤ä½ä¸ºæºè½ä½ï¼èªä¸»è§åãè°ç¨å·¥å ·ãæ§è¡å¤æä»»å¡ã
- æ´å è¿çPEFTæ¹æ³ï¼ç 究人å䏿æåºæ°çPEFTæ¹æ³ï¼å¦LoRAçåç§ï¼å¦S-LoRA, DoRAï¼åæ´éç¨çéé 卿¡æ¶ã
LLMå¾®è°æ¯ä¸ä¸ªå 满ææä¸æºéçé¢åãéè¿æ¬æçæåï¼ç¸ä¿¡ä½ å·²ç»å ·å¤äºå¼å¯ä¸å±LLM乿 çç¥è¯åæè½ãç°å¨ï¼æ¯æ¶åå°è¿äºç论ä»è¯¸å®è·µï¼è®©ä½ çLLMå¨ç¹å®ä»»å¡ä¸å¤§æ¾å¼å½©äºï¼ç¥ä½ å¾®è°é¡ºå©ï¼AIä¹è·¯è¶èµ°è¶å®½å¹¿ï¼