本文已参与「新人创作礼」活动,一起开启掘金创作之路
[Paddle2.0学习之第二步]实现图像分类
文章目录
本案例通过Cifer10数据集进行举例,十几行代码完成图像分类
项目放置在AIstudio平台
经典五步:
1. 导包
# 从paddle.vision.models 模块中import 残差网络,VGG网络,LeNet网络
from paddle.vision.models import resnet50, vgg16, LeNet
from paddle.vision.datasets import Cifar10
from paddle.optimizer import Momentum
from paddle.regularizer import L2Decay
from paddle.nn import CrossEntropyLoss
from paddle.metric import Accuracy
from paddle.vision.transforms import Transpose
2. 处理数据
# 使用Cifar10数据集
train_dataset = Cifar10(mode='train', transform=Transpose())
val_dataset = Cifar10(mode='test', transform=Transpose())
train_dataset[0]
(array([[[178., 178., 178., ..., 170., 168., 165.],
[180., 179., 180., ..., 173., 171., 168.],
[177., 177., 178., ..., 171., 169., 167.],
...,
[112., 113., 114., ..., 100., 98., 101.],
[112., 112., 113., ..., 102., 102., 102.],
[103., 100., 103., ..., 92., 93., 91.]],
[[176., 176., 176., ..., 168., 166., 163.],
[178., 177., 178., ..., 171., 169., 166.],
[175., 175., 176., ..., 169., 167., 165.],
...,
[107., 109., 110., ..., 97., 94., 95.],
[102., 103., 103., ..., 95., 93., 92.],
[ 96., 93., 95., ..., 84., 86., 84.]],
[[189., 189., 189., ..., 180., 177., 174.],
[191., 190., 191., ..., 182., 180., 177.],
[188., 188., 189., ..., 180., 178., 176.],
...,
[107., 108., 110., ..., 94., 93., 95.],
[101., 102., 103., ..., 93., 91., 91.],
[ 92., 90., 94., ..., 80., 80., 77.]]], dtype=float32), array(0))
3. 创建模型
import paddle
# 确保从paddle.vision.datasets.Cifar10中加载的图像数据是np.ndarray类型
paddle.vision.set_image_backend('cv2')
# 调用resnet50模型
model = paddle.Model(resnet50(pretrained=False, num_classes=10))
4. 优化准备
# 定义优化器
optimizer = Momentum(learning_rate=0.01,
momentum=0.9,
weight_decay=L2Decay(1e-4),
parameters=model.parameters())
# 进行训练前准备
model.prepare(optimizer, CrossEntropyLoss(), Accuracy(topk=(1, 5)))
5. 训练
#开启GPU
use_gpu = True
paddle.set_device('gpu:0') if use_gpu else paddle.set_device('cpu')
# 启动训练
model.fit(train_dataset,
val_dataset,
epochs=30,
batch_size=64,
save_dir="./output",
num_workers=10)
总结
paddle yyds!
我在AI Studio上获得黄金等级,点亮8个徽章,来互关呀~