GLM-4v-9b保姆级教程:Jupyter中可视化注意力热图,理解图文对齐过程
GLM-4v-9b保姆级教程:Jupyter中可视化注意力热图,理解图文对齐过程
你有没有想过,一个能看懂图片的AI模型,它到底是怎么“看”的?当它回答“图片里有什么”时,它的“注意力”究竟聚焦在图片的哪个部分?
今天,我们就来亲手揭开这个黑盒。我将带你一步步在Jupyter Notebook中,运行GLM-4v-9b这个强大的多模态模型,并可视化它的注意力热图。通过这个教程,你不仅能学会如何调用模型,更能直观地看到模型“思考”的过程,理解它是如何将文字问题和图片内容联系起来的。
1. 教程目标与环境准备
1.1 我们能学到什么?
通过这篇教程,你将掌握:
- 基础调用:学会在Jupyter中加载并运行GLM-4v-9b模型,进行基础的图文问答。
- 核心揭秘:学会提取并可视化模型的“交叉注意力”权重,生成热力图。
- 过程理解:通过热力图,直观理解模型是如何将文本中的词语与图片中的区域进行对齐和关联的。
- 实用技巧:获得一套可复用的代码模板,方便你未来分析其他多模态模型。
简单说,就是让你从“只会问问题”,升级到“能看见模型是怎么想问题的”。
1.2 你需要准备什么?
门槛非常低,跟着做就行:
- 基础知识:会用Python,知道怎么打开Jupyter Notebook。
- 硬件要求:教程主要基于CSDN星图镜像环境,它已经为你准备好了所有依赖和模型。如果你想在本地运行,需要一张显存不小于24GB的GPU(例如RTX 4090)。
- 心态准备:保持好奇,我们是在探索AI的“视觉焦点”。
2. 快速上手:第一次图文对话
我们先让模型跑起来,看看它最基本的本事。
2.1 启动环境与安装依赖
如果你使用的是CSDN星图镜像,那么恭喜你,最复杂的模型部署和环境配置步骤已经完成了。你只需要创建一个新的Jupyter Notebook即可开始。
如果你在本地环境,需要安装以下核心库:
# 在Jupyter Notebook的Cell中运行,或使用终端
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install transformers Pillow matplotlib
!pip install accelerate # 用于模型加载优化
2.2 加载模型与处理图片
接下来,我们写代码加载GLM-4v-9b模型,并准备一张测试图片。这里我准备了一张包含猫和狗的图片。
import torch
from PIL import Image
import requests
from io import BytesIO
from transformers import AutoModelForCausalLM, AutoProcessor
# 1. 指定模型名称(使用智谱AI在Hugging Face上的官方仓库)
model_id = "THUDM/glm-4v-9b"
# 2. 加载模型和处理器
print("正在加载模型和处理器,这可能需要几分钟...")
device = "cuda" if torch.cuda.is_available() else "cpu"
# 使用低精度加载以节省显存,如果显存充足可以去掉 torch_dtype 参数
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16, # 使用半精度浮点数
trust_remote_code=True, # 信任来自作者的远程代码
device_map="auto" # 自动将模型层分配到可用的设备上
).eval() # 设置为评估模式,关闭dropout等训练层
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
print("模型加载完成!")
# 3. 准备一张图片
# 你可以从本地加载,也可以从网络下载
# 方式一:从本地文件加载
# image_path = “./your_image.jpg”
# image = Image.open(image_path).convert(“RGB”)
# 方式二:从网络下载示例图片(一只猫和一只狗)
url = “https://images.unsplash.com/photo-1514888286974-6d03bdeacba8?ixlib=rb-4.0.3&auto=format&fit=crop&w=800&q=80”
response = requests.get(url)
image = Image.open(BytesIO(response.content)).convert(“RGB”)
image.thumbnail((800, 800)) # 调整大小以便显示
display(image) # 在Jupyter中显示图片
print(“图片加载完成,尺寸为:”, image.size)
2.3 进行第一次问答
现在,让我们问模型一个简单的问题。
# 4. 构建对话
# GLM-4v-9b 使用特定的对话格式
conversation = [
{
“role”: “user”,
“content”: [
{“type”: “image”},
{“type”: “text”, “text”: “描述一下这张图片。”}
]
}
]
# 5. 处理输入
inputs = processor.apply_chat_template(
conversation,
add_generation_prompt=True, # 添加让模型开始生成的提示
images=[image], # 传入图片列表
return_tensors=“pt” # 返回PyTorch张量
).to(model.device)
# 6. 生成回答
print(“模型正在思考...”)
with torch.no_grad(): # 禁用梯度计算,节省内存和计算资源
outputs = model.generate(**inputs, max_new_tokens=100)
# max_new_tokens 限制生成的最大新令牌数,防止生成过长
# 7. 解码并打印结果
# 需要跳过输入部分,只解码新生成的部分
input_length = inputs[“input_ids”].shape[1]
response = processor.decode(outputs[0][input_length:], skip_special_tokens=True)
print(“\n=== 模型回答 ===")
print(response)
运行这段代码,你应该能看到模型对图片的描述,比如“图片中有一只猫和一只狗…”。到这一步,你已经成功完成了基础的图文交互!但这只是开始,我们还没看到模型内部的运作。
3. 核心揭秘:提取并可视化注意力热图
关键来了。我们要让模型在生成回答的每个步骤中,把它对图片不同区域的“关注程度”吐出来。
3.1 理解“交叉注意力”
多模态模型(如GLM-4v-9b)的核心机制之一是交叉注意力。你可以把它想象成模型的一束“思维探照灯”:
- 文本令牌:模型正在处理的文字(比如“描述”、“猫”、“狗”)。
- 图像令牌:图片被分割编码成的一系列小片段。
- 交叉注意力权重:当模型思考“猫”这个词时,这个权重就代表了“探照灯”在图片的各个片段上分别打了多强的光。权重高的区域,就是模型认为与当前文字最相关的区域。
我们的目标就是捕获并可视化这束“光”。
3.2 修改代码以捕获注意力
我们需要“钩住”模型内部的注意力层,把数据拿出来。这里我们使用PyTorch的register_forward_hook方法。
import matplotlib.pyplot as plt
import numpy as np
# 存储注意力权重的列表
attention_maps = []
# 定义钩子函数
def hook_fn(module, input, output):
# output 通常是一个元组,其中包含注意力权重
# 对于GLM-4v-9b,我们需要找到正确的索引,这里假设在 output[1]
if isinstance(output, tuple) and len(output) > 1:
attn_weights = output[1] # 获取注意力权重,形状通常是 (batch, heads, seq_len, seq_len)
attention_maps.append(attn_weights.cpu().detach()) # 转移到CPU并脱离计算图
# 找到模型的交叉注意力层并注册钩子
# GLM-4v-9b的层名可能不同,这里需要根据实际情况调整
# 一个常见的方法是查找包含 “vision” 或 “cross” 的注意力层
for name, module in model.named_modules():
if “vision” in name and “attention” in name:
# if “cross_attention” in name: # 另一种可能的命名
print(f”注册钩子到层: {name}”)
module.register_forward_hook(hook_fn)
# 清空之前的存储
attention_maps.clear()
# 重新进行生成,这次会触发钩子
print(“再次生成回答,并捕获注意力...”)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=50, output_attentions=True) # 确保输出注意力
print(f”捕获到 {len(attention_maps)} 个注意力张量”)
3.3 可视化热力图
现在attention_maps里存储了注意力权重。我们需要处理这些数据,将其映射回原图。
def visualize_attention(image, attn_weights, token_idx, head_idx=0):
“””
可视化某个文本令牌对图像区域的注意力。
image: PIL Image 对象
attn_weights: 注意力权重张量 (layers, batch, heads, seq_len, seq_len)
token_idx: 要查看的文本令牌在序列中的索引
head_idx: 使用哪个注意力头(通常取平均或选一个)
“””
# 1. 获取最后一层的注意力,并选择指定的头和文本令牌
# attn_weights 可能是一个列表,每个元素是一个层的输出
# 我们取最后一个层(通常是最终输出前的层),并取第一个批次
if isinstance(attn_weights, list):
last_layer_attn = attn_weights[-1] # 取最后一个层
else:
last_layer_attn = attn_weights
# 假设形状是 (batch=1, heads, seq_len, seq_len)
# seq_len 包含图像令牌 + 文本令牌
attn_to_img = last_layer_attn[0, head_idx, token_idx, :] # 获取指定令牌对所有令牌的注意力
# 2. 我们需要知道图像令牌在序列中的起始和结束位置
# 这取决于处理器的具体实现。一个简单的方法是:假设开头的部分是图像令牌。
# 更准确的方法需要查看处理器的内部细节,这里我们做一个简化演示。
# 假设 inputs[‘image_patches’] 或类似属性能告诉我们数量,这里我们手动估算。
# 例如,如果图片被编码为 256 个令牌,那么 attn_to_img[:256] 就是对应图像的注意力
num_image_tokens = 256 # 这是一个示例值,实际需要根据模型和图片尺寸确定
image_attention = attn_to_img[:num_image_tokens].numpy()
# 3. 将一维的注意力权重重塑为二维网格(模拟图像空间布局)
# 图像令牌通常对应一个 H’ x W’ 的网格。假设是 16x16
grid_h, grid_w = 16, 16
if len(image_attention) != grid_h * grid_w:
# 如果不匹配,调整网格大小或使用插值
grid_h = int(np.sqrt(len(image_attention)))
grid_w = int(np.sqrt(len(image_attention)))
image_attention = image_attention[:grid_h*grid_w] # 截断或填充
attn_map = image_attention.reshape(grid_h, grid_w)
# 4. 将热力图放大到原图尺寸并叠加
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# 显示原图
axes[0].imshow(image)
axes[0].set_title(“Original Image”)
axes[0].axis(‘off’)
# 显示热力图
im = axes[1].imshow(attn_map, cmap=‘hot’, interpolation=‘bilinear’)
axes[1].set_title(f”Attention Heatmap (Token {token_idx}, Head {head_idx})”)
axes[1].axis(‘off’)
plt.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04)
plt.tight_layout()
plt.show()
# 尝试可视化生成第一个词时的注意力
# 我们需要知道在生成序列中,第一个新文本令牌的索引
# input_length 是输入序列的长度,input_length 之后的索引就是生成令牌
first_gen_token_idx = input_length # 这是第一个生成位置的索引
if attention_maps and first_gen_token_idx < attention_maps[0].shape[-1]:
print(f”可视化第一个生成令牌(索引 {first_gen_token_idx})的注意力”)
visualize_attention(image, attention_maps, first_gen_token_idx)
else:
print(“无法可视化,注意力图维度或索引有问题。”)
print(f”注意力图长度: {len(attention_maps)}”)
if attention_maps:
print(f”注意力图形状: {attention_maps[0].shape}”)
print(f”输入长度: {input_length}”)
注意:上面的visualize_attention函数是一个概念演示。实际中,num_image_tokens、grid_h、grid_w需要根据GLM-4v-9b模型具体的视觉编码器(如CLIP-ViT)的参数来确定。你可能需要查阅模型的文档或源代码来获取这些信息,例如图片被分割成了多少个小块(patches)。
3.4 进阶:追踪特定词语的注意力
如果我们想看看模型在说出“猫”这个词的时候,注意力在哪里,该怎么做?
# 假设我们想知道模型生成“猫”这个字时的注意力
# 1. 首先,我们需要得到“猫”这个字在生成序列中的令牌ID和位置
# 2. 然后,在生成过程中,在对应的解码步骤捕获注意力
# 这是一个更高级的实现思路:
# 使用 model.generate 的 callback 函数,在每一步生成后获取中间层的注意力
from transformers import GenerationConfig
generation_output = model.generate(
**inputs,
max_new_tokens=20,
output_attentions=True,
return_dict_in_generate=True, # 返回详细信息
generation_config=GenerationConfig(do_sample=False) # 使用贪婪解码保证确定性
)
# 从输出中提取注意力
# generation_output.attentions 可能包含各层的注意力
if hasattr(generation_output, ‘attentions’) and generation_output.attentions:
all_attentions = generation_output.attentions # 可能是一个元组或列表
print(“成功获取生成过程中的所有注意力张量”)
# 后续处理 all_attentions 来定位特定令牌
4. 解读热图:理解模型的“视觉焦点”
当你成功生成热力图后,你会看到图片上有些区域被高亮(红色/黄色),有些区域暗淡(蓝色)。这代表了模型在处理当前文本概念时的“视觉焦点”。
- 当问题为“描述图片”时:第一个生成令牌的注意力可能比较分散,或者聚焦在图片的中心主体上,因为模型在尝试获取全局信息。
- 当问题为“猫在哪里?”时:模型在生成“猫”或相关描述词时,其注意力热图应该清晰地高亮图片中猫所在的区域。
- 当分析复杂图表时:你可以问“曲线在哪个点达到峰值?”,然后观察模型在生成“峰值”、“点”等词时,注意力是否精准地聚焦在图表的坐标轴和数据点上。
通过对比不同问题下的热力图,你能直观验证模型是否真的“理解”了图文之间的语义关联,而不仅仅是进行模式匹配。
5. 总结与后续探索
5.1 本教程回顾
我们完成了一次从使用到洞察的旅程:
- 搭建环境:学会了在Jupyter中加载强大的GLM-4v-9b多模态模型。
- 基础应用:实现了图文对话,验证了模型的基础能力。
- 深度探索:通过注册钩子,捕获了模型内部的交叉注意力权重。
- 可视化呈现:将抽象的权重数据转化为直观的热力图,让模型的“注意力”变得可见。
- 过程理解:通过热力图,我们得以窥见模型是如何实现图文对齐这一核心任务的。
5.2 你可以继续做什么?
- 更换图片和问题:用你自己的图片,问更刁钻的问题,观察注意力如何变化。
- 分析不同层和注意力头:模型有多层网络和多个“注意力头”,每个头可能关注不同类型的信息(如颜色、形状、纹理)。尝试可视化不同的头,看看它们的分工。
- 定量分析:计算注意力聚焦区域的IoU(交并比)等指标,定量评估模型定位的准确性。
- 对比其他模型:用同样的方法去分析其他开源多模态模型(如Qwen-VL,LLaVA),比较它们在图文对齐机制上的异同。
可视化注意力不仅是炫酷的技术,更是理解、调试和信任AI模型的重要手段。希望这篇教程为你打开了一扇窗,让你在应用多模态AI时,不仅知其然,更能知其所以然。
获取更多AI镜像
想探索更多AI镜像和应用场景?访问 CSDN星图镜像广场,提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个领域,支持一键部署。
更多推荐



所有评论(0)