å¼è¨ï¼æç ´ææ¬çè¾¹çï¼æ¥æ±å¤æ¨¡ææºè½
å¿ï¼å使æ¯ç±å¥½è ä»¬ï¼ éç大åè¯è¨æ¨¡åï¼LLMsï¼å¨èªç¶è¯è¨å¤çé¢åæèµ·é©å½ï¼æä»¬å·²ç»è§è¯äºå®ä»¬å¨ææ¬çæãçè§£ãæ¨çæ¹é¢çæäººè½åãç¶èï¼ä¸çå¹¶éåªææåææãæä»¬ççæ´»å 满äºå¾åã声é³ãè§é¢ç丰å¯ç夿¨¡æä¿¡æ¯ã设æ³ä¸ä¸ï¼å¦ææä»¬çLLMåªè½âé 读âæåï¼èæ æ³âçè§âå¾çä¸çç©ä½ï¼æ æ³â嬿âè¯é³ä¸çæ æï¼é£å®è½æä¾çæºè½æå¡å°å¤ä¹åéï¼
è¿å°±æ¯æä»¬ä»å¤©ç主é¢ââLLM夿¨¡æèåæ¹æ¡çéè¦æ§æå¨ãçº¯ææ¬LLMï¼å°±åä¸ä½å¦å¯äºè½¦å´åç®å¤±æçæºè ï¼è½è½è¨å辩ï¼å´æ æ³ç´æ¥æç¥ç°å®ä¸çç丰å¯ç»èãä¾å¦ï¼å½ä½ æ³è®©å®æ ¹æ®ä¸å¼ å¾ççæè¯¦ç»æè¿°æ¶ï¼å®å¯è½ä¼è¿æ ·âåçâï¼
# è¿æ¯ä¸ä¸ªçº¯ææ¬LLMé¢å¯¹å¾åç伪代ç
class TextOnlyLLM:
def generate_description(self, prompt: str, image_path: str = None) -> str:
if image_path:
# çº¯ææ¬LLMæ æ³ç´æ¥å¤çå¾åæä»¶
return f"对ä¸èµ·ï¼ææ æ³ç´æ¥çè§£å¾åæä»¶ã请æ¨ç¨æåæè¿°å¾åå
å®¹ï¼æä¼å°½åçæãæ¨æå°äºï¼'{prompt}'"
else:
# åªè½å¤çææ¬æç¤º
return self._text_generation_logic(prompt)
# åºæ¯ï¼ç¨æ·ä¸ä¼ ä¸å¼ ç«åªå¾çï¼å¹¶é®âè¿æ¯ä»ä¹ï¼â
llm = TextOnlyLLM()
image_file = "path/to/cat_image.jpg"
response = llm.generate_description(prompt="请æè¿°è¿å¼ å¾ç", image_path=image_file)
print(response)
# è¾åº: 对ä¸èµ·ï¼ææ æ³ç´æ¥çè§£å¾åæä»¶ãææ æ³ç´æ¥çè§£å¾åæä»¶ã请æ¨ç¨æåæè¿°å¾åå
å®¹ï¼æä¼å°½åçæãæ¨æå°äºï¼'请æè¿°è¿å¼ å¾ç'
è¿æ¾ç¶ä¸æ¯æä»¬æå¾ çæºè½ã为äºè®©LLMè½å¤çæ£å°âæç¥âä¸çï¼ç解并çæè·¨è¶å¤ç§æ¨¡æçä¿¡æ¯ï¼å¤æ¨¡æèåæä¸ºäºä¸å¯æç¼ºçå ³é®ææ¯ã宿¨å¨å°ä¸å模æçæ°æ®ï¼å¦ææ¬ãå¾åãé³é¢ï¼æ´åèµ·æ¥ï¼è®©æ¨¡åè½å¤æ´å ¨é¢ãæ´æ·±å ¥å°çè§£è¾å ¥ä¿¡æ¯ï¼ä»èå®ç°æ´é«çº§å«çæºè½ã
ä»å¤©ï¼æä»¬å°æ·±å ¥æ¢è®¨LLM夿¨¡æèåçåç§æ¹æ¡ï¼ä»åºæ¬åçå°å沿模åï¼åå°å®æææä¸ä¼åæå·§ã让æä»¬ä¸èµ·ï¼ä¸ºå¤§æ¨¡åæä¸âæç¥âçç¿ èï¼
第ä¸ç« ï¼ä¸ºä½å¤æ¨¡æèåå¿å¨å¿ è¡ï¼
1.1 LLMçâç²åºâä¸âèåºâï¼ææ¬çå±éæ§
大åè¯è¨æ¨¡åï¼LLMsï¼å¨å¤çèªç¶è¯è¨ææ¬æ¹é¢è¡¨ç°åºè²ï¼è½å¤çè§£å¤æçè¯ä¹ãçæè¿è´¯çæç« ãè¿è¡é»è¾æ¨çãè¿å¾çäºå®ä»¬å¨æµ·éææ¬æ°æ®ä¸å¦ä¹ å°çè¯è¨æ¨¡å¼åä¸çç¥è¯ãç¶èï¼å®ä»¬çç¥è¯è¾¹ç乿¢æ¥äºææ¬ãå®ä»¬æ æ³ç´æ¥çè§£ä¸å¼ å¾çä¸å¯¹è±¡ç空é´å ³ç³»ï¼æ æ³æç¥ä¸æ®µé³é¢ä¸è¯´è¯è çæ ç»ªï¼æ´æ æ³çè§£ä¸æ®µè§é¢ä¸äºä»¶ç卿åå±ãè¿ç§å¯¹éææ¬ä¿¡æ¯çâæ ç¥âï¼æä»¬ç§°ä¹ä¸ºLLMçâç²åºâä¸âèåºâã
çº¯ææ¬LLMçå±éæ§å¨äºå ¶ç¼ºä¹å¯¹ç©çä¸ççæ¥å°ï¼Groundingï¼è½åãææ¬æ¯é«åº¦æ½è±¡ç符å·è¡¨ç¤ºï¼å®æè¿°çæ¯ä¸çï¼èéä¸çæ¬èº«ãä¾å¦ï¼ä¸ä¸ªçº¯ææ¬LLMå¯è½ç¥éâè¹æâæ¯ä¸ç§æ°´æï¼ä½å®æ æ³åºåä¸å¼ ç §çä¸ççº¢è¹æåéè¹æï¼ä¹æ æ³è¯å«åºç §ç䏿¯ä¸ä¸ªè¹æå ¬å¸çLogoãå®çè³ä¸ç¥éå¾åä¸çâç«âåææ¬ä¸çâç«âæ¯åä¸ä¸ªæ¦å¿µï¼å 为两è å¨åå§æ°æ®å±é¢ä¸æ¯å®å ¨ä¸åçè¡¨ç¤ºï¼æ¨¡åæ æ³å¤©ç¶å°å»ºç«èµ·è¿ç§è·¨æ¨¡æçè¯ä¹å ³èãè¿ç§âç²åºâéå¶äºLLMå¨éè¦çå®ä¸çæç¥å交äºåºæ¯ä¸çåºç¨ï¼ä½¿å ¶æ æ³åäººç±»ä¸æ ·ï¼éè¿å¤ç§æå®è¾å ¥æ¥å ¨é¢çè§£åååºç¯å¢ã
# çº¯ææ¬LLMå¤çææ¬ï¼æ æ³å
³èå¾å
def analyze_text_llm(text_input: str) -> str:
if "ç«" in text_input:
return "ææ¬ä¸æå°äºç«ï¼è¿æ¯ä¸ç§å¸¸è§çå® ç©ã"
return "ææ¬å¤çç»æã"
# å¤çå¾åç彿°ï¼ä½LLMæ æ³è°ç¨æçè§£å
¶è¾åºï¼
def process_image_vision_model(image_data: bytes) -> str:
# åè®¾è¿æ¯ä¸ä¸ªç¬ç«çè§è§æ¨¡åï¼è½å¤è¯å«å¾åå
容
# å®é
ä¸ä¼è¿åå¾åçåç±»ãæ£æµç»æææè¿°
if b"cat_features" in image_data: # æ¨¡ææ£æµå°ç«çç¹å¾
return "å¾åä¸å
å«ä¸åªç«ã"
return "å¾åè¯å«ç»æã"
# åºæ¯ï¼ç¨æ·å±ç¤ºä¸å¼ ç«çå¾çï¼å¹¶é®ï¼âè¿æ¯ä»ä¹ï¼â
user_image_data = b"...åå§ç«åªå¾çäºè¿å¶æ°æ®..." # 模æå¾çæ°æ®
user_text_question = "è¿æ¯ä»ä¹ï¼"
# çº¯ææ¬LLMçååº
llm_response = analyze_text_llm(user_text_question)
print(f"LLM (text-only) ååº: {llm_response}")
# è¾åº: LLM (text-only) ååº: ææ¬å¤çç»æã (å 为é®é¢ä¸æ²¡æç´æ¥æå°âç«âå)
# è§è§æ¨¡åçååºï¼ä¸LLMè±èï¼
vision_response = process_image_vision_model(user_image_data)
print(f"è§è§æ¨¡åååº: {vision_response}")
# è¾åº: è§è§æ¨¡åååº: å¾åä¸å
å«ä¸åªç«ã
# ä¸¤ä¸ªæ¨¡åæ æ³ååå·¥ä½ï¼å¯¼è´æ´ä½æºè½ä¸è¶³ãLLMæ æ³å©ç¨è§è§æ¨¡åçç»ææ¥åçé®é¢ã
1.2 夿¨¡ææºè½ç巨大价å¼
å°LLMä¸å¤æ¨¡ææç¥è½åç»åï¼è½å¤ unlock åææªæçæºè½åºç¨åºæ¯ï¼å¸¦æ¥å·¨å¤§çåä¸å社ä¼ä»·å¼ãå®ä½¿å¾AIè½å¤æ´å ¨é¢ãæ´æ·±å ¥å°çè§£ç°å®ä¸çï¼ä»èæä¾æ´èªç¶ãæ´å¼ºå¤§ç交äºåæå¡ã
- æºè½å®¢æä¸èæå©æï¼ ä¸åä» ä» ä¾èµæåï¼å¯ä»¥ç´æ¥çè§£ç¨æ·ä¸ä¼ çæªå¾ãè¯é³æ¶æ¯ï¼æä¾æ´åç¡®ãæ´äººæ§åçæå¡ãä¾å¦ï¼ç¨æ·æä¸çµå¨æ éç §çï¼AIå¯ä»¥ç´æ¥è¯æå¹¶æä¾ç»´ä¿®å»ºè®®ï¼éè¿è¯é³è¯å«ç¨æ·é®é¢ï¼å¹¶ç»åå±å¹å 容è¿è¡æä½æå¯¼ãè¿æå¤§å°æåäºç¨æ·ä½éªåé®é¢è§£å³æçã
- èªå¨é©¾é©¶ï¼ ç»åè§è§ï¼æå头ï¼ãé·è¾¾ãæ¿å é·è¾¾æ°æ®ä¸è¯è¨çè§£ï¼å®ç°æ´ç²¾åçç¯å¢æç¥ãå³çä¸äººè½¦äº¤äºã夿¨¡æLLMå¯ä»¥çè§£âåæ¹çº¢ç¯å³è½¬âè¿æ ·çæä»¤ï¼å¹¶ç»åä¼ æå¨æ°æ®å¤æè·¯åµï¼è§åå®å ¨è·¯å¾ï¼çè³å¨ç´§æ¥æ åµä¸ä¸ä¹å®¢è¿è¡æ²éï¼è§£éå½åæ åµã
- å»çè¯æï¼ ç»åå»å¦å½±åï¼Xå ãCTãMRIï¼ãç çæ¥åãçµåç åææ¬ï¼è¾ å©å»çè¿è¡æ´å ¨é¢çè¯æåæ²»çæ¹æ¡å»ºè®®ãAIå¯ä»¥è¯å«å½±åä¸çå¼å¸¸ï¼å ³èç å²ä¸çå ³é®ä¿¡æ¯ï¼å¹¶ç¨èªç¶è¯è¨è§£éè¯æç»æåæ¨èæ²»çæ¹æ¡ï¼æä¸ºå»ççå¾å婿ã
- æºè½é¶å®ä¸å·¥ä¸æ£æµï¼ éè¿åæé¡¾å®¢å¨åºå çè¡ä¸ºè§é¢ãè¯é³ãæå交æµï¼æä¾ä¸ªæ§åæ¨èåè´ç©ä½éªãå¨å·¥ä¸é¢åï¼ç»åå¾åæ£æµäº§å缺é·ã声é³è¯å«è®¾å¤å¼å¸¸ï¼å¹¶ç¨è¯è¨æ¥åé®é¢åæä¾è§£å³æ¹æ¡ï¼å®ç°æºè½è´¨æ£å颿µæ§ç»´æ¤ã
- å 容åä½ä¸æè²ï¼ æ ¹æ®ç¨æ·æä¾çå¾ççææ äºãè¯æãçµå½±å§æ¬ï¼æè æ ¹æ®æåæè¿°çæå¾åã卿è²é¢åï¼å¤æ¨¡æLLMå¯ä»¥çè§£å¦çæäº¤çæåä½ä¸å¾çï¼å¹¶æä¾ä¸ªæ§åçæ¹æ¹å讲解ã
å¯ä»¥è¯´ï¼å¤æ¨¡æè忝æå»ºçæ£éç¨äººå·¥æºè½ï¼AGIï¼çå¿ ç±ä¹è·¯ãå®è®©AIè½å¤åäººç±»ä¸æ ·ï¼éè¿å¤ç§æå®è¾å ¥æ¥ç解并交äºä¸çï¼ä»èè§£éæ éå¯è½ã
# çæ³ç夿¨¡æLLMå¦ä½å¤çä¸è¿°åºæ¯ç伪代ç
class MultimodalLLM:
def understand_and_respond(self, prompt: str, image_data: bytes = None) -> str:
# æ ¸å¿èåæ¥éª¤ï¼å°å¤æ¨¡æä¿¡æ¯æ´å为ä¸ä¸ªç»ä¸çä¸ä¸æè¡¨ç¤º
multimodal_context = self._fuse_modalities(prompt, image_data)
# åºäºèååçä¸ä¸æè¿è¡çè§£åçæ
if "ç«" in multimodal_context and "å¾åä¸å
å«ä¸åªç«" in multimodal_context:
return "æçå°å¾å䏿ä¸åªå¯ç±çç«åªï¼å®å¯è½æ£å¨ç©èãæ¨æ³ç¥éå
³äºå®ä»ä¹å¢ï¼"
elif image_data and "æ
é" in prompt:
# 模æçè§£å¾çä¸çæ
éç°è±¡å¹¶ç»åææ¬é®é¢è¿è¡è¯æ
if "çµçº¿" in multimodal_context and "磨æ" in multimodal_context:
return "æ ¹æ®å¾çæ¾ç¤ºï¼çµçº¿æç£¨æè¿¹è±¡ï¼è¿å¯è½æ¯å¯¼è´æ
éçåå ã建议ç«å³æ£æ¥å¹¶æ´æ¢åæçµçº¿ã"
return f"æçè§£äºå¾ååæ¨çææ¬ï¼'{prompt}'ãè¯·é®æè½ä¸ºæ¨åäºä»ä¹ï¼"
else:
# å¦ææ²¡æå¾åæç¹å®å¤æ¨¡æäº¤äºï¼åéåä¸ºçº¯ææ¬çæ
return self._text_generation_logic(prompt)
def _fuse_modalities(self, text: str, image: bytes = None) -> str:
# è¿æ¯ä¸ä¸ªç®åçèåé»è¾ï¼å®é
æ¶å夿çç¹å¾æåå对é½
fused_info = f"ææ¬ä¿¡æ¯ï¼{text}"
if image:
# å设è¿éè°ç¨äºä¸ä¸ªè§è§ç¼ç å¨å¹¶å°å
¶è¾åºè½¬æ¢æææ¬æè¿°æç¹å¾Token
vision_features_as_text = self._vision_encoder_to_text(image)
fused_info += f"
è§è§ä¿¡æ¯ï¼{vision_features_as_text}"
return fused_info
def _vision_encoder_to_text(self, image_data: bytes) -> str:
# 模æè§è§ç¼ç å¨å°å¾å转å为LLMå¯çè§£çææ¬å½¢å¼æTokenåºå
# ä¾å¦ï¼"å¾å䏿ä¸ä¸ªç©ä½ï¼ç¹å¾åé表示ï¼[0.1, 0.5, ...]" æè
ç´æ¥çæå¾åæè¿°
if b"cat_features" in image_data:
return "å¾åä¸å
å«ä¸åªç«ï¼èæ¯æ¯å®¢å
ã"
elif b"broken_wire" in image_data:
return "å¾åæ¾ç¤ºä¸æ ¹çµçº¿æææ¾ç£¨æï¼å¯è½å¯¼è´çè·¯ã"
return "å¾åè¯å«ç»æã"
# åºæ¯1ï¼ç¨æ·å±ç¤ºä¸å¼ ç«çå¾çï¼å¹¶é®ï¼âè¿æ¯ä»ä¹ï¼â
multimodal_llm = MultimodalLLM()
user_image_data_cat = b"cat_features" # 模æç«åªå¾çäºè¿å¶æ°æ®
user_text_question_cat = "è¿æ¯ä»ä¹ï¼"
response_cat = multimodal_llm.understand_and_respond(user_text_question_cat, user_image_data_cat)
print(f"夿¨¡æLLMååº (ç«åª): {response_cat}")
# 颿è¾åº: 夿¨¡æLLMååº (ç«åª): æçå°å¾å䏿ä¸åªå¯ç±çç«åªï¼å®å¯è½æ£å¨ç©èãæ¨æ³ç¥éå
³äºå®ä»ä¹å¢ï¼
# åºæ¯2ï¼ç¨æ·ä¸ä¼ ä¸å¼ çµå¨æ
éå¾çï¼å¹¶è¯¢é®åå
user_image_data_fault = b"broken_wire" # 模æçµå¨æ
éå¾ç
user_text_question_fault = "æççµå¨ä¸å·¥ä½äºï¼è¿æ¯ä»ä¹æ
åµï¼"
response_fault = multimodal_llm.understand_and_respond(user_text_question_fault, user_image_data_fault)
print(f"夿¨¡æLLMååº (æ
é): {response_fault}")
# 颿è¾åº: 夿¨¡æLLMååº (æ
é): æ ¹æ®å¾çæ¾ç¤ºï¼çµçº¿æç£¨æè¿¹è±¡ï¼è¿å¯è½æ¯å¯¼è´æ
éçåå ã建议ç«å³æ£æ¥å¹¶æ´æ¢åæçµçº¿ã
第äºç« ï¼å¤æ¨¡æèåçæ ¸å¿çç¥ä¸èå¼
夿¨¡æèåå¹¶éå䏿æ¯ï¼èæ¯å å«å¤ç§çç¥åèå¼ï¼å®ä»¬å¨ä¸åé¶æ®µã以ä¸åæ¹å¼æ´å夿¨¡æä¿¡æ¯ãæä»¬å¯ä»¥å°å ¶å¤§è´åä¸ºæ©æèåãææèååæ··å/跨模æèåãçè§£è¿äºè弿å©äºæä»¬éæ©æéåç¹å®ä»»å¡åèµæºéå¶çæ¹æ¡ã
2.1 æ©æèå (Early Fusion)ï¼æ°æ®å±é¢çäº²å¯æ¥è§¦
æ¦å¿µè§£éï¼æ©æèååç卿°æ®è¾å ¥é¶æ®µï¼å¨ä¸å模æçæ°æ®è¢«åèªçç¼ç å¨å¤çä¹åæå¤ççæ©æé¶æ®µï¼å°±å°å®ä»¬è¿è¡åå¹¶ãæå¸¸è§çåæ³æ¯å°ä¸å模æçåå§ç¹å¾æä½çº§ç¹å¾ç´æ¥æ¼æ¥ï¼concatenationï¼ï¼ç¶åå°æ¼æ¥åçç»ä¸ç¹å¾éå ¥ä¸ä¸ªå ±äº«æ¨¡åï¼å¦ä¸ä¸ªå¤§åç¥ç»ç½ç»ï¼è¿è¡å¦ä¹ åå¤çãè¿ç§æ¹å¼çä¼ç¹æ¯æ¨¡åå¯ä»¥ææå°æ¨¡æé´æç»ç²åº¦ç交äºä¿¡æ¯ï¼ç论ä¸è½å¤åç°æ¨¡æé´æ´å¤æçãéå¼çå ³èæ¨¡å¼ï¼å 为å®å¨ä¿¡æ¯æå¤±æå°çæ åµä¸è¿è¡èåãç¶èï¼å ¶ç¼ºç¹ä¹æ¾èæè§ï¼è¦æ±æ¨¡æå¨æ¶é´æè¯ä¹ä¸é«åº¦å¯¹é½ï¼ç¹å¾ç»´åº¦å¯è½è¿é«å¯¼è´è®¡ç®é大ï¼ä¸åªå£°å®¹æç¸äºå½±åï¼å¦ææä¸ä¸ªæ¨¡æçæ°æ®è´¨éè¾å·®ï¼å¯è½ä¼æ±¡æå ¶ä»æ¨¡æçä¿¡æ¯ã
åºç¨åºæ¯ï¼æ©æèå常ç¨äºéè¦æææ¨¡æé´ç´§å¯æ¶ç©ºæè¯ä¹å ³èçä»»å¡ãä¾å¦ï¼å¨æ æè¯å«ä¸ï¼å°é¢é¨è¡¨æ çåç´ ç¹å¾ä¸è¯é³ç声谱å¾ç¹å¾ç´æ¥æ¼æ¥ï¼ç¶åéå ¥ä¸ä¸ªç»ä¸ç深度å¦ä¹ 模åï¼ä»¥ææè¡¨æ åè¯è°çç»å¾®ååååï¼å¨è§é¢äºä»¶æ£æµä¸ï¼ç»åè§é¢å¸§çåç´ ä¿¡æ¯ä¸é³é¢æ³¢å½¢æ°æ®ã
# æ©æèå示ä¾ï¼å¾åç¹å¾ä¸ææ¬ç¹å¾çç®åæ¼æ¥
import numpy as np
import torch
import torch.nn as nn
# å设æä»¬æç¬ç«çå¾åç¼ç å¨åææ¬ç¼ç å¨
class ImageEncoder(nn.Module):
def __init__(self, output_dim=128):
super().__init__()
# ç®åï¼ç´æ¥è¿åéæºç¹å¾ãå®é
ä¸è¿éæ¯ä¸ä¸ªCNNæVision Transformer
self.output_dim = output_dim
self.linear = nn.Linear(224*224*3, output_dim) # 模æå¤çåå§å¾ååç´
def forward(self, image_input): # image_inputå¯ä»¥æ¯å¾ååç´ æé¢å¤çåçTensor
# å®é
ä¸ä¼ç»è¿å¤å±å·ç§¯æTransformerå±
return self.linear(image_input.view(image_input.size(0), -1)) # å±å¹³å¾å
class TextEncoder(nn.Module):
def __init__(self, output_dim=128):
super().__init__()
# ç®åï¼ç´æ¥è¿åéæºç¹å¾ãå®é
ä¸è¿éæ¯ä¸ä¸ªBERTæTransformer
self.output_dim = output_dim
self.embedding = nn.Embedding(1000, output_dim) # 模æè¯åµå
¥
self.lstm = nn.LSTM(output_dim, output_dim) # 模æåºåå¤ç
def forward(self, text_input_ids): # text_inputå¯ä»¥æ¯ææ¬Token IDæEmbedding
# å®é
ä¸ä¼å¤çtokenåºåï¼è¿éç®å为å第ä¸ä¸ªtokençLSTMè¾åº
embedded = self.embedding(text_input_ids)
_, (hidden, _) = self.lstm(embedded.unsqueeze(0)) # å设batch_size=1
return hidden.squeeze(0)
# åå§åç¼ç å¨
image_encoder = ImageEncoder(output_dim=256)
text_encoder = TextEncoder(output_dim=256)
# 模æè¾å
¥æ°æ®
image_data = torch.randn(1, 224, 224, 3) # 模æä¸å¼ å¾ç (Batch=1)
text_data_ids = torch.randint(0, 1000, (1, 5)) # 模æä¸æ®µææ¬çToken ID (Batch=1, SeqLen=5)
# 1. åèªç¼ç è·åç¹å¾
image_features = image_encoder(image_data)
text_features = text_encoder(text_data_ids)
print(f"å¾åç¹å¾ç»´åº¦: {image_features.shape}") # ä¾å¦: torch.Size([1, 256])
print(f"ææ¬ç¹å¾ç»´åº¦: {text_features.shape}") # ä¾å¦: torch.Size([1, 256])
# 2. æ©æèåï¼ç¹å¾æ¼æ¥
fused_features_early = torch.cat((image_features, text_features), dim=-1)
print(f"æ©æèååç¹å¾ç»´åº¦: {fused_features_early.shape}") # ä¾å¦: torch.Size([1, 512])
# 3. æ¼æ¥åçç¹å¾éå
¥ä¸æ¸¸LLMæä»»å¡å¤´
class UnifiedFusionModel(nn.Module):
def __init__(self, input_dim, num_classes):
super().__init__()
self.classifier = nn.Linear(input_dim, num_classes) # ä¾å¦ï¼å类任å¡
def forward(self, fused_features):
return self.classifier(fused_features)
fusion_model = UnifiedFusionModel(fused_features_early.shape[-1], 2) # ä¾å¦ï¼2å类任å¡
output = fusion_model(fused_features_early)
print(f"è忍¡åè¾åº (ä¾å¦åç±»logits): {output}")
# æ¨èåæ³ï¼
# ä¼ç¹ï¼ç论ä¸è½ææå°æç»ç²åº¦ç模æé´äº¤äºï¼å¯¹ä¸æ¸¸ä»»å¡æä¾æå
¨é¢çåå§ä¿¡æ¯ã
# 缺ç¹ï¼è¦æ±æ¨¡æå¨æ¶é´æè¯ä¹ä¸é«åº¦å¯¹é½ï¼ç¹å¾ç»´åº¦å¯è½è¿é«ï¼ä¸å®¹æå°åªå£°æ··åã
# å®é
åºç¨ä¸ï¼ç±äºåå§æ°æ®ç»´åº¦è¾é«ï¼é常ä¼å¨æ¼æ¥åè¿è¡ä¸å®ç¨åº¦çç¹å¾æååéç»´ã
2.2 ææèå (Late Fusion)ï¼å³çå±é¢çæºæ §åå
æ¦å¿µè§£éï¼ææèååçå¨ä¸å模æçæ°æ®è¢«åèªç模åç¬ç«å¤çå¹¶å¾åºåæ¥ç»æï¼å¦åç±»æ¦çãæ£æµæ¡ãç¹å¾åéï¼ä¹åãè¿äºåæ¥ç»æå¨å³çå±è¿è¡åå¹¶ï¼ä¾å¦éè¿æç¥¨ãå æå¹³åãå å ï¼stackingï¼ææ´å¤æçå å¦ä¹ å¨ï¼meta-learnerï¼è¿è¡éæãè¿ç§æ¹å¼çä¼ç¹æ¯æ¨¡åç»æç®åï¼åæ¨¡ææ¨¡åå¯ä»¥ç¬ç«ä¼ååè®ç»ï¼é¨ç½²èµ·æ¥ä¹æ´çµæ´»ï¼å¯¹æ¨¡æé´ç弿¥æ§å®¹å¿åº¦é«ï¼ä¾å¦ï¼å³ä½¿ææ¨¡ææ°æ®ç¼ºå¤±ï¼å ¶ä»æ¨¡æä»è½ç¬ç«å·¥ä½ï¼ãæ¤å¤ï¼ç±äºæ¯ä¸ªæ¨¡ææ¨¡åå¯ä»¥éå¯¹å ¶ç¹å®æ°æ®ç±»åè¿è¡ä¼åï¼éå¸¸å ·æè¾å¥½ç鲿£æ§ãç¼ºç¹æ¯æ æ³ææå°æ¨¡æé´æ·±å±æ¬¡ç交äºä¿¡æ¯ï¼å¯è½éè¿ä¸äºéè¦çå ³èï¼å 为信æ¯èååçå¨é«çº§æ½è±¡å±é¢ï¼åºå±ç»ç²åº¦ç模æé´ä¾èµå ³ç³»å·²è¢«åèªæ¨¡åâæ¶åâæå¿½ç¥ã
åºç¨åºæ¯ï¼ææèå常ç¨äºä»»å¡å¯ä»¥éè¿ç¬ç«æ¨¡ææ¨¡åè¿è¡åæ¥å¤æï¼ç¶åå°è¿äºå¤æç»æç»åèµ·æ¥ååºæç»å³ççåºæ¯ãä¾å¦ï¼å¨å»çè¯æä¸ï¼ç»åä¸åç§å®¤ï¼å¦å½±åç§ãç çç§ï¼çç¬ç«è¯ææ¥åææ¨¡å颿µç»æï¼å½¢æä¸ä¸ªç»¼åè¯æï¼å¨è§é¢è¡ä¸ºè¯å«ä¸ï¼ç»åå¾å模åè¯å«çå¨ä½åé³é¢æ¨¡åè¯å«ç声é³äºä»¶ï¼ç¶åéè¿èåå±å¤ææç»çè¡ä¸ºã
# ææèå示ä¾ï¼ç¬ç«æ¨¡åè¾åºçèå
import torch
import torch.nn as nn
# å设æä»¬æç¬ç«çå¾ååç±»å¨åææ¬æ
æåææ¨¡å
class ImageClassifier(nn.Module):
def __init__(self):
super().__init__()
# å设è½è¯å«å¾åæ¯å¦ä¸ºâå¼å¿â
self.linear = nn.Linear(256, 2) # ä¾å¦ï¼è¾åº [is_happy, not_happy]
def forward(self, image_features):
return torch.softmax(self.linear(image_features), dim=-1)
class TextSentimentAnalyzer(nn.Module):
def __init__(self):
super().__init__()
# å设è½è¯å«ææ¬æ¯å¦ä¸ºâå¼å¿â
self.linear = nn.Linear(256, 2) # ä¾å¦ï¼è¾åº [is_happy, not_happy]
def forward(self, text_features):
return torch.softmax(self.linear(text_features), dim=-1)
# å®ä¾å并模æç¹å¾æå (å¤ç¨ä¹åçEncoder)
class DummyImageEncoder(nn.Module): # ç®åçï¼ç´æ¥è¿åéæºç¹å¾
def __init__(self, output_dim):
super().__init__()
self.output_dim = output_dim
def forward(self, image_input):
return torch.randn(1, self.output_dim)
class DummyTextEncoder(nn.Module): # ç®åçï¼ç´æ¥è¿åéæºç¹å¾
def __init__(self, output_dim):
super().__init__()
self.output_dim = output_dim
def forward(self, text_input):
return torch.randn(1, self.output_dim)
image_encoder_late = DummyImageEncoder(output_dim=256)
text_encoder_late = DummyTextEncoder(output_dim=256)
image_features_late = image_encoder_late(None) # 模æå¾åæ°æ®è¾å
¥
text_features_late = text_encoder_late(None) # æ¨¡æææ¬æ°æ®è¾å
¥
# å®ä¾åç¬ç«åç±»å¨
image_classifier = ImageClassifier()
text_sentiment_analyzer = TextSentimentAnalyzer()
# 1. åèªæ¨¡åç¬ç«é¢æµ
image_pred_probs = image_classifier(image_features_late)
text_pred_probs = text_sentiment_analyzer(text_features_late)
print(f"å¾å模å颿µæ¦ç (ä¾å¦ï¼[ä¸å¼å¿, å¼å¿]): {image_pred_probs}")
print(f"ææ¬æ¨¡å颿µæ¦ç (ä¾å¦ï¼[ä¸å¼å¿, å¼å¿]): {text_pred_probs}")
# 2. ææèåï¼ç»æå æå¹³åï¼ææç¥¨ï¼
# å设å¾ååææ¬å¯¹âå¼å¿âç颿µæéç¸å
final_pred_probs = (image_pred_probs + text_pred_probs) / 2
# æç»å³ç
predicted_class_id = torch.argmax(final_pred_probs).item()
classes = ["ä¸å¼å¿", "å¼å¿"]
print(f"ææèåæç»é¢æµ: {classes[predicted_class_id]} (æ¦ç: {final_pred_probs[0][predicted_class_id]:.4f})")
# æ¨èåæ³ï¼
# ä¼ç¹ï¼æ¨¡å模ååï¼æäºè°è¯åæ©å±ï¼åæ¨¡ææ¨¡åå¯ç¬ç«ä¼åï¼å¯¹æ¨¡æé´ç弿¥æ§å®¹å¿åº¦é«ã
# 缺ç¹ï¼æ æ³æè·æ¨¡æé´æ·±å±æ¬¡ç交äºåä¾èµå
³ç³»ï¼å¯è½éè¿ä¸äºåªæå¨ä½å±æè½åç°çå
³èã
# éç¨äºä»»å¡å¯è¢«å解为å¤ä¸ªç¬ç«åä»»å¡çåºæ¯ã
2.3 æ··å/跨模æèå (Hybrid/Cross-modal Fusion)ï¼æ·±å±äº¤äºä¸å¦ä¹
æ¦å¿µè§£éï¼æ··å/跨模æèåç»åäºæ©æåææèåçç¹ç¹ï¼æè æ´å¼ºè°æ¨¡æé´çæ·±å±äº¤äºå¦ä¹ ãå®é叏䏿»¡è¶³äºç®åçç¹å¾æ¼æ¥æå³çéæï¼èæ¯å¨æ¨¡åçä¸é´å±å¼å ¥æºå¶ï¼å 许ä¸å模æçç¹å¾è¿è¡å¨æãæéæ©æ§çä¿¡æ¯äº¤æ¢å对é½ãæå è¿çæ¹æ¡é常éç¨è·¨æ¨¡ææ³¨æåæºå¶ (Cross-modal Attention) æ æ¨¡ææ¡¥æ¥å¨ (Modality Bridge) ãè·¨æ¨¡ææ³¨æåå 许ä¸ä¸ªæ¨¡æï¼å¦ææ¬ï¼ä½ä¸ºæ¥è¯¢ï¼Queryï¼ï¼å»å ³æ³¨å¦ä¸ä¸ªæ¨¡æï¼å¦å¾åï¼çç¹å®é¨åï¼KeyåValueï¼ï¼ä»èå®ç°æéæ©æ§çä¿¡æ¯æååæ´åãæ¨¡ææ¡¥æ¥å¨åè´è´£å°ä¸ç§æ¨¡æçç¹å¾è½¬æ¢æå¦ä¸ç§æ¨¡æå¯ä»¥çè§£çå½¢å¼ï¼ä»è弥忍¡æé´çè¯ä¹é¸¿æ²ãè¿ç§æ¹æ³æ¨å¨æææ¨¡æé´çå¤æå ³èï¼åæ¶ä¿æä¸å®ç模åååçµæ´»æ§ï¼é¿å æ©æèåçé«ç»´åº¦é®é¢ï¼ä¹å¼¥è¡¥äºææèåç¼ºä¹æ·±å±äº¤äºçä¸è¶³ã
åºç¨åºæ¯ï¼æ··å/跨模æè忝å½å夿¨¡æLLMç ç©¶ç主æµï¼å¹¿æ³åºç¨äºè§è§é®çï¼VQAï¼ãå¾ååå¹çæã夿¨¡æå¯¹è¯ç³»ç»ã以åéè¦å¤æè·¨æ¨¡ææ¨ççåºæ¯ãä¾å¦ï¼å¨VQAä¸ï¼æ¨¡åéè¦æ ¹æ®ææ¬é®é¢å¨å¾å䏿¾å°ç¸å ³åºåå¹¶è¿è¡æ¨çï¼å¨å¾ååå¹çæä¸ï¼æ¨¡åéè¦æ ¹æ®å¾åå 容çæè¿è´¯çææ¬æè¿°ã
# æ··å/跨模æèå示ä¾ï¼ç®åçè·¨æ¨¡ææ³¨æåå±ä¼ªä»£ç
import torch
import torch.nn as nn
import torch.nn.functional as F
class CrossModalAttention(nn.Module):
def __init__(self, query_dim, key_dim, value_dim, output_dim, num_heads=4):
super().__init__()
self.num_heads = num_heads
self.head_dim = value_dim // num_heads
assert self.head_dim * num_heads == value_dim, "value_dim must be divisible by num_heads"
# 线æ§å±ç¨äºå°è¾å
¥æå½±å°Q, K, V空é´
self.wq = nn.Linear(query_dim, value_dim)
self.wk = nn.Linear(key_dim, value_dim)
self.wv = nn.Linear(value_dim, value_dim) # è¿éç®åï¼é常kvæ¯æ¥èªå䏿ºï¼key_dim=value_dim
self.fc_out = nn.Linear(value_dim, output_dim)
def forward(self, query_features, key_features, value_features):
# query_features (ä¾å¦ï¼ææ¬åµå
¥): [batch_size, query_seq_len, query_dim]
# key_features (ä¾å¦ï¼å¾ååºåç¹å¾): [batch_size, key_seq_len, key_dim]
# value_features (ä¾å¦ï¼å¾ååºåç¹å¾): [batch_size, value_seq_len, value_dim]
batch_size = query_features.shape[0]
# å°Q, K, Væå½±å¹¶åå²ä¸ºå¤å¤´
Q = self.wq(query_features).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
K = self.wk(key_features).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
V = self.wv(value_features).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
# è®¡ç®æ³¨æååæ° (Q @ K^T)
energy = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
attention_weights = F.softmax(energy, dim=-1)
# å ææ±å (Attention @ V)
x = torch.matmul(attention_weights, V)
# æ¼æ¥å¤å¤´ç»æå¹¶éå
¥è¾åºçº¿æ§å±
x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
return self.fc_out(x)
# 模æè¾å
¥ï¼ææ¬ä½ä¸ºQueryï¼å¾åä½ä¸ºKeyåValue
# åè®¾ææ¬å·²ç»ç¼ç æåºåç¹å¾ (ä¾å¦ï¼æ¯ä¸ªtokençembedding)
text_query_embeddings = torch.randn(1, 10, 768) # Batch=1, 10个token, dim=768
# å设å¾ååºåç¹å¾ (ä¾å¦ï¼ä»Vision Transformerçpatch embeddings)
image_key_value_features = torch.randn(1, 49, 768) # Batch=1, 49个patch, dim=768
# å®ä¾åè·¨æ¨¡ææ³¨æåå±
cross_attention_layer = CrossModalAttention(
query_dim=768, key_dim=768, value_dim=768, output_dim=768 # è¾åºç»´åº¦ä¸LLMè¾å
¥ç»´åº¦å¹é
)
# æ§è¡è·¨æ¨¡ææ³¨æåï¼è®©ææ¬æ¥è¯¢å¾åä¿¡æ¯
fused_output_from_attention = cross_attention_layer(
query_features=text_query_embeddings,
key_features=image_key_value_features,
value_features=image_key_value_features
)
print(f"ææ¬ä¸å¾åè·¨æ¨¡ææ³¨æåèååçè¾åºç»´åº¦: {fused_output_from_attention.shape}") # ä¾å¦: torch.Size([1, 10, 768])
# è¿ä¸ªè¾åºå¯ä»¥ä½ä¸ºLLMçè¾å
¥ï¼LLMç°å¨è½å¤âæç¥âå¾åç¸å
³çä¿¡æ¯äºã
# æ¨èåæ³ï¼
# ä¼ç¹ï¼è½å¤å®ç°æ¨¡æé´çæ·±å±ãå¨æäº¤äºï¼æè·å¤æçè¯ä¹å
³èã
# 缺ç¹ï¼æ¨¡åå¤æåº¦é«ï¼è®¡ç®é大ï¼å¯¹è®ç»æ°æ®å对é½è¦æ±ä¹æ´é«ã
# 宿¯ç®åæå»ºå¼ºå¤§å¤æ¨¡æLLMçä¸»æµæ¹æ³ï¼éè¿ç²¾å·§çè®¾è®¡å¹³è¡¡äºæ§è½ä¸æçã
第ä¸ç« ï¼ä¸»æµLLM夿¨¡æèåæ¹æ¡æ·±åº¦è§£æ
è¿å¹´æ¥ï¼æ¶ç°åºè®¸å¤æ°åºçLLM夿¨¡æè忍¡åï¼å®ä»¬éè¿å·§å¦ç设计ï¼è®©å¤§è¯è¨æ¨¡åæ¥æäºâçå¾è¯´è¯âãâè§è§é®çâä¹è³æ´é«çº§çæç¥è½åãæä»¬æ¥ççå 个éç¨ç¢å¼çæ¹æ¡ï¼å®ä»¬ä»£è¡¨äºä¸åçèåææ³åææ¯è·¯çº¿ã
3.1 CLIPï¼ä»å¯¹æ¯å¦ä¹ å°è·¨æ¨¡æå¯¹é½
æ¦å¿µè§£éï¼CLIP (Contrastive LanguageâImage Pre-training) æ¯ OpenAI äº2021å¹´æåºçä¸ä¸ªéè¦æ¨¡åãå®å¹¶éç´æ¥è¿è¡å¤æ¨¡æçææé®çï¼èæ¯éè¿å¯¹æ¯å¦ä¹ ï¼Contrastive Learningï¼ï¼å¨ä¸ä¸ªå·¨å¤§çå¾å-ææ¬å¯¹æ°æ®éä¸ï¼å¦ä¹ å¾ååææ¬å¨åä¸ä¸ªè¯ä¹ç©ºé´ä¸çèååµå ¥ï¼Joint Embeddingï¼ãå ¶æ ¸å¿ææ³æ¯ï¼å°å¹é çå¾å-ææ¬å¯¹çåµå ¥è·ç¦»æè¿ï¼åæ¶å°ä¸å¹é çå¾å-ææ¬å¯¹çåµå ¥è·ç¦»æ¨è¿ãéè¿è¿ç§æ¹å¼ï¼CLIPå¦ä¼äºå°å¾åä¸æè¿°å ¶å å®¹çææ¬è¿è¡å ³èï¼å®ç°äºå¼ºå¤§çè·¨æ¨¡ææ£ç´¢åé¶æ ·æ¬åç±»ï¼Zero-shot Classificationï¼è½åãå®ä¸ºåç»ççæå¼å¤æ¨¡ææ¨¡åå¥ å®äºéè¦ç模æå¯¹é½åºç¡ã
å·¥ä½åçï¼CLIPå å«ä¸¤ä¸ªç¬ç«çç¼ç å¨ï¼ä¸ä¸ªå¾åç¼ç å¨ï¼é常æ¯Vision Transformer æ ResNetï¼åä¸ä¸ªææ¬ç¼ç å¨ï¼Transformerï¼ãå¨è®ç»è¿ç¨ä¸ï¼æ¨¡å伿¥æ¶N个å¾ååNä¸ªææ¬æè¿°ï¼å½¢æNä¸ªæ£æ ·æ¬å¯¹åN*(N-1)ä¸ªè´æ ·æ¬å¯¹ãè®ç»ç®æ æ¯æå¤§å对è§çº¿ä¸çæ£æ ·æ¬å¯¹çç¸ä¼¼åº¦ï¼åæ¶æå°åå ¶ä»ä½ç½®çè´æ ·æ¬å¯¹çç¸ä¼¼åº¦ãè¿ç§å ¨å±å¯¹æ¯å¦ä¹ 使å¾CLIPçåµå ¥ç©ºé´å ·æå¼ºå¤§çè¯ä¹è¡¨å¾è½åï¼è½å¤å°è§è§æ¦å¿µä¸è¯è¨æ¦å¿µå¯¹é½ã
åºç¨åºæ¯ï¼CLIPæ¬èº«ä¸ç´æ¥æ¯LLMï¼ä½å
¶å¦å°ç对é½è½åæ¯å®ç°LLM夿¨¡ææç¥çå
³é®ãå®å¯ä»¥ä½ä¸ºï¼
* å¾åæ£ç´¢ï¼ æ ¹æ®ææ¬æè¿°æ¾å°æç¸å
³çå¾åã
* é¶æ ·æ¬å¾ååç±»ï¼ æ éé¢å¤è®ç»ï¼ç´æ¥æ ¹æ®å¾åä¸ç±»å«åç§°çææ¬åµå
¥ç¸ä¼¼åº¦è¿è¡åç±»ã
* æå¯¼ææ¬å°å¾åçæï¼ å¦DALL-E 2ãStable Diffusionçæ¨¡åå©ç¨CLIPçå¾å-ææ¬å¯¹é½è½åæ¥è¯ä¼°çæå¾åä¸ææ¬æè¿°çå¹é
度ã
* 夿¨¡æLLMçè§è§ç¼ç å¨ï¼ 许å¤å¤æ¨¡æLLMï¼å¦LLaVAï¼ç´æ¥ä½¿ç¨å»ç»çCLIPè§è§ç¼ç 卿¥æåå¾åç¹å¾ã
# 示ä¾ï¼ä½¿ç¨Hugging Face Transformers模æCLIPçå¾å-ææ¬å¹é
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import requests
import torch
# 1. å è½½é¢è®ç»çCLIP模ååå¤çå¨
# 模åä¸è½½å¯è½éè¦ä¸äºæ¶é´ï¼è¯·ç¡®ä¿ç½ç»è¿æ¥
print("å è½½CLIP模ååå¤çå¨...")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print("CLIP模åå è½½å®æã")
# 2. åå¤å¾ååææ¬è¾å
¥
# 模æä»ç½ç»å è½½å¾å
url = "http://images.cocodataset.org/val2017/000000039769.jpg" # ä¸åªç«
print(f"ä¸è½½å¾ç: {url}")
image = Image.open(requests.get(url, stream=True).raw)
texts = ["ä¸åªç«å卿²åä¸", "ä¸åªçå¨èå°ä¸è·", "ä¸ç¾¤äººå¨è¡ä¸èµ°", "ä¸åªèèå¨ç¡è§"]
# 3. å¤çè¾å
¥æ°æ®
# processorä¼èªå¨å¯¹å¾åè¿è¡resizeãå½ä¸åï¼å¯¹ææ¬è¿è¡tokenization
inputs = processor(text=texts, images=image, return_tensors="pt", padding=True).to(device)
# 4. æ¨¡åæ¨çï¼è·åå¾ååææ¬çåµå
¥
with torch.no_grad():
outputs = model(**inputs)
# 5. è·åå¾ååææ¬ç¹å¾
image_features = outputs.image_embeds # å¾åçå
¨å±ç¹å¾
text_features = outputs.text_embeds # æ¯ä¸ªææ¬æè¿°çç¹å¾
# 6. 计ç®å¾å䏿æ¬ä¹é´çç¸ä¼¼åº¦ (ä½å¼¦ç¸ä¼¼åº¦)
# å½ä¸åç¹å¾åéï¼ä½¿å
¶L2èæ°å为1ï¼æ¹ä¾¿è®¡ç®ä½å¼¦ç¸ä¼¼åº¦
image_features_norm = image_features / image_features.norm(dim=-1, keepdim=True)
text_features_norm = text_features / text_features.norm(dim=-1, keepdim=True)
# 计ç®å¾åç¹å¾ä¸ææææ¬ç¹å¾çç¸ä¼¼åº¦
# logits_per_image[i][j] 表示第i个å¾åä¸ç¬¬jä¸ªææ¬çç¸ä¼¼åº¦
logits_per_image = torch.matmul(image_features_norm, text_features_norm.T)
print("
å¾åæè¿°å¹é
å¾å (logits_per_image):")
for i, text in enumerate(texts):
print(f" '{text}': {logits_per_image[0][i].item():.4f}")
# æ¨èåæ³ï¼
# éè¿å¯¹æ¯å¦ä¹ ï¼CLIP使å¾å¾ååææ¬è½å¤å¨å
±åçè¯ä¹ç©ºé´ä¸è¿è¡æ¯è¾åå¹é
ã
# è¿æ¯è®¸å¤åç»å¤æ¨¡æLLMæ¹æ¡çåºç¡ï¼å ä¸ºå®æä¾äºä¸ä¸ªå¼ºå¤§ç模æå¯¹é½âæ¡¥æ¢âã
# CLIPæ¬èº«ä¸ç´æ¥æ¯LLMï¼ä½å
¶å¦å°ç对é½è½åæ¯å®ç°LLM夿¨¡ææç¥çå
³é®ï¼
# å°¤å
¶æ¯å¨é¶æ ·æ¬çè§£åè·¨æ¨¡ææ£ç´¢æ¹é¢è¡¨ç°åè¶ã
3.2 BLIPä¸BLIP-2ï¼æå»ºè§è§-è¯è¨å¤§æ¨¡åçåºç³
æ¦å¿µè§£éï¼BLIP (Bootstrapping Language-Image Pre-training) åå ¶åç» BLIP-2 æ¯ Salesforce AI æåºçæ¨å¨æå»ºæ´éç¨çè§è§-è¯è¨é¢è®ç»æ¨¡åãå®ä»¬çç®æ æ¯å¼¥åè§è§æ¨¡ååè¯è¨æ¨¡åä¹é´ç巨大鸿æ²ï¼è®©å»ç»çï¼Frozenï¼LLMä¹è½âçæâå¾åãBLIPéè¿å¤ä»»å¡å¦ä¹ ï¼å æ¬å¾å-ææ¬å¯¹æ¯å¦ä¹ ãå¾å-ææ¬å¹é ãå¾ååå¹çæï¼æ¥æåè§è§-è¯è¨çè§£åçæè½åã
BLIP-2 å¨ BLIP çåºç¡ä¸è¿ä¸æ¥åæ°ï¼å ¶æ ¸å¿ææ³æ¯å¼å ¥äºä¸ä¸ªè½»é级çQ-Former (Querying Transformer) ä½ä¸ºæ¡¥æ¢ãQ-Former çä½ç¨æ¯ä»å»ç»çå¾åç¼ç å¨ä¸æåä¸ææ¬æç¤ºæç¸å ³çè§è§ç¹å¾ï¼å¹¶å°å ¶è½¬å为LLMå¯çè§£ç表示形å¼ï¼ç¶åå°è¿äºè§è§ç¹å¾è¾å ¥å°å¦ä¸ä¸ªå»ç»çLLMä¸ãè¿ç§ä¸é¶æ®µè®ç»çç¥ï¼è§è§-è¯è¨è¡¨å¾å¦ä¹ ãè§è§-è¯è¨çæå¼å¦ä¹ ãè§è§-è¯è¨æä»¤å¾®è°ï¼æå¤§å°æé«äºè®ç»æç忍¡åæ§è½ï¼ä½¿å¾æ¨¡åè½å¤å©ç¨ç°æç强大è§è§åè¯è¨æ¨¡åï¼èæ éä»å¤´è®ç»æ´ä¸ªå·¨å¤§ç夿¨¡ææ¨¡åã
ä¼å¿ï¼
* 髿å©ç¨ç°æèµæºï¼ å©ç¨å»ç»çé¢è®ç»å¾åç¼ç å¨åLLMï¼å¤§å¤§åå°äºéè¦è®ç»çåæ°ï¼éä½äºè®¡ç®ææ¬åæ°æ®éæ±ã
* Q-Formerçç²¾å·§è®¾è®¡ï¼ Q-Formerè½å¤ææå°ä»å¾åä¸âæ¥è¯¢âåæå䏿æ¬ç¸å
³çãææä¹çè§è§ç¹å¾ï¼é¿å
äºç´æ¥å°ææè§è§ç¹å¾éå
¥LLMé æçä¿¡æ¯è¿è½½ã
* 强大çè§è§-è¯è¨è½åï¼ å¨åç§è§è§-è¯è¨ä»»å¡ä¸ï¼å¦å¾ååå¹ãè§è§é®çï¼è¡¨ç°åºè²ï¼æ¯ç®åæå
è¿ç夿¨¡æLLMä¹ä¸ã
# 示ä¾ï¼BLIP-2çè§è§ç¹å¾æåä¸LLM交äºç®å示ä¾ï¼ä¼ªä»£ç åæ¦å¿µå®ç°ï¼
import torch
import torch.nn as nn
# 模æBLIP-2çå
³é®ç»ä»¶
class FrozenImageEncoder(nn.Module):
def __init__(self): # è¿æ¯é¢è®ç»å¥½ç大模åï¼åæ°å»ç»
super().__init__()
# å®é
æ¯ä¸ä¸ªViTçï¼è¿éç®åè¾åºåºå®ç»´åº¦ç¹å¾
# ä¾å¦ï¼ViTä¼è¾åºä¸ä¸ªCLS tokenåä¸ç³»åpatch tokens
self.output_dim = 1024 # æ¯ä¸ªtokenç维度
self.num_image_tokens = 257 # ä¾å¦ï¼1个CLS token + 256个patch tokens (16x16)
def forward(self, image_input):
# 模æå¾åç¼ç å¨è¾åºçç¹å¾åºå
return torch.randn(1, self.num_image_tokens, self.output_dim)
class QFormer(nn.Module):
def __init__(self, image_feature_dim, num_query_tokens, llm_embedding_dim):
super().__init__()
# å¯å¦ä¹ çQuery Tokensï¼è¿äºTokensä¼éè¿äº¤å注æåä»å¾åç¹å¾ä¸æåä¿¡æ¯
self.query_tokens = nn.Parameter(torch.randn(1, num_query_tokens, llm_embedding_dim))
# æ ¸å¿æ¯äº¤å注æåï¼è®©Query Tokenså»å
³æ³¨å¾åç¹å¾
# è¿éå¤ç¨ä¹åçCrossModalAttentionç»æ
self.cross_attention = CrossModalAttention(
query_dim=llm_embedding_dim,
key_dim=image_feature_dim,
value_dim=image_feature_dim,
output_dim=llm_embedding_dim
)
# å®é
Q-Formeræ´å¤æï¼å
å«èªæ³¨æåå±åå¤ä¸ªäº¤å注æåå±ï¼ä»¥åFFN
# è¿éçoutput_linearæ¯ä¸ºäºç¡®ä¿è¾åºç»´åº¦ä¸LLMçembedding维度å¹é
self.output_linear = nn.Linear(llm_embedding_dim, llm_embedding_dim)
def forward(self, image_features):
# query_tokensä½ä¸ºQueryï¼å¾åç¹å¾ä½ä¸ºKeyåValue
# Q-Formerçæ ¸å¿ï¼éè¿query tokensä»å¾åç¹å¾ä¸æåç¸å
³ä¿¡æ¯
# expandæ¯ä¸ºäºéé
batch_size
extracted_features = self.cross_attention(
query_features=self.query_tokens.expand(image_features.shape[0], -1, -1),
key_features=image_features,
value_features=image_features
)
return self.output_linear(extracted_features) # è¿åLLM坿¥åçè§è§Tokens
class FrozenLLM(nn.Module):
def __init__(self, vocab_size, embedding_dim):
super().__init__()
# è¿æ¯ä¸ä¸ªé¢è®ç»å¥½ç大è¯è¨æ¨¡åï¼åæ°å»ç»
self.token_embeddings = nn.Embedding(vocab_size, embedding_dim)
# å®é
LLMå
å«å¤å±Transformerè§£ç å¨ï¼è¿éç®å为ä¸å±
self.decoder_layer = nn.TransformerDecoderLayer(d_model=embedding_dim, nhead=4, batch_first=True)
self.output_head = nn.Linear(embedding_dim, vocab_size) # è¾åºå°è¯è¡¨
def forward(self, input_embeddings, attention_mask=None):
# ç®åï¼ç´æ¥éè¿ä¸ä¸ªè§£ç å±ï¼å®é
伿´å¤æ
# è¿éçtgt_maskåmemory_maskçåæ°å¨å®é
LLMä¸ä¼æ ¹æ®èªæ³¨æååäº¤åæ³¨æåéæ±è®¾ç½®
output_hidden_states = self.decoder_layer(input_embeddings, input_embeddings)
return self.output_head(output_hidden_states) # è¿åè¯è¡¨logits
# æ¨¡ææ°æ®åç»ä»¶
image_data_blip = torch.randn(1, 224, 224, 3) # 模æå¾å
text_prompt_blip = "è¿å¼ å¾çæè¿°äºä»ä¹ï¼" # å设LLM伿 ¹æ®è§è§ä¿¡æ¯çææè¿°
# å®ä¾åç»ä»¶
frozen_image_encoder = FrozenImageEncoder()
# Q-Formerçè¾åºtokensæ°é (ä¾å¦32个) å³å®äºLLMæ¥æ¶çè§è§ä¿¡æ¯å¯åº¦
q_former = QFormer(image_feature_dim=frozen_image_encoder.output_dim, num_query_tokens=32, llm_embedding_dim=768)
frozen_llm = FrozenLLM(vocab_size=50000, embedding_dim=768)
# 1. å¾åç¼ç (å»ç»)
image_features_blip = frozen_image_encoder(image_data_blip)
print(f"å»ç»å¾åç¼ç å¨è¾åºç¹å¾å½¢ç¶: {image_features_blip.shape}")
# 2. Q-Formeræåè§è§ä¿¡æ¯ (å¯è®ç»)
vision_tokens_for_llm = q_former(image_features_blip)
print(f"Q-Formerè¾åºçè§è§Tokenså½¢ç¶: {vision_tokens_for_llm.shape}") # ä¾å¦: torch.Size([1, 32, 768])
# 3. ææ¬ç¼ç (å»ç»LLMçEmbeddingå±)
# åè®¾ææ¬"è¿å¼ å¾çæè¿°äºä»ä¹ï¼"被ç¼ç 为LLMçtoken embeddings
# å®é
ä¸ä¼ç¨LLMçtokenizerï¼è¿éç®å为å 个èætoken ID
text_input_ids = torch.tensor([[100, 200, 300]]) # æ¨¡æææ¬çtoken IDs
text_embeddings_blip = frozen_llm.token_embeddings(text_input_ids)
print(f"LLMææ¬åµå
¥å½¢ç¶: {text_embeddings_blip.shape}")
# 4. æ¼æ¥è§è§Tokensåææ¬Tokensï¼è¾å
¥ç»å»ç»çLLM
# é常ï¼è§è§Tokensä¼ä½ä¸ºåç¼ï¼å¼å¯¼LLMçæå
容
fused_llm_input = torch.cat((vision_tokens_for_llm, text_embeddings_blip), dim=1)
print(f"éè¿Q-Formerèååï¼LLMçè¾å
¥åºåé¿åº¦: {fused_llm_input.shape[1]}") # ä¾å¦: 32 + 3 = 35
# 5. LLMçæååº (å»ç»)
llm_output_logits = frozen_llm(fused_llm_input)
# å®é
è¿éè¿éè¦ä¸ä¸ªçº¿æ§å±å°è¾åºembeddingæ å°å°è¯è¡¨logitsï¼ç¶åè¿è¡éæ ·çæ
print(f"LLMè¾åºlogitså½¢ç¶: {llm_output_logits.shape}")
print("LLMç°å¨å¯ä»¥æ ¹æ®å¾åä¿¡æ¯è¿è¡ææ¬çæäºï¼")
# æ¨èåæ³ï¼
# BLIP-2çQ-Formeræ¯ä¸ä¸ªé叏髿䏿æçæ¨¡ææ¡¥æ¥å¨ï¼å®è½å°é«ç»´å¾åä¿¡æ¯å缩并转æ¢ä¸ºLLMå¯çè§£çæ ¼å¼ï¼
# ä»èå¨ä¸ä¿®æ¹æå¤§å¹
è®ç»LLMçåæä¸ï¼èµäºå
¶å¼ºå¤§çè§è§çè§£è½åãè¿ç§å»ç»-æ¡¥æ¥-å»ç»çèå¼ï¼
# 大å¹
éä½äºè®ç»å¤æ¨¡æå¤§æ¨¡åç鍿§ã
3.3 LLaVAï¼æä»¤éµå¾ªä¸çè§è§å©æ
æ¦å¿µè§£éï¼LLaVA (Large Language and Vision Assistant) æ¯ä¸ä¸ªæå ·å½±ååç夿¨¡æLLMï¼å®éè¿å°é¢è®ç»çè§è§ç¼ç å¨ï¼å¦CLIPçViTï¼å大åè¯è¨æ¨¡åï¼å¦LLaMAï¼è¿æ¥èµ·æ¥ï¼å¹¶éè¿ä¸ä¸ªç®åççº¿æ§æå½±å± (Linear Projection Layer) å°è§è§ç¹å¾æ å°å°è¯è¨æ¨¡åçåµå ¥ç©ºé´ãç¶åï¼å¨è§è§æä»¤éµå¾ªæ°æ®ï¼Visual Instruction Tuningï¼ä¸è¿è¡å¾®è°ï¼ä½¿å ¶è½å¤çè§£åæ§è¡ä¸å¾åç¸å ³çæä»¤ï¼å¦å¾åæè¿°ãè§è§é®ççãLLaVAçæåè¯æäºå³ä½¿æ¯ç¸å¯¹ç®åçè¿æ¥æ¹å¼ï¼ç»åé«è´¨éçæä»¤å¾®è°æ°æ®ï¼ä¹è½å°LLM转å为强大ç夿¨¡æå©æã
æ ¸å¿ææ³ï¼
1. å©ç¨ç°æå¼ºå¤§çåºåº§æ¨¡åï¼ å
åå©ç¨äºCLIPçè§è§ç¼ç å¨å¼ºå¤§çå¾å表å¾è½åï¼ä»¥åLLaMAç大åè¯è¨æ¨¡åç强大è¯è¨çè§£åçæè½åã
2. ç®æ´ç模æå¯¹é½ï¼ éè¿ä¸ä¸ªç®åçå¤å±æç¥æºï¼MLPï¼æçº¿æ§å±ï¼å°è§è§ç¼ç å¨è¾åºçç¹å¾æå½±å°LLMçè¯åµå
¥ç©ºé´ï¼å®ç°æ¨¡æé´ç维度å¹é
ã
3. è§è§æä»¤å¾®è°ï¼ è¿æ¯LLaVAæåçå
³é®ã模åå¨ä¸ä¸ªå¤§è§æ¨¡çãé«è´¨éçè§è§æä»¤æ°æ®éä¸è¿è¡å¾®è°ï¼è¿äºæ°æ®å
å«å¾åãç¨æ·æä»¤ï¼å¦âæè¿°è¿å¼ å¾âãâå¾ä¸æä»ä¹ï¼âï¼å对åºçåçãè¿ä½¿å¾LLMå¦ä¼äºå¦ä½æ ¹æ®è§è§è¾å
¥åæä»¤è¿è¡å¤æ¨¡ææ¨çåçæã
ä¼å¿ï¼
* ç»æç¸å¯¹ç®åï¼ ç¸è¾äºBLIP-2çQ-Formerï¼LLaVAçè¿æ¥æ¹å¼æ´ä¸ºç´æ¥ï¼æäºçè§£åå®ç°ã
* è®ç»ææ¬è¾ä½ï¼ ä¸»è¦æ¯å¾®è°ä¸ä¸ªé¢è®ç»çLLMåè¿æ¥å±ï¼è䏿¯ä»å¤´è®ç»æ´ä¸ªæ¨¡åã
* æææäººï¼ å¨è§è§é®çåå¾åæè¿°çä»»å¡ä¸è¡¨ç°åºå¼ºå¤§çæ§è½ï¼è¯æäºæä»¤å¾®è°ç巨大æ½åã
# 示ä¾ï¼LLaVA飿 ¼çè§è§é®çï¼VQAï¼ç®åå®ç°
import torch
import torch.nn as nn
# 1. å»ç»çè§è§ç¼ç å¨ (ä¾å¦CLIPçViT)
class FrozenVisionEncoder(nn.Module):
def __init__(self, output_dim=768):
super().__init__()
# 模æVision Transformerè¾åºçpatch featuresï¼æpooler output
self.output_dim = output_dim
# å设è¾åºçæ¯ä¸ä¸ªè¡¨ç¤ºå¾åæ´ä½çç¹å¾åé
def forward(self, image_input):
# å®é
ä¼éè¿ä¸ä¸ªViTï¼è¿åå¾åç¹å¾åºåææ´ä½ç¹å¾
# è¿é模æè¿åä¸ä¸ª batch_size x feature_dim çå¼ é
return torch.randn(image_input.shape[0], self.output_dim)
# 2. 模æå¯¹é½ççº¿æ§æå½±å± (é常æ¯ä¸ä¸ªMLPï¼è¿éç®å为线æ§å±)
class ProjectionLayer(nn.Module):
def __init__(self, vision_feature_dim, llm_embedding_dim):
super().__init__()
# è¿ä¸ªçº¿æ§å±å°è§è§ç¹å¾ç»´åº¦æ å°å°LLMçè¯åµå
¥ç»´åº¦
self.linear = nn.Linear(vision_feature_dim, llm_embedding_dim)
self.activation = nn.GELU() # LLaVAé常ä¼ä½¿ç¨GELUæ¿æ´»å½æ°
self.norm = nn.LayerNorm(llm_embedding_dim)
def forward(self, vision_features):
projected = self.linear(vision_features)
projected = self.activation(projected)
return self.norm(projected)
# 3. å»ç»çLLM (ä¾å¦LLaMA)
class FrozenLLaMA(nn.Module):
def __init__(self, vocab_size, embedding_dim):
super().__init__()
self.token_embeddings = nn.Embedding(vocab_size, embedding_dim)
# ç®åLLaMA为ä¸ä¸ªTransformerè§£ç å¨å±ï¼å®é
æ¯å¤å±
self.transformer_decoder = nn.TransformerDecoderLayer(d_model=embedding_dim, nhead=4, batch_first=True)
self.output_head = nn.Linear(embedding_dim, vocab_size)
def forward(self, input_embeddings):
# å®é
LLaMA模åï¼è¿éç®åä¸ä¸ªDecoderå±çååä¼ æ
# å¨LLaVAä¸ï¼è§è§ç¹å¾ä½ä¸ºä¸ä¸æä¿¡æ¯ï¼ææ¬tokenä½ä¸ºqueryè¾å
¥å°LLM
output_hidden_states = self.transformer_decoder(input_embeddings, input_embeddings) # ç®åèªæ³¨æå
return self.output_head(output_hidden_states)
# æ¨¡ææ°æ®åç»ä»¶
image_data_llava = torch.randn(1, 3, 224, 224) # 模æä¸å¼ å¾å (batch, C, H, W)
# ç¨æ·æä»¤ï¼è¯·æè¿°è¿å¼ å¾çä¸åççäºæ
instruction_text = "请æè¿°è¿å¼ å¾çä¸åççäºæ
ã"
# å®ä¾åç»ä»¶
vision_encoder_llava = FrozenVisionEncoder(output_dim=768) # å设è¾åº768ç»´å¾åç¹å¾
projection_layer = ProjectionLayer(vision_feature_dim=768, llm_embedding_dim=768)
llama_model = FrozenLLaMA(vocab_size=32000, embedding_dim=768) # LLaMAé常çembedding维度
# 1. å¾åç¹å¾æå (å»ç»)
raw_vision_features = vision_encoder_llava(image_data_llava)
print(f"åå§è§è§ç¹å¾å½¢ç¶: {raw_vision_features.shape}")
# 2. æå½±å°LLMåµå
¥ç©ºé´ (å¯è®ç»)
# è§è§ç¹å¾è¢«æå½±æä¸ä¸ªæå¤ä¸ªâèæâtokençåµå
¥
projected_vision_tokens = projection_layer(raw_vision_features)
# LLaVAé常å°å¾åç¹å¾è¡¨ç¤ºä¸ºä¸ä¸ªåºåï¼å³ä½¿åå§æ¯ä¸ä¸ªåéï¼è¿é乿¨¡ææåºå
projected_vision_tokens = projected_vision_tokens.unsqueeze(1) # 模ææä¸ä¸ªtokenåºå
print(f"æå½±åçè§è§Tokenså½¢ç¶: {projected_vision_tokens.shape}") # ä¾å¦: torch.Size([1, 1, 768])
# 3. åå¤ææ¬æä»¤ï¼å
å«å¾åå ä½ç¬¦ï¼
# å®é
ä¸ï¼LLaVAä¼å°å¾åtokenæ¿æ¢æ<image_placeholder>
# è¿éæ¨¡æææ¬æä»¤çtokenåµå
¥
text_input_ids_instruction = torch.randint(0, 32000, (1, 15)) # 模æ15个token
text_tokens_embeddings = llama_model.token_embeddings(text_input_ids_instruction)
print(f"ææ¬æä»¤åµå
¥å½¢ç¶: {text_tokens_embeddings.shape}")
# 4. æ¼æ¥è§è§Tokensåææ¬Tokensï¼ä½ä¸ºLLMçè¾å
¥
# LLaVAé常å°å¾åç¹å¾æ¾å¨ææ¬æä»¤ä¹åï¼ä½ä¸ºä¸ä¸æ
fused_input_for_llama = torch.cat((projected_vision_tokens, text_tokens_embeddings), dim=1)
print(f"LLaVA飿 ¼èååï¼LLMçè¾å
¥åºåé¿åº¦: {fused_input_for_llama.shape[1]}") # ä¾å¦: 1 + 15 = 16
# 5. LLMæ ¹æ®èååçè¾å
¥çæåå¤ (å»ç»ï¼ä½å¨å¾®è°é¶æ®µå
¶åæ°ä¼æ´æ°)
llm_output_logits_llava = llama_model(fused_input_for_llama)
# ç»è¿è§£ç å±åè¾åºå¤´ï¼æç»çæææ¬åå¤
print(f"LLMè¾åºlogitså½¢ç¶: {llm_output_logits_llava.shape}")
print("LLMç°å¨å¯ä»¥æ ¹æ®è§è§æä»¤è¿è¡é®çææè¿°äºï¼")
# æ¨èåæ³ï¼
# LLaVAè¯æäºéè¿ç®åçæå½±å±åæä»¤å¾®è°ï¼å¯ä»¥ææå°å°LLM转å为强大ç夿¨¡æå©æã
# å®çæåå¨äºå·§å¦å°å©ç¨äºé¢è®ç»è§è§åè¯è¨æ¨¡åç强大è½åï¼å¹¶ä¸æ³¨äºæ¨¡æé´ç对é½ä¸æä»¤éµå¾ªå¦ä¹ ï¼
# 为åç»ç弿ºå¤æ¨¡æLLMæä¾äºéè¦çèä¾ã
3.4 Flamingoï¼æç¥å¨ä¸å»ç»LLMçç»å
æ¦å¿µè§£éï¼Google DeepMind ç Flamingo æ¨¡åæ¯å¦ä¸ä¸ªéè¦ç夿¨¡æLLMãå®çè®¾è®¡ç®æ æ¯å®ç°å¤æ¨¡æçå°æ ·æ¬å¦ä¹ (Few-shot Learning) è½åï¼å³å¨åªçå°å°é示ä¾çæ åµä¸å°±è½å¿«ééåºæ°ç夿¨¡æä»»å¡ãFlamingoçæ ¸å¿ç»ä»¶æ¯Perceiver Resampler å鍿§äº¤å注æåå± (Gated Cross-Attention) ãPerceiver Resampler è´è´£ä»é«ç»´ãå¯åé¿åº¦çè§è§ï¼æå¤æ¨¡æï¼è¾å ¥ä¸æåå°éãåºå®é¿åº¦çâæç¥å¨è¾åºï¼Perceiver Outputï¼âï¼å°å ¶è½¬æ¢ä¸ºä¸LLMè¯åµå ¥å ¼å®¹çTokenãè¿äºTokenéåéè¿é¨æ§äº¤å注æåå±ï¼æ³¨å ¥å°å»ç»çLLMçTransformerå±ä¹é´ï¼å®ç°è§è§ä¿¡æ¯å¯¹LLMçè°å¶ï¼ä½¿å ¶è½å¤æ ¹æ®è§è§ä¸ä¸æçæææ¬ã
æ ¸å¿ææ³ï¼
1. Perceiver Resamplerï¼ è§£å³äºå¤æ¨¡æè¾å
¥ï¼å°¤å
¶æ¯è§é¢ï¼é¿åº¦ä¸ä¸ãä¿¡æ¯éè¿å¤§çé®é¢ãå®éè¿äº¤å注æåæºå¶ï¼è®©ä¸ç»å¯å¦ä¹ çâæ¥è¯¢åéï¼Query Latentsï¼âä»å¤§éçè§è§ç¹å¾ä¸âéæ ·âåºå°éãåºå®æ°éä¸è¯ä¹ä¸°å¯çè§è§Tokenï¼ææå°è¿è¡äºä¿¡æ¯å缩åæç¼ã
2. 鍿§äº¤å注æåï¼ è¿äºè§è§Tokenéè¿äº¤å注æåå±ä¸LLMçææ¬Tokenè¿è¡äº¤äºï¼å¹¶ä¸å¼å
¥**鍿§æºå¶**æ§å¶è§è§ä¿¡æ¯æ³¨å
¥ç强度ãè¿ç§é¨æ§è®¾è®¡ç¡®ä¿äºè§è§ä¿¡æ¯ç注å
¥ä¸ä¼ç ´åLLMåæç强大è¯è¨è½åï¼èæ¯ä»¥ä¸ç§æ¸©åã坿§çæ¹å¼è¿è¡è°å¶ã
3. å»ç»LLMï¼ ä¸BLIP-2åLLaVA类似ï¼Flamingoä¹å©ç¨äºé¢è®ç»LLMç强大è¯è¨è½åï¼åªè®ç»Perceiver Resampleråäº¤åæ³¨æå模åï¼å¤§å¹
åå°äºéè¦è®ç»çåæ°éï¼ä»èå éäºè®ç»å¹¶åå°äºå¯¹å¤§é夿¨¡æé
å¯¹æ°æ®çä¾èµã
# 示ä¾ï¼æ¨¡æPerceiver Resamplerå¦ä½å°è§è§ç¹å¾è½¬æ¢ä¸ºLLM坿¥åçTokens
import torch
import torch.nn as nn
import torch.nn.functional as F
# å¤ç¨ä¹åçCrossModalAttentionå®ä¹
class CrossModalAttention(nn.Module):
def __init__(self, query_dim, key_dim, value_dim, output_dim, num_heads=4):
super().__init__()
self.num_heads = num_heads
self.head_dim = value_dim // num_heads
assert self.head_dim * num_heads == value_dim, "value_dim must be divisible by num_heads"
self.wq = nn.Linear(query_dim, value_dim)
self.wk = nn.Linear(key_dim, value_dim)
self.wv = nn.Linear(value_dim, value_dim)
self.fc_out = nn.Linear(value_dim, output_dim)
def forward(self, query_features, key_features, value_features):
batch_size = query_features.shape[0]
Q = self.wq(query_features).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
K = self.wk(key_features).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
V = self.wv(value_features).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
energy = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
attention_weights = F.softmax(energy, dim=-1)
x = torch.matmul(attention_weights, V)
x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
return self.fc_out(x)
class PerceiverResampler(nn.Module):
def __init__(self, input_dim, output_len, output_dim, num_layers=2):
super().__init__()
# å¯å¦ä¹ çQuery Latentsï¼å
¶æ°éå³å®äºè¾åºè§è§Tokençæ°é
self.query_latents = nn.Parameter(torch.randn(1, output_len, output_dim))
# ç®åçäº¤åæ³¨æåå±ï¼è¿éå¯ä»¥å¤ç¨ä¹åçCrossModalAttentionï¼
# å®é
FlamingoçPerceiver Resamplerå
å«å¤å±Transformerå
self.cross_attention_block = CrossModalAttention(
query_dim=output_dim, key_dim=input_dim, value_dim=input_dim, output_dim=output_dim
)
self.norm = nn.LayerNorm(output_dim)
def forward(self, visual_features):
# visual_features: [batch_size, num_visual_tokens, input_dim]
batch_size = visual_features.shape[0]
# Query Latentsä½ä¸ºQueryï¼è§è§ç¹å¾ä½ä¸ºKeyåValue
# éè¿å¤å±äº¤å注æåï¼ä»å¤§éè§è§ç¹å¾ä¸âéæ ·âåºå°éæä»£è¡¨æ§çä¿¡æ¯
# expandæ¯ä¸ºäºéé
batch_size
resampled_tokens = self.cross_attention_block(
query_features=self.query_latents.expand(batch_size, -1, -1),
key_features=visual_features,
value_features=visual_features
)
return self.norm(resampled_tokens)
class GatedCrossAttention(nn.Module):
def __init__(self, llm_hidden_dim, vision_token_dim, num_heads=4):
super().__init__()
# äº¤åæ³¨æåå±ï¼è®©LLMçéèç¶æä½ä¸ºQueryï¼è§è§Tokensä½ä¸ºKey/Value
self.cross_attention = CrossModalAttention(
query_dim=llm_hidden_dim, key_dim=vision_token_dim, value_dim=vision_token_dim, output_dim=llm_hidden_dim
)
# 鍿§æºå¶ï¼ä¸ä¸ªçº¿æ§å±åæ¥Sigmoidï¼æ§å¶è§è§ä¿¡æ¯æ³¨å
¥ç强度
self.gate = nn.Linear(llm_hidden_dim, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, llm_hidden_states, vision_tokens):
# llm_hidden_states: [batch_size, seq_len, llm_hidden_dim]
# vision_tokens: [batch_size, num_vision_tokens, vision_token_dim]
# LLMçéèç¶ææ¥è¯¢è§è§Tokensï¼è·åè§è§ä¸ä¸æä¿¡æ¯
attended_vision_info = self.cross_attention(
query_features=llm_hidden_states,
key_features=vision_tokens,
value_features=vision_tokens
)
# 鍿§æºå¶ï¼è®¡ç®ä¸ä¸ª0å°1ä¹é´ç鍿§å¼ï¼æ§å¶è§è§ä¿¡æ¯æ³¨å
¥LLMçç¨åº¦
gate_values = self.sigmoid(self.gate(llm_hidden_states))
# å°LLMçåå§éèç¶æä¸é¨æ§åçè§è§ä¿¡æ¯ç¸å
return llm_hidden_states + gate_values * attended_vision_info
# 模æè¾å
¥
video_features = torch.randn(1, 1000, 1024) # 模æ1000å¸§ï¼æ¯å¸§1024ç»´çè§é¢ç¹å¾
llm_hidden_states = torch.randn(1, 50, 768) # 模æLLMä¸é´å±çéç¶æ (Batch=1, SeqLen=50, Dim=768)
# å®ä¾åç»ä»¶
perceiver_resampler = PerceiverResampler(input_dim=1024, output_len=64, output_dim=768) # å°è§é¢ç¹å¾å缩å°64个Token
gated_cross_attention_layer = GatedCrossAttention(llm_hidden_dim=768, vision_token_dim=768)
# 1. Perceiver Resamplerå¤çè§é¢ç¹å¾ï¼çæåºå®æ°éçè§è§Token
compressed_vision_tokens = perceiver_resampler(video_features)
print(f"Perceiver Resamplerè¾åºçè§è§Token维度: {compressed_vision_tokens.shape}") # ä¾å¦: torch.Size([1, 64, 768])
# 2. 鍿§äº¤å注æåå±å°è§è§Token注å
¥LLMçéèç¶æ
modulated_llm_hidden_states = gated_cross_attention_layer(llm_hidden_states, compressed_vision_tokens)
print(f"鍿§äº¤å注æååLLMçéèç¶æç»´åº¦: {modulated_llm_hidden_states.shape}") # ä¾å¦: torch.Size([1, 50, 768])
# è¿äºè¢«è°å¶çéèç¶æä¼ç»§ç»æµç»LLMçåç»å±ï¼ä»èå½±åæç»çææ¬çæã
# æ¨èåæ³ï¼
# FlamingoçPerceiver Resamplerè§£å³äºå¤æ¨¡æè¾å
¥é¿åº¦ä¸ä¸ãä¿¡æ¯éè¿å¤§çé®é¢ï¼
# å°¤å
¶éç¨äºè§é¢è¿ç±»åºåé¿åº¦å¯åçæ¨¡æã
# è鍿§äº¤å注æååå¨ä¸ç ´åLLMåæè¯è¨è½åçåæä¸ï¼å®ç°äºè§è§ä¿¡æ¯çæææ³¨å
¥ï¼
# è¿ä½¿å¾Flamingoå¨å°æ ·æ¬å¦ä¹ åºæ¯ä¸è¡¨ç°åºä¼ï¼è½å¤å¿«ééåºæ°ç夿¨¡æä»»å¡ã
第åç« ï¼LLM夿¨¡æèåç宿ææä¸ä¼å
夿¨¡æèåå¬èµ·æ¥å¾ç¾å¥½ï¼ä½å¨å®é æä½ä¸ï¼æä»¬ä»é¢ä¸´è¯¸å¤ææãè¦æå»ºä¸ä¸ªé«æã鲿£ç夿¨¡æLLMï¼éè¦æä»¬å¨æ°æ®ãè®ç»åæ¨ççå¤ä¸ªå±é¢è¿è¡ä¼åã
4.1 æ°æ®å¯¹é½ä¸æ 注ï¼é«è´¨éèåçåºç³
æ¦å¿µè§£éï¼é«è´¨éç夿¨¡ææ°æ®éæ¯è®ç»æåçå
³é®ãå®ä¸ä»
ä»
æ¯ç®åå°å°ä¸å模æçæ°æ®æ¶éèµ·æ¥ï¼æ´éè¦çæ¯ç¡®ä¿è¿äºæ°æ®å¨è¯ä¹åæ¶é´ç»´åº¦ä¸æ¯å确对é½çãè¿å
æ¬ï¼
* 模æé´è¯ä¹å¯¹é½ï¼ ç¡®ä¿å¾å䏿æ¬ãè§é¢ä¸é³é¢å¨è¯ä¹ä¸æ¯å¹é
çãä¾å¦ï¼ä¸å¼ å¾çæè¿°çç«ï¼å¯¹åºçææ¬ä¹åºè¯¥æ¯ä¸åªç«ï¼å¹¶ä¸ææ¬åºè¯¥åç¡®æè¿°å¾åä¸çå
³é®å
ç´ åå
³ç³»ãè¿å¾å¾éè¦èæ¶èåçäººå·¥æ æ³¨ã
* æ¶åºå¯¹é½ï¼ 对äºè§é¢-ææ¬ãé³é¢-ææ¬çåºå模æï¼éè¦ç¡®ä¿ä¸å模æå¨æ¶é´ç»´åº¦ä¸æ¯åæ¥çãä¾å¦ï¼è§é¢ä¸çæä¸ªå¨ä½å¿
é¡»ä¸é³é¢ä¸æè¿°è¯¥å¨ä½çå£°é³æææ¬æè¿°çæ¶é´ç¹ç²¾ç¡®å¯¹åºã
* æ°æ®éä¸å¤æ ·æ§ï¼ 夿¨¡ææ¨¡åéè¦æå
¶åºå¤§çæ°æ®éæè½æ¶æå¹¶æ³åãæ°æ®éçé¢å夿 ·æ§ãåºæ¯å¤æ ·æ§ä¹è³å
³éè¦ï¼ä»¥ç¡®ä¿æ¨¡åè½å¤å¤çåç§ç°å®ä¸çç夿æ
åµã
ææï¼å¤æ¨¡ææ°æ®éçæå»ºææ¬é«æï¼è´¨éé¾ä»¥ä¿è¯ãäººå·¥æ æ³¨ä¸ä» èæ¶èåï¼èä¸å®¹æåºéï¼å°¤å ¶æ¯å¨å¤çç»ç²åº¦å¯¹é½åæ§ä¹æ§é®é¢æ¶ãæ¤å¤ï¼æ°æ®éç§å伦çé®é¢ä¹æ¥ççªåºãé误æä½è´¨éç对é½ä¼ä¸¥é影忍¡åæ§è½ï¼å¯¼è´æ¨¡åå¦ä¹ å°é误çå ³èï¼ä»è卿¨çæ¶äº§çå¹»è§æä¸åç¡®çåçã
# æ¨¡ææ°æ®éé¢å¤çä¸ç坹齿¥éª¤ä¼ªä»£ç
import torch
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
def load_and_transform_image(path):
# å®é
ä¼ä»æä»¶å è½½å¾çå¹¶è¿è¡é¢å¤ç
# è¿é模æè¿åä¸ä¸ªå¼ é
return torch.randn(1, 3, 224, 224)
def tokenize_text(text):
# å®é
ä¼ä½¿ç¨LLMçtokenizerå°ææ¬è½¬æ¢ä¸ºtoken IDs
return torch.randint(0, 1000, (1, 20)) # 模æToken IDsï¼é¿åº¦20
def process_audio(audio):
# å®é
ä¼å¤çé³é¢æä»¶ï¼æå声å¦ç¹å¾ï¼å¦Melé¢è°±å¾
return torch.randn(1, 100, 128) # 模æé³é¢ç¹å¾ï¼100å¸§ï¼æ¯å¸§128ç»´
def check_semantic_consistency(img_tensor, txt_tensor):
# å®é
æ¯å¤æç模å夿æäººå·¥æ ¡éªãä¾å¦ï¼éè¿CLIP计ç®å¾ååææ¬çç¸ä¼¼åº¦ã
# è¿éç®åä¸ºéæºè¿åTrue/False
return np.random.rand() > 0.1 # 90%æ¦çè¯ä¹ä¸è´
def check_temporal_alignment(aud_tensor, txt_tensor):
# å®é
æ¯å¤æçç®æ³å¤æï¼å¦æ¶é´åºå对é½ç®æ³
return np.random.rand() > 0.05 # 95%æ¦çæ¶åºå¯¹é½
def preprocess_multimodal_data(image_path, text_caption, audio_segment=None):
# 1. å¾åé¢å¤ç (resize, normalize)
image_tensor = load_and_transform_image(image_path)
# 2. ææ¬é¢å¤ç (tokenize, pad)
text_tokens = tokenize_text(text_caption)
# 3. é³é¢é¢å¤ç (resample, mel-spectrogram)
audio_tensor = None
if audio_segment:
audio_tensor = process_audio(audio_segment)
# 4. 模æå¯¹é½æ ¡éª (ç®åç伪代ç ï¼å®é
鿴夿é»è¾æäººå·¥ä»å
¥)
# 对äºç产ç¯å¢ï¼è¿äºè¦åå¯è½éè¦æ´ä¸¥æ ¼çå¤çï¼å¦ç´æ¥è¿æ»¤æé®é¢æ°æ®
if not check_semantic_consistency(image_tensor, text_tokens):
print(f" è¦åï¼å¾å '{image_path}' ä¸ææ¬æè¿°å¯è½è¯ä¹ä¸ä¸è´ï¼")
# å®é
å¯è½è¿è¡è¿æ»¤ãéæ°æ æ³¨æå¼±çç£å¦ä¹
# return None # å¯ä»¥éæ©è·³è¿æ¤æ°æ®
if audio_tensor is not None and not check_temporal_alignment(audio_tensor, text_tokens):
print(f" è¦åï¼é³é¢ä¸ææ¬å¯è½æ¶åºæªå¯¹é½ï¼")
# return None # å¯ä»¥éæ©è·³è¿æ¤æ°æ®
return {
"image_input": image_tensor,
"text_input": text_tokens,
"audio_input": audio_tensor
}
# 示ä¾è°ç¨
processed_data = preprocess_multimodal_data(
"/data/img_001.jpg",
"ä¸åªå¯ç±çå°çå¨å
¬åéç©è"
)
if processed_data:
print("
æ°æ®é¢å¤ç宿ï¼åå¤éå
¥æ¨¡åè®ç»ã")
else:
print("
æ°æ®è´¨éä¸è¾¾æ ï¼å·²è·³è¿ã")
# 䏿¨èï¼ç´æ¥ä½¿ç¨æªç»æ ¡éªæä½è´¨éç夿¨¡ææ°æ®ã
# é®é¢ï¼æ°æ®åªå£°åä¸å¯¹é½ä¼å¯¼è´æ¨¡åå¦ä¹ å°é误çå
³èï¼å½±åæ³åè½åï¼çè³äº§çâå¹»è§âã
# æä½³å®è·µæ¸
åï¼
# 1. æ°æ®æ¥æºå¤æ ·åï¼ç»åå
¬å¼çé«è´¨é夿¨¡ææ°æ®éï¼å¦MS-COCO, Conceptual Captions, WebLIï¼ï¼ä¼ä¸å
鍿°æ®ï¼ä»¥åéè¿è§åæçææ¨¡ååæçæ°æ®ã
# 2. ç»ç²åº¦æ 注ï¼ä¸ä»
æ æ³¨æ´ä½å
容ï¼è¿éæ æ³¨æ¨¡æé´ç»è´ç对åºå
³ç³»ï¼å¦å¾åä¸ç¹å®åºå䏿æ¬ä¸ç¹å®è¯è¯ç对åºï¼ã对äºå¤æä»»å¡ï¼èèå¤è½®æ 注å交åéªè¯ã
# 3. èªå¨åè¾
å©ä¸å¼±çç£ï¼å©ç¨ç°ææ¨¡åï¼å¦CLIPï¼è¿è¡åæ¥å¯¹é½æè¿æ»¤ï¼åå°äººå·¥æ æ³¨ææ¬ãæ¢ç´¢å©ç¨åªå£°æ°æ®è¿è¡å¼±çç£å¦ä¹ ï¼æéè¿èªçç£ä»»å¡ï¼å¦æ¨¡æå¹é
ï¼é¢è®ç»ã
# 4. è´¨éæ§å¶ï¼å»ºç«ä¸¥æ ¼çæ æ³¨è§èåå¤è½®å®¡æ ¸æºå¶ï¼ç¡®ä¿æ°æ®è´¨éã宿è¿è¡æ°æ®å®¡è®¡ï¼åç°å¹¶çº æ£æ 注é误ã
# 5. éç§ä¿æ¤ï¼å¤çæææ°æ®æ¶ï¼ä¸¥æ ¼éµå®æ°æ®éç§æ³è§ï¼è¿è¡å¿ååã廿 è¯åæå·®åéç§å¤çã
4.2 è®ç»çç¥ä¸èµæºæ¶è
ææï¼å¤æ¨¡æLLMé叏忰éå·¨å¤§ï¼æ¶åå¤ä¸ªç¼ç å¨åè¯è¨æ¨¡åã端å°ç«¯è®ç»éè¦æµ·éç计ç®èµæºï¼GPUæ¾ååç®åï¼å漫é¿çè®ç»æ¶é´ï¼è¿å¯¹äºå¤§å¤æ°ç ç©¶å¢éåä¼ä¸æ¥è¯´é½æ¯ä¸ä¸ªå·¨å¤§çéç¢ãæ¤å¤ï¼å¤§å模åçè®ç»ç¨³å®æ§ä¹æ´é¾æ§å¶ã
ä¼åæ¹æ¡ï¼ä¸ºäºåºå¯¹è¿äºææï¼æä»¬é常éç¨ä»¥ä¸çç¥ï¼
* å»ç»é¢è®ç»æ¨¡å (Frozen Pre-trained Models)ï¼ è¿æ¯æå¸¸è§ççç¥ï¼å¦BLIP-2åFlamingoæç¤ºãå»ç»å¼ºå¤§çè§è§ç¼ç å¨åLLMï¼åªè®ç»æ¨¡ææ¡¥æ¥å¨ï¼å¦Q-FormerãPerceiver Resamplerï¼æå°éAdapterå±ãè¿è½æ¾èåå°å¯è®ç»åæ°å计ç®éï¼å°è®ç»éç¹æ¾å¨æ¨¡æé´ç对é½åä¿¡æ¯äº¤äºä¸ã
* åæ°é«æå¾®è° (Parameter-Efficient Fine-Tuning, PEFT)ï¼ å¨å»ç»å¤§é¨åé¢è®ç»æ¨¡ååæ°çæ
åµä¸ï¼æå
¥å°éå¯è®ç»ç模åæå¯¹ç°æåæ°è¿è¡ä½ç§©æ´æ°ã
* Adapter å¾®è°ï¼ å¨Transformerå±çç¹å®ä½ç½®ï¼å¦FFNå±ä¹é´ï¼æå
¥å°åãå¯è®ç»çâéé
å¨â模åã
* LoRA (Low-Rank Adaptation)ï¼ ç¨ä½ç§©ç©éµåè§£æ¥è¿ä¼¼æ´æ°é¢è®ç»æ¨¡åçæéãå®éè¿å¼å
¥ä¸¤ä¸ªå°çç©éµ A å B æ¥è¡¨ç¤ºæéå¢é âW = BAï¼åªè®ç» A å Bï¼èå»ç»åå§æé Wãè¿ä½¿å¾å¾®è°çåæ°é大å¹
åå°ï¼éå¸¸åªæåå§æ¨¡ååæ°ç0.01%å°1%ã
* æ¸è¿å¼è®ç» (Progressive Training)ï¼ åé¶æ®µè®ç»æ¨¡åãä¾å¦ï¼å
è¿è¡æ¨¡æå¯¹é½é¢è®ç»ï¼å¦CLIPï¼ï¼åè¿è¡å¤æ¨¡æçæå¼é¢è®ç»ï¼å¦BLIPï¼ï¼æåè¿è¡å¤æ¨¡ææä»¤å¾®è°ï¼å¦LLaVAï¼ãè¿ç§åé¶æ®µçæ¹æ³å¯ä»¥éæ¥å¼å
¥å¤ææ§ï¼æé«è®ç»çç¨³å®æ§åæçã
* åå¸å¼è®ç»ï¼ å©ç¨å¤GPUãå¤èç¹è¿è¡å¹¶è¡è®ç»ï¼å¦æ°æ®å¹¶è¡ã模åå¹¶è¡åæµæ°´çº¿å¹¶è¡ï¼ä»¥å¤çè¶
å¤§è§æ¨¡æ¨¡ååæ°æ®éã
# LoRAå¨å¤æ¨¡ææ¨¡åä¸çåºç¨ç¤ºä¾ï¼ä¼ªä»£ç ï¼
import torch
import torch.nn as nn
import torch.nn.functional as F
class LoRA_Linear(nn.Module):
def __init__(self, original_linear_layer, rank=4, alpha=1.0):
super().__init__()
self.original_linear = original_linear_layer
self.rank = rank
self.alpha = alpha
# å»ç»åå§æéï¼ä½¿å
¶å¨LoRAè®ç»æé´ä¸è¢«æ´æ°
self.original_linear.weight.requires_grad = False
if self.original_linear.bias is not None:
self.original_linear.bias.requires_grad = False
# LoRAçä½ç§©åè§£ç©éµ A å B
# A ç©éµå°è¾å
¥ç»´åº¦æ å°å°ç§© r
self.lora_A = nn.Parameter(torch.randn(original_linear_layer.in_features, rank))
# B ç©éµå°ç§© r æ å°å°è¾åºç»´åº¦
self.lora_B = nn.Parameter(torch.randn(rank, original_linear_layer.out_features))
# åå§å LoRA æé
# é常 A éç¨ Kaiming åå§åï¼B éç¨é¶åå§åï¼ç¡®ä¿åå§å¢é为0
nn.init.kaiming_uniform_(self.lora_A, a=5**0.5)
nn.init.zeros_(self.lora_B)
def forward(self, x):
# åå§è®¡ç®è·¯å¾
original_output = self.original_linear(x)
# LoRAçå¢é计ç®ï¼(x @ A) @ Bï¼å¹¶éè¿ alpha/rank è¿è¡ç¼©æ¾
lora_delta = (x @ self.lora_A @ self.lora_B) * (self.alpha / self.rank)
# æç»è¾åºæ¯åå§è¾åºå ä¸ LoRA å¢é
return original_output + lora_delta
# 模æLLMä¸çä¸ä¸ªçº¿æ§å± (ä¾å¦ï¼TransformerçQKVæå½±å±)
llm_original_linear = nn.Linear(768, 768) # è¾å
¥åè¾åºç»´åº¦é½æ¯768
llm_input_features = torch.randn(1, 128, 768) # 模æLLMè¾å
¥ (Batch, SeqLen, Dim)
print(f"åå§çº¿æ§å±åæ°æ°é: {sum(p.numel() for p in llm_original_linear.parameters())}")
# æ¿æ¢ä¸ºLoRAå±ï¼rank=8æå³çLoRAåæ°é大大åå°
llm_lora_linear = LoRA_Linear(llm_original_linear, rank=8)
# æå°å¯è®ç»åæ°ï¼éªè¯åªælora_Aålora_Bæ¯å¯è®ç»ç
print("
LoRAå±ä¸çå¯è®ç»åæ°:")
trainable_params_count = 0
for name, param in llm_lora_linear.named_parameters():
if param.requires_grad:
print(f" - {name}, å½¢ç¶: {param.shape}, åæ°é: {param.numel()}")
trainable_params_count += param.numel()
else:
print(f" - {name}, å½¢ç¶: {param.shape}, Requires grad: {param.requires_grad} (å»ç»)")
print(f"LoRA屿»å¯è®ç»åæ°æ°é: {trainable_params_count}")
# åå§åæ°é (768*768 + 768) = 590592
# LoRAåæ°é (768*8 + 8*768) = 12288ï¼å¤§å¹
åå°
# æ§è¡ååä¼ æ
output_with_lora = llm_lora_linear(llm_input_features)
print(f"使ç¨LoRAåçè¾åºå½¢ç¶: {output_with_lora.shape}")
# 䏿¨èï¼ç´æ¥å
¨åæ°å¾®è°ä¸ä¸ªå
嫿°åäº¿åæ°ç夿¨¡æLLMï¼è¿ä¼æ¶è巨大ç计ç®èµæºåæ¶é´ï¼å¹¶ä¸å®¹æè¿æåã
# æ¨èï¼éç¨LoRAãAdapterçåæ°é«æå¾®è°ææ¯ï¼å¨ä¿è¯æ§è½ç忶大å¹
éä½è®ç»ææ¬ï¼å éå®éªè¿ä»£ã
4.3 æ§è½ä¼å䏿¨çæç
ææï¼å¤æ¨¡æLLMé常å å«å¤ä¸ªå¤§å模åï¼è§è§ç¼ç å¨ãLLMï¼ï¼å¯¼è´æ¨çæ¶å»¶é«ãæ¾åå ç¨å¤§ï¼é¾ä»¥å¨å®æ¶åºç¨æèµæºåé设å¤ï¼å¦ç§»å¨è®¾å¤ãè¾¹ç¼è®¾å¤ï¼ä¸é¨ç½²ãé«æçæ¨çææ¬ä¹éå¶äºå ¶å¤§è§æ¨¡åºç¨ã
ä¼åæ¹æ¡ï¼ä¸ºäºä½¿å¤æ¨¡æLLMå¨å®é
åºç¨ä¸æ´å
·å¯è¡æ§ï¼æä»¬éè¦éç¨å¤ç§æ§è½ä¼åææ¯ï¼
* 模åéå (Quantization)ï¼ å°æ¨¡åæéåæ¿æ´»ä»æµ®ç¹æ°ï¼å¦FP32ãFP16ï¼è½¬æ¢ä¸ºä½ç²¾åº¦æ´æ°ï¼å¦INT8ï¼ãè¿è½å¤§å¹
åå°æ¨¡å大å°åæ¾åå ç¨ï¼é常åå°2-4åï¼ï¼å¹¶å 鿍çï¼å ä¸ºæ´æ°è¿ç®æ´å¿«ï¼ä¸å¯å©ç¨ä¸ç¨ç¡¬ä»¶å éï¼ãææå¨äºéåå¯è½å¯¼è´ç²¾åº¦æå¤±ï¼éè¦éæ©åéçéåçç¥ï¼å¦è®ç»åéåãéåæç¥è®ç»ï¼ã
* 模ååªæ (Pruning)ï¼ ç§»é¤æ¨¡åä¸ä¸éè¦çè¿æ¥ï¼æéï¼æç¥ç»å
ï¼ééï¼ï¼å¨ä¸æ¾èé使§è½çåæä¸åå°æ¨¡åè§æ¨¡ãåªæå¯ä»¥å为éç»æååªæï¼ç§»é¤å个æéï¼åç»æååªæï¼ç§»é¤æ´ä¸ªç¥ç»å
æééï¼ï¼åè
æ´å©äºç¡¬ä»¶å éã
* ç¥è¯è¸é¦ (Knowledge Distillation)ï¼ ç¨ä¸ä¸ªæ´å°çâå¦ç模åâå¦ä¹ ä¸ä¸ªå¤§åâæå¸æ¨¡åâçè¡ä¸ºãå¦ç模åéè¿æ¨¡ä»¿æå¸æ¨¡åçè¾åºï¼å¦logitsæä¸é´ç¹å¾ï¼æ¥å¦ä¹ ï¼ä»èè·å¾æ¥è¿æå¸æ¨¡åçæ§è½ï¼ä½æ¨¡åä½ç§¯åæ¨çé度大大ä¼åã
* åå¸å¼æ¨ç (Distributed Inference)ï¼ å°å¤§å模ååè§£å°å¤ä¸ªGPUæå¤å°æºå¨ä¸è¿è¡æ¨çï¼ä»¥åºå¯¹å设å¤èµæºä¸è¶³çé®é¢ãè¿å
æ¬æ¨¡åå¹¶è¡ï¼å°æ¨¡åçä¸å屿é¨åæ¾ç½®å¨ä¸å设å¤ä¸ï¼åæµæ°´çº¿å¹¶è¡ï¼å°è¯·æ±å¨æ¨¡åå±é´ä»¥æµæ°´çº¿æ¹å¼å¤çï¼ã
* æ¹å¤çä¼å (Batching)ï¼ å¨å¯è½çæ
åµä¸ï¼å¯¹å¤ä¸ªè¾å
¥è¿è¡æ¹å¤çæ¨çï¼å
åå©ç¨GPUå¹¶è¡è®¡ç®è½åãæ¹å¤çå¯ä»¥æ¾èæé«ååéï¼ä½ä¼å¢å å»¶è¿ã
* 硬件å éï¼ å©ç¨ä¸é¨çAIè¯çè¿è¡å éï¼å¦NVIDIAçTensor Coresï¼æ¯æFP16åINT8è¿ç®ï¼ãGoogleçTPUã以ååç§è¾¹ç¼AIè¯çãè¿äºç¡¬ä»¶é对深度å¦ä¹ 计ç®è¿è¡äºä¼åï¼è½æä¾æ¯éç¨CPU/GPUæ´é«çè½ææ¯ã
* ä¼åæ¨çæ¡æ¶ï¼ 使ç¨TensorRTãOpenVINOãONNX Runtimeç髿§è½æ¨ç弿ï¼å®ä»¬å¯ä»¥å¯¹æ¨¡åå¾è¿è¡ä¼åï¼å¦å±èåãå
åä¼åï¼ï¼è¿ä¸æ¥å 鿍çã
# æ§è½å¯¹æ¯ï¼ä¸åèåçç¥å¨æ¨çé度ä¸ç模æå·®å¼ (伪代ç )
# å设ä¸ä¸ªç®åçå¾å-ææ¬é®çä»»å¡
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
# å¤ç¨ä¹åçCrossModalAttentionå®ä¹
class CrossModalAttention(nn.Module):
def __init__(self, query_dim, key_dim, value_dim, output_dim, num_heads=4):
super().__init__()
self.num_heads = num_heads
self.head_dim = value_dim // num_heads
assert self.head_dim * num_heads == value_dim, "value_dim must be divisible by num_heads"
self.wq = nn.Linear(query_dim, value_dim)
self.wk = nn.Linear(key_dim, value_dim)
self.wv = nn.Linear(value_dim, value_dim)
self.fc_out = nn.Linear(value_dim, output_dim)
def forward(self, query_features, key_features, value_features):
batch_size = query_features.shape[0]
Q = self.wq(query_features).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
K = self.wk(key_features).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
V = self.wv(value_features).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
energy = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
attention_weights = F.softmax(energy, dim=-1)
x = torch.matmul(attention_weights, V)
x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
return self.fc_out(x)
def measure_inference_time(model_func, *args, iterations=10, device="cpu"):
# å°æ¨¡ååè¾å
¥æ°æ®ç§»å¨å°æå®è®¾å¤
model_func.to(device)
args_on_device = [arg.to(device) if isinstance(arg, torch.Tensor) else arg for arg in args]
# é¢çï¼è¿è¡å 次ï¼è®©GPUæCPUç¼åé¢ç
for _ in range(3):
_ = model_func(*args_on_device)
# æµéï¼å¤æ¬¡è¿è¡å平忶é´
start_time = time.perf_counter()
for _ in range(iterations):
_ = model_func(*args_on_device)
end_time = time.perf_counter()
return ((end_time - start_time) / iterations) * 1000 # 毫ç§/次
# 模æä¸åèåçç¥ç模å
class EarlyFusionModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(512, 10) # 256(img_feat) + 256(txt_feat)
def forward(self, img_feat, txt_feat):
return self.linear(torch.cat((img_feat, txt_feat), dim=-1))
class LateFusionModel(nn.Module):
def __init__(self):
super().__init__()
self.img_cls = nn.Linear(256, 5)
self.txt_cls = nn.Linear(256, 5)
self.final_cls = nn.Linear(10, 10) # èå两个5ç»´çè¾åº
def forward(self, img_feat, txt_feat):
img_out = self.img_cls(img_feat)
txt_out = self.txt_cls(txt_feat)
return self.final_cls(torch.cat((img_out, txt_out), dim=-1))
class HybridFusionModel(nn.Module):
def __init__(self):
super().__init__()
self.cross_attn = CrossModalAttention(query_dim=256, key_dim=256, value_dim=256, output_dim=256, num_heads=4)
self.linear = nn.Linear(256, 10) # å设è¾åºç»´åº¦ä¸EarlyFusionç¸å
def forward(self, img_feat, txt_feat):
# å设txt_featæ¯queryï¼img_featæ¯key/value
# éè¦unsqueeze(1)æ¥æ¨¡æåºåé¿åº¦ä¸º1çè¾å
¥
fused_feat = self.cross_attn(
txt_feat.unsqueeze(1),
img_feat.unsqueeze(1),
img_feat.unsqueeze(1)
).squeeze(1) # æ¢å¤batch_size, feature_dim
return self.linear(fused_feat)
# 模æç¹å¾
img_feat_sample = torch.randn(1, 256)
txt_feat_sample = torch.randn(1, 256)
# æ£æµå¯ç¨è®¾å¤
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"å½åæ¨ç设å¤: {device.upper()}")
# å®ä¾å模å
early_model = EarlyFusionModel()
late_model = LateFusionModel()
hybrid_model = HybridFusionModel()
print("
=== èåçç¥æ¨çæ¶é´å¯¹æ¯ (模æ) ===")
print("ï¼è¯·æ³¨æï¼å®é
æ§è½å模å大å°ãç¡¬ä»¶ãæ¹å¤ç大å°çå ç´ å½±åãï¼")
# æ©æèåé常æå¿«ï¼å 为å®é常æ¯åä¸ãæ´ç´§åçç½ç»ç»æ
early_time = measure_inference_time(early_model, img_feat_sample, txt_feat_sample, device=device)
print(f"æ©æè忍¡åæ¨çæ¶é´: {early_time:.2f} ms/次")
# ææèåå¯è½éè¦åå«è¿è¡å¤ä¸ªå模åï¼ç¶ååå¹¶ç»æã妿忍¡åå¹¶è¡è¿è¡ï¼æ»æ¶é´å¯è½æ¥è¿ææ
¢ç忍¡åï¼
# 妿䏲è¡ï¼å伿¯ç´¯å ãè¿é模æä¸ºä¸²è¡ï¼æ¶é´å¯è½ç¥é¿ã
late_time = measure_inference_time(late_model, img_feat_sample, txt_feat_sample, device=device)
print(f"ææè忍¡åæ¨çæ¶é´: {late_time:.2f} ms/次")
# æ··åèåï¼ç¹å«æ¯å¸¦æ³¨æåæºå¶ï¼é常计ç®éæå¤§ï¼æ¨çæ¶é´å¯è½æé¿ï¼
# å 为注æåæºå¶æ¶åç©éµä¹æ³åSoftmaxè¿ç®ï¼å¤æåº¦è¾é«ã
hybrid_time = measure_inference_time(hybrid_model, img_feat_sample, txt_feat_sample, device=device)
print(f"æ··å/跨模æè忍¡åæ¨çæ¶é´: {hybrid_time:.2f} ms/次")
# æ¨èåæ³ï¼
# æ§è½ä¼åæ¯ä¸ä¸ªç³»ç»å·¥ç¨ï¼éè¦æ ¹æ®å
·ä½åºç¨åºæ¯åèµæºéå¶éæ©åéççç¥ã
# éåãåªæåç¥è¯è¸é¦è½ææåå°æ¨¡åä½ç§¯åå 鿍çï¼èåå¸å¼æ¨çåæ¹å¤çåè½æé«ååéã
# å®é
项ç®ä¸ï¼é常ä¼ç»åå¤ç§ä¼åææ®µï¼ä»¥è¾¾å°æä½³çæ§è½-精度平衡ã
# æ¤å¤ï¼ä½¿ç¨é«æ§è½æ¨çæ¡æ¶ï¼å¦TensorRTï¼å¯ä»¥è¿ä¸æ¥æ¦¨å硬件æ§è½ã
第äºç« ï¼æå»ºä½ ç第ä¸ä¸ªå¤æ¨¡æLLMåºç¨
让æä»¬æ¥æå»ºä¸ä¸ªç®åçå¾åé®çï¼Visual Question Answering, VQAï¼å©æï¼å°åé¢è®¨è®ºçç论ç¥è¯ä»è¯¸å®è·µãç¨æ·å¯ä»¥ä¸ä¼ ä¸å¼ å¾çå¹¶æåºä¸ä¸ªå ³äºå¾çå 容çé®é¢ï¼æ¨¡åå°ç»åå¾çä¿¡æ¯åé®é¢ï¼ç»åºç¸åºçæååçãè¿ä¸ªåºç¨å°ç»åæä»¬åé¢è®¨è®ºçè§è§ç¼ç ãæ¨¡æå¯¹é½åLLMæ¨çã
5.1 åºç¨åºæ¯ï¼å¾åé®ç婿
æä»¬æ¨å¨å建ä¸ä¸ªå½ä»¤è¡çé¢çVQA婿ãå®å°æ¥æ¶ä¸ä¸ªå¾çURLåç¨æ·æåºçé®é¢ï¼ç¶åè¾åºä¸ä¸ªåºäºå¾çåé®é¢çåçãè¿ä¸ªå©æå°å±ç¤ºå¤æ¨¡æLLMå¦ä½å°è§è§æç¥ä¸è¯è¨çè§£ç»åï¼å®ç°æ´æºè½ç交äºã
5.2 æ ¸å¿ç»ä»¶æ¦è§
ä¸ä¸ªå¤æ¨¡æå¾åé®ç婿é常å å«ä»¥ä¸æ ¸å¿ç»ä»¶ï¼
- é
置管ç (
config.py)ï¼ éä¸ç®¡ç模ååç§°ã维度ãå¾çå¤çåæ°çã - 模åå®ä¹ä¸å è½½ (
models.py)ï¼ å°è£ è§è§ç¼ç å¨ã模ææå½±å±å大åè¯è¨æ¨¡åçå è½½åååä¼ æé»è¾ãæä»¬å°å©ç¨Hugging Facetransformersåºæ¥ç®å模åæä½ã - è¾
å©å·¥å
·å½æ° (
utils.py)ï¼ å¤çå¾çä¸è½½ãé¢å¤ç以åLLMè¾å ¥åµå ¥çæå»ºã - 主ç¨åºé»è¾ (
main.py)ï¼ æ´åææç»ä»¶ï¼å®ç°åºç¨ç端å°ç«¯æµç¨ã
5.3 代ç å®ç°ï¼æ¨¡åå设计
为äºä¿æä»£ç çæ¸ æ°åå¯ç»´æ¤æ§ï¼æä»¬å°éç¨æ¨¡ååç设计ï¼å°ä¸ååè½å°è£ å¨ä¸åçæä»¶ä¸ã
config.pyï¼é
ç½®åæ°
# config.py
import torch
class AppConfig:
# è§è§ç¼ç 卿¨¡ååç§° (ä¾å¦ï¼CLIPçVision Transformer)
# æä»¬å¯ä»¥ä½¿ç¨ "openai/clip-vit-base-patch32"
VISION_ENCODER_MODEL = "openai/clip-vit-base-patch32"
# LLM模ååç§° (è¿é使ç¨ä¸ä¸ªè¾å°çGPT-2æ¼ç¤ºï¼å®é
夿¨¡æLLM伿¯BLIP-2, LLaVAç)
# ç±äºæä»¬æ¯æ¼ç¤ºï¼GPT-2å¯ä»¥ä½ä¸ºåºç¡LLMï¼é
åæå½±å±
LLM_MODEL = "gpt2"
# è§è§ç¹å¾ç»´åº¦ (CLIP ViT-B/32 output dim)
VISION_FEATURE_DIM = 768
# LLMåµå
¥ç»´åº¦ (GPT-2 embedding dim)
LLM_EMBEDDING_DIM = 768
# å¾çå¤çåæ° (CLIP Processorä¼èªå¨å¤çï¼è¿éä»
åè®°å½)
IMAGE_SIZE = (224, 224)
# LLMçæææ¬åæ°
MAX_NEW_TOKENS = 50
TEMPERATURE = 0.7
# 设å¤è®¾ç½®
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
models.pyï¼æ¨¡åå®ä¹ä¸å è½½
# models.py
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, CLIPVisionModel, CLIPProcessor
class MultimodalVQA:
def __init__(self, config):
self.config = config
self.device = config.DEVICE
print(f"Loading Vision Encoder: {config.VISION_ENCODER_MODEL} to {self.device}")
# å è½½å»ç»çCLIPè§è§ç¼ç å¨
self.vision_processor = CLIPProcessor.from_pretrained(config.VISION_ENCODER_MODEL)
self.vision_encoder = CLIPVisionModel.from_pretrained(config.VISION_ENCODER_MODEL).to(self.device)
self.vision_encoder.eval() # å»ç»åæ°ï¼ä¸è¿è¡è®ç»
for param in self.vision_encoder.parameters():
param.requires_grad = False
print(f"Loading LLM: {config.LLM_MODEL} to {self.device}")
# å è½½å»ç»çLLM (è¿éç¨GPT-2æ¼ç¤ºï¼å®é
ä¼ç¨å¤æ¨¡æLLMå¦BLIP-2, LLaVA)
self.llm_tokenizer = AutoTokenizer.from_pretrained(config.LLM_MODEL)
if self.llm_tokenizer.pad_token is None: # GPT-2没æé»è®¤pad token
self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token
self.llm_model = AutoModelForCausalLM.from_pretrained(config.LLM_MODEL).to(self.device)
self.llm_model.eval() # å»ç»åæ°ï¼ä¸è¿è¡è®ç»
for param in self.llm_model.parameters():
param.requires_grad = False
# 模ææå½±å±ï¼å°è§è§ç¹å¾æ å°å°LLMçåµå
¥ç©ºé´
# LLaVAçç®å线æ§å±æè·¯ï¼è¿éæä»¬å设å®å·²ç»è¿è®ç»ï¼æè
å¨å®é
åºç¨ä¸ä¼å¾®è°
self.projection_layer = nn.Linear(config.VISION_FEATURE_DIM, config.LLM_EMBEDDING_DIM).to(self.device)
# 注æï¼å¨å®é
ç夿¨¡æLLMä¸ï¼è¿ä¸ªæå½±å±ä¼å¨è§è§-è¯è¨å¯¹é½é¶æ®µè¢«è®ç»ã
# å¨è¿ä¸ªç®å示ä¾ä¸ï¼æä»¬å设è¿ä¸ªå±æ¯é¢è®ç»å¥½çï¼æéæºåå§ååä¸è®ç»ï¼ã
# 对äºçå®åºç¨ï¼ä½ éè¦å è½½ä¸ä¸ªå·²ç»è¿è¡è¿è§è§-è¯è¨å¯¹é½çæéï¼æè
å°å
¶è®¾ç½®ä¸ºå¯è®ç»å¹¶å¨æä»¤å¾®è°é¶æ®µè®ç»å®ã
def encode_image(self, image):
# å¾åé¢å¤çå¹¶ç¼ç
inputs = self.vision_processor(images=image, return_tensors="pt").to(self.device)
with torch.no_grad():
# 使ç¨pooler_outputä½ä¸ºå¾åçæ´ä½ç¹å¾ï¼å½¢ç¶ä¸º [batch_size, VISION_FEATURE_DIM]
image_features = self.vision_encoder(**inputs).pooler_output
return image_features
def project_vision_features(self, image_features):
# å°è§è§ç¹å¾æå½±å°LLMçåµå
¥ç©ºé´
projected_features = self.projection_layer(image_features)
return projected_features # [batch_size, LLM_EMBEDDING_DIM]
def generate_answer(self, combined_input_embeddings):
# LLMçæåç
with torch.no_grad():
output_ids = self.llm_model.generate(
inputs_embeds=combined_input_embeddings,
max_new_tokens=self.config.MAX_NEW_TOKENS,
temperature=self.config.TEMPERATURE,
pad_token_id=self.llm_tokenizer.pad_token_id,
attention_mask=torch.ones(combined_input_embeddings.shape[:-1], device=self.device) # çææ¶ä¹éè¦attention mask
)
# è§£ç çæçtokenï¼å¹¶å»é¤ç¹æ®token
generated_text = self.llm_tokenizer.decode(output_ids[0], skip_special_tokens=True)
return generated_text
utils.pyï¼è¾
å©å·¥å
·å½æ°
# utils.py
from PIL import Image
import requests
from io import BytesIO
import torch
def load_image_from_url(url):
"""ä»URLå è½½å¾ç并转æ¢ä¸ºPIL Image"""
try:
response = requests.get(url, timeout=10)
response.raise_for_status() # æ£æ¥HTTPè¯·æ±æ¯å¦æå
image = Image.open(BytesIO(response.content)).convert("RGB")
return image
except requests.exceptions.RequestException as e:
print(f"Error loading image from URL: {e}")
return None
except Exception as e:
print(f"Error processing image: {e}")
return None
def prepare_llm_input(question, projected_vision_features, llm_tokenizer, llm_model, device):
"""
å°è§è§ç¹å¾åææ¬é®é¢æ¼æ¥ä½ä¸ºLLMè¾å
¥ã
模æLLaVAçè¾å
¥æ ¼å¼ï¼<image_features_as_prefix> + text_prompt
"""
# ææ¬é®é¢ç¼ç
text_input_ids = llm_tokenizer(
question,
return_tensors="pt",
add_special_tokens=True # æ·»å CLS/SEPçç¹æ®token
).input_ids.to(device)
# è·åLLMçè¯åµå
¥å±
llm_embeddings_layer = llm_model.get_input_embeddings()
# å°ææ¬input_ids转æ¢ä¸ºåµå
¥åé
text_embeddings = llm_embeddings_layer(text_input_ids)
# å°æå½±åçè§è§ç¹å¾ä½ä¸ºåç¼ä¸ææ¬åµå
¥æ¼æ¥
# projected_vision_features å½¢ç¶åºä¸º [batch_size, LLM_EMBEDDING_DIM]
# æä»¬éè¦å°å
¶æ©å±ä¸º [batch_size, 1, LLM_EMBEDDING_DIM] ä»¥ä¾¿ä¸ææ¬åºåæ¼æ¥
combined_embeddings = torch.cat(
(projected_vision_features.unsqueeze(1), text_embeddings),
dim=1
)
return combined_embeddings
main.pyï¼ä¸»ç¨åºé»è¾
# main.py
from config import AppConfig
from models import MultimodalVQA
from utils import load_image_from_url, prepare_llm_input
import torch
def main():
config = AppConfig()
vqa_model = MultimodalVQA(config)
# 示ä¾å¾çåé®é¢
image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" # ä¸åªç«å¨æ²åä¸
question = "è¿å¼ å¾çéæä»ä¹å¨ç©ï¼å®ä»¬å¨åä»ä¹ï¼"
print(f"
--- å¤çå¾ç: {image_url} ---")
print(f"--- ç¨æ·é®é¢: {question} ---")
# 1. å è½½å¾ç
image = load_image_from_url(image_url)
if image is None:
print("æ æ³å è½½å¾çï¼è¯·æ£æ¥URLæç½ç»è¿æ¥ã")
return
# 2. ç¼ç å¾åè·åè§è§ç¹å¾
image_features = vqa_model.encode_image(image)
print(f"å¾åç¹å¾å½¢ç¶: {image_features.shape}")
# 3. æå½±è§è§ç¹å¾å°LLMåµå
¥ç©ºé´
projected_vision_features = vqa_model.project_vision_features(image_features)
print(f"æå½±åçè§è§ç¹å¾å½¢ç¶: {projected_vision_features.shape}")
# 4. åå¤LLMçè¾å
¥åµå
¥ (è§è§ç¹å¾ + ææ¬é®é¢)
combined_input_embeddings = prepare_llm_input(
question,
projected_vision_features,
vqa_model.llm_tokenizer,
vqa_model.llm_model,
vqa_model.device
)
print(f"LLMæç»è¾å
¥åµå
¥å½¢ç¶: {combined_input_embeddings.shape}")
# 5. LLMçæåç
print("
LLMæ£å¨çæåç... (è¿å¯è½éè¦ä¸äºæ¶é´ï¼åå³äºä½ çç¡¬ä»¶åæ¨¡å大å°)")
# 注æï¼è¿éGPT-2æ¬èº«æ¯çº¯ææ¬æ¨¡åï¼å®ä¼å°è¯åºäºæ¼æ¥ç"è§è§ç¹å¾"åé®é¢çæï¼
# ä½å
¶åççè´¨é严éä¾èµäºæå½±å±çè®ç»åGPT-2æ¬èº«çæ³åè½åã
# çæ£ç夿¨¡æLLMï¼å¦BLIP-2æLLaVAï¼ä¼å¨å
¶çæè¿ç¨ä¸æ´æ·±å
¥å°å©ç¨è§è§ä¿¡æ¯ã
answer = vqa_model.generate_answer(combined_input_embeddings)
print(f"
--- åç: {answer} ---")
if __name__ == "__main__":
main()
# æ¨èåæ³ï¼
# è¿ä¸ªæ¨¡ååç示ä¾å±ç¤ºäºå¦ä½å°è§è§ç¼ç ãæ¨¡æå¯¹é½åè¯è¨çæç»åèµ·æ¥ï¼
# æå»ºä¸ä¸ªç®åç夿¨¡æLLMåºç¨ãå¨å®é
项ç®ä¸ï¼æ¨¡åçå è½½åè®ç»é常æ´ä¸ºå¤æï¼
# ç¹å«æ¯æå½±å±éè¦è¿è¡è§è§-è¯è¨å¯¹é½çè®ç»ï¼ä½æ ¸å¿çæ°æ®æµåç»ä»¶åä½é»è¾æ¯ç¸ä¼¼çã
# è¿ç§ç»ææå©äºå¢éåä½åç»´æ¤ï¼å¹¶ä¸ä¾¿äºæ¿æ¢ä¸åçè§è§ç¼ç 卿LLMã
第å ç« ï¼æ»ç»ä¸å±æï¼å¤æ¨¡æLLMçæªæ¥
å¿ï¼æä»¬å·²ç»ä¸èµ·æ¢ç´¢äºLLM夿¨¡æèåçå¥å¦ä¸çï¼ä»çº¯ææ¬LLMçâç²åºâåºåï¼æä»¬æ·±å ¥äºè§£äºæ©æãææåæ··åèåççç¥ï¼åæäºCLIPãBLIP-2ãLLaVAåFlamingoè¿äºéç¨ç¢å¼ç模åï¼å¹¶è®¨è®ºäºæ°æ®ãè®ç»åæ¨çä¸ç宿ææä¸ä¼åæ¹æ¡ãæåï¼æä»¬è¿æå»ºäºä¸ä¸ªç®åç夿¨¡æé®çåºç¨ï¼å°ç论ä»è¯¸å®è·µã
6.1 æ ¸å¿ç¥è¯ç¹å顾
让æä»¬å¿«éå顾ä¸ä¸ä»å¤©å¦ä¹ å°çå ³é®ç¹ï¼
-
夿¨¡æèåçå¿ è¦æ§ï¼ çº¯ææ¬LLMæ æ³æç¥çå®ä¸ççå¾åã声é³çä¿¡æ¯ï¼ç¼ºä¹å¯¹ç©çä¸ççâæ¥å°âè½åã夿¨¡æè忝èµäºLLMæ´å ¨é¢æç¥è½åï¼å®ç°éç¨äººå·¥æºè½çå ³é®ä¸æ¥ã
-
ä¸ç§èåèå¼ï¼
- æ©æèåï¼ å¨ç¹å¾å±é¢å¯¹æ¥ï¼æè·ç»ç²åº¦äº¤äºï¼ä½å¯¹å¯¹é½è¦æ±é«ï¼æååªå£°å½±åã
- ææèåï¼ å¨å³çå±é¢éæï¼æ¨¡åå强ï¼ä½äº¤äºæ·±åº¦æéï¼å¯è½éè¿åºå±å ³èã
- æ··å/跨模æèåï¼ ç»åå两è ä¼ç¹ï¼éè¿æ³¨æåæºå¶åæ¨¡ææ¡¥æ¥å¨å®ç°æ·±å±å¨æäº¤äºï¼æ¯å½å主æµä¸æææä½³çæ¹æ¡ã
-
ä¸»æµæ¨¡åè§£æï¼
- CLIPï¼ éè¿å¯¹æ¯å¦ä¹ å®ç°å¾å-ææ¬å¯¹é½ï¼æå»ºå ±äº«åµå ¥ç©ºé´ï¼æ¯è®¸å¤å¤æ¨¡ææ¨¡åçåºç¡ã
- BLIP/BLIP-2ï¼ å¼å ¥Q-Formeré«æè¿æ¥å»ç»çè§è§åè¯è¨æ¨¡åï¼å®ç°å¼ºå¤§çè§è§-è¯è¨çè§£ä¸çæã
- LLaVAï¼ ç®æ´é«æï¼éè¿çº¿æ§æå½±åå¤§è§æ¨¡æä»¤å¾®è°å°LLM转å为è§è§å©æï¼è¯æäºæä»¤å¾®è°ç巨大æ½åã
- Flamingoï¼ éç¨Perceiver Resamplerå鍿§äº¤å注æåï¼æ é¿å¤çåé¿è§è§è¾å ¥åå®ç°å°æ ·æ¬å¦ä¹ ã
-
宿ææä¸ä¼åï¼ é«è´¨éçæ°æ®å¯¹é½ä¸æ æ³¨æ¯æ¨¡åæåçåºç³ï¼å»ç»é¢è®ç»æ¨¡åãåæ°é«æå¾®è° (PEFTï¼å¦LoRA) æ¯åºå¯¹è®ç»èµæºæ¶èçææçç¥ï¼æ¨¡åéåãåªæãç¥è¯è¸é¦çæ¯æåæ¨çæçãéä½é¨ç½²ææ¬çå ³é®ã
6.2 å®æå»ºè®®ä¸è¿é¶æ¹å
- ä»å°å¤çæï¼éæ¥è¿ä»£ï¼ å¦æä½ æç®æå»ºèªå·±ç夿¨¡æåºç¨ï¼å»ºè®®ä»å©ç¨ç°æçé¢è®ç»æ¨¡åï¼å¦Hugging Faceä¸çBLIP-2æLLaVAï¼å¼å§ï¼è䏿¯ä»å¤´è®ç»ãçè§£è¿äºæ¨¡åçæ¶æåæ°æ®æµï¼å卿¤åºç¡ä¸è¿è¡å®å¶åä¼åã
- å ³æ³¨æ°æ®è´¨éï¼ æå ¥æ¶é´åç²¾å卿°æ®çæ¶éãæ¸ æ´åæ æ³¨ä¸ãé«è´¨éçæ°æ®æ¯æ¨¡åæ§è½çå³å®æ§å ç´ ã对äºç¹å®é¢åï¼å¯ä»¥æ¢ç´¢å©ç¨åææ°æ®æå¼±çç£ææ¯æ¥æ©å æ°æ®éã
- æ¥æ±åæ°é«æå¾®è° (PEFT)ï¼ å¯¹äºèµæºæéçå¢éï¼LoRAãAdapterçPEFTæ¹æ³æ¯å¾®è°å¤§å夿¨¡ææ¨¡åççæ³éæ©ï¼å®ä»¬è½ä»¥æå°çè®¡ç®ææ¬ååå¨å¼éå®ç°æ¥è¿å ¨åæ°å¾®è°çææã
- å ³æ³¨æ¨çä¼åï¼ å¨æ¨¡åé¨ç½²æ¶ï¼å¡å¿ èèéåãåªæçææ¯ï¼ä»¥é使¨çææ¬å¹¶æé«ååºé度ãéæ©åéçæ¨çæ¡æ¶ï¼å¦TensorRTï¼ä¹è½å¸¦æ¥æ¾èæ§è½æåã
- æ¢ç´¢å¤æ¨¡ææ¨çï¼ é¤äºç®åçé®çåæè¿°ï¼å¤æ¨¡æLLM卿´å¤æçæ¨çä»»å¡ï¼å¦è§è§å¸¸è¯æ¨çãæ¥éª¤è§åãå¤è·³é®çï¼ä¸æ½å巨大ãç ç©¶å¦ä½è®©æ¨¡åè¿è¡æ´æ·±å±æ¬¡ç跨模æé»è¾æ¨çæ¯æªæ¥çéè¦æ¹åã
- å ³æ³¨ä¼¦çä¸å®å ¨ï¼ 夿¨¡ææ¨¡åå¯è½çææå®³ãåè§å å®¹ï¼æè¢«ç¨äºæ¶æç®çï¼å¦æ·±åº¦ä¼ªé ï¼ãå¨å¼ååé¨ç½²æ¶ï¼å¡å¿ èèå ¶ä¼¦çå½±ååå®å ¨é²æ¤ï¼è¿è¡åè§æ£æµåå å®¹è¿æ»¤ã
- è§é¢ä¸é³é¢æ¨¡æç深度èåï¼ å½åä¸»æµæ´å¤å ³æ³¨å¾å-ææ¬ï¼ä½è§é¢åé³é¢ç夿¨¡æèåæ¯æªæ¥çéè¦æ¹åãç ç©¶å¦ä½å°LLMä¸è¯é³è¯å«ãè§é¢çè§£æ¨¡åæ´ç´§å¯å°ç»åï¼å¤çæ¶åºå¨æä¿¡æ¯ï¼æå»ºæ´å ¨é¢çæç¥æºè½ä½ï¼å°æ¯ä¸ä¸ä¸ªå沿ã
夿¨¡æLLMæ£ä»¥æäººçé度åå±ï¼å®ä¸ä» ä» æ¯ææ¯ççªç ´ï¼æ´æ¯æä»¬éåæ´æºè½ãæ´æ¥è¿äººç±»çè§£ä¸ççAIçæ¡¥æ¢ã叿è¿ç¯æåè½ä¸ºä½ ç¹äº®åè¡çéè·¯ï¼æå¾ çå°ä½ ç¨å¤æ¨¡æLLMåé åºæ´å¤ä»¤äººæè³çåºç¨ï¼