使用 pywt 执行小波变换实现图像去噪

748 阅读2分钟

开启掘金成长之旅!这是我参与「掘金日新计划 · 12 月更文挑战」的第29天,点击查看活动详情

前言

小波变换通常用于图像去噪,基于小波变换的图像去噪步骤如下:

  • 选择一个小波类型(例如,双正交小波或 N 级分解小波),利用小波对图像执行离散小波变换
  • 在图像分解后,确定每个级别的阈值 (Birgé-Massart 策略是选择阈值的常见方法),使用此过程,可以为 N 个级别设置单独的阈值
  • 最后一步是使用逆离散小波变换从修改后的级别重建图像

需要注意的是,选择使用不同小波、级别和阈值策略可能会导致不同类型的滤波。

基于小波变换实现图像去噪

在本节中,我们将在输入 RGB 图像中添加高斯噪声,并使用小波的软阈值消除噪声。

(1) 首先,导入所需库,读取输入 RGB 图像,并用 σ=0.25σ=0.25 添加高斯噪声,得到带有噪声的图像:

import numpy as np
import pywt
from skimage import img_as_float
import matplotlib.pylab as plt
from skimage.io import imread

image = img_as_float(imread('3.png'))
noise_sigma = 0.25 #16.0
image += np.random.normal(0, noise_sigma, size=image.shape)

(2) 我们可以使用 pywt 中的函数 Wavelet() 应用多级 2D DWT,该函数会应用小波变换并且级别 level=7

wavelet = pywt.Wavelet('haar')
levels  = int(np.floor(np.log2(image.shape[0])))
print(levels)
wavelet_coeffs = pywt.wavedec2(image, wavelet, level=levels)
# 7

(3) 定义函数 denoise() 执行图像去噪操作,该函数接受给定类型的小波对象和(估计的)噪声标准差作为参数。然后,函数计算出不同级别的 DWT 系数,并使用估计的噪声应用软阈值来计算新的系数;最后,使用新系数重建图像,并得到的返回图像:

def denoise(image, wavelet, noise_sigma):
    levels = int(np.floor(np.log2(image.shape[0])))
    wc = pywt.wavedec2(image, wavelet, level=levels)
    arr, coeff_slices = pywt.coeffs_to_array(wc)
    arr = pywt.threshold(arr, noise_sigma, mode='soft')
    nwc = pywt.array_to_coeffs(arr, coeff_slices, output_format='wavedec2')
    return pywt.waverec2(nwc, wavelet)

(4) 使用不同类型的离散小波,并将它们应用于带有噪声的图像,并使用 denoise() 函数从有噪输入 RGB 图像的每个颜色通道中去除噪声:

print(pywt.wavelist(kind='discrete'))
wlts = ['bior1.5', 'coif5', 'db6', 'dmey', 'haar', 'rbio2.8', 'sym15'] # pywt.wavelist(kind='discrete')
Denoised={}
for wlt in wlts:
    out = image.copy()
    for i in range(3):
        out[...,i] = denoise(image[...,i], wavelet=wlt, noise_sigma=3/2*noise_sigma)
    Denoised[wlt] = np.clip(out, 0, 1)
print(len(Denoised))

(5) 最后,绘制用不同小波类型去噪的所有输出图像,并使用 PSNR 比较图像质量:

plt.figure(figsize=(15,8))
plt.subplots_adjust(0,0,1,0.9,0.05,0.07)
plt.subplot(241), plt.imshow(np.clip(image,0,1)), plt.axis('off'), plt.title('original image', size=8)
i = 2
for wlt in Denoised:
    plt.subplot(2,4,i), plt.imshow(Denoised[wlt]), plt.axis('off'), plt.title(wlt, size=8)
    i += 1
plt.suptitle('Image Denoising with Wavelets', size=12)
plt.show()

Figure_11.png