[深度学习]基于 PyWavelet 的小波包分解完成时序数据降噪

284 阅读4分钟

一、前言

小波包(Wavelet Packet Transform)变换是基于小波变换(Wavelet Transform)的第二代技术。在深度学习中,我们关心它们对时序数据的降噪效果

小波变换挑出某个小波基函数,将原始信号递归地分解为高频(细节系数)和低频(近似系数)。其局限性在于,每次只分解上一层的低频信号

Pasted image 20230815104805.png level=5

而小波包变换则在每一次分解时,均对高频和低频信号都进行各自的分解,试图达到更精确的降噪效果。其数据结构类似二叉树:

image.png 小波变换降噪时序数据的 Python 代码很容易找到,不再赘述。本文重点记录小波包降噪处理的相关 Python 代码。

二、基于 PyWavelet 的代码实现

PyWavelet 中,小波包的类有 pywt.WaveletPacketpywt.WaveletPacket2Dpywt.WaveletPacketND三种,分别处理1-D、2-D、N-D数据。
本例中的时序数据属于一维数据,故选用pywt.WaveletPacket,有且只有两个成员函数

  • get_level() 获取指定分解级别的所有 Nodes
  • reconstruct() 将所有 Sub Nodes 重组为新数据。

因此小波包的处理流程也很简单:

  1. 用 wp.get_level() 获取原始数据的所有分量
  2. 对分量进行降噪处理(准确来说,是对除了'A...A'以外的所有含'D'的分量的处理)
  3. 使用new_wp.reconstruct() 重新组合所有处理过的分量。得到降噪数据

1.导包

采用 PyWavelet 包,包名为pywt

import pywt

2.选取小波包函数

wavelet = 'db4'
padding_mode = 'symmetric'
order = "natural"
# WaveletPacket 返回一个树结构,包括划分好的高低频滤波结果
wp = pywt.WaveletPacket(data, wavelet, mode=padding_mode)  # 选用 Daubechies8 小波
new_wp = pywt.WaveletPacket(training_set_scaled, wavelet, mode=padding_mode)  # 存放降噪后的信号

maxlevel = wp.maxlevel  # 具体选用的层数取决于信号的长度和小波的长度
maxlevel
8

核心函数是 pywt.WaveletPacket() ,用于创建一个小波包对象,其中

  • data 传入一个 array-like 对象
  • wavelet 指定小波族的某一个小波基函数,官网文档给出了Wavelets List,了解有哪些小波基函数可用。
  • mode 通常实际传入的data是一个有限长度的数组,其首尾两端是被 “截断” 的,这容易导致两端数据处理异常。一种解决方法是将数据周期性复制并拼接起来。mode 可以选择多个副本拼接的方式。(这和 CNN 的 padding 有相似之处)

3. 取出指定层次的结点

nodes = wp.get_level(3, order=order)

此处传入 3,得到第三次分解的23=82^3 = 8 个分量(type: list),PyWaveletpywt.Node 类型表示二叉树上的每一个分量。后面会介绍 Node 结点的常用属性。
根据数据的长度和所选取的特定 wavelet 基函数,小波包所能分解的级别是有上限 maxlevel 的(即二叉树最多只能有 maxlevel 层)。

labels = [n.path for n in nodes]
values = [n.data for n in nodes]
labels, values
(['aaa', 'aad', 'ada', 'add', 'daa', 'dad', 'dda', 'ddd'],
 [array([1.05758547, 1.06546091, 1.05783781, 1.06609727, 1.06085854,...]...)]

4.查看图表

# 创建一个2行4列的子图布局,并指定每个子图的尺寸比例
fig, axes = plt.subplots(2, 4, figsize=(32, 8))
# 可以通过索引访问每个子图
# 第一行的第一个子图
axes[0, 0].plot(values[0])
axes[0, 0].set_title(labels[0])

...

axes[1, 3].plot(values[7])
axes[1, 3].set_title(labels[7].upper())

# 调整子图之间的间距
plt.tight_layout()

# 显示图形
plt.show()

image.png

5.※ 降噪

# 设置阈值
threshold = 0.3

# 复制一份数据,并更新除了 'AAA' 以外的所有分量
denoise_values = values
denoise_values[1:] = pywt.threshold(values[1:], threshold, mode='soft') 

降噪并不是将原始数据分离为近似系数细节系数后,抛弃细节系数。
而是将细节系数用 threshold 降噪后,再与近似系数重组回去。

注意,相关中文 blog 大多以讹传讹,将 pywt.threshold()误写为 WaveletPacket对象的成员函数。

6.写入 node 数据

(new_wp['aaa'].data, new_wp['aad'].data, new_wp['ada'].data, new_wp['add'].data, 
 new_wp['daa'].data, new_wp['dad'].data,new_wp['dda'].data, new_wp['ddd'].data) = denoise_values

7.※ 重构、绘图

fig = plt.figure(figsize=(32, 5))  # 设置了 fig 变量,就会影响到本jupyter代码单元的 plt 的画布尺寸
plt.plot(training_set_scaled, "r", label="original data")

# ※ 重构
plt.plot(new_wp.reconstruct(update=True), "b", label="denoised data")

plt.legend()  # 按照 label 生成图例

参考引用