NumPy 秘籍中文第二版(二)
四、将 NumPy 与世界的其他地方连接
在本章中,我们将介绍以下秘籍:
- 使用缓冲区协议
- 使用数组接口
- 与 MATLAB 和 Octave 交换数据
- 安装 RPy2
- 与 R 交互
- 安装 JPype
- 将 NumPy 数组发送到 JPype
- 安装 Google App Engine
- 在 Google Cloud 上部署 NumPy 代码
- 在 PythonAnywhere Web 控制台中运行 NumPy 代码
简介
本章是关于互操作性的。 我们必须不断提醒自己,NumPy 在科学(Python)软件生态系统中并不孤单。 与 SciPy 和 matplotlib 一起工作非常容易。 还存在用于与其他 Python 包互操作性的协议。 在 Python 生态系统之外,Java,R,C 和 Fortran 等语言非常流行。 我们将详细介绍与这些环境交换数据的细节。
此外,我们还将讨论如何在云上获取 NumPy 代码。 这是在快速移动的空间中不断发展的技术。 您可以使用许多选项,其中包括 Google App Engine 和 PythonAnywhere。
使用缓冲区协议
基于 C 的 Python 对象具有所谓的缓冲区接口。 Python 对象可以公开其数据以进行直接访问,而无需复制它们。 缓冲区协议使我们能够与其他 Python 软件进行通信,例如 Python 图像库(PIL)。
我们将看到一个从 NumPy 数组保存 PIL 图像的示例。
准备
如有必要,请安装 PIL 和 SciPy。 有关说明,查阅本秘籍的“另见”部分。
操作步骤
该秘籍的完整代码在本书代码包的buffer.py文件中:
import numpy as np
import Image #from PIL import Image (Python3)
import scipy.misc
lena = scipy.misc.lena()
data = np.zeros((lena.shape[0], lena.shape[1], 4), dtype=np.int8)
data[:,:,3] = lena.copy()
img = Image.frombuffer("RGBA", lena.shape, data, 'raw', "RGBA", 0, 1)
img.save('lena_frombuffer.png')
data[:,:,3] = 255
data[:,:,0] = 222
img.save('lena_modified.png')
首先,我们需要一个 NumPy 数组来玩:
-
在前面的章节中,我们看到了如何加载 Lena 的样例图像。 创建一个填充零的数组,并使用图像数据填充 alpha 通道:
lena = scipy.misc.lena() data = np.zeros((lena.shape[0], lena.shape[1], 4), dtype=numpy.int8) data[:,:,3] = lena.copy() -
使用 PIL API 将数据另存为 RGBA 图像:
img = Image.frombuffer("RGBA", lena.shape, data, 'raw', "RGBA", 0, 1) img.save('lena_frombuffer.png') -
通过去除图像数据并使图像变为红色来修改数据数组。 使用 PIL API 保存图像:
data[:,:,3] = 255 data[:,:,0] = 222 img.save('lena_modified.png')以下是之前的图片:
注意
在计算机图形中,原点的位置与您从高中数学中知道的通常的直角坐标系不同。 原点位于屏幕,画布或图像的左上角,y 轴向下。
PIL 图像对象的数据由于缓冲接口的作用而发生了变化,因此,我们看到以下图像:
工作原理
我们从缓冲区(一个 NumPy 数组)创建了一个 PIL 图像。 更改缓冲区后,我们看到更改反映在图像对象中。 我们这样做时没有复制 PIL 图像对象; 相反,我们直接访问并修改了其数据,以使模型的图片显示红色图像。 通过一些简单的更改,代码就可以与其他基于 PIL 的库一起使用,例如 Pillow。
另见
- 第 2 章,“高级索引和数组概念”中的“安装 PIL”
- 第 2 章,“高级索引和数组概念”中的“安装 SciPy”
- 这个页面中介绍了 Python 缓冲区协议。
使用数组接口
数组接口是用于与其他 Python 应用通信的另一种机制。 顾名思义,该协议仅适用于类似数组的对象。 进行了示范。 让我们再次使用 PIL,但不保存文件。
准备
我们将重用先前秘籍中的部分代码,因此前提条件是相似的。 在这里,我们将跳过上一秘籍的第一步,并假定它已经为人所知。
操作步骤
该秘籍的代码在本书代码包的array_interface.py文件中:
from __future__ import print_function
import numpy as np
import Image
import scipy.misc
lena = scipy.misc.lena()
data = np.zeros((lena.shape[0], lena.shape[1], 4), dtype=np.int8)
data[:,:,3] = lena.copy()
img = Image.frombuffer("RGBA", lena.shape, data, 'raw', "RGBA", 0, 1)
array_interface = img.__array_interface__
print("Keys", array_interface.keys())
print("Shape", array_interface['shape'])
print("Typestr", array_interface['typestr'])
numpy_array = np.asarray(img)
print("Shape", numpy_array.shape)
print("Data type", numpy_array.dtype)
以下步骤将使我们能够探索数组接口:
-
PIL
Image对象具有__array_interface__属性。 让我们检查它的内容。 此属性的值是 Python 字典:array_interface = img.__array_interface__ print("Keys", array_interface.keys()) print("Shape", array_interface['shape']) print("Typestr", array_interface['typestr'])此代码显示以下信息:
Keys ['shape', 'data', 'typestr'] Shape (512, 512, 4) Typestr |u1 -
NumPy
ndarray类也具有__array_interface__属性。 我们可以使用asarray()函数将 PIL 图像转换成 NumPy 数组:numpy_array = np.asarray(img) print("Shape", numpy_array.shape) print("Data type", numpy_array.dtype)数组的形状和数据类型如下:
Shape (512, 512, 4) Data type uint8
如您所见,形状没有改变。
工作原理
数组接口或协议使我们可以在类似数组的 Python 对象之间共享数据。 NumPy 和 PIL 都提供了这样的接口。
另见
- 本章中的“使用缓冲区协议”
- 数组接口在这个页面中进行了详细描述。
与 MATLAB 和 Octave 交换数据
MATLAB 及其开放源代码 Octave 是流行的数学应用。 scipy.io包具有savemat()函数,该函数允许您将 NumPy 数组存储为.mat文件作为 Python 字典的值。
准备
安装 MATLAB 或 Octave 超出了本书的范围。 Octave 网站上有一些安装的指南。 如有必要,检查本秘籍的“另见”部分,来获取安装 SciPy 的说明。
操作步骤
该秘籍的完整代码在本书代码包的octave.py文件中:
import numpy as np
import scipy.io
a = np.arange(7)
scipy.io.savemat("a.mat", {"array": a})
一旦安装了 MATLAB 或 Octave,就需要按照以下步骤存储 NumPy 数组:
-
创建一个 NumPy 数组,然后调用
savemat()将其存储在.mat文件中。 此函数有两个参数-文件名和包含变量名和值的字典。a = np.arange(7) scipy.io.savemat("a.mat", {"array": a}) -
导航到创建文件的目录。 加载文件并检查数组:
octave-3.4.0:2> load a.mat octave-3.4.0:3> array array = 0 1 2 3 4 5 6
另见
- 第 2 章,“高级索引和数组概念”中的“安装 SciPy”
savemat()函数的 SciPy 文档
安装 RPy2
R 是一种流行的脚本语言,用于统计和数据分析。 RPy2 是 R 和 Python 之间的接口。 我们将在此秘籍中安装 RPy2。
操作步骤
如果要安装 RPy2,请选择以下选项之一:
-
使用
pip或easy_install进行安装:RPy2 在 PYPI 上可用,因此我们可以使用以下命令进行安装:$ easy_install rpy2另外,我们可以使用以下命令:
$ sudo pip install rpy2 $ pip freeze|grep rpy2 rpy2==2.4.2 -
从源代码安装:我们可以从
tar.gz源安装 RPy2:$ tar -xzf <rpy2_package>.tar.gz $ cd <rpy2_package> $ python setup.py build install
另见
与 R 交互
RPy2 只能用作从 Python 调用 R,而不能相反。 我们将导入一些样本 R 数据集并绘制其中之一的数据。
准备
如有必要,请安装 RPy2。 请参阅先前的秘籍。
操作步骤
该秘籍的完整代码在本书代码包的rdatasets.py文件中:
from rpy2.robjects.packages import importr
import numpy as np
import matplotlib.pyplot as plt
datasets = importr('datasets')
mtcars = datasets.__rdata__.fetch('mtcars')['mtcars']
plt.title('R mtcars dataset')
plt.xlabel('wt')
plt.ylabel('mpg')
plt.plot(mtcars)
plt.grid(True)
plt.show()
motorcars数据集在这个页面中进行了描述。 让我们从加载此样本 R 数据集开始:
-
使用 RPy2
importr()函数将数据集加载到数组中。 此函数可以导入R包。 在此示例中,我们将导入数据集 R 包。 从mtcars数据集创建一个 NumPy 数组:datasets = importr('datasets') mtcars = np.array(datasets.mtcars) -
使用 matplotlib 绘制数据集:
plt.plot(mtcars) plt.show()数据包含英里每加仑(
mpg)和重量(wt)值,单位为千分之一磅。 以下屏幕快照显示了数据,它是一个二维数组:
另见
- 第 1 章“使用 IPython”中的“安装 matplotlib”
安装 JPype
Jython 是用于 Python 和 Java 的默认互操作性解决方案。 但是,Jython 在 Java 虚拟机(JVM)上运行。 因此,它无法访问主要用 C 语言编写的 NumPy 模块。 JPype 是一个开放源代码项目,试图解决此问题。 接口发生在 Python 和 JVM 之间的本机级别上。 让我们安装 JPype。
操作步骤
-
从这里下载 JPype。
-
打开压缩包,然后运行以下命令:
$ python setup.py install
将 NumPy 数组发送到 JPype
在此秘籍中,我们将启动 JVM 并向其发送 NumPy 数组。 我们将使用标准 Java 调用打印接收到的数组。 显然,您将需要安装 Java。
操作步骤
该秘籍的完整代码在本书代码包的hellojpype.py文件中:
import jpype
import numpy as np
#1\. Start the JVM
jpype.startJVM(jpype.getDefaultJVMPath())
#2\. Print hello world
jpype.java.lang.System.out.println("hello world")
#3\. Send a NumPy array
values = np.arange(7)
java_array = jpype.JArray(jpype.JDouble, 1)(values.tolist())
for item in java_array:
jpype.java.lang.System.out.println(item)
#4\. Shutdown the JVM
jpype.shutdownJVM()
首先,我们需要从 JPype 启动 JVM:
-
从 JPype 启动 JVM; JPype 可以方便地找到默认的 JVM 路径:
jpype.startJVM(jpype.getDefaultJVMPath()) -
仅出于传统原因,让我们打印
"hello world":jpype.java.lang.System.out.println("hello world") -
创建一个 NumPy 数组,将其转换为 Python 列表,然后将其传递给 JPype。 现在很容易打印数组元素:
values = np.arange(7) java_array = jpype.JArray(jpype.JDouble, 1)(values.tolist()) for item in java_array: jpype.java.lang.System.out.println(item) -
完成后,让我们关闭 JVM:
jpype.shutdownJVM()JPype 中一次只能运行一个 JVM。 如果我们忘记关闭 JVM,则可能导致意外错误。 程序输出如下:
hello world 0.0 1.0 2.0 3.0 4.0 5.0 6.0 JVM activity report : classes loaded : 31 JVM has been shutdown
工作原理
JPype 允许我们启动和关闭 JVM。 它为标准 Java API 调用提供了包装器。 如本例所示,我们可以传递要由 JArray 包装器转换为 Java 数组的 Python 列表。 JPype 使用 Java 本机接口(JNI),这是本机 C 代码和 Java 之间的桥梁。 不幸的是,使用 JNI 会损害性能,因此您必须注意这一事实。
另见
- 本章中的“安装 JPype”
- JPype 主页
安装 Google App Engine
Google App Engine(GAE)使您可以在 Google Cloud 上构建 Web 应用。 自 2012 年以来, 是 NumPy 的官方支持; 您需要一个 Google 帐户才能使用 GAE。
操作步骤
第一步是下载 GAE:
您也可以从此页面下载文档和 GAE Eclipse 插件。 如果使用 Eclipse 开发,则一定要安装它。
-
开发环境。
GAE 带有一个模拟生产云的开发环境。 在撰写本书时,GAE 正式仅支持 Python 2.5 和 2.7。 GAE 将尝试在您的系统上找到 Python; 但是,例如,如果您有多个 Python 版本,则可能需要自行设置。 您可以在启动器应用的首选项对话框中设置此设置。
SDK 中有两个重要的脚本:
dev_appserver.py:开发服务器appcfg.py:部署在云上
在 Windows 和 Mac 上,有一个 GAE 启动器应用。 启动器具有运行和部署按钮,它们执行与上述脚本相同的操作。
在 Google Cloud 上部署 NumPy 代码
部署 GAE 应用非常容易。 对于 NumPy,需要额外的配置步骤,但这仅需几分钟。
操作步骤
让我们创建一个新的应用:
-
使用启动器创建一个新应用(文件 | 新应用)。 命名为
numpycloud。 这将创建一个包含以下文件的同名文件夹:app.yaml:YAML 应用配置文件favicon.ico:一个图标index.yaml:自动生成的文件main.py:Web 应用的主要入口点
-
将 NumPy 添加到库中。
首先,我们需要让 GAE 知道我们要使用 NumPy。 将以下行添加到库部分中的
app.yaml配置文件中:- name: NumPy version: "1.6.1"这不是最新的 NumPy 版本,但它是 GAE 当前支持的最新版本。 配置文件应具有以下内容:
application: numpycloud version: 1 runtime: python27 api_version: 1 threadsafe: yes handlers: - url: /favicon\.ico static_files: favicon.ico upload: favicon\.ico - url: .* script: main.app libraries: - name: webapp2 version: "2.5.1" - name: numpy version: "1.6.1" -
为了演示我们可以使用 NumPy 代码,让我们修改
main.py文件。 有一个MainHandler类,带有用于 GET 请求的处理器方法。 用以下代码替换此方法:def get(self): self.response.out.write('Hello world!<br/>') self.response.out.write('NumPy sum = ' + str(numpy.arange(7).sum()))最后,我们将提供以下代码:
import webapp2 import numpy class MainHandler(webapp2.RequestHandler): def get(self): self.response.out.write('Hello world!<br/>') self.response.out.write('NumPy sum = ' + str(numpy.arange(7).sum())) app = webapp2.WSGIApplication([('/', MainHandler)], debug=True)
如果您单击在 GAE 启动器中浏览按钮(在 Linux 上,以项目根为参数运行dev_appserver.py),则您应该在默认浏览器中看到一个包含以下文字的网页:
Hello world!NumPy sum = 21
工作原理
GAE 是免费的,具体取决于使用了多少资源。 您最多可以创建 10 个 Web 应用。 GAE 采用沙盒方法,这意味着 NumPy 暂时无法使用,但现在可以使用,如本秘籍所示。
在 PythonAnywhere Web 控制台中运行 NumPy 代码
在第 1 章,“使用 IPython”中,我们已经看到了运行 PythonAnywhere 控制台的过程,而没有任何权限。 此秘籍将需要您有一个帐户,但不要担心-它是免费的,如果您不需要太多资源,至少是免费的。
注册是一个非常简单的过程,此处将不涉及。 NumPy 已经与其他 Python 软件一起安装。 有关完整列表,请参见这里。
我们将建立一个简单的脚本,该脚本每分钟从 Google 财经获取价格数据,并使用 NumPy 对价格进行简单的统计。
操作步骤
当我们签名后,我们可以登录并查看 PythonAnywhere 信息中心。
-
编写代码。 此示例的完整代码如下:
from __future__ import print_function import urllib2 import re import time import numpy as np prices = np.array([]) for i in xrange(3): req = urllib2.Request('http://finance.google.com/finance/info?client=ig&q=AAPL') req.add_header('User-agent', 'Mozilla/5.0') response = urllib2.urlopen(req) page = response.read() m = re.search('l_cur" : "(.*)"', page) prices = np.append(prices, float(m.group(1))) avg = prices.mean() stddev = prices.std() devFactor = 1 bottom = avg - devFactor * stddev top = avg + devFactor * stddev timestr = time.strftime("%H:%M:%S", time.gmtime()) print(timestr, "Average", avg, "-Std", bottom, "+Std", top) time.sleep(60)除我们在其中生长包含价格的 NumPy 数组并计算价格的均值和标准差的位以外,大多数都是标准 Python。 如果有股票代号,例如
AAPL,则可以使用 URL 从 Google 财经下载 JSON 格式的价格数据。 该 URL 当然可以更改。接下来,我们使用正则表达式解析 JSON 以提取价格。 此价格已添加到 NumPy 数组中。 我们计算价格的均值和标准差。 价格是根据标准差乘以我们指定的某个因素后在时间戳的顶部和底部打印出来的。
-
上传代码。
在本地计算机上完成代码后,我们可以将脚本上传到 PythonAnywhere。 转到仪表板,然后单击文件选项卡。 从页面底部的小部件上传脚本。
-
要运行代码,请单击控制台选项卡,然后单击 Bash 链接。 PythonAnywhere 应该立即为我们创建一个 bash 控制台。 现在,我们可以在一个标准差范围内运行
AAPL程序,如以下屏幕截图所示:
工作原理
如果您想在远程服务器上运行 NumPy 代码,则 PythonAnywhere 是完美的选择,尤其是当您需要程序在计划的时间执行时。 至少对于免费帐户而言,进行交互式工作并不那么方便,因为每当您在 Web 控制台中输入文本时都会有一定的滞后。
但是,正如我们所看到的,可以在本地创建和测试程序,并将其上传到 PythonAnywhere。 这也会释放本地计算机上的资源。 我们可以做一些花哨的事情,例如根据股价发送电子邮件或安排在交易时间内激活脚本 。 通过 ,使用 Google App Engine 也可以做到这一点,但是它是通过 Google 方式完成的,因此您需要了解其 API。
五、音频和图像处理
在本章中,我们将介绍 NumPy 和 SciPy 的基本图像和音频(WAV 文件)处理。 在以下秘籍中,我们将使用 NumPy 对声音和图像进行有趣的操作:
- 将图像加载到内存映射中
- 添加图像
- 图像模糊
- 重复音频片段
- 产生声音
- 设计音频过滤器
- 使用 Sobel 过滤器进行边界检测
简介
尽管本书中的所有章节都很有趣,但在本章中,我们确实会继续努力并专注于获得乐趣。 在第 10 章,“Scikits 的乐趣”中,您会发现更多使用scikits-image的图像处理秘籍。 不幸的是,本书没有对音频文件的直接支持,因此您确实需要运行代码示例以充分了解其中的秘籍。
将图像加载到内存映射中
建议将大文件加载到内存映射中。 内存映射文件仅加载大文件的一小部分。 NumPy 内存映射类似于数组。 在此示例中,我们将生成彩色正方形的图像并将其加载到内存映射中。
准备
如有必要,“安装 matplotlib”的“另请参见”部分具有对相应秘籍的引用。
操作步骤
我们将通过初始化数组来开始 :
-
首先,我们需要初始化以下数组:
- 保存图像数据的数组
- 具有正方形中心随机坐标的数组
- 具有平方的随机半径(复数个半径)的数组
- 具有正方形随机颜色的数组
初始化数组:
img = np.zeros((N, N), np.uint8) NSQUARES = 30 centers = np.random.random_integers(0, N, size=(NSQUARES, 2)) radii = np.random.randint(0, N/9, size=NSQUARES) colors = np.random.randint(100, 255, size=NSQUARES)如您所见,我们正在将第一个数组初始化为零。 其他数组使用
numpy.random包中的函数初始化,这些函数生成随机整数。 -
下一步是生成正方形。 我们在上一步中使用数组创建正方形。 使用
clip()函数,我们将确保正方形不会在图像区域外徘徊。meshgrid()函数为我们提供了正方形的坐标。 如果我们给此函数两个大小分别为N和M的数组,它将给我们两个形状为N x M的数组。第一个数组的元素将沿 x 轴重复。 第二个数组将沿 y 轴重复其元素。 以下示例 IPython 会话应该使这一点更加清楚:注意
In: x = linspace(1, 3, 3) In: x Out: array([ 1., 2., 3.]) In: y = linspace(1, 2, 2) In: y Out: array([ 1., 2.]) In: meshgrid(x, y) Out: [array([[ 1., 2., 3.], [ 1., 2., 3.]]), array([[ 1., 1., 1.], [ 2., 2., 2.]])] -
最后,我们将设置正方形的颜色:
for i in xrange(NSQUARES): xindices = range(centers[i][0] - radii[i], centers[i][0] + radii[i]) xindices = np.clip(xindices, 0, N - 1) yindices = range(centers[i][1] - radii[i], centers[i][1] + radii[i]) yindices = np.clip(yindices, 0, N - 1) if len(xindices) == 0 or len(yindices) == 0: continue coordinates = np.meshgrid(xindices, yindices) img[coordinates] = colors[i] -
在将图像数据加载到内存映射之前,我们需要使用
tofile()函数将其存储在文件中。 然后使用memmap()函数将图像文件中的图像数据加载到内存映射中:img.tofile('random_squares.raw') img_memmap = np.memmap('random_squares.raw', shape=img.shape) -
为了确认一切正常,我们使用 matplotlib 显示图像:
plt.imshow(img_memmap) plt.axis('off') plt.show()注意,我们没有显示轴。 生成图像的示例如下所示:
这是本书代码包中
memmap.py文件的完整源代码:import numpy as np import matplotlib.pyplot as plt N = 512 NSQUARES = 30 # Initialize img = np.zeros((N, N), np.uint8) centers = np.random.random_integers(0, N, size=(NSQUARES, 2)) radii = np.random.randint(0, N/9, size=NSQUARES) colors = np.random.randint(100, 255, size=NSQUARES) # Generate squares for i in xrange(NSQUARES): xindices = range(centers[i][0] - radii[i], centers[i][0] + radii[i]) xindices = np.clip(xindices, 0, N - 1) yindices = range(centers[i][1] - radii[i], centers[i][1] + radii[i]) yindices = np.clip(yindices, 0, N - 1) if len(xindices) == 0 or len(yindices) == 0: continue coordinates = np.meshgrid(xindices, yindices) img[coordinates] = colors[i] # Load into memory map img.tofile('random_squares.raw') img_memmap = np.memmap('random_squares.raw', shape=img.shape) # Display image plt.imshow(img_memmap) plt.axis('off') plt.show()
工作原理
我们在此秘籍中使用了以下函数:
| 函数 | 描述 |
|---|---|
zeros() | 此函数给出一个由零填充的数组。 |
random_integers() | 此函数返回一个数组,数组中的随机整数值在上限和下限之间。 |
randint() | 该函数与random_integers()函数相同,除了它使用半开间隔而不是关闭间隔。 |
clip() | 该函数在给定最小值和最大值的情况下裁剪数组的值。 |
meshgrid() | 此函数从包含 x 坐标的数组和包含 y 坐标的数组返回坐标数组。 |
tofile() | 此函数将数组写入文件。 |
memmap() | 给定文件名,此函数从文件创建 NumPy 内存映射。 (可选)您可以指定数组的形状。 |
axis() | 该函数是用于配置绘图轴的 matplotlib 函数。 例如,我们可以将其关闭。 |
另见
- 第 1 章“使用 IPython”中的“安装 matplotlib”
- NumPy 内存映射文档
合成图像
在此秘籍中,我们将结合著名的 Mandelbrot 分形和 Lena 图像。 Mandelbrot 集是数学家 Benoit Mandelbrot 发明的。 这些类型的分形由递归公式定义,您可以在其中通过将当前拥有的复数乘以自身并为其添加一个常数来计算序列中的下一个复数。 本秘籍将涵盖更多详细信息。
准备
如有必要,请安装 SciPy。“另请参见”部分具有相关秘籍的参考。
操作步骤
首先初始化数组,然后生成和绘制分形,最后将分形与 Lena 图像组合:
-
使用
meshgrid(),zeros()和linspace()函数初始化对应于图像区域中像素的x,y和z数组:x, y = np.meshgrid(np.linspace(x_min, x_max, SIZE), np.linspace(y_min, y_max, SIZE)) c = x + 1j * y z = c.copy() fractal = np.zeros(z.shape, dtype=np.uint8) + MAX_COLOR -
如果
z是复数,则对于 Mandelbrot 分形具有以下关系:在此,
c是常数复数。 这可以在复平面上绘制,水平轴显示实数值,垂直轴显示虚数值。 我们将使用所谓的逃逸时间算法绘制分形。 该算法以大约 2 个单位的距离扫描原点周围小区域中的点。 这些点中的每一个都用作c值,并根据逃避区域所需的迭代次数为其指定颜色。 如果所需的迭代次数超过了预定义的次数,则像素将获得默认背景色。 有关更多信息,请参见此秘籍中已提及的维基百科文章:for n in range(ITERATIONS): print(n) mask = numpy.abs(z) <= 4 z[mask] = z[mask] ** 2 + c[mask] fractal[(fractal == MAX_COLOR) & (-mask)] = (MAX_COLOR - 1) * n / ITERATIONS Plot the fractal with matplotlib: plt.subplot(211) plt.imshow(fractal) plt.title('Mandelbrot') plt.axis('off') Use the choose() function to pick a value from the fractal or Lena array: plt.subplot(212) plt.imshow(numpy.choose(fractal < lena, [fractal, lena])) plt.axis('off') plt.title('Mandelbrot + Lena')结果图像如下所示:
以下是本书代码集中
mandelbrot.py文件中该秘籍的完整代码:import numpy as np import matplotlib.pyplot as plt from scipy.misc import lena ITERATIONS = 10 lena = lena() SIZE = lena.shape[0] MAX_COLOR = 255. x_min, x_max = -2.5, 1 y_min, y_max = -1, 1 # Initialize arrays x, y = np.meshgrid(np.linspace(x_min, x_max, SIZE), np.linspace(y_min, y_max, SIZE)) c = x + 1j * y z = c.copy() fractal = np.zeros(z.shape, dtype=np.uint8) + MAX_COLOR # Generate fractal for n in range(ITERATIONS): mask = np.abs(z) <= 4 z[mask] = z[mask] ** 2 + c[mask] fractal[(fractal == MAX_COLOR) & (-mask)] = (MAX_COLOR - 1) * n / ITERATIONS # Display the fractal plt.subplot(211) plt.imshow(fractal) plt.title('Mandelbrot') plt.axis('off') # Combine with lena plt.subplot(212) plt.imshow(np.choose(fractal < lena, [fractal, lena])) plt.axis('off') plt.title('Mandelbrot + Lena') plt.show()
工作原理
在此示例中使用了以下函数:
| 函数 | 描述 |
|---|---|
linspace() | 此函数返回范围内具有指定间隔的数字 |
choose() | 此函数通过根据条件从数组中选择值来创建数组 |
meshgrid() | 此函数从包含 x 坐标的数组和包含 y 坐标的数组返回坐标数组 |
另见
- 第 1 章,“使用 IPython”中的“安装 matplotlib”秘籍
- 第 2 章,“高级索引和数组”中的“安装 SciPy”秘籍
图像模糊
我们可以使用高斯过滤器来模糊图像。 该过滤器基于正态分布。 对应的 SciPy 函数需要标准差作为参数。 在此秘籍中,我们还将绘制极地玫瑰和螺旋形。 这些数字没有直接关系,但是在这里将它们组合起来似乎更有趣。
操作步骤
我们从初始化极坐标图开始,之后我们将模糊 Lena 图像并使用极坐标进行绘图:
-
初始化极坐标图:
NFIGURES = 5 k = np.random.random_integers(1, 5, NFIGURES) a = np.random.random_integers(1, 5, NFIGURES) colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k'] -
要模糊 Lena,请应用标准差为
4的高斯过滤器:plt.subplot(212) blurred = scipy.ndimage.gaussian_filter(lena, sigma=4) plt.imshow(blurred) plt.axis('off') -
matplotlib 有一个
polar()函数,它以极坐标进行绘制:theta = np.linspace(0, k[0] * np.pi, 200) plt.polar(theta, np.sqrt(theta), choice(colors)) for i in xrange(1, NFIGURES): theta = np.linspace(0, k[i] * np.pi, 200) plt.polar(theta, a[i] * np.cos(k[i] * theta), choice(colors))这是这本书的代码集中
spiral.py文件中该秘籍的完整代码:import numpy as np import matplotlib.pyplot as plt from random import choice import scipy import scipy.ndimage # Initialization NFIGURES = 5 k = np.random.random_integers(1, 5, NFIGURES) a = np.random.random_integers(1, 5, NFIGURES) colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k'] lena = scipy.misc.lena() plt.subplot(211) plt.imshow(lena) plt.axis('off') # Blur Lena plt.subplot(212) blurred = scipy.ndimage.gaussian_filter(lena, sigma=4) plt.imshow(blurred) plt.axis('off') # Plot in polar coordinates theta = np.linspace(0, k[0] * np.pi, 200) plt.polar(theta, np.sqrt(theta), choice(colors)) for i in xrange(1, NFIGURES): theta = np.linspace(0, k[i] * np.pi, 200) plt.polar(theta, a[i] * np.cos(k[i] * theta), choice(colors)) plt.axis('off') plt.show()
工作原理
在本教程中,我们使用了以下函数:
| 函数 | 描述 |
|---|---|
gaussian_filter() | 此函数应用高斯过滤器 |
random_integers() | 此函数返回一个数组,数组中的随机整数值在上限和下限之间 |
polar() | 该函数使用极坐标绘制图形 |
另见
- 可以在这个页面中找到
scipy.ndimage文档。
重复音频片段
正如我们在第 2 章,“高级索引和数组概念”中所看到的那样,我们可以使用 WAV 文件来完成整洁的事情。 只需使用urllib2标准 Python 模块下载文件并将其加载到 SciPy 中即可。 让我们下载一个 WAV 文件并重复 3 次。 我们将跳过在第 2 章,“高级索引和数组概念”中已经看到的一些步骤。
操作步骤
-
尽管NumPy具有
repeat()函数,但在这种情况下,更适合使用tile()函数。 函数repeat()的作用是通过重复单个元素而不重复其内容来扩大数组。 以下 IPython 会话应阐明这些函数之间的区别:In: x = array([1, 2]) In: x Out: array([1, 2]) In: repeat(x, 3) Out: array([1, 1, 1, 2, 2, 2]) In: tile(x, 3) Out: array([1, 2, 1, 2, 1, 2])现在,有了这些知识,就可以应用
tile()函数:repeated = np.tile(data, 3) -
使用 matplotlib 绘制音频数据:
plt.title("Repeated") plt.plot(repeated)原始声音数据和重复数据图如下所示:
这是本书代码包中
repeat_audio.py文件中该秘籍的完整代码:import scipy.io.wavfile import matplotlib.pyplot as plt import urllib2 import numpy as np response = urllib2.urlopen('http://www.thesoundarchive.com/austinpowers/smashingbaby.wav') print(response.info()) WAV_FILE = 'smashingbaby.wav' filehandle = open(WAV_FILE, 'w') filehandle.write(response.read()) filehandle.close() sample_rate, data = scipy.io.wavfile.read(WAV_FILE) print("Data type", data.dtype, "Shape", data.shape) plt.subplot(2, 1, 1) plt.title("Original") plt.plot(data) plt.subplot(2, 1, 2) # Repeat the audio fragment repeated = np.tile(data, 3) # Plot the audio data plt.title("Repeated") plt.plot(repeated) scipy.io.wavfile.write("repeated_yababy.wav", sample_rate, repeated) plt.show()
工作原理
以下是此秘籍中最重要的函数:
| 函数 | 描述 |
|---|---|
scipy.io.wavfile.read() | 将 WAV 文件读入数组 |
numpy.tile() | 重复数组指定次数 |
scipy.io.wavfile.write() | 从 NumPy 数组中以指定的采样率创建 WAV 文件 |
另见
- 可以在这个页面中找到
scipy.io文档。
产生声音
声音可以用具有一定幅度,频率和相位的正弦波在数学上表示如下。 我们可以从这个页面中指定的列表中随机选择符合以下公式的频率:
此处,n是钢琴键的编号。 我们将键的编号从 1 到 88。我们将随机选择振幅,持续时间和相位。
操作步骤
首先初始化随机值,然后生成正弦波,编写旋律,最后使用 matplotlib 绘制生成的音频数据:
-
初始化随机值为:
-
200-2000之间的幅度 -
0.01-0.2的持续时间 -
使用已经提到的公式的频率
-
0和2 pi之间的相位值:NTONES = 89 amps = 2000\. * np.random.random((NTONES,)) + 200. durations = 0.19 * np.random.random((NTONES,)) + 0.01 keys = np.random.random_integers(1, 88, NTONES) freqs = 440.0 * 2 ** ((keys - 49.)/12.) phi = 2 * np.pi * np.random.random((NTONES,))
-
-
编写
generate()函数以生成正弦波:def generate(freq, amp, duration, phi): t = np.linspace(0, duration, duration * RATE) data = np.sin(2 * np.pi * freq * t + phi) * amp return data.astype(DTYPE) -
一旦我们产生了一些音调,我们只需要组成一个连贯的旋律。 现在,我们将连接正弦波。 这不会产生良好的旋律,但可以作为更多实验的起点:
for i in xrange(NTONES): newtone = generate(freqs[i], amp=amps[i], duration=durations[i], phi=phi[i]) tone = np.concatenate((tone, newtone)) -
使用 matplotlib 绘制生成的音频数据:
plt.plot(np.linspace(0, len(tone)/RATE, len(tone)), tone) plt.show()生成的音频数据图如下:
可以在此处找到本示例的源代码 ,该代码来自本书代码包中的
tone_generation.py文件:import scipy.io.wavfile import numpy as np import matplotlib.pyplot as plt RATE = 44100 DTYPE = np.int16 # Generate sine wave def generate(freq, amp, duration, phi): t = np.linspace(0, duration, duration * RATE) data = np.sin(2 * np.pi * freq * t + phi) * amp return data.astype(DTYPE) # Initialization NTONES = 89 amps = 2000\. * np.random.random((NTONES,)) + 200. durations = 0.19 * np.random.random((NTONES,)) + 0.01 keys = np.random.random_integers(1, 88, NTONES) freqs = 440.0 * 2 ** ((keys - 49.)/12.) phi = 2 * np.pi * np.random.random((NTONES,)) tone = np.array([], dtype=DTYPE) # Compose for i in xrange(NTONES): newtone = generate(freqs[i], amp=amps[i], duration=durations[i], phi=phi[i]) tone = np.concatenate((tone, newtone)) scipy.io.wavfile.write('generated_tone.wav', RATE, tone) # Plot audio data plt.plot(np.linspace(0, len(tone)/RATE, len(tone)), tone) plt.show()
工作原理
我们创建了带有随机生成声音的 WAV 文件 。 concatenate()函数用于连接正弦波。
另见
设计音频过滤器
我记得在模拟电子课上学习了所有类型的过滤器。 然后,我们实际上构造了这些过滤器。 可以想象,用软件制作过滤器要比用硬件制作过滤器容易得多。
我们将构建一个过滤器,并将其应用于要下载的音频片段。 在本章之前,我们已经完成了一些步骤,因此我们将省略那些部分。
操作步骤
顾名思义,iirdesign()函数允许我们构造几种类型的模拟和数字过滤器。 可以在scipy.signal模块中找到。 该模块包含信号处理函数的完整列表:
-
使用
scipy.signal模块的iirdesign()函数设计过滤器。 IIR 代表无限冲激响应; 有关更多信息,请参见这里。 我们将不涉及iirdesign())函数的所有细节。 在这个页面中查看文档。 简而言之,这些是我们将要设置的参数:-
频率标准化为 0 到 1
-
最大损失
-
最小衰减
-
过滤器类型:
b,a = scipy.signal.iirdesign(wp=0.2, ws=0.1, gstop=60, gpass=1, ftype='butter')
此过滤器的配置对应于 Butterworth 过滤器。 Butterworth 过滤器由物理学家 Stephen Butterworth 于 1930 年首次描述。
-
-
使用
scipy.signal.lfilter()函数应用过滤器。 它接受上一步的值作为参数,当然也接受要过滤的数据数组:filtered = scipy.signal.lfilter(b, a, data)写入新的音频文件时,请确保其数据类型与原始数据数组相同:
scipy.io.wavfile.write('filtered.wav', sample_rate, filtered.astype(data.dtype))在绘制原始数据和过滤后的数据之后,我们得到以下图:
音频过滤器的代码列出如下:
import scipy.io.wavfile import matplotlib.pyplot as plt import urllib2 import scipy.signal response =urllib2.urlopen('http://www.thesoundarchive.com/austinpowers/smashingbaby.wav') print response.info() WAV_FILE = 'smashingbaby.wav' filehandle = open(WAV_FILE, 'w') filehandle.write(response.read()) filehandle.close() sample_rate, data = scipy.io.wavfile.read(WAV_FILE) print("Data type", data.dtype, "Shape", data.shape) plt.subplot(2, 1, 1) plt.title("Original") plt.plot(data) # Design the filter b,a = scipy.signal.iirdesign(wp=0.2, ws=0.1, gstop=60, gpass=1, ftype='butter') # Filter filtered = scipy.signal.lfilter(b, a, data) # Plot filtered data plt.subplot(2, 1, 2) plt.title("Filtered") plt.plot(filtered) scipy.io.wavfile.write('filtered.wav', sample_rate, filtered.astype(data.dtype)) plt.show()
工作原理
我们创建并应用了 Butterworth 带通过滤器。 引入了以下函数来创建过滤器:
| 函数 | 描述 |
|---|---|
scipy.signal.iirdesign() | 创建一个 IIR 数字或模拟过滤器。 此函数具有广泛的参数列表,该列表在这个页面中记录。 |
scipy.signal.lfilter() | 给定一个数字过滤器,对数组进行滤波。 |
使用 Sobel 过滤器进行边界检测
Sobel 过滤器可以用于图像中的边界检测 。 边界检测基于对图像强度执行离散差分。 由于图像是二维的,因此渐变也有两个分量,除非我们将自身限制为一维。 我们将 Sobel 过滤器应用于 Lena 的图片。
操作步骤
在本部分中,您将学习如何应用 Sobel 过滤器来检测 Lena 图像中的边界:
-
要在 x 方向上应用 Sobel 过滤器,请将轴参数设置为
0:sobelx = scipy.ndimage.sobel(lena, axis=0, mode='constant') -
要在 y 方向上应用 Sobel 过滤器,请将轴参数设置为
1:sobely = scipy.ndimage.sobel(lena, axis=1, mode='constant') -
默认的 Sobel 过滤器仅需要输入数组:
default = scipy.ndimage.sobel(lena)是原始图像图和所得图像图,显示了使用 Sobel 过滤器进行边界检测:
完整的边界检测代码如下:
import scipy import scipy.ndimage import matplotlib.pyplot as plt lena = scipy.misc.lena() plt.subplot(221) plt.imshow(lena) plt.title('Original') plt.axis('off') # Sobel X filter sobelx = scipy.ndimage.sobel(lena, axis=0, mode='constant') plt.subplot(222) plt.imshow(sobelx) plt.title('Sobel X') plt.axis('off') # Sobel Y filter sobely = scipy.ndimage.sobel(lena, axis=1, mode='constant') plt.subplot(223) plt.imshow(sobely) plt.title('Sobel Y') plt.axis('off') # Default Sobel filter default = scipy.ndimage.sobel(lena) plt.subplot(224) plt.imshow(default) plt.title('Default Filter') plt.axis('off') plt.show()
工作原理
我们将 Sobel 过滤器应用于著名模型 Lena 的图片。 如本例所示,我们可以指定沿哪个轴进行计算。 默认设置为独立于轴。
六、特殊数组和通用函数
在本章中,我们将介绍以下秘籍:
- 创建通用函数
- 查找勾股三元组
- 用
chararray执行字符串操作 - 创建一个遮罩数组
- 忽略负值和极值
- 使用
recarray函数创建一个得分表
简介
本章是关于特殊数组和通用函数的。 这些是您每天可能不会遇到的主题,但是它们仍然很重要,因此在此需要提及。**通用函数(Ufuncs)**逐个元素或标量地作用于数组。 Ufuncs 接受一组标量作为输入,并产生一组标量作为输出。 通用函数通常可以映射到它们的数学对等物上,例如加法,减法,除法,乘法等。 这里提到的特殊数组是基本 NumPy 数组对象的所有子类,并提供其他功能。
创建通用函数
我们可以使用frompyfunc() NumPy 函数从 Python 函数创建通用函数。
操作步骤
以下步骤可帮助我们创建通用函数:
-
定义一个简单的 Python 函数以使输入加倍:
def double(a): return 2 * a -
用
frompyfunc()创建通用函数。 指定输入参数的数目和返回的对象数目(均等于1):from __future__ import print_function import numpy as np def double(a): return 2 * a ufunc = np.frompyfunc(double, 1, 1) print("Result", ufunc(np.arange(4)))该代码在执行时输出以下输出:
Result [0 2 4 6]
工作原理
我们定义了一个 Python 函数,该函数会将接收到的数字加倍。 实际上,我们也可以将字符串作为输入,因为这在 Python 中是合法的。 我们使用frompyfunc() NumPy 函数从此 Python 函数创建了一个通用函数。 通用函数是 NumPy 类,具有特殊功能,例如广播和适用于 NumPy 数组的逐元素处理。 实际上,许多 NumPy 函数都是通用函数,但是都是用 C 编写的。
另见
查找勾股三元组
对于本教程,您可能需要阅读有关勾股三元组的维基百科页面。 勾股三元组是一组三个自然数,即a < b < c,为此,。
这是勾股三元组的示例:。
勾股三元组与勾股定理密切相关,您可能在中学几何学过的。
勾股三元组代表直角三角形的三个边,因此遵循勾股定理。 让我们找到一个分量总数为 1,000 的勾股三元组。 我们将使用欧几里得公式进行此操作:
在此示例中,我们将看到通用函数的运行。
操作步骤
欧几里得公式定义了m和n索引。
-
创建包含以下索引的数组:
m = np.arange(33) n = np.arange(33) -
第二步是使用欧几里得公式计算勾股三元组的数量
a,b和c。 使用outer()函数获得笛卡尔积,差和和:a = np.subtract.outer(m ** 2, n ** 2) b = 2 * np.multiply.outer(m, n) c = np.add.outer(m ** 2, n ** 2) -
现在,我们有许多包含
a,b和c值的数组。 但是,我们仍然需要找到符合问题条件的值。 使用where()NumPy 函数查找这些值的索引:idx = np.where((a + b + c) == 1000) -
使用
numpy.testing模块检查解决方案:np.testing.assert_equal(a[idx]**2 + b[idx]**2, c[idx]**2)
以下代码来自本书代码包中的triplets.py文件:
from __future__ import print_function
import numpy as np
#A Pythagorean triplet is a set of three natural numbers, a < b < c, for which,
#a ** 2 + b ** 2 = c ** 2
#
#For example, 3 ** 2 + 4 ** 2 = 9 + 16 = 25 = 5 ** 2.
#
#There exists exactly one Pythagorean triplet for which a + b + c = 1000.
#Find the product abc.
#1\. Create m and n arrays
m = np.arange(33)
n = np.arange(33)
#2\. Calculate a, b and c
a = np.subtract.outer(m ** 2, n ** 2)
b = 2 * np.multiply.outer(m, n)
c = np.add.outer(m ** 2, n ** 2)
#3\. Find the index
idx = np.where((a + b + c) == 1000)
#4\. Check solution
np.testing.assert_equal(a[idx]**2 + b[idx]**2, c[idx]**2)
print(a[idx], b[idx], c[idx])
# [375] [200] [425]
工作原理
通用函数不是实函数,而是表示函数的对象。 工具具有outer()方法,我们已经在实践中看到它。 NumPy 的许多标准通用函数都是用 C 实现的 ,因此比常规的 Python 代码要快。 Ufuncs 支持逐元素处理和类型转换,这意味着更少的循环。
另见
使用chararray执行字符串操作
NumPy 具有保存字符串的专用chararray对象。 它是ndarray的子类,并具有特殊的字符串方法。 我们将从 Python 网站下载文本并使用这些方法。 chararray相对于普通字符串数组的优点如下:
- 索引时会自动修剪数组元素的空白
- 字符串末尾的空格也被比较运算符修剪
- 向量化字符串操作可用,因此不需要循环
操作步骤
让我们创建字符数组:
-
创建字符数组作为视图:
carray = np.array(html).view(np.chararray) -
使用
expandtabs()函数将制表符扩展到空格。 此函数接受制表符大小作为参数。 如果未指定,则值为8:carray = carray.expandtabs(1) -
使用
splitlines()函数将行分割成几行:carray = carray.splitlines()以下是此示例的完整代码:
import urllib2 import numpy as np import re response = urllib2.urlopen('http://python.org/') html = response.read() html = re.sub(r'<.*?>', '', html) carray = np.array(html).view(np.chararray) carray = carray.expandtabs(1) carray = carray.splitlines() print(carray)
工作原理
我们看到了专门的chararray类在起作用。 它提供了一些向量化的字符串操作以及有关空格的便捷行为。
另见
创建遮罩数组
遮罩数组可用于忽略丢失或无效的数据项。 numpy.ma模块中的MaskedArray类是ndarray的子类,带有遮罩。 我们将使用 Lena 图像作为数据源,并假装其中一些数据已损坏。 最后,我们将绘制原始图像,原始图像的对数值,遮罩数组及其对数值。
操作步骤
让我们创建被屏蔽的数组:
-
要创建一个遮罩数组,我们需要指定一个遮罩。 创建一个随机遮罩,其值为
0或1:random_mask = np.random.randint(0, 2, size=lena.shape) -
使用上一步中的遮罩,创建一个遮罩数组:
masked_array = np.ma.array(lena, mask=random_mask)以下是此遮罩数组教程的完整代码:
from __future__ import print_function import numpy as np from scipy.misc import lena import matplotlib.pyplot as plt lena = lena() random_mask = np.random.randint(0, 2, size=lena.shape) plt.subplot(221) plt.title("Original") plt.imshow(lena) plt.axis('off') masked_array = np.ma.array(lena, mask=random_mask) print(masked_array) plt.subplot(222) plt.title("Masked") plt.imshow(masked_array) plt.axis('off') plt.subplot(223) plt.title("Log") plt.imshow(np.log(lena)) plt.axis('off') plt.subplot(224) plt.title("Log Masked") plt.imshow(np.log(masked_array)) plt.axis('off') plt.show()这是显示结果图像的屏幕截图:
工作原理
我们对 NumPy 数组应用了随机的遮罩。 这具有忽略对应于遮罩的数据的效果。 您可以在numpy.ma 模块中找到一系列遮罩数组操作 。 在本教程中,我们仅演示了如何创建遮罩数组。
另见
忽略负值和极值
当我们想忽略负值时,例如当取数组值的对数时,屏蔽的数组很有用。 遮罩数组的另一个用例是排除极值。 这基于极限值的上限和下限。
我们将把这些技术应用于股票价格数据。 我们将跳过前面几章已经介绍的下载数据的步骤。
操作步骤
我们将使用包含负数的数组的对数:
-
创建一个数组,该数组包含可被三除的数字:
triples = np.arange(0, len(close), 3) print("Triples", triples[:10], "...")接下来,使用与价格数据数组大小相同的数组创建一个数组:
signs = np.ones(len(close)) print("Signs", signs[:10], "...")借助您在第 2 章,“高级索引和数组概念”中学习的索引技巧,将每个第三个数字设置为负数。
signs[triples] = -1 print("Signs", signs[:10], "...")最后,取该数组的对数:
ma_log = np.ma.log(close * signs) print("Masked logs", ma_log[:10], "...")这应该为
AAPL打印以下输出:Triples [ 0 3 6 9 12 15 18 21 24 27] ... Signs [ 1\. 1\. 1\. 1\. 1\. 1\. 1\. 1\. 1\. 1.] ... Signs [-1\. 1\. 1\. -1\. 1\. 1\. -1\. 1\. 1\. -1.] ... Masked logs [-- 5.93655586575 5.95094223368 -- 5.97468290742 5.97510711452 -- 6.01674381162 5.97889061623 --] ... -
让我们将极值定义为低于平均值的一个标准差,或高于平均值的一个标准差(这仅用于演示目的)。 编写以下代码以屏蔽极值:
dev = close.std() avg = close.mean() inside = numpy.ma.masked_outside(close, avg - dev, avg + dev) print("Inside", inside[:10], "...")此代码显示前十个元素:
Inside [-- -- -- -- -- -- 409.429675172 410.240597855 -- --] ...绘制原始价格数据,绘制对数后的数据,再次绘制指数,最后绘制基于标准差的遮罩后的数据。 以下屏幕截图显示了结果(此运行):
本教程的完整程序如下:
from __future__ import print_function import numpy as np from matplotlib.finance import quotes_historical_yahoo from datetime import date import matplotlib.pyplot as plt def get_close(ticker): today = date.today() start = (today.year - 1, today.month, today.day) quotes = quotes_historical_yahoo(ticker, start, today) return np.array([q[4] for q in quotes]) close = get_close('AAPL') triples = np.arange(0, len(close), 3) print("Triples", triples[:10], "...") signs = np.ones(len(close)) print("Signs", signs[:10], "...") signs[triples] = -1 print("Signs", signs[:10], "...") ma_log = np.ma.log(close * signs) print("Masked logs", ma_log[:10], "...") dev = close.std() avg = close.mean() inside = np.ma.masked_outside(close, avg - dev, avg + dev) print("Inside", inside[:10], "...") plt.subplot(311) plt.title("Original") plt.plot(close) plt.subplot(312) plt.title("Log Masked") plt.plot(np.exp(ma_log)) plt.subplot(313) plt.title("Not Extreme") plt.plot(inside) plt.tight_layout() plt.show()
工作原理
numpy.ma模块中的函数掩盖了数组元素,我们认为这些元素是非法的。 例如,log()和sqrt()函数不允许使用负值。 屏蔽值类似于数据库和编程中的NULL或None值。 具有屏蔽值的所有操作都将导致屏蔽值。
另见
使用recarray函数创建得分表
recarray类是ndarray的子类。 这些数组可以像数据库中一样保存记录,具有不同的数据类型。 例如,我们可以存储有关员工的记录,其中包含诸如薪水之类的数字数据和诸如员工姓名之类的字符串。
现代经济理论告诉我们,投资归结为优化风险和回报。 风险是由对数回报的标准差表示的。 另一方面,奖励由对数回报的平均值表示。 我们可以拿出相对分数,高分意味着低风险和高回报。 这只是理论上的,未经测试,所以不要太在意。 我们将计算几只股票的得分,并将它们与股票代号一起使用 NumPy recarray()函数中的表格格式存储。
操作步骤
让我们从创建记录数组开始:
-
为每个记录创建一个包含符号,标准差得分,平均得分和总得分的记录数组:
weights = np.recarray((len(tickers),), dtype=[('symbol', np.str_, 16), ('stdscore', float), ('mean', float), ('score', float)]) -
为了简单起见,请根据对数收益在循环中初始化得分:
for i, ticker in enumerate(tickers): close = get_close(ticker) logrets = np.diff(np.log(close)) weights[i]['symbol'] = ticker weights[i]['mean'] = logrets.mean() weights[i]['stdscore'] = 1/logrets.std() weights[i]['score'] = 0如您所见,我们可以使用在上一步中定义的字段名称来访问元素。
-
现在,我们有一些数字,但是它们很难相互比较。 归一化分数,以便我们以后可以将它们合并。 在这里,归一化意味着确保分数加起来为:
for key in ['mean', 'stdscore']: wsum = weights[key].sum() weights[key] = weights[key]/wsum -
总体分数将只是中间分数的平均值。 对总分上的记录进行排序以产生排名:
weights['score'] = (weights['stdscore'] + weights['mean'])/2 weights['score'].sort()The following is the complete code for this example:
from __future__ import print_function import numpy as np from matplotlib.finance import quotes_historical_yahoo from datetime import date tickers = ['MRK', 'T', 'VZ'] def get_close(ticker): today = date.today() start = (today.year - 1, today.month, today.day) quotes = quotes_historical_yahoo(ticker, start, today) return np.array([q[4] for q in quotes]) weights = np.recarray((len(tickers),), dtype=[('symbol', np.str_, 16), ('stdscore', float), ('mean', float), ('score', float)]) for i, ticker in enumerate(tickers): close = get_close(ticker) logrets = np.diff(np.log(close)) weights[i]['symbol'] = ticker weights[i]['mean'] = logrets.mean() weights[i]['stdscore'] = 1/logrets.std() weights[i]['score'] = 0 for key in ['mean', 'stdscore']: wsum = weights[key].sum() weights[key] = weights[key]/wsum weights['score'] = (weights['stdscore'] + weights['mean'])/2 weights['score'].sort() for record in weights: print("%s,mean=%.4f,stdscore=%.4f,score=%.4f" % (record['symbol'], record['mean'], record['stdscore'], record['score']))该程序产生以下输出:
MRK,mean=0.8185,stdscore=0.2938,score=0.2177 T,mean=0.0927,stdscore=0.3427,score=0.2262 VZ,mean=0.0888,stdscore=0.3636,score=0.5561
分数已归一化,因此值介于0和1之间,我们尝试从秘籍开始使用定义获得最佳收益和风险组合 。 根据输出,VZ得分最高,因此是最好的投资。 当然,这只是一个 NumPy 演示,数据很少,所以不要认为这是推荐。
工作原理
我们计算了几只股票的得分,并将它们存储在recarray NumPy 对象中。 这个数组使我们能够混合不同数据类型的数据,在这种情况下,是股票代码和数字得分。 记录数组使我们可以将字段作为数组成员访问,例如arr.field。 本教程介绍了记录数组的创建。 您可以在numpy.recarray模块中找到更多与记录数组相关的功能。
另见
七、性能分析和调试
在本章中,我们将介绍以下秘籍:
- 使用
timeit进行性能分析 - 使用 IPython 进行分析
- 安装
line_profiler - 使用
line_profiler分析代码 - 具有
cProfile扩展名的性能分析代码 - 使用 IPython 进行调试
- 使用
PuDB进行调试
简介
调试是从软件中查找和删除错误的行为。 分析是指构建程序的概要文件,以便收集有关内存使用或时间复杂度的信息。 分析和调试是开发人员生活中必不可少的活动。 对于复杂的软件尤其如此。 好消息是,许多工具可以为您提供帮助。 我们将回顾 NumPy 用户中流行的技术。
使用timeit进行性能分析
timeit是一个模块,可用于计时代码段。 它是标准 Python 库的一部分。 我们将使用几种数组大小对sort() NumPy 函数计时。 经典的快速排序和归并排序算法的平均运行时间为O(N log N),因此我们将尝试将这个模型拟合到结果。
操作步骤
我们将要求数组进行排序:
-
创建数组以排序包含随机整数值的各种大小:
times = np.array([]) for size in sizes: integers = np.random.random_integers (1, 10 ** 6, size) -
要测量时间,请创建一个计时器,为其提供执行函数,并指定相关的导入。 然后,排序 100 次以获取有关排序时间的数据:
def measure(): timer = timeit.Timer('dosort()', 'from __main__ import dosort') return timer.timeit(10 ** 2) -
通过一次乘以时间来构建测量时间数组:
times = np.append(times, measure()) -
将时间拟合为
n log n的理论模型。 由于我们将数组大小更改为 2 的幂,因此很容易:fit = np.polyfit(sizes * powersOf2, times, 1)以下是完整的计时代码:
import numpy as np import timeit import matplotlib.pyplot as plt # This program measures the performance of the NumPy sort function # and plots time vs array size. integers = [] def dosort(): integers.sort() def measure(): timer = timeit.Timer('dosort()', 'from __main__ import dosort') return timer.timeit(10 ** 2) powersOf2 = np.arange(0, 19) sizes = 2 ** powersOf2 times = np.array([]) for size in sizes: integers = np.random.random_integers(1, 10 ** 6, size) times = np.append(times, measure()) fit = np.polyfit(sizes * powersOf2, times, 1) print(fit) plt.title("Sort array sizes vs execution times") plt.xlabel("Size") plt.ylabel("(s)") plt.semilogx(sizes, times, 'ro') plt.semilogx(sizes, np.polyval(fit, sizes * powersOf2)) plt.grid() plt.show()以下屏幕截图显示了运行时间与数组大小的关系图:
工作原理
我们测量了sort() NumPy 函数的平均运行时间。 此秘籍中使用了以下函数:
| 函数 | 描述 |
|---|---|
random_integers() | 给定值和数组大小的范围时,此函数创建一个随机整数数组 |
append() | 此函数将值附加到 NumPy 数组 |
polyfit() | 此函数将数据拟合为给定阶数的多项式 |
polyval() | 此函数计算多项式,并为给定的 x值返回相应的值 |
semilogx() | 此函数使用对数刻度在 X 轴上绘制数据 |
另见
使用 IPython 进行分析
在 IPython 中,我们可以使用timeit来分析代码的小片段。 我们可以也分析较大的脚本。 我们将展示两种方法。
操作步骤
首先,我们将介绍一个小片段:
-
以
pylab模式启动 IPython:$ ipython --pylab创建一个包含 1000 个介于 0 到 1000 之间的整数值的数组:
In [1]: a = arange(1000)测量在数组中搜索“所有问题的答案”(42)所花费的时间。 是的,所有问题的答案都是 42。如果您不相信我,请阅读这个页面:
In [2]: %timeit searchsorted(a, 42) 100000 loops, best of 3: 7.58 us per loop -
剖析以下小脚本,该小脚本可以反转包含随机值的大小可变的矩阵。 NumPy 矩阵的
.I属性(即大写I)表示该矩阵的逆:import numpy as np def invert(n): a = np.matrix(np.random.rand(n, n)) return a.I sizes = 2 ** np.arange(0, 12) for n in sizes: invert(n)将此代码计时如下:
In [1]: %run -t invert_matrix.py IPython CPU timings (estimated): User : 6.08 s. System : 0.52 s. Wall time: 19.26 s.然后使用
p选项对脚本进行配置:In [2]: %run -p invert_matrix.py 852 function calls in 6.597 CPU seconds Ordered by: internal time ncalls tottime percall cumtime percall filename:lineno(function) 12 3.228 0.269 3.228 0.269 {numpy.linalg.lapack_lite.dgesv} 24 2.967 0.124 2.967 0.124 {numpy.core.multiarray._fastCopyAndTranspose} 12 0.156 0.013 0.156 0.013 {method 'rand' of 'mtrand.RandomState' objects} 12 0.087 0.007 0.087 0.007 {method 'copy' of 'numpy.ndarray' objects} 12 0.069 0.006 0.069 0.006 {method 'astype' of 'numpy.ndarray' objects} 12 0.025 0.002 6.304 0.525 linalg.py:404(inv) 12 0.024 0.002 6.328 0.527 defmatrix.py:808(getI) 1 0.017 0.017 6.596 6.596 invert_matrix.py:1(<module>) 24 0.014 0.001 0.014 0.001 {numpy.core.multiarray.zeros} 12 0.009 0.001 6.580 0.548 invert_matrix.py:3(invert) 12 0.000 0.000 6.264 0.522 linalg.py:244(solve) 12 0.000 0.000 0.014 0.001 numeric.py:1875(identity) 1 0.000 0.000 6.597 6.597 {execfile} 36 0.000 0.000 0.000 0.000 defmatrix.py:279(__array_finalize__) 12 0.000 0.000 2.967 0.247 linalg.py:139(_fastCopyAndTranspose) 24 0.000 0.000 0.087 0.004 defmatrix.py:233(__new__) 12 0.000 0.000 0.000 0.000 linalg.py:99(_commonType) 24 0.000 0.000 0.000 0.000 {method '__array_prepare__' of 'numpy.ndarray' objects} 36 0.000 0.000 0.000 0.000 linalg.py:66(_makearray) 36 0.000 0.000 0.000 0.000 {numpy.core.multiarray.array} 12 0.000 0.000 0.000 0.000 {method 'view' of 'numpy.ndarray' objects} 12 0.000 0.000 0.000 0.000 linalg.py:127(_to_native_byte_order) 1 0.000 0.000 6.597 6.597 interactiveshell.py:2270(safe_execfile)
工作原理
我们通过分析器运行了上述 NumPy 代码。 下表概述了分析器的输出:
| 函数 | 描述 |
|---|---|
ncalls | 这是调用次数 |
tottime | 这是一个函数花费的总时间 |
percall | 这是每次通话所花费的时间 ,计算方法是将总时间除以通话次数 |
cumtime | 这是在函数和由函数调用的函数(包括递归调用)上花费的累积时间 |
另见
安装line_profiler
line_profiler由 NumPy 的开发人员之一创建。 此模块对 Python 代码进行逐行分析。 我们将在此秘籍中描述必要的安装步骤。
准备
您可能需要安装setuptools。 先前的秘籍中对此进行了介绍; 如有必要,请参阅“另见”部分。 为了安装开发版本,您将需要 Git。 安装 Git 超出了本书的范围。
操作步骤
选择适合您的安装选项:
-
使用以下任一命令将
line_profiler与easy_install一起安装:$ easy_install line_profiler $ pip install line_profiler -
安装开发版本。
使用 Git 查看源代码:
$ git clone https://github.com/rkern/line_profiler签出源代码后,按如下所示构建它:
$ python setup.py install
另见
- 第 1 章,“使用 IPython”中的“安装 IPython”
使用line_profiler分析代码
现在我们已经安装完毕,可以开始分析。
操作步骤
显然,我们将需要代码来分析:
-
编写以下代码,以自身乘以大小可变的随机矩阵。 此外,线程将休眠几秒钟。 使用
@profile注解函数以进行概要分析:import numpy as np import time @profile def multiply(n): A = np.random.rand(n, n) time.sleep(np.random.randint(0, 2)) return np.matrix(A) ** 2 for n in 2 ** np.arange(0, 10): multiply(n) -
使用以下命令运行事件分析器:
$ kernprof.py -l -v mat_mult.py Wrote profile results to mat_mult.py.lprof Timer unit: 1e-06 s File: mat_mult.py Function: multiply at line 4 Total time: 3.19654 s Line # Hits Time Per Hit % Time Line Contents ============================================================== 4 @profile 5 def multiply(n): 6 10 13461 1346.1 0.4 A = numpy.random.rand(n, n) 7 10 3000689 300068.9 93.9 time.sleep(numpy.random.randint(0, 2)) 8 10 182386 18238.6 5.7 return numpy.matrix(A) ** 2
工作原理
@profile装饰器告诉line_profiler要分析哪些函数。 下表说明了分析器的输出:
| 函数 | 描述 |
|---|---|
Line # | 文件中的行号 |
Hits | 执行该行的次数 |
Time | 执行该行所花费的时间 |
Per Hit | 执行该行所花费的平均时间 |
% Time | 执行该行所花费的时间相对于执行所有行所花费的时间的百分比 |
Line Contents | 该行的内容 |
另见
cProfile扩展和代码性能分析
cProfile是 Python 2.5 中引入的C扩展名。 它可以用于确定性分析。 确定性分析表示所获得的时间测量是精确的,并且不使用采样。 这与统计分析相反,统计分析来自随机样本。 我们将使用cProfile对一个小的 NumPy 程序进行分析,该程序会对具有随机值的数组进行转置。
操作步骤
同样,我们需要代码来配置:
-
编写以下
transpose()函数以创建具有随机值的数组并将其转置:def transpose(n): random_values = np.random.random((n, n)) return random_values.T -
运行分析器,并为其提供待分析函数:
cProfile.run('transpose (1000)')可以在以下片段中找到本教程的完整代码:
import numpy as np import cProfile def transpose(n): random_values = np.random.random((n, n)) return random_values.T cProfile.run('transpose (1000)')对于
1000 x 1000的数组,我们得到以下输出:4 function calls in 0.029 CPU seconds Ordered by: standard name ncalls tottime percall cumtime percall filename:lineno(function) 1 0.001 0.001 0.029 0.029 <string>:1(<module>) 1 0.000 0.000 0.028 0.028 cprofile_transpose.py:5(transpose) 1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects} 1 0.028 0.028 0.028 0.028 {method 'random_sample' of 'mtrand.RandomState' objects}输出中的列与 IPython 分析秘籍中看到的列相同。
另见
使用 IPython 进行调试
“如果调试是清除软件错误的过程,则编程必须是放入它们的过程。”
-- 荷兰计算机科学家 Edsger Dijkstra,1972 年图灵奖的获得者
调试是没人真正喜欢,但是掌握这些东西非常重要的东西之一。 这可能需要几个小时,并且由于墨菲定律,您很可能没有时间。 因此,重要的是要系统地了解您的工具。 找到错误并实现修复后,您应该进行单元测试(如果该错误具有来自问题跟踪程序的相关 ID,我通常在末尾附加 ID 来命名测试)。 这样,您至少不必再次进行调试。 下一章将介绍单元测试。 我们将调试以下错误代码。 它尝试访问不存在的数组元素:
import numpy as np
a = np.arange(7)
print(a[8])
IPython 调试器充当普通的 Python pdb调试器; 它添加了选项卡补全和语法突出显示等功能。
操作步骤
以下步骤说明了典型的调试会话:
-
启动 IPython Shell。 通过发出以下命令在 IPython 中运行错误脚本:
In [1]: %run buggy.py --------------------------------------------------------------------------- IndexError Traceback (most recent call last) .../site-packages/IPython/utils/py3compat.pyc in execfile(fname, *where) 173 else: 174 filename = fname --> 175 __builtin__.execfile(filename, *where) .../buggy.py in <module>() 2 3 a = numpy.arange(7) ----> 4 print a[8] IndexError: index out of bounds -
现在您的程序崩溃了,启动调试器。 在发生错误的行上设置一个断点:
In [2]: %debug > .../buggy.py(4)<module>() 2 3 a = numpy.arange(7) ----> 4 print a[8] -
使用
list命令列出代码,或使用简写l:ipdb> list 1 import numpy as np 2 3 a = np.arange(7) ----> 4 print(a[8]) -
现在,我们可以在调试器当前所在的行上求值任意代码:
ipdb> len(a) 7 ipdb> print(a) [0 1 2 3 4 5 6] -
调用栈是包含有关正在运行的程序的活动函数的信息的栈。 使用
bt命令查看调用栈:ipdb> bt .../py3compat.py(175)execfile() 171 if isinstance(fname, unicode): 172 filename = fname.encode(sys.getfilesystemencoding()) 173 else: 174 filename = fname --> 175 __builtin__.execfile(filename, *where) > .../buggy.py(4)<module>() 0 print a[8]向上移动调用栈:
ipdb> u > .../site-packages/IPython/utils/py3compat.py(175)execfile() 173 else: 174 filename = fname --> 175 __builtin__.execfile(filename, *where)下移调用栈:
ipdb> d > .../buggy.py(4)<module>() 2 3 a = np.arange(7) ----> 4 print(a[8])
工作原理
在本教程中,您学习了如何使用 IPython 调试 NumPy 程序。 我们设置一个断点并导航调用栈。 使用了以下调试器命令:
| 函数 | 描述 |
|---|---|
list或 l | 列出源代码 |
bt | 显示调用栈 |
u | 向上移动调用栈 |
d | 下移调用栈 |
另见
使用 PuDB 进行调试
PuDB 是基于视觉的,全屏,基于控制台的 Python 调试器,易于安装。 PuDB 支持光标键和 vi 命令。 如果需要,我们还可以将此调试器与 IPython 集成。
操作步骤
我们将从安装pudb开始:
-
要安装
pudb,我们只需执行以下命令(或等效的pip命令):$ sudo easy_install pudb $ pip install pudb $ pip freeze|grep pudb pudb==2014.1 -
让我们调试前面示例中的
buggy程序。 如下所示启动调试器:$ python -m pudb buggy.py以下屏幕截图显示了调试器的用户界面:
屏幕快照在顶部显示了最重要的调试命令。 我们还可以看到正在调试的代码,变量,栈和定义的断点。 键入q退出大多数菜单。 键入n将调试器移至下一行。 我们还可以使用光标键或 vi 的j和k键移动,例如,通过键入b设置断点。
另见
八、质量保证
“如果您对计算机撒谎,它将帮助您。”
-- Perry Farrar,ACM 通讯,第 28 卷
在本章中,我们将介绍以下秘籍:
- 安装 Pyflakes
- 使用 Pyflakes 执行静态分析
- 用 Pylint 分析代码
- 使用 Pychecker 执行静态分析
- 使用
docstrings测试代码 - 编写单元测试
- 使用模拟测试代码
- 以 BDD 方式来测试
简介
与普遍的看法相反,质量保证与其说是发现错误,不如说是发现它们。 我们将讨论两种提高代码质量,从而防止出现问题的方法。 首先,我们将对已经存在的代码进行静态分析。 然后,我们将讨论单元测试; 这包括模拟和行为驱动开发(BDD)。
安装 Pyflakes
Pyflakes 是 Python 代码分析包。 它可以分析代码并发现潜在的问题,例如:
- 未使用的导入
- 未使用的变量
准备
如有必要,请安装pip或easy_install。
操作步骤
选择以下之一来安装pyflakes:
-
使用
pip命令安装 pyflakes:$ sudo pip install pyflakes -
使用
easy_install命令安装 Pyflakes:$ sudo easy_install pyflakes -
这是在 Linux 上安装此包的两种方法:
Linux 包的名称也为
pyflakes。 例如,在 RedHat 上执行以下操作:$ sudo yum install pyflakes在 Debian/Ubuntu 上,命令如下:
$ sudo apt-get install pyflakes
另见
使用 Pyflakes 执行静态分析
我们将对 NumPy 代码库的一部分执行静态分析。 为此,我们将使用 Git 签出代码。 然后,我们将使用pyflakes对部分代码进行静态分析。
操作步骤
要检查 NumPy 代码中,我们需要 Git。 安装 Git 超出了本书的范围:
-
用 Git 命令检索代码如下:
$ git clone git://github.com/numpy/numpy.git numpy或者,从这里下载源档案。
-
上一步使用完整的 NumPy 代码创建一个
numpy目录。 转到此目录,并在其中运行以下命令:$ pyflakes *.py pavement.py:71: redefinition of unused 'md5' from line 69 pavement.py:88: redefinition of unused 'GIT_REVISION' from line 86 pavement.py:314: 'virtualenv' imported but unused pavement.py:315: local variable 'e' is assigned to but never used pavement.py:380: local variable 'sdir' is assigned to but never used pavement.py:381: local variable 'bdir' is assigned to but never used pavement.py:536: local variable 'st' is assigned to but never used setup.py:21: 're' imported but unused setup.py:27: redefinition of unused 'builtins' from line 25 setup.py:124: redefinition of unused 'GIT_REVISION' from line 118 setupegg.py:17: 'setup' imported but unused setupscons.py:61: 'numpy' imported but unused setupscons.py:64: 'numscons' imported but unused setupsconsegg.py:6: 'setup' imported but unused这将对代码样式进行分析,并检查当前目录中所有 Python 脚本中的 PEP-8 违规情况。 如果愿意,还可以分析单个文件。
工作原理
正如您所见,分析代码样式并使用 Pyflakes 查找违反 PEP-8 的行为非常简单。 另一个优点是速度。 但是,Pyflakes 报告的错误类型的数量是有限的。
使用 Pylint 分析代码
Pylint 是另一个由 Logilab 创建的开源静态分析器 。 Pylint 比 Pyflakes 更复杂; 它允许更多的自定义和代码检查。 但是,它比 Pyflakes 慢。 有关更多信息,请参见手册。
在本秘籍中,我们再次从 Git 存储库下载 NumPy 代码-为简便起见,省略了此步骤。
准备
您可以从源代码发行版中安装 Pylint。 但是,有很多依赖项,因此最好使用easy_install或pip进行安装。 安装命令如下:
$ easy_install pylint
$ sudo pip install pylint
操作步骤
我们将再次从 NumPy 代码库的顶部目录进行分析。 注意,我们得到了更多的输出。 实际上,Pylint 打印了太多文本,因此在这里大部分都必须省略:
$ pylint *.py
No config file found, using default configuration
************* Module pavement
C: 60: Line too long (81/80)
C:139: Line too long (81/80)
...
W: 50: TODO
W:168: XXX: find out which env variable is necessary to avoid the pb with python
W: 71: Reimport 'md5' (imported line 143)
F: 73: Unable to import 'paver'
F: 74: Unable to import 'paver.easy'
C: 79: Invalid name "setup_py" (should match (([A-Z_][A-Z0-9_]*)|(__.*__))$)
F: 86: Unable to import 'numpy.version'
E: 86: No name 'version' in module 'numpy'
C:149: Operator not followed by a space
if sys.platform =="darwin":
^^
C:202:prepare_nsis_script: Missing docstring
W:228:bdist_superpack: Redefining name 'options' from outer scope (line 74)
C:231:bdist_superpack.copy_bdist: Missing docstring
W:275:bdist_wininst_nosse: Redefining name 'options' from outer scope (line 74)
工作原理
Pylint 默认输出原始文本; 但是我们可以根据需要请求 HTML 输出。 消息具有以下格式:
MESSAGE_TYPE: LINE_NUM:[OBJECT:] MESSAGE
消息类型可以是以下之一:
[R]:这意味着建议进行重构[C]:这意味着存在违反代码风格的情况[W]:用于警告小问题[E]:用于错误或潜在的错误[F]:这表明发生致命错误,阻止了进一步的分析
另见
- 使用 Pyflakes 执行静态分析
使用 Pychecker 执行静态分析
Pychecker 是一个古老的静态分析工具。 它不是十分活跃的开发工具,但它在此提到的速度又足够好。 在编写本书时,最新版本是 0.8.19,最近一次更新是在 2011 年。Pychecker 尝试导入每个模块并对其进行处理。 然后,它搜索诸如传递不正确数量的参数,使用不存在的方法传递不正确的格式字符串以及其他问题之类的问题。 在本秘籍中,我们将再次分析代码,但是这次使用 Pychecker。
操作步骤
-
从 Sourceforge 下载
tar.gz。 解压缩源归档文件并运行以下命令:$ python setup.py install或者,使用
pip安装 Pychecker:$ sudo pip install http://sourceforge.net/projects/pychecker/files/pychecker/0.8.19/pychecker-0.8.19.tar.gz/download -
分析代码,就像先前的秘籍一样。 我们需要的命令如下:
$ pychecker *.py ... Warnings... ... setup.py:21: Imported module (re) not used setup.py:27: Module (builtins) re-imported ...
使用文档字符串测试代码
Doctests 是注释字符串,它们嵌入在类似交互式会话的 Python 代码中。 这些字符串可用于测试某些假设或仅提供示例。 我们需要使用doctest模块来运行这些测试。
让我们写一个简单的示例,该示例应该计算阶乘,但不涵盖所有可能的边界条件。 换句话说,某些测试将失败。
操作步骤
-
用将通过的测试和将失败的另一个测试编写
docstring。docstring文本应类似于在 Python shell 中通常看到的文本:""" Test for the factorial of 3 that should pass. >>> factorial(3) 6 Test for the factorial of 0 that should fail. >>> factorial(0) 1 """ -
编写以下 NumPy 代码:
return np.arange(1, n+1).cumprod()[-1]我们希望这段代码有时会故意失败。 它将创建一个序列号数组,计算该数组的累积乘积,并返回最后一个元素。
-
使用
doctest模块运行测试:doctest.testmod()以下是本书代码包中
docstringtest.py文件的完整测试示例代码:import numpy as np import doctest def factorial(n): """ Test for the factorial of 3 that should pass. >>> factorial(3) 6 Test for the factorial of 0 that should fail. >>> factorial(0) 1 """ return np.arange(1, n+1).cumprod()[-1] doctest.testmod()我们可以使用
-v选项获得详细的输出,如下所示:$ python docstringtest.py -v Trying: factorial(3) Expecting: 6 ok Trying: factorial(0) Expecting: 1 ********************************************************************** File "docstringtest.py", line 11, in __main__.factorial Failed example: factorial(0) Exception raised: Traceback (most recent call last): File ".../doctest.py", line 1253, in __run compileflags, 1) in test.globs File "<doctest __main__.factorial[1]>", line 1, in <module> factorial(0) File "docstringtest.py", line 14, in factorial return numpy.arange(1, n+1).cumprod()[-1] IndexError: index out of bounds 1 items had no tests: __main__ ********************************************************************** 1 items had failures: 1 of 2 in __main__.factorial 2 tests in 2 items. 1 passed and 1 failed. ***Test Failed*** 1 failures.
工作原理
如您所见,我们没有考虑零和负数。 实际上,由于数组为空,我们出现了index out of bounds错误。 当然,这很容易解决,我们将在下一个教程中进行。
另见
编写单元测试
测试驱动开发(TDD)是本世纪软件开发诞生的最好的事情。 TDD 的最重要方面之一是,几乎把重点放在单元测试上。
注意
TDD 方法使用所谓的测试优先方法,在此方法中,我们首先编写一个失败的测试,然后编写相应的代码以通过测试。 测试应记录开发人员的意图,但要比功能设计的水平低。 一组测试通过降低回归概率来增加置信度,并促进重构。
单元测试是自动测试,通常测试一小段代码,通常是一个函数或方法。 Python 具有用于单元测试的 PyUnit API。 作为 NumPy 的用户,我们也可以使用numpy.testing模块中的便捷函数。 顾名思义,该模块专用于测试。
操作步骤
让我们编写一些代码进行测试:
-
首先编写以下
factorial()函数:def factorial(n): if n == 0: return 1 if n < 0: raise ValueError, "Don't be so negative" return np.arange(1, n+1).cumprod()该代码与前面的秘籍中的代码相同,但是我们添加了一些边界条件检查。
-
让我们写一个类; 此类将包含单元测试。 它从
unittest模块扩展了TestCase类,是 Python 标准测试的一部分。 我们通过调用factorial()函数并运行以下代码来运行测试:-
一个正数-幸福的道路!
-
边界条件等于
0 -
负数,这将导致错误:
class FactorialTest(unittest.TestCase): def test_factorial(self): #Test for the factorial of 3 that should pass. self.assertEqual(6, factorial(3)[-1]) np.testing.assert_equal(np.array([1, 2, 6]), factorial(3)) def test_zero(self): #Test for the factorial of 0 that should pass. self.assertEqual(1, factorial(0)) def test_negative(self): #Test for the factorial of negative numbers that should fail. # It should throw a ValueError, but we expect IndexError self.assertRaises(IndexError, factorial(-10))factorial()函数和整个单元测试的代码如下:import numpy as np import unittest def factorial(n): if n == 0: return 1 if n < 0: raise ValueError, "Don't be so negative" return np.arange(1, n+1).cumprod() class FactorialTest(unittest.TestCase): def test_factorial(self): #Test for the factorial of 3 that should pass. self.assertEqual(6, factorial(3)[-1]) np.testing.assert_equal(np.array([1, 2, 6]), factorial(3)) def test_zero(self): #Test for the factorial of 0 that should pass. self.assertEqual(1, factorial(0)) def test_negative(self): #Test for the factorial of negative numbers that should fail. # It should throw a ValueError, but we expect IndexError self.assertRaises(IndexError, factorial(-10)) if __name__ == '__main__': unittest.main()负数测试失败,如以下输出所示:
.E. ====================================================================== ERROR: test_negative (__main__.FactorialTest) ---------------------------------------------------------------------- Traceback (most recent call last): File "unit_test.py", line 26, in test_negative self.assertRaises(IndexError, factorial(-10)) File "unit_test.py", line 9, in factorial raise ValueError, "Don't be so negative" ValueError: Don't be so negative ---------------------------------------------------------------------- Ran 3 tests in 0.001s FAILED (errors=1)
-
工作原理
我们看到了如何使用标准unittest Python 模块实现简单的单元测试。 我们编写了一个测试类 ,该类从unittest模块扩展了TestCase类。 以下函数用于执行各种测试:
| 函数 | 描述 |
|---|---|
numpy.testing.assert_equal() | 测试两个 NumPy 数组是否相等 |
unittest.assertEqual() | 测试两个值是否相等 |
unittest.assertRaises() | 测试是否引发异常 |
testing NumPy 包具有许多我们应该了解的测试函数,如下所示:
| 函数 | 描述 |
|---|---|
assert_almost_equal() | 如果两个数字不等于指定的精度,则此函数引发异常 |
assert_approx_equal() | 如果两个数字在一定意义上不相等,则此函数引发异常 |
assert_array_almost_equal() | 如果两个数组不等于指定的精度,此函数会引发异常 |
assert_array_equal() | 如果两个数组不相等,则此函数引发异常 |
assert_array_less() | 如果两个数组的形状不同,并且此函数引发异常,则第一个数组的元素严格小于第二个数组的元素 |
assert_raises() | 如果使用定义的参数调用的可调用对象未引发指定的异常,则此函数将失败 |
assert_warns() | 如果未抛出指定的警告,则此函数失败 |
assert_string_equal() | 此函数断言两个字符串相等 |
使用模拟测试代码
模拟是用来代替真实对象的对象,目的是测试真实对象的部分行为。 如果您看过电影《身体抢夺者》,您可能已经对基本概念有所了解。 一般来说, 仅在被测试的真实对象的创建成本很高(例如数据库连接)或测试可能产生不良副作用时才有用。 例如,我们可能不想写入文件系统或数据库。
在此秘籍中,我们将测试一个核反应堆,当然不是真正的反应堆! 此类核反应堆执行阶乘计算,从理论上讲,它可能导致连锁反应,进而导致核灾难。 我们将使用mock包通过模拟来模拟阶乘计算。
操作步骤
首先,我们将安装mock包; 之后,我们将创建一个模拟并测试一段代码:
-
要安装
mock包,请执行以下命令:$ sudo easy_install mock -
核反应堆类有一个
do_work()方法,该方法调用了我们要模拟的危险的factorial()方法。 创建一个模拟,如下所示:reactor.factorial = MagicMock(return_value=6)这样可以确保模拟返回值
6。 -
我们可以通过多种方式检查模拟的行为,然后从中检查真实对象的行为。 例如,断言使用正确的参数调用了潜在爆炸性的
factorial()方法,如下所示:reactor.factorial.assert_called_with(3, "mocked")带有模拟的完整测试代码如下:
from __future__ import print_function from mock import MagicMock import numpy as np import unittest class NuclearReactor(): def __init__(self, n): self.n = n def do_work(self, msg): print("Working") return self.factorial(self.n, msg) def factorial(self, n, msg): print(msg) if n == 0: return 1 if n < 0: raise ValueError, "Core meltdown" return np.arange(1, n+1).cumprod() class NuclearReactorTest(unittest.TestCase): def test_called(self): reactor = NuclearReactor(3) reactor.factorial = MagicMock(return_value=6) result = reactor.do_work("mocked") self.assertEqual(6, result) reactor.factorial.assert_called_with(3, "mocked") def test_unmocked(self): reactor = NuclearReactor(3) reactor.factorial(3, "unmocked") np.testing.assert_raises(ValueError) if __name__ == '__main__': unittest.main()
我们将一个字符串传递给factorial()方法,以显示带有模拟的代码不会执行实际的代码。 该单元测试的工作方式与上一秘籍中的单元测试相同。 这里的第二项测试不测试任何内容。 第二个测试的目的只是演示,如果我们在没有模拟的情况下执行真实代码,会发生什么。
测试的输出如下:
Working
.unmocked
.
----------------------------------------------------------------------
Ran 2 tests in 0.000s
OK
工作原理
模拟没有任何行为。 他们就像外星人的克隆人,假装是真实的人。 只能比外星人傻—外星人克隆人无法告诉您被替换的真实人物的生日。 我们需要设置它们以适当的方式进行响应。 例如,在此示例中,模拟返回6 。 我们可以记录模拟发生了什么,被调用了多少次以及使用了哪些参数。
另见
以 BDD 方式来测试
BDD(行为驱动开发)是您可能遇到的另一个热门缩写。 在 BDD 中,我们首先根据某些约定和规则定义(英语)被测系统的预期行为。 在本秘籍中,我们将看到这些约定的示例。
这种方法背后的想法是,我们可以让可能无法编程或编写测试大部分内容的人员参加。 这些人编写的功能采用句子的形式,包括多个步骤。 每个步骤或多或少都是我们可以编写的单元测试,例如,使用 NumPy。 有许多 Python BDD 框架。 在本秘籍中,我们使用 Lettuce 来测试阶乘函数。
操作步骤
在本节中,您将学习如何安装 Lettuce,设置测试以及编写测试规范:
-
要安装 Lettuce,请运行以下命令之一:
$ pip install lettuce $ sudo easy_install lettuce -
Lettuce 需要特殊的目录结构进行测试。 在
tests目录中,我们将有一个名为features的目录,其中包含factorial.feature文件,以及steps.py文件中的功能说明和测试代码:./tests: features ./tests/features: factorial.feature steps.py -
提出业务需求是一项艰巨的工作。 以易于测试的方式将其全部写下来更加困难。 幸运的是,这些秘籍的要求非常简单-我们只需写下不同的输入值和预期的输出。 我们在
Given,When和Then部分中有不同的方案,它们对应于不同的测试步骤。 为阶乘函数定义以下三种方案:Feature: Compute factorial Scenario: Factorial of 0 Given I have the number 0 When I compute its factorial Then I see the number 1 Scenario: Factorial of 1 Given I have the number 1 When I compute its factorial Then I see the number 1 Scenario: Factorial of 3 Given I have the number 3 When I compute its factorial Then I see the number 1, 2, 6 -
我们将定义与场景步骤相对应的方法。 要特别注意用于注释方法的文本。 它与业务场景文件中的文本匹配,并且我们使用正则表达式获取输入参数。 在前两个方案中,我们匹配数字,在最后一个方案中,我们匹配任何文本。
fromstring()NumPy 函数用于从 NumPy 数组创建字符串,字符串中使用整数数据类型和逗号分隔符。 以下代码测试了我们的方案:from lettuce import * import numpy as np @step('I have the number (\d+)') def have_the_number(step, number): world.number = int(number) @step('I compute its factorial') def compute_its_factorial(step): world.number = factorial(world.number) @step('I see the number (.*)') def check_number(step, expected): expected = np.fromstring(expected, dtype=int, sep=',') np.testing.assert_equal(world.number, expected, \ "Got %s" % world.number) def factorial(n): if n == 0: return 1 if n < 0: raise ValueError, "Core meltdown" return np.arange(1, n+1).cumprod() -
要运行测试,请转到
tests目录,然后键入以下命令:$ lettuce Feature: Compute factorial # features/factorial.feature:1 Scenario: Factorial of 0 # features/factorial.feature:3 Given I have the number 0 # features/steps.py:5 When I compute its factorial # features/steps.py:9 Then I see the number 1 # features/steps.py:13 Scenario: Factorial of 1 # features/factorial.feature:8 Given I have the number 1 # features/steps.py:5 When I compute its factorial # features/steps.py:9 Then I see the number 1 # features/steps.py:13 Scenario: Factorial of 3 # features/factorial.feature:13 Given I have the number 3 # features/steps.py:5 When I compute its factorial # features/steps.py:9 Then I see the number 1, 2, 6 # features/steps.py:13 1 feature (1 passed) 3 scenarios (3 passed) 9 steps (9 passed)
工作原理
我们定义了具有三个方案和相应步骤的函数。 我们使用 NumPy 的测试函数来测试不同步骤,并使用fromstring()函数从规格文本创建 NumPy 数组。