卷积操作具有不变形和局部性
卷积层是错误叫法,实际其运算是互相关运算(cross-correlation),而不是数学上的卷积运算。 卷积层对输入和卷积核权重进行互相关运算,并添加偏置后输出。卷积层中的两个被训练的参数是卷积核权重和标量偏置
互相关运算
当卷积核在图像上滑动时,它会计算当前像素与其周围像素之间的差异。如果在这个窗口内有一个边缘,那么像素值的变化会很大,从而导致卷积操作的结果产生较大的响应。这样,通过卷积操作,我们可以将边缘区域与其他区域区分开来。卷积操作的结果是生成一个新的图像,其中边缘信息被强调出来,而其他区域则被抑制。这使得边缘检测成为许多计算机视觉任务的基础,例如目标检测、图像分割和图像识别等。
一个简单的目标边缘检测
通过找到像素变化的位置,来检测图像中不同颜色的边缘
class Conv2D(nn.Module):
def __init__(self, kernel_size):
super().__init__()
self.weight = nn.Parameter(torch.rand(kernel_size))
self.bias = nn.Parameter(torch.zeros(1))
def forward(self, x):
**return** corr2d(x, self.weight) + self.bias
#构造一个6*8的黑白图像
x = torch.ones((6,8))
x[:,2:6]=0
#构造一个高为1,宽为2的卷积核
K = torch.tensor([[1.0, -1.0]])
Y = corr2d(X, K)
一个有意义的卷积操作的完整代码
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
# 加载图像并将其转换为灰度图
image_path = "494601.jpg"
image = Image.open(image_path).convert("L")
# 归一化图像
image_tensor = transforms.ToTensor()(image).unsqueeze(0)
image_tensor = (image_tensor - image_tensor.min()) / (image_tensor.max() - image_tensor.min())
# 输出输入图像的统计信息
print("Input Image Statistics:")
print("Min:", image_tensor.min().item())
print("Max:", image_tensor.max().item())
# 使用PyTorch内置的边缘检测卷积核
edge_detection_kernel = torch.tensor([[1, 0, -1],
[1, 0, -1],
[1, 0, -1]], dtype=torch.float32).unsqueeze(0).unsqueeze(0)
# 输出卷积核的值
print("Convolution Kernel:")
print(edge_detection_kernel)
# 定义卷积层
conv_layer = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False)
# 将卷积层的权重设置为边缘检测卷积核
with torch.no_grad():
conv_layer.weight.copy_(edge_detection_kernel)
# 应用卷积操作
edge_map = conv_layer(image_tensor)
# 可视化结果
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.imshow(image, cmap='gray')
plt.title('原始图像')
plt.subplot(1, 3, 2)
plt.imshow(edge_map.squeeze().detach().numpy(), cmap='gray', vmin=edge_map.min().item(), vmax=edge_map.max().item())
plt.title('边缘检测结果')
plt.subplot(1, 3, 3)
plt.imshow(edge_map.squeeze().detach().numpy(), cmap='gray', vmin=-1, vmax=1) # 调整比例以获得更好的可视化效果
plt.title('边缘检测结果(调整后)')
plt.show()