446 字
1 分钟
timm库的相关使用
Timm 库使用教程
简介
timm (PyTorch Image Models) 是一个由 Ross Wightman 开发的优秀计算机视觉库,提供了:
- 500+ 预训练模型
- 标准化的模型接口
- 高效的数据增强和训练工具
- SOTA 模型的快速实现
安装方式:
pip install timm一、ResNet 使用示例
1.1 基础使用
加载预训练模型
import timmimport torch
# 查看所有可用的 ResNet 模型resnet_models = timm.list_models('resnet*', pretrained=True)print(f"可用的 ResNet 模型数量: {len(resnet_models)}")print("部分模型:", resnet_models[:5])
# 加载预训练的 ResNet50model = timm.create_model('resnet50', pretrained=True)model.eval()
print(f"模型类型: {type(model)}")print(f"参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")推理示例
from PIL import Imageimport requestsfrom timm.data import resolve_data_configfrom timm.data.transforms_factory import create_transform
# 1. 准备图像url = 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'image = Image.open(requests.get(url, stream=True).raw)
# 2. 获取模型的数据配置config = resolve_data_config({}, model=model)print("数据配置:", config)
# 3. 创建相应的变换transform = create_transform(**config)
# 4. 预处理图像input_tensor = transform(image).unsqueeze(0) # 添加 batch 维度
# 5. 推理with torch.no_grad(): output = model(input_tensor)
print(f"输出形状: {output.shape}") # [1, 1000] - ImageNet 1000 类
# 6. 获取预测结果probabilities = torch.nn.functional.softmax(output[0], dim=0)top5_prob, top5_catid = torch.topk(probabilities, 5)
print("\nTop-5 预测:")for i in range(5): print(f"{i+1}. 类别 {top5_catid[i].item()}, 概率: {top5_prob[i].item():.4f}")1.2 自定义分类头
# 创建用于自定义类别数的模型num_classes = 10 # 例如 CIFAR-10
model = timm.create_model( 'resnet50', pretrained=True, num_classes=num_classes)
print(f"新的分类头输出: {model.fc}")
# 也可以只提取特征(去掉分类头)feature_model = timm.create_model( 'resnet50', pretrained=True, num_classes=0 # 移除分类头)
# 测试特征提取with torch.no_grad(): features = feature_model(input_tensor) print(f"特征向量形状: {features.shape}") # [1, 2048]1.3 ResNet 变体对比
# 尝试不同的 ResNet 变体variants = [ 'resnet34', # 标准 ResNet34 'resnet50', # 标准 ResNet50 'resnet50d', # ResNet50-D (改进版本) 'resnet101', # 更深的 ResNet 'resnetv2_50', # ResNet V2]
for variant in variants: model = timm.create_model(variant, pretrained=False) params = sum(p.numel() for p in model.parameters()) / 1e6 print(f"{variant:20s} - 参数量: {params:6.2f}M")1.4 微调训练示例
import torch.nn as nnimport torch.optim as optim
# 创建模型model = timm.create_model('resnet50', pretrained=True, num_classes=10)
# 冻结除最后一层外的所有层for name, param in model.named_parameters(): if 'fc' not in name: # fc 是分类头 param.requires_grad = False
# 设置优化器(只优化分类头)optimizer = optim.Adam(model.fc.parameters(), lr=0.001)criterion = nn.CrossEntropyLoss()
# 训练循环示例model.train()# dummy_data = torch.randn(8, 3, 224, 224)# dummy_labels = torch.randint(0, 10, (8,))## optimizer.zero_grad()# outputs = model(dummy_data)# loss = criterion(outputs, dummy_labels)# loss.backward()# optimizer.step()二、Vision Transformer (ViT) 使用示例
2.1 基础使用
加载预训练 ViT 模型
import timmimport torch
# 查看所有 ViT 模型vit_models = timm.list_models('vit*', pretrained=True)print(f"可用的 ViT 模型数量: {len(vit_models)}")print("部分模型:", vit_models[:10])
# 加载 ViT-Basemodel = timm.create_model('vit_base_patch16_224', pretrained=True)model.eval()
print(f"\n模型参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")ViT 推理
from PIL import Imageimport requestsfrom timm.data import resolve_data_configfrom timm.data.transforms_factory import create_transform
# 准备图像url = 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'image = Image.open(requests.get(url, stream=True).raw)
# ViT 的数据预处理config = resolve_data_config({}, model=model)transform = create_transform(**config)
input_tensor = transform(image).unsqueeze(0)
# 推理with torch.no_grad(): output = model(input_tensor)
print(f"输出形状: {output.shape}")
# 预测probabilities = torch.nn.functional.softmax(output[0], dim=0)top5_prob, top5_catid = torch.topk(probabilities, 5)
print("\nTop-5 预测:")for i in range(5): print(f"{i+1}. 类别 {top5_catid[i].item()}, 概率: {top5_prob[i].item():.4f}")2.2 ViT 模型架构详解
# 查看 ViT 的详细配置model = timm.create_model('vit_base_patch16_224', pretrained=True)
print("ViT 配置信息:")print(f"Patch 大小: {model.patch_embed.patch_size}")print(f"嵌入维度: {model.embed_dim}")print(f"深度(Transformer 层数): {model.depth}")print(f"注意力头数: {model.num_heads}")print(f"图像尺寸: {model.patch_embed.img_size}")
# 查看模型结构print("\n模型主要组件:")print(f"1. Patch Embedding: {model.patch_embed}")print(f"2. 位置编码形状: {model.pos_embed.shape}")print(f"3. CLS Token 形状: {model.cls_token.shape}")print(f"4. Transformer 块数量: {len(model.blocks)}")print(f"5. 分类头: {model.head}")2.3 ViT 变体对比
# 不同尺寸的 ViT 模型vit_variants = [ 'vit_tiny_patch16_224', # Tiny: 5.7M 参数 'vit_small_patch16_224', # Small: 22M 参数 'vit_base_patch16_224', # Base: 86M 参数 'vit_large_patch16_224', # Large: 304M 参数]
print("ViT 模型对比:\n")for variant in vit_variants: try: model = timm.create_model(variant, pretrained=False) params = sum(p.numel() for p in model.parameters()) / 1e6 print(f"{variant:30s} - 参数量: {params:7.2f}M") except: print(f"{variant:30s} - 不可用")
# 不同 Patch 大小的影响print("\n不同 Patch 大小:")patch_variants = [ 'vit_base_patch32_224', # Patch 32x32 'vit_base_patch16_224', # Patch 16x16]
for variant in patch_variants: model = timm.create_model(variant, pretrained=False) params = sum(p.numel() for p in model.parameters()) / 1e6 num_patches = model.patch_embed.num_patches print(f"{variant:30s} - Patches: {num_patches:4d}, 参数: {params:.2f}M")2.4 提取中间层特征
import torch
model = timm.create_model('vit_base_patch16_224', pretrained=True)model.eval()
# 方法1: 使用 forward_features 提取特征with torch.no_grad(): input_tensor = torch.randn(1, 3, 224, 224)
# 提取所有 patch 的特征 features = model.forward_features(input_tensor) print(f"特征形状: {features.shape}") # [1, 197, 768] # 197 = 1 (CLS token) + 196 (14x14 patches)
# 只获取 CLS token(通常用于分类) cls_token = features[:, 0] print(f"CLS Token 形状: {cls_token.shape}") # [1, 768]
# 方法2: 使用 feature hooksfeatures_dict = {}
def get_features(name): def hook(model, input, output): features_dict[name] = output.detach() return hook
# 注册 hookmodel.blocks[-1].register_forward_hook(get_features('last_block'))
with torch.no_grad(): _ = model(input_tensor)
print(f"\n最后一个 Transformer 块的输出: {features_dict['last_block'].shape}")2.5 ViT 微调示例
import torch.nn as nnimport torch.optim as optim
# 创建自定义分类头的 ViTnum_classes = 10model = timm.create_model( 'vit_base_patch16_224', pretrained=True, num_classes=num_classes)
# 策略1: 冻结 Patch Embedding 和部分 Transformer 层for name, param in model.named_parameters(): if 'patch_embed' in name or 'blocks.0.' in name or 'blocks.1.' in name: param.requires_grad = False
# 策略2: 使用不同的学习率param_groups = [ {'params': model.head.parameters(), 'lr': 1e-3}, {'params': model.blocks.parameters(), 'lr': 1e-4}, {'params': model.patch_embed.parameters(), 'lr': 1e-5}]
optimizer = optim.AdamW(param_groups, weight_decay=0.05)criterion = nn.CrossEntropyLoss()
print("ViT 微调配置完成")print(f"可训练参数: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.2f}M")print(f"总参数: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")三、高级功能
3.1 模型信息查询
import timm
# 查看特定模型的详细信息model_name = 'resnet50'model = timm.create_model(model_name, pretrained=False)
# 获取默认配置default_cfg = model.default_cfgprint("默认配置:")for key, value in default_cfg.items(): print(f" {key}: {value}")
# 查看模型的输入要求data_config = timm.data.resolve_data_config(model.default_cfg)print(f"\n推荐输入尺寸: {data_config['input_size']}")print(f"均值: {data_config['mean']}")print(f"标准差: {data_config['std']}")3.2 数据增强
from timm.data import create_transformfrom timm.data.auto_augment import rand_augment_transform
# 创建训练用的数据增强transform_train = create_transform( input_size=(3, 224, 224), is_training=True, auto_augment='rand-m9-mstd0.5-inc1', # RandAugment re_prob=0.25, # Random Erasing 概率 re_mode='pixel', re_count=1,)
# 创建验证/测试用的变换transform_eval = create_transform( input_size=(3, 224, 224), is_training=False,)
print("训练变换:", transform_train)print("\n评估变换:", transform_eval)3.3 混合精度训练
import torchfrom torch.cuda.amp import autocast, GradScaler
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=10)model = model.cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)scaler = GradScaler()
# 训练循环# for batch in dataloader:# images, labels = batch# images, labels = images.cuda(), labels.cuda()## optimizer.zero_grad()## # 混合精度前向传播# with autocast():# outputs = model(images)# loss = criterion(outputs, labels)## # 混合精度反向传播# scaler.scale(loss).backward()# scaler.step(optimizer)# scaler.update()3.4 模型导出
import torch
model = timm.create_model('resnet50', pretrained=True)model.eval()
# 导出为 ONNX 格式dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export( model, dummy_input, "resnet50.onnx", input_names=['input'], output_names=['output'], dynamic_axes={ 'input': {0: 'batch_size'}, 'output': {0: 'batch_size'} })
print("模型已导出为 ONNX 格式")
# 导出为 TorchScripttraced_model = torch.jit.trace(model, dummy_input)traced_model.save("resnet50_traced.pt")
print("模型已导出为 TorchScript 格式")四、常用技巧和最佳实践
4.1 查找合适的模型
import timm
# 按任务查找模型# 图像分类classification_models = timm.list_models(pretrained=True)print(f"预训练分类模型总数: {len(classification_models)}")
# 查找特定架构efficient_nets = timm.list_models('efficientnet*', pretrained=True)convnext_models = timm.list_models('convnext*', pretrained=True)swin_models = timm.list_models('swin*', pretrained=True)
print(f"\nEfficientNet 系列: {len(efficient_nets)} 个")print(f"ConvNeXt 系列: {len(convnext_models)} 个")print(f"Swin Transformer 系列: {len(swin_models)} 个")
# 按模型大小筛选print("\n推荐的轻量级模型:")lightweight = ['mobilenetv3_large_100', 'efficientnet_b0', 'resnet34']for model_name in lightweight: model = timm.create_model(model_name, pretrained=False) params = sum(p.numel() for p in model.parameters()) / 1e6 print(f" {model_name:30s}: {params:5.2f}M 参数")4.2 性能优化
import torchimport timm
model = timm.create_model('resnet50', pretrained=True)model.eval()
# 1. 使用 torch.compile (PyTorch 2.0+)if hasattr(torch, 'compile'): model = torch.compile(model) print("模型已编译优化")
# 2. 使用 channels_last 内存格式model = model.to(memory_format=torch.channels_last)dummy_input = torch.randn(1, 3, 224, 224).to(memory_format=torch.channels_last)
# 3. 推理时禁用梯度with torch.no_grad(): output = model(dummy_input)4.3 批量推理
import torchfrom torch.utils.data import DataLoader, Datasetfrom timm.data import create_transform
class ImageDataset(Dataset): def __init__(self, image_paths, transform): self.image_paths = image_paths self.transform = transform
def __len__(self): return len(self.image_paths)
def __getitem__(self, idx): from PIL import Image image = Image.open(self.image_paths[idx]).convert('RGB') return self.transform(image)
# 设置model = timm.create_model('resnet50', pretrained=True)model.eval()model = model.cuda()
transform = create_transform( input_size=(3, 224, 224), is_training=False)
# 假设有图像路径列表# image_paths = ['path1.jpg', 'path2.jpg', ...]# dataset = ImageDataset(image_paths, transform)# loader = DataLoader(dataset, batch_size=32, num_workers=4)
# for batch in loader:# batch = batch.cuda()# with torch.no_grad():# outputs = model(batch)# # 处理输出...五、总结
ResNet vs ViT 对比
| 特性 | ResNet | ViT |
|---|---|---|
| 架构类型 | CNN | Transformer |
| 归纳偏置 | 强(局部性、平移不变性) | 弱 |
| 数据需求 | 中等 | 大(需要大规模预训练) |
| 计算效率 | 高 | 中等(与输入大小平方相关) |
| 参数量 | 较少 | 较多 |
| 迁移学习 | 优秀 | 优秀(预训练后) |
| 推理速度 | 快 | 中等 |
选择建议
使用 ResNet 当:
- 数据集较小(< 100k 图像)
- 需要快速推理
- 计算资源有限
- 需要部署到边缘设备
使用 ViT 当:
- 有大规模预训练模型可用
- 数据集较大
- 追求最高精度
- 有足够的计算资源
常用资源
- GitHub: https://github.com/huggingface/pytorch-image-models
- 文档: https://huggingface.co/docs/timm
- 模型库: https://huggingface.co/timm
- 论文集合: https://github.com/rwightman/pytorch-image-models#papers
附录:完整代码示例
ResNet 完整训练示例
import torchimport torch.nn as nnimport torch.optim as optimfrom torch.utils.data import DataLoaderimport timmfrom timm.data import create_transform
# 1. 创建模型model = timm.create_model('resnet50', pretrained=True, num_classes=10)model = model.cuda()
# 2. 数据增强transform_train = create_transform( input_size=(3, 224, 224), is_training=True, auto_augment='rand-m9-mstd0.5-inc1',)
transform_val = create_transform( input_size=(3, 224, 224), is_training=False,)
# 3. 准备数据加载器# train_dataset = YourDataset(transform=transform_train)# val_dataset = YourDataset(transform=transform_val)# train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)# val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
# 4. 优化器和损失函数optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)criterion = nn.CrossEntropyLoss()
# 5. 训练循环def train_epoch(model, loader, criterion, optimizer): model.train() running_loss = 0.0 correct = 0 total = 0
for images, labels in loader: images, labels = images.cuda(), labels.cuda()
optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step()
running_loss += loss.item() _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item()
return running_loss / len(loader), 100. * correct / total
def validate(model, loader, criterion): model.eval() running_loss = 0.0 correct = 0 total = 0
with torch.no_grad(): for images, labels in loader: images, labels = images.cuda(), labels.cuda() outputs = model(images) loss = criterion(outputs, labels)
running_loss += loss.item() _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item()
return running_loss / len(loader), 100. * correct / total
# 训练# num_epochs = 10# for epoch in range(num_epochs):# train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer)# val_loss, val_acc = validate(model, val_loader, criterion)## print(f'Epoch {epoch+1}/{num_epochs}')# print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')# print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')ViT 完整训练示例
import torchimport torch.nn as nnimport torch.optim as optimimport timm
# 1. 创建 ViT 模型model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=10)model = model.cuda()
# 2. 分层学习率def get_parameter_groups(model): no_decay = ['bias', 'norm'] return [ { 'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and 'head' not in n], 'lr': 1e-4, 'weight_decay': 0.05 }, { 'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and 'head' not in n], 'lr': 1e-4, 'weight_decay': 0.0 }, { 'params': model.head.parameters(), 'lr': 1e-3, 'weight_decay': 0.05 } ]
param_groups = get_parameter_groups(model)optimizer = optim.AdamW(param_groups)
# 3. 学习率调度器from timm.scheduler import CosineLRScheduler
scheduler = CosineLRScheduler( optimizer, t_initial=10, # epochs lr_min=1e-6, warmup_t=1, warmup_lr_init=1e-6,)
# 4. 训练(与 ResNet 类似,但可能需要更多 epochs)criterion = nn.CrossEntropyLoss()# ... 训练循环代码同上 分享
如果这篇文章对你有帮助,欢迎分享给更多人!
部分信息可能已经过时
随机文章 随机推荐
暂无数据






