多模态大模型微调:LLaVA 与 Qwen-VL 视觉语言模型训练

发布时间:2026/6/21 16:07:27
多模态大模型微调:LLaVA 与 Qwen-VL 视觉语言模型训练 1. 引言多模态大模型如 LLaVA、Qwen-VL、InternVL能够同时理解图像和文本实现视觉问答、图像描述、OCR 等任务。本文将介绍如何微调这些模型以适应特定领域。主流多模态架构对比模型视觉编码器LLM参数量特点LLaVA-1.5CLIP-ViT-LVicuna/LLaMA7B/13B简单高效Qwen-VLViT-bigGQwen-7B9.6B中文优秀InternVL-2InternViT-6BInternLM28B-76B开源最强Phi-3-VisionCLIP-ViTPhi-34.2B轻量级2. LLaVA 架构解析2.1 三组件架构图像 → Vision Encoder (CLIP ViT-L/14) → 视觉 tokens ↓ Projection Layer (MLP) ↓ 文本 → Tokenizer → 文本 tokens ──────→ 拼接 → LLM → 回答2.2 两阶段训练阶段一预训练投影层 - 冻结 Vision Encoder 和 LLM - 只训练 Projection Layer - 数据558K 图文对图像描述 - 目标对齐视觉和语言空间 阶段二指令微调 - 冻结 Vision Encoder - 训练 Projection Layer LLM - 数据665K 多模态指令数据 - 目标学习遵循指令回答问题3. 数据准备3.1 数据格式{id:vqa_001,image:images/001.jpg,conversations:[{from:human,value:image\n这张图片中有什么},{from:gpt,value:图片中显示了一条繁忙的城市街道有多个行人和车辆。}]}3.2 数据处理脚本importjsonfromPILimportImagefromtorch.utils.dataimportDatasetclassMultimodalDataset(Dataset):多模态指令微调数据集def__init__(self,data_path,image_dir,processor,tokenizer,max_length2048):withopen(data_path)asf:self.datajson.load(f)self.image_dirimage_dir self.processorprocessor self.tokenizertokenizer self.max_lengthmax_lengthdef__len__(self):returnlen(self.data)def__getitem__(self,idx):itemself.data[idx]# 加载图像image_pathf{self.image_dir}/{item[image]}imageImage.open(image_path).convert(RGB)# 处理对话conversationsitem[conversations]promptconversations[0][value].replace(image,)answerconversations[1][value]# 构造输入input_textfUSER: image\n{prompt}\nASSISTANT:{answer}# 编码image_inputsself.processor(imagesimage,return_tensorspt)text_inputsself.tokenizer(input_text,truncationTrue,max_lengthself.max_length,paddingmax_length,return_tensorspt,)return{pixel_values:image_inputs[pixel_values].squeeze(),input_ids:text_inputs[input_ids].squeeze(),attention_mask:text_inputs[attention_mask].squeeze(),}4. LLaVA 微调4.1 环境准备pipinstalltransformers accelerate peft pipinstallflash-attn --no-build-isolation4.2 加载模型fromtransformersimportLlavaForConditionalGeneration,AutoProcessor,BitsAndBytesConfigfrompeftimportLoraConfig,get_peft_model model_idllava-hf/llava-1.5-7b-hf# QLoRA 配置bnb_configBitsAndBytesConfig(load_in_4bitTrue,bnb_4bit_quant_typenf4,bnb_4bit_compute_dtypetorch.bfloat16,)# 加载模型modelLlavaForConditionalGeneration.from_pretrained(model_id,quantization_configbnb_config,device_mapauto,torch_dtypetorch.bfloat16,attn_implementationflash_attention_2,)processorAutoProcessor.from_pretrained(model_id)# LoRA 配置只适配语言模型部分lora_configLoraConfig(r16,lora_alpha32,target_modules[q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj],lora_dropout0.05,biasnone,)modelget_peft_model(model,lora_config)model.print_trainable_parameters()4.3 训练fromtransformersimportTrainingArguments,Trainer training_argsTrainingArguments(output_dir./llava-finetuned,num_train_epochs3,per_device_train_batch_size4,gradient_accumulation_steps4,learning_rate2e-5,weight_decay0.01,warmup_ratio0.03,lr_scheduler_typecosine,bf16True,gradient_checkpointingTrue,logging_steps10,save_strategyepoch,remove_unused_columnsFalse,optimpaged_adamw_8bit,)trainerTrainer(modelmodel,argstraining_args,train_datasettrain_dataset,data_collatorlambdabatch:{pixel_values:torch.stack([b[pixel_values]forbinbatch]),input_ids:torch.stack([b[input_ids]forbinbatch]),attention_mask:torch.stack([b[attention_mask]forbinbatch]),labels:torch.stack([b[input_ids]forbinbatch]),},)trainer.train()5. Qwen-VL 微调5.1 加载 Qwen-VLfromtransformersimportAutoModelForCausalLM,AutoTokenizer model_idQwen/Qwen-VL-ChatmodelAutoModelForCausalLM.from_pretrained(model_id,device_mapauto,trust_remote_codeTrue,bf16True,)tokenizerAutoTokenizer.from_pretrained(model_id,trust_remote_codeTrue)5.2 Qwen-VL 数据格式{id:vqa_001,conversations:[{from:user,value:Picture 1: images/001.jpg\n这张图片中有什么},{from:assistant,value:图片中显示了一条繁忙的城市街道。}]}6. 推理与评估6.1 推理代码fromPILimportImagedefinference(model,processor,image_path,question):多模态推理imageImage.open(image_path).convert(RGB)promptfUSER: image\n{question}\nASSISTANT:inputsprocessor(textprompt,imagesimage,return_tensorspt)inputs{k:v.to(model.device)fork,vininputs.items()}outputmodel.generate(**inputs,max_new_tokens512,do_sampleTrue,temperature0.7,)responseprocessor.decode(output[0],skip_special_tokensTrue)# 提取 ASSISTANT 后的回答returnresponse.split(ASSISTANT:)[-1].strip()# 使用answerinference(model,processor,test.jpg,描述这张图片的内容)print(answer)6.2 评估指标# 常用评估指标# 1. VQA Accuracy: 答案是否正确# 2. CIDEr: 图像描述质量# 3. BLEU: 文本生成质量# 4. ROUGE: 召考率导向的评估frompycocoevalcap.cider.ciderimportCiderdefevaluate_cider(predictions,references):CIDEr 评估scorerCider()score,_scorer.compute_score(references,predictions)returnscore7. 领域适配示例7.1 医学影像分析# 医学影像数据集格式medical_data{id:med_001,image:xray/chest_001.jpg,conversations:[{from:human,value:image\n分析这张X光片指出异常区域。},{from:gpt,value:右下肺野可见片状模糊影考虑肺炎可能。建议CT进一步检查。}]}7.2 工业质检# 工业质检数据格式qc_data{id:qc_001,image:defect/scratch_001.jpg,conversations:[{from:human,value:image\n检测这张产品图片中的缺陷。},{from:gpt,value:检测到表面划痕缺陷位于图片中部偏左位置长度约2cm。严重程度中等。}]}8. 总结多模态大模型微调的关键要点数据质量 数据量1000 条高质量标注 10000 条低质量数据两阶段训练先预训练投影层对齐模态再指令微调LoRA 微调只适配语言模型部分视觉编码器通常冻结领域数据收集领域特定的图文对是成功的关键