mobile wallpaper 1mobile wallpaper 2mobile wallpaper 3mobile wallpaper 4
446 字
1 分钟
timm库的相关使用
2026-04-05

Timm 库使用教程#

简介#

timm (PyTorch Image Models) 是一个由 Ross Wightman 开发的优秀计算机视觉库,提供了:

  • 500+ 预训练模型
  • 标准化的模型接口
  • 高效的数据增强和训练工具
  • SOTA 模型的快速实现

安装方式:

pip install timm

一、ResNet 使用示例#

1.1 基础使用#

加载预训练模型#

import timm
import torch
# 查看所有可用的 ResNet 模型
resnet_models = timm.list_models('resnet*', pretrained=True)
print(f"可用的 ResNet 模型数量: {len(resnet_models)}")
print("部分模型:", resnet_models[:5])
# 加载预训练的 ResNet50
model = timm.create_model('resnet50', pretrained=True)
model.eval()
print(f"模型类型: {type(model)}")
print(f"参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")

推理示例#

from PIL import Image
import requests
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
# 1. 准备图像
url = 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
image = Image.open(requests.get(url, stream=True).raw)
# 2. 获取模型的数据配置
config = resolve_data_config({}, model=model)
print("数据配置:", config)
# 3. 创建相应的变换
transform = create_transform(**config)
# 4. 预处理图像
input_tensor = transform(image).unsqueeze(0) # 添加 batch 维度
# 5. 推理
with torch.no_grad():
output = model(input_tensor)
print(f"输出形状: {output.shape}") # [1, 1000] - ImageNet 1000 类
# 6. 获取预测结果
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top5_prob, top5_catid = torch.topk(probabilities, 5)
print("\nTop-5 预测:")
for i in range(5):
print(f"{i+1}. 类别 {top5_catid[i].item()}, 概率: {top5_prob[i].item():.4f}")

1.2 自定义分类头#

# 创建用于自定义类别数的模型
num_classes = 10 # 例如 CIFAR-10
model = timm.create_model(
'resnet50',
pretrained=True,
num_classes=num_classes
)
print(f"新的分类头输出: {model.fc}")
# 也可以只提取特征(去掉分类头)
feature_model = timm.create_model(
'resnet50',
pretrained=True,
num_classes=0 # 移除分类头
)
# 测试特征提取
with torch.no_grad():
features = feature_model(input_tensor)
print(f"特征向量形状: {features.shape}") # [1, 2048]

1.3 ResNet 变体对比#

# 尝试不同的 ResNet 变体
variants = [
'resnet34', # 标准 ResNet34
'resnet50', # 标准 ResNet50
'resnet50d', # ResNet50-D (改进版本)
'resnet101', # 更深的 ResNet
'resnetv2_50', # ResNet V2
]
for variant in variants:
model = timm.create_model(variant, pretrained=False)
params = sum(p.numel() for p in model.parameters()) / 1e6
print(f"{variant:20s} - 参数量: {params:6.2f}M")

1.4 微调训练示例#

import torch.nn as nn
import torch.optim as optim
# 创建模型
model = timm.create_model('resnet50', pretrained=True, num_classes=10)
# 冻结除最后一层外的所有层
for name, param in model.named_parameters():
if 'fc' not in name: # fc 是分类头
param.requires_grad = False
# 设置优化器(只优化分类头)
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# 训练循环示例
model.train()
# dummy_data = torch.randn(8, 3, 224, 224)
# dummy_labels = torch.randint(0, 10, (8,))
#
# optimizer.zero_grad()
# outputs = model(dummy_data)
# loss = criterion(outputs, dummy_labels)
# loss.backward()
# optimizer.step()

二、Vision Transformer (ViT) 使用示例#

2.1 基础使用#

加载预训练 ViT 模型#

import timm
import torch
# 查看所有 ViT 模型
vit_models = timm.list_models('vit*', pretrained=True)
print(f"可用的 ViT 模型数量: {len(vit_models)}")
print("部分模型:", vit_models[:10])
# 加载 ViT-Base
model = timm.create_model('vit_base_patch16_224', pretrained=True)
model.eval()
print(f"\n模型参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")

ViT 推理#

from PIL import Image
import requests
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
# 准备图像
url = 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
image = Image.open(requests.get(url, stream=True).raw)
# ViT 的数据预处理
config = resolve_data_config({}, model=model)
transform = create_transform(**config)
input_tensor = transform(image).unsqueeze(0)
# 推理
with torch.no_grad():
output = model(input_tensor)
print(f"输出形状: {output.shape}")
# 预测
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top5_prob, top5_catid = torch.topk(probabilities, 5)
print("\nTop-5 预测:")
for i in range(5):
print(f"{i+1}. 类别 {top5_catid[i].item()}, 概率: {top5_prob[i].item():.4f}")

2.2 ViT 模型架构详解#

# 查看 ViT 的详细配置
model = timm.create_model('vit_base_patch16_224', pretrained=True)
print("ViT 配置信息:")
print(f"Patch 大小: {model.patch_embed.patch_size}")
print(f"嵌入维度: {model.embed_dim}")
print(f"深度(Transformer 层数): {model.depth}")
print(f"注意力头数: {model.num_heads}")
print(f"图像尺寸: {model.patch_embed.img_size}")
# 查看模型结构
print("\n模型主要组件:")
print(f"1. Patch Embedding: {model.patch_embed}")
print(f"2. 位置编码形状: {model.pos_embed.shape}")
print(f"3. CLS Token 形状: {model.cls_token.shape}")
print(f"4. Transformer 块数量: {len(model.blocks)}")
print(f"5. 分类头: {model.head}")

2.3 ViT 变体对比#

# 不同尺寸的 ViT 模型
vit_variants = [
'vit_tiny_patch16_224', # Tiny: 5.7M 参数
'vit_small_patch16_224', # Small: 22M 参数
'vit_base_patch16_224', # Base: 86M 参数
'vit_large_patch16_224', # Large: 304M 参数
]
print("ViT 模型对比:\n")
for variant in vit_variants:
try:
model = timm.create_model(variant, pretrained=False)
params = sum(p.numel() for p in model.parameters()) / 1e6
print(f"{variant:30s} - 参数量: {params:7.2f}M")
except:
print(f"{variant:30s} - 不可用")
# 不同 Patch 大小的影响
print("\n不同 Patch 大小:")
patch_variants = [
'vit_base_patch32_224', # Patch 32x32
'vit_base_patch16_224', # Patch 16x16
]
for variant in patch_variants:
model = timm.create_model(variant, pretrained=False)
params = sum(p.numel() for p in model.parameters()) / 1e6
num_patches = model.patch_embed.num_patches
print(f"{variant:30s} - Patches: {num_patches:4d}, 参数: {params:.2f}M")

2.4 提取中间层特征#

import torch
model = timm.create_model('vit_base_patch16_224', pretrained=True)
model.eval()
# 方法1: 使用 forward_features 提取特征
with torch.no_grad():
input_tensor = torch.randn(1, 3, 224, 224)
# 提取所有 patch 的特征
features = model.forward_features(input_tensor)
print(f"特征形状: {features.shape}") # [1, 197, 768]
# 197 = 1 (CLS token) + 196 (14x14 patches)
# 只获取 CLS token(通常用于分类)
cls_token = features[:, 0]
print(f"CLS Token 形状: {cls_token.shape}") # [1, 768]
# 方法2: 使用 feature hooks
features_dict = {}
def get_features(name):
def hook(model, input, output):
features_dict[name] = output.detach()
return hook
# 注册 hook
model.blocks[-1].register_forward_hook(get_features('last_block'))
with torch.no_grad():
_ = model(input_tensor)
print(f"\n最后一个 Transformer 块的输出: {features_dict['last_block'].shape}")

2.5 ViT 微调示例#

import torch.nn as nn
import torch.optim as optim
# 创建自定义分类头的 ViT
num_classes = 10
model = timm.create_model(
'vit_base_patch16_224',
pretrained=True,
num_classes=num_classes
)
# 策略1: 冻结 Patch Embedding 和部分 Transformer 层
for name, param in model.named_parameters():
if 'patch_embed' in name or 'blocks.0.' in name or 'blocks.1.' in name:
param.requires_grad = False
# 策略2: 使用不同的学习率
param_groups = [
{'params': model.head.parameters(), 'lr': 1e-3},
{'params': model.blocks.parameters(), 'lr': 1e-4},
{'params': model.patch_embed.parameters(), 'lr': 1e-5}
]
optimizer = optim.AdamW(param_groups, weight_decay=0.05)
criterion = nn.CrossEntropyLoss()
print("ViT 微调配置完成")
print(f"可训练参数: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.2f}M")
print(f"总参数: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")

三、高级功能#

3.1 模型信息查询#

import timm
# 查看特定模型的详细信息
model_name = 'resnet50'
model = timm.create_model(model_name, pretrained=False)
# 获取默认配置
default_cfg = model.default_cfg
print("默认配置:")
for key, value in default_cfg.items():
print(f" {key}: {value}")
# 查看模型的输入要求
data_config = timm.data.resolve_data_config(model.default_cfg)
print(f"\n推荐输入尺寸: {data_config['input_size']}")
print(f"均值: {data_config['mean']}")
print(f"标准差: {data_config['std']}")

3.2 数据增强#

from timm.data import create_transform
from timm.data.auto_augment import rand_augment_transform
# 创建训练用的数据增强
transform_train = create_transform(
input_size=(3, 224, 224),
is_training=True,
auto_augment='rand-m9-mstd0.5-inc1', # RandAugment
re_prob=0.25, # Random Erasing 概率
re_mode='pixel',
re_count=1,
)
# 创建验证/测试用的变换
transform_eval = create_transform(
input_size=(3, 224, 224),
is_training=False,
)
print("训练变换:", transform_train)
print("\n评估变换:", transform_eval)

3.3 混合精度训练#

import torch
from torch.cuda.amp import autocast, GradScaler
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=10)
model = model.cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scaler = GradScaler()
# 训练循环
# for batch in dataloader:
# images, labels = batch
# images, labels = images.cuda(), labels.cuda()
#
# optimizer.zero_grad()
#
# # 混合精度前向传播
# with autocast():
# outputs = model(images)
# loss = criterion(outputs, labels)
#
# # 混合精度反向传播
# scaler.scale(loss).backward()
# scaler.step(optimizer)
# scaler.update()

3.4 模型导出#

import torch
model = timm.create_model('resnet50', pretrained=True)
model.eval()
# 导出为 ONNX 格式
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
dummy_input,
"resnet50.onnx",
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
print("模型已导出为 ONNX 格式")
# 导出为 TorchScript
traced_model = torch.jit.trace(model, dummy_input)
traced_model.save("resnet50_traced.pt")
print("模型已导出为 TorchScript 格式")

四、常用技巧和最佳实践#

4.1 查找合适的模型#

import timm
# 按任务查找模型
# 图像分类
classification_models = timm.list_models(pretrained=True)
print(f"预训练分类模型总数: {len(classification_models)}")
# 查找特定架构
efficient_nets = timm.list_models('efficientnet*', pretrained=True)
convnext_models = timm.list_models('convnext*', pretrained=True)
swin_models = timm.list_models('swin*', pretrained=True)
print(f"\nEfficientNet 系列: {len(efficient_nets)} 个")
print(f"ConvNeXt 系列: {len(convnext_models)} 个")
print(f"Swin Transformer 系列: {len(swin_models)} 个")
# 按模型大小筛选
print("\n推荐的轻量级模型:")
lightweight = ['mobilenetv3_large_100', 'efficientnet_b0', 'resnet34']
for model_name in lightweight:
model = timm.create_model(model_name, pretrained=False)
params = sum(p.numel() for p in model.parameters()) / 1e6
print(f" {model_name:30s}: {params:5.2f}M 参数")

4.2 性能优化#

import torch
import timm
model = timm.create_model('resnet50', pretrained=True)
model.eval()
# 1. 使用 torch.compile (PyTorch 2.0+)
if hasattr(torch, 'compile'):
model = torch.compile(model)
print("模型已编译优化")
# 2. 使用 channels_last 内存格式
model = model.to(memory_format=torch.channels_last)
dummy_input = torch.randn(1, 3, 224, 224).to(memory_format=torch.channels_last)
# 3. 推理时禁用梯度
with torch.no_grad():
output = model(dummy_input)

4.3 批量推理#

import torch
from torch.utils.data import DataLoader, Dataset
from timm.data import create_transform
class ImageDataset(Dataset):
def __init__(self, image_paths, transform):
self.image_paths = image_paths
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
from PIL import Image
image = Image.open(self.image_paths[idx]).convert('RGB')
return self.transform(image)
# 设置
model = timm.create_model('resnet50', pretrained=True)
model.eval()
model = model.cuda()
transform = create_transform(
input_size=(3, 224, 224),
is_training=False
)
# 假设有图像路径列表
# image_paths = ['path1.jpg', 'path2.jpg', ...]
# dataset = ImageDataset(image_paths, transform)
# loader = DataLoader(dataset, batch_size=32, num_workers=4)
# for batch in loader:
# batch = batch.cuda()
# with torch.no_grad():
# outputs = model(batch)
# # 处理输出...

五、总结#

ResNet vs ViT 对比#

特性ResNetViT
架构类型CNNTransformer
归纳偏置强(局部性、平移不变性)
数据需求中等大(需要大规模预训练)
计算效率中等(与输入大小平方相关)
参数量较少较多
迁移学习优秀优秀(预训练后)
推理速度中等

选择建议#

使用 ResNet 当:

  • 数据集较小(< 100k 图像)
  • 需要快速推理
  • 计算资源有限
  • 需要部署到边缘设备

使用 ViT 当:

  • 有大规模预训练模型可用
  • 数据集较大
  • 追求最高精度
  • 有足够的计算资源

常用资源#


附录:完整代码示例#

ResNet 完整训练示例#

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import timm
from timm.data import create_transform
# 1. 创建模型
model = timm.create_model('resnet50', pretrained=True, num_classes=10)
model = model.cuda()
# 2. 数据增强
transform_train = create_transform(
input_size=(3, 224, 224),
is_training=True,
auto_augment='rand-m9-mstd0.5-inc1',
)
transform_val = create_transform(
input_size=(3, 224, 224),
is_training=False,
)
# 3. 准备数据加载器
# train_dataset = YourDataset(transform=transform_train)
# val_dataset = YourDataset(transform=transform_val)
# train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
# 4. 优化器和损失函数
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)
criterion = nn.CrossEntropyLoss()
# 5. 训练循环
def train_epoch(model, loader, criterion, optimizer):
model.train()
running_loss = 0.0
correct = 0
total = 0
for images, labels in loader:
images, labels = images.cuda(), labels.cuda()
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
return running_loss / len(loader), 100. * correct / total
def validate(model, loader, criterion):
model.eval()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for images, labels in loader:
images, labels = images.cuda(), labels.cuda()
outputs = model(images)
loss = criterion(outputs, labels)
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
return running_loss / len(loader), 100. * correct / total
# 训练
# num_epochs = 10
# for epoch in range(num_epochs):
# train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer)
# val_loss, val_acc = validate(model, val_loader, criterion)
#
# print(f'Epoch {epoch+1}/{num_epochs}')
# print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
# print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')

ViT 完整训练示例#

import torch
import torch.nn as nn
import torch.optim as optim
import timm
# 1. 创建 ViT 模型
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=10)
model = model.cuda()
# 2. 分层学习率
def get_parameter_groups(model):
no_decay = ['bias', 'norm']
return [
{
'params': [p for n, p in model.named_parameters()
if not any(nd in n for nd in no_decay) and 'head' not in n],
'lr': 1e-4,
'weight_decay': 0.05
},
{
'params': [p for n, p in model.named_parameters()
if any(nd in n for nd in no_decay) and 'head' not in n],
'lr': 1e-4,
'weight_decay': 0.0
},
{
'params': model.head.parameters(),
'lr': 1e-3,
'weight_decay': 0.05
}
]
param_groups = get_parameter_groups(model)
optimizer = optim.AdamW(param_groups)
# 3. 学习率调度器
from timm.scheduler import CosineLRScheduler
scheduler = CosineLRScheduler(
optimizer,
t_initial=10, # epochs
lr_min=1e-6,
warmup_t=1,
warmup_lr_init=1e-6,
)
# 4. 训练(与 ResNet 类似,但可能需要更多 epochs)
criterion = nn.CrossEntropyLoss()
# ... 训练循环代码同上
分享

如果这篇文章对你有帮助,欢迎分享给更多人!

timm库的相关使用
https://blog.azusacat.cn/posts/timm库的相关使用/
作者
Yui
发布于
2026-04-05
许可协议
CC BY-NC-SA 4.0

部分信息可能已经过时

随机文章 随机推荐
暂无数据

目录