NumPy 初学者指南中文第三版(三)
七、探索特殊例程
作为 NumPy 的用户,我们有时会发现自己有特殊需要,例如财务计算或信号处理。 幸运的是,NumPy 满足了我们的大多数需求。 本章介绍一些更专门的 NumPy 函数。
在本章中,我们将介绍以下主题:
- 排序和搜索
- 特殊函数
- 财务函数
- 窗口函数
排序
NumPy 具有几个数据排序例程:
sort()函数返回排序数组lexsort()函数使用键列表执行排序argsort()函数返回将对数组进行排序的索引ndarray类具有执行原地排序的sort()方法msort()函数沿第一轴对数组进行排序sort_complex()函数按复数的实部和虚部对它们进行排序
从此列表中,argsort()和sort()函数也可用作 NumPy 数组的方法。
实战时间 – 按词法排序
NumPy lexsort()函数返回输入数组元素的索引数组,这些索引对应于按词法对数组进行排序。 我们需要给函数一个数组或排序键元组:
-
让我们回到第 3 章,“熟悉常用函数”。 在该章中,我们使用了
AAPL的股价数据。 我们将加载收盘价和(总是复杂的)日期。 实际上,只为日期创建一个转换器函数:def datestr2num(s): return datetime.datetime.strptime(s, "%d-%m-%Y").toordinal() dates, closes=np.loadtxt('AAPL.csv', delimiter=',', usecols=(1, 6), converters={1:datestr2num}, unpack=True) -
使用
lexsort()函数按词法对名称进行排序。 数据已经按日期排序,但也按结束排序:indices = np.lexsort((dates, closes)) print("Indices", indices) print(["%s %s" % (datetime.date.fromordinal(dates[i]), closes[i]) for i in indices])该代码显示以下内容:
Indices [ 0 16 1 17 18 4 3 2 5 28 19 21 15 6 29 22 27 20 9 7 25 26 10 8 14 11 23 12 24 13] ['2011-01-28 336.1', '2011-02-22 338.61', '2011-01-31 339.32', '2011-02-23 342.62', '2011-02-24 342.88', '2011-02-03 343.44', '2011-02-02 344.32', '2011-02-01 345.03', '2011-02-04 346.5', '2011-03-10 346.67', '2011-02-25 348.16', '2011-03-01 349.31', '2011-02-18 350.56', '2011-02-07 351.88', '2011-03-11 351.99', '2011-03-02 352.12', '2011-03-09 352.47', '2011-02-28 353.21', '2011-02-10 354.54', '2011-02-08 355.2', '2011-03-07 355.36', '2011-03-08 355.76', '2011-02-11 356.85', '2011-02-09 358.16', '2011-02-17 358.3', '2011-02-14 359.18', '2011-03-03 359.56', '2011-02-15 359.9', '2011-03-04 360.0', '2011-02-16 363.13']
刚刚发生了什么?
我们使用 NumPy lexsort()函数按词法对AAPL的收盘价进行分类。 该函数返回与数组排序相对应的索引(请参见lex.py):
from __future__ import print_function
import numpy as np
import datetime
def datestr2num(s):
return datetime.datetime.strptime(s, "%d-%m-%Y").toordinal()
dates, closes=np.loadtxt('AAPL.csv', delimiter=',', usecols=(1, 6), converters={1:datestr2num}, unpack=True)
indices = np.lexsort((dates, closes))
print("Indices", indices)
print(["%s %s" % (datetime.date.fromordinal(int(dates[i])), closes[i]) for i in indices])
勇往直前 – 尝试不同的排序顺序
我们使用日期和收盘价顺序进行了排序。 请尝试其他顺序。 使用我们在上一章中学习到的随机模块生成随机数,然后使用lexsort()对其进行排序。
实战时间 – 通过使用partition()函数选择快速中位数进行部分排序
partition()函数执行部分排序, 应该比完整排序更快,因为它的工作量较小。
注意
有关更多信息,请参考这里。 一个常见的用例是获取集合的前 10 个元素。 部分排序不能保证顶部元素组本身的正确顺序。
该函数的第一个参数是要部分排序的数组。 第二个参数是与数组元素索引相对应的整数或整数序列。 partition()函数对那些索引中的元素进行正确排序。 使用一个指定的索引,我们得到两个分区。 具有多个索引,我们得到多个分区。 排序算法确保分区中的元素(小于正确排序的元素)位于该元素之前。 否则,它们将放置在此元素后面。 让我们用一个例子来说明这个解释。 启动 Python 或 IPython Shell 并导入 NumPy:
$ ipython
In [1]: import numpy as np
创建一个包含随机元素的数组以进行排序:
In [2]: np.random.seed(20)
In [3]: a = np.random.random_integers(0, 9, 9)
In [4]: a
Out[4]: array([3, 9, 4, 6, 7, 2, 0, 6, 8])
通过将其分成两个大致相等的部分,对数组进行部分排序:
In [5]: np.partition(a, 4)
Out[5]: array([0, 2, 3, 4, 6, 6, 7, 9, 8])
除了最后两个元素外,我们得到了几乎完美的排序。
刚刚发生了什么?
我们对 9 个元素的数组进行了部分排序。 排序仅保证索引 4 中间的一个元素位于正确的位置。 这对应于尝试获取数组的前五个元素而不关心前五个组中的顺序。 由于正确排序的元素位于中间,因此这也给出了数组的中位数。
复数
复数是具有实部和虚部的数字。 如您在前几章中所记得的那样,NumPy 具有特殊的复杂数据类型,这些数据类型通过两个浮点数表示复数。 可以使用 NumPy sort_complex()函数对这些数字进行排序。 此函数首先对实部进行排序,然后对虚部进行排序。
实战时间 – 对复数进行排序
我们将创建复数数组并将其排序:
-
为复数的实部生成五个随机数,为虚部生成五个数。 将随机生成器播种到
42:np.random.seed(42) complex_numbers = np.random.random(5) + 1j * np.random.random(5) print("Complex numbers\n", complex_numbers) -
调用
sort_complex()函数对我们在上一步中生成的复数进行排序:print("Sorted\n", np.sort_complex(complex_numbers))排序的数字将是:
Sorted [ 0.39342751+0.34955771j 0.40597665+0.77477433j 0.41516850+0.26221878j 0.86631422+0.74612422j 0.92293095+0.81335691j]
刚刚发生了什么?
我们生成了随机复数,并使用sort_complex()函数对其进行了排序(请参见sortcomplex.py):
from __future__ import print_function
import numpy as np
np.random.seed(42)
complex_numbers = np.random.random(5) + 1j * np.random.random(5)
print("Complex numbers\n", complex_numbers)
print("Sorted\n", np.sort_complex(complex_numbers))
小测验 - 生成随机数
Q1. 哪个 NumPy 模块处理随机数?
randnumrandomrandomutilrand
搜索
NumPy 具有几个可以搜索数组的函数:
-
argmax()函数提供数组最大值的索引 :>>> a = np.array([2, 4, 8]) >>> np.argmax(a) 2 -
nanargmax()函数的作用与上面相同,但忽略 NaN 值:>>> b = np.array([np.nan, 2, 4]) >>> np.nanargmax(b) 2 -
argmin()和nanargmin()函数提供相似的功能,但针对最小值。argmax()和nanargmax()函数也可用作ndarray类的方法。 -
argwhere()函数搜索非零值,并返回按元素分组的相应索引:>>> a = np.array([2, 4, 8]) >>> np.argwhere(a <= 4) array([[0], [1]]) -
searchsorted()函数告诉您数组中的索引,指定值所属的数组将保持排序顺序。 它使用二分搜索,即O(log n)算法。 我们很快就会看到此函数的作用。 -
extract()函数根据条件从数组中检索值。
实战时间 – 使用searchsorted
searchsorted()函数获取排序数组中值的索引。 一个例子应该清楚地说明这一点:
-
为了演示,使用
arange()创建一个数组,该数组当然被排序:a = np.arange(5) -
是时候调用
searchsorted()函数了:indices = np.searchsorted(a, [-2, 7]) print("Indices", indices)索引,应保持排序顺序:
Indices [0 5] -
用
insert()函数构造完整的数组:print("The full array", np.insert(a, indices, [-2, 7]))这给了我们完整的数组:
The full array [-2 0 1 2 3 4 7]
刚刚发生了什么?
searchsorted()函数为我们提供了7和-2的索引5和0。 使用这些索引,我们将数组设置为array [-2, 0, 1, 2, 3, 4, 7],因此数组保持排序状态(请参见sortedsearch.py):
from __future__ import print_function
import numpy as np
a = np.arange(5)
indices = np.searchsorted(a, [-2, 7])
print("Indices", indices)
print("The full array", np.insert(a, indices, [-2, 7]))
数组元素提取
NumPy extract()函数使我们可以根据条件从数组中提取项目。 此函数类似于第 3 章,“我们熟悉的函数”。 特殊的nonzero()函数选择非零元素。
实战时间 – 从数组中提取元素
让我们提取数组的偶数元素:
-
使用
arange()函数创建数组:a = np.arange(7) -
创建选择偶数元素的条件:
condition = (a % 2) == 0 -
使用我们的条件和
extract()函数提取偶数元素:print("Even numbers", np.extract(condition, a))这为我们提供了所需的偶数(
np.extract(condition, a)等于a[np.where(condition)[0]]):Even numbers [0 2 4 6] -
使用
nonzero()函数选择非零值:print("Non zero", np.nonzero(a))这将打印数组的所有非零值:
Non zero (array([1, 2, 3, 4, 5, 6]),)
刚刚发生了什么?
我们使用布尔值条件和 NumPy extract()函数从数组中提取了偶数元素(请参见extracted.py):
from __future__ import print_function
import numpy as np
a = np.arange(7)
condition = (a % 2) == 0
print("Even numbers", np.extract(condition, a))
print("Non zero", np.nonzero(a))
财务函数
NumPy 具有多种财务函数:
fv()函数计算出所谓的未来值。 未来值基于某些假设,给出了金融产品在未来日期的价值。pv()函数计算当前值(请参阅这里)。 当前值是今天的资产价值。npv()函数返回净当前值。 净当前值定义为所有当前现金流的总和。pmt()函数计算借贷还款的本金加上利息。irr()函数计算的内部收益率。 内部收益率是实际利率, 未将通货膨胀考虑在内。mirr()函数计算修正的内部收益率。 修正的内部收益率是内部收益率的改进版本。nper()函数返回定期付款数值。rate()函数计算利率。
实战时间 – 确定未来值
未来值根据某些假设给出了金融产品在未来日期的价值。 终值取决于四个参数-利率,周期数,定期付款和当前值。
注意
在这个页面上阅读更多关于未来值的东西。 具有复利的终值的公式如下:
在上式中, PV是当前值,r是利率,n是周期数。
在本节中,让我们以3% 的利率,5年的季度10的季度付款以及1000的当前值。 用适当的值调用fv()函数(负值表示支出现金流):
print("Future value", np.fv(0.03/4, 5 * 4, -10, -1000))
终值如下:
Future value 1376.09633204
如果我们改变保存和保持其他参数不变的年数,则会得到以下图表:
刚刚发生了什么?
我们使用 NumPy fv()函数从1000的当前值,3的利率,5年和10的季度付款开始计算未来值。 。 我们绘制了各种保存期的未来值(请参见futurevalue.py):
from __future__ import print_function
import numpy as np
import matplotlib.pyplot as plt
print("Future value", np.fv(0.03/4, 5 * 4, -10, -1000))
fvals = []
for i in xrange(1, 10):
fvals.append(np.fv(.03/4, i * 4, -10, -1000))
plt.plot(range(1, 10), fvals, 'bo')
plt.title('Future value, 3 % interest,\n Quarterly payment of 10')
plt.xlabel('Saving periods in years')
plt.ylabel('Future value')
plt.grid()
plt.legend(loc='best')
plt.show()
当前值
当前值是今天的资产价值。 NumPy pv()函数可以计算当前值。 此函数与fv()函数类似,并且需要利率,期间数和定期还款,但是这里我们从终值开始。
了解有关当前值的更多信息。 如果需要,可以很容易地从将来值的公式中得出当前值的公式。
实战时间 – 获得当前值
让我们将“实战时间 – 确定未来值”中的数字反转:
插入“实战时间 – 确定未来值”部分:
print("Present value", np.pv(0.03/4, 5 * 4, -10, 1376.09633204))
除了微小的数值误差外,这给了我们1000预期的效果。 实际上,这不是错误,而是表示问题。 我们在这里处理现金流出,这就是负值的原因:
Present value -999.999999999
刚刚发生了什么?
我们反转了“实战时间 – 确定将来值”部分,以从将来值中获得当前值。 这是通过 NumPy pv()函数完成的。
净当前值
净当前值定义为所有当前值现金流的总和。 NumPy npv()函数返回现金流的净当前值。 该函数需要两个参数:rate和代表现金流的数组。
阅读有关净当前值的更多信息,。 在净当前值的公式中, Rt是时间段的现金流,r是折现率,t是时间段的指数:
实战时间 – 计算净当前值
我们将计算随机产生的现金流序列的净当前值:
-
为现金流量序列生成五个随机值。 插入 -100 作为起始值:
cashflows = np.random.randint(100, size=5) cashflows = np.insert(cashflows, 0, -100) print("Cashflows", cashflows)现金流如下:
Cashflows [-100 38 48 90 17 36] -
调用
npv()函数从上一步生成的现金流量序列中计算净当前值。 使用百分之三的比率:print("Net present value", np.npv(0.03, cashflows))净当前值:
Net present value 107.435682443
刚刚发生了什么?
我们使用 NumPy npv()函数(请参见netpresentvalue.py)从随机生成的现金流序列中计算出净当前值:
from __future__ import print_function
import numpy as np
cashflows = np.random.randint(100, size=5)
cashflows = np.insert(cashflows, 0, -100)
print("Cashflows", cashflows)
print("Net present value", np.npv(0.03, cashflows))
内部收益率
收益率的内部利率是有效利率,它没有考虑通货膨胀。 NumPy irr()函数返回给定现金流序列的内部收益率。
实战时间 – 确定内部收益率
让我们重用“实战时间 – 计算净当前值”部分的现金流序列。 在现金流序列上调用irr()函数:
print("Internal rate of return", np.irr([-100, 38, 48, 90, 17, 36]))
内部收益率:
Internal rate of return 0.373420226888
刚刚发生了什么?
我们根据“实战时间 – 计算净当前值”部分的现金流系列计算内部收益率。 该值由 NumPy irr()函数给出。
定期付款
NumPy pmt()函数允许您基于利率和定期还款次数来计算贷款的定期还款。
实战时间 – 计算定期付款
假设您的贷款为 1000 万,利率为1%。 您有30年还清贷款。 您每个月要付多少钱? 让我们找出答案。
使用上述值调用pmt()函数:
print("Payment", np.pmt(0.01/12, 12 * 30, 10000000))
每月付款:
Payment -32163.9520447
刚刚发生了什么?
我们以每年1% 的利率计算了 1000 万的贷款的每月付款。 鉴于我们有30年的还款期,pmt()函数告诉我们我们需要每月支付32163.95。
付款次数
NumPy nper()函数告诉我们要偿还贷款需要多少次定期付款。 必需的参数是贷款的利率,固定金额的定期还款以及当前值。
实战时间 – 确定定期付款的次数
考虑一笔9000的贷款,其利率为10% ,固定每月还款100。
使用 NumPy nper()函数找出需要多少笔付款:
print("Number of payments", np.nper(0.10/12, -100, 9000))
付款次数:
Number of payments 167.047511801
刚刚发生了什么?
我们确定了还清利率为10的9000贷款和100每月还款所需的还款次数。 返回的付款数为167。
利率
NumPy rate()函数根据给定的定期付款次数, 付款金额,当前值和终值来计算利率。
实战时间 – 确定利率
让我们从“实战时间 – 确定定期付款的数量”部分的值,并从其他参数反向计算利率。
填写上一个“实战时间”部分中的数字:
print("Interest rate", 12 * np.rate(167, -100, 9000, 0))
预期的利率约为 10%:
Interest rate 0.0999756420664
刚刚发生了什么?
我们使用 NumPy rate()函数和“实战时间 – 确定定期付款的数量”部分的值来计算贷款的利率。 忽略舍入错误,我们得到了最初的10百分比。
窗口函数
窗口函数是信号处理中常用的数学函数。 应用包括光谱分析和过滤器设计。 这些函数在指定域之外定义为 0。 NumPy 具有许多窗口函数:bartlett(),blackman(),hamming(),hanning()和kaiser()。 您可以在第 4 章,“便捷函数”和第 3 章,“熟悉常用函数”。
实战时间 – 绘制 Bartlett 窗口
Bartlett 窗口是三角形平滑窗口:
-
调用 NumPy
bartlett()函数:window = np.bartlett(42) -
使用 matplotlib 进行绘图很容易:
plt.plot(window) plt.show()如下所示,这是 Bartlett 窗口,该窗口是三角形的:
刚刚发生了什么?
我们用 NumPy bartlett()函数绘制了 Bartlett 窗口。
布莱克曼窗口
布莱克曼窗口是以下余弦的和:
NumPy blackman()函数返回布莱克曼窗口。 唯一参数是输出窗口中M的点数。 如果该数字为0或小于0,则该函数返回一个空数组。
实战时间 – 使用布莱克曼窗口平滑股票价格
让我们从小型AAPL股价数据文件中平滑收盘价:
-
将数据加载到 NumPy 数组中。 调用 NumPy
blackman()函数形成一个窗口,然后使用该窗口平滑价格信号:closes=np.loadtxt('AAPL.csv', delimiter=',', usecols=(6,), converters={1:datestr2num}, unpack=True) N = 5 window = np.blackman(N) smoothed = np.convolve(window/window.sum(), closes, mode='same') -
使用 matplotlib 绘制平滑价格。 在此示例中,我们将省略前五个数据点和后五个数据点。 这样做的原因是存在强烈的边界效应:
plt.plot(smoothed[N:-N], lw=2, label="smoothed") plt.plot(closes[N:-N], label="closes") plt.legend(loc='best') plt.show()使用布莱克曼窗口平滑的
AAPL收盘价应如下所示:
刚刚发生了什么?
我们从样本数据文件中绘制了AAPL的收盘价,该价格使用布莱克曼窗口和 NumPy blackman()函数进行了平滑处理(请参见plot_blackman.py):
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.dates import datestr2num
closes=np.loadtxt('AAPL.csv', delimiter=',', usecols=(6,), converters={1:datestr2num}, unpack=True)
N = 5
window = np.blackman(N)
smoothed = np.convolve(window/window.sum(), closes, mode='same')
plt.plot(smoothed[N:-N], lw=2, label="smoothed")
plt.plot(closes[N:-N], '--', label="closes")
plt.title('Blackman window')
plt.xlabel('Days')
plt.ylabel('Price ($)')
plt.grid()
plt.legend(loc='best')
plt.show()
汉明窗口
汉明窗由加权余弦形成。 计算公式如下:
NumPy hamming()函数返回汉明窗口。 唯一的参数是输出窗口中点的数量M。 如果此数字为0或小于0,则返回一个空数组。
实战时间 – 绘制汉明窗口
让我们绘制汉明窗口:
-
调用 NumPy
hamming()函数:window = np.hamming(42) -
使用 matplotlib 绘制窗口:
plt.plot(window) plt.show()汉明窗图显示如下:
刚刚发生了什么?
我们使用 NumPy hamming()函数绘制了汉明窗口。
凯撒窗口
凯撒窗口由贝塞尔函数形成。
注意
公式如下:
I0是零阶贝塞尔函数。 NumPy kaiser()函数返回凯撒窗口。 第一个参数是输出窗口中的点数。 如果此数字为0或小于0,则函数将返回一个空数组。 第二个参数是beta。
实战时间 – 绘制凯撒窗口
让我们绘制凯撒窗口:
-
调用 NumPy
kaiser()函数:window = np.kaiser(42, 14) -
使用 matplotlib 绘制窗口:
plt.plot(window) plt.show()凯撒窗口显示如下:
刚刚发生了什么?
我们使用 NumPy kaiser()函数绘制了凯撒窗口。
特殊数学函数
我们将以一些特殊的数学函数结束本章。 第一类 0 阶的修正的贝塞尔函数由i0()表示为 NumPy 中的 。 sinc函数在 NumPy 中由具有相同名称的函数表示, 也有此函数的二维版本。 sinc是三角函数; 有关更多详细信息,请参见这里。 sinc()函数具有两个定义。
NumPy sinc()函数符合以下定义:
实战时间 – 绘制修正的贝塞尔函数
让我们看看修正的第一种零阶贝塞尔函数是什么样的:
-
使用 NumPy
linspace()函数计算均匀间隔的值:x = np.linspace(0, 4, 100) -
调用 NumPy
i0()函数:vals = np.i0(x) -
使用 matplotlib 绘制修正的贝塞尔函数:
plt.plot(x, vals) plt.show()修正的贝塞尔函数将具有以下输出:
刚刚发生了什么?
我们用 NumPy i0()函数绘制了第一种零阶修正的贝塞尔函数。
sinc
sinc()函数广泛用于数学和信号处理中。 NumPy 具有相同名称的函数。 也存在二维函数。
实战时间 – 绘制sinc函数
我们将绘制sinc()函数:
-
使用 NumPy
linspace()函数计算均匀间隔的值:x = np.linspace(0, 4, 100) -
调用 NumPy
sinc()函数:vals = np.sinc(x) -
用 matplotlib 绘制
sinc()函数:plt.plot(x, vals) plt.show()sinc()函数将具有以下输出:sinc2d()函数需要二维数组。 我们可以使用outer()函数创建它,从而得到该图(代码在以下部分中):
刚刚发生了什么?
我们用 NumPy sinc()函数(参见plot_sinc.py)绘制了众所周知的sinc函数:
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0, 4, 100)
vals = np.sinc(x)
plt.plot(x, vals)
plt.title('Sinc function')
plt.xlabel('x')
plt.ylabel('y')
plt.grid()
plt.show()
我们在两个维度上都做了相同的操作(请参见sinc2d.py):
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0, 4, 100)
xx = np.outer(x, x)
vals = np.sinc(xx)
plt.imshow(vals)
plt.title('Sinc 2D')
plt.xlabel('x')
plt.ylabel('y')
plt.grid()
plt.show()
总结
这是一章,涵盖了更多专门的 NumPy 主题。 我们介绍了排序和搜索,特殊函数,财务工具和窗口函数。
下一章是关于非常重要的测试主题的。
八、通过测试确保质量
一些程序员仅在生产中进行测试。 如果您不是其中之一,那么您可能熟悉单元测试的概念。 单元测试是程序员编写的用于测试其代码的自动测试。 例如,这些测试可以单独测试函数或函数的一部分。 每个测试仅覆盖一小部分代码。 这样做的好处是提高了对代码质量,可重复测试的信心,并附带了更清晰的代码。
Python 对单元测试有很好的支持。 此外,NumPy 将numpy.testing包添加到 NumPy 代码单元测试的包中。
测试驱动的开发(TDD)是最重要的事情之一发生在软件开发中。 TDD 将集中在自动化单元测试上。 目标是尽可能自动地测试代码。 下次更改代码时,我们可以运行测试并捕获潜在的回归。 换句话说,任何已经存在的函数仍然可以使用。
本章中的主题包括:
- 单元测试
- 断言
- 浮点精度
断言函数
单元测试通常使用函数,这些函数断言某些内容是测试的一部分。 在进行数值计算时,通常存在一个基本问题,即试图比较几乎相等的浮点数。 对于整数,比较是微不足道的操作,但对于浮点数则不是,因为计算机的表示不准确。 NumPy testing包具有许多工具函数,这些函数可以测试先决条件是否成立,同时考虑到浮点比较的问题。 下表显示了不同的工具函数:
| 函数 | 描述 |
|---|---|
assert_almost_equal() | 如果两个数字不等于指定的精度,则此函数引发异常 |
assert_approx_equal() | 如果两个数字在一定意义上不相等,则此函数引发异常 |
assert_array_almost_equal() | 如果两个数组的指定精度不相等,此函数将引发异常 |
assert_array_equal() | 如果两个数组不相等,此函数将引发异常。 |
assert_array_less() | 如果两个数组的形状不同,并且第一个数组的元素严格小于第二个数组的元素,则此函数引发异常 |
assert_equal() | 如果两个对象不相等,则此函数引发异常 |
assert_raises() | 如果使用定义的参数调用的可调用对象未引发指定的异常,则此函数失败 |
assert_warns() | 如果未抛出指定的警告,则此函数失败 |
assert_string_equal() | 此函数断言两个字符串相等 |
assert_allclose() | 如果两个对象不等于期望的公差,则此函数引发断言 |
实战时间 – 断言几乎相等
假设您有两个几乎相等的数字。 让我们使用assert_almost_equal()函数检查它们是否相等:
-
以较低精度调用函数(最多 7 个小数位):
print("Decimal 6", np.testing.assert_almost_equal(0.123456789, 0.123456780, decimal=7))请注意,不会引发异常,如以下结果所示:
Decimal 6 None -
以更高的精度调用该函数(最多 8 个小数位):
print("Decimal 7", np.testing.assert_almost_equal(0.123456789, 0.123456780, decimal=8))结果如下:
Decimal 7 Traceback (most recent call last): … raise AssertionError(msg) AssertionError: Arrays are not almost equal ACTUAL: 0.123456789 DESIRED: 0.12345678
刚刚发生了什么?
我们使用了 NumPy testing包中的assert_almost_equal()函数来检查0.123456789和0.123456780对于不同的十进制精度是否相等。
小测验 - 指定小数精度
Q1. assert_almost_equal()函数的哪个参数指定小数精度?
decimalprecisiontolerancesignificant
近似相等的数组
如果两个数字在一定数量的有效数字下不相等,则assert_approx_equal()函数会引发异常。 该函数引发由以下情况触发的异常:
abs(actual - expected) >= 10**-(significant - 1)
实战时间 – 断言近似相等
让我们从上一个“实战”部分中选取数字,在它们上应用assert_approx_equal()函数:
-
以低重要性调用函数:
print("Significance 8", np.testing.assert_approx_equal (0.123456789, 0.123456780,significant=8))The result is as follows:
Significance 8 None -
以高重要性调用函数:
print("Significance 9", np.testing.assert_approx_equal (0.123456789, 0.123456780, significant=9))该函数引发一个
AssertionError:Significance 9 Traceback (most recent call last): ... raise AssertionError(msg) AssertionError: Items are not equal to 9 significant digits: ACTUAL: 0.123456789 DESIRED: 0.12345678
刚刚发生了什么?
我们使用了 NumPy testing包中的assert_approx_equal()函数来检查0.123456789和0.123456780对于不同的十进制精度是否相等。
几乎相等的数组
如果两个数组在指定的精度下不相等,则assert_array_almost_equal()函数会引发异常。 该函数检查两个数组的形状是否相同。 然后,将数组的值与以下元素进行逐元素比较:
|expected - actual| < 0.5 10-decimal
实战时间 – 断言数组几乎相等
通过向每个数组添加0,用上一个“实战时间”部分的值构成数组:
-
以较低的精度调用该函数:
print("Decimal 8", np.testing.assert_array_almost_equal([0, 0.123456789], [0, 0.123456780], decimal=8))The result is as follows:
Decimal 8 None -
以较高的精度调用该函数:
print("Decimal 9", np.testing.assert_array_almost_equal([0, 0.123456789], [0, 0.123456780], decimal=9))测试产生一个
AssertionError:Decimal 9 Traceback (most recent call last): … assert_array_compare raise AssertionError(msg) AssertionError: Arrays are not almost equal (mismatch 50.0%) x: array([ 0\. , 0.12345679]) y: array([ 0\. , 0.12345678])
刚刚发生了什么?
我们将两个数组与 NumPy array_almost_equal()函数进行了比较。
勇往直前 – 比较不同形状的数组
使用 NumPy array_almost_equal()函数比较具有不同形状的两个数组。
相等的数组
如果两个数组不相等,assert_array_equal()函数将引发异常。 数组的形状必须相等,并且每个数组的元素必须相等。 数组中允许使用 NaN。 或者,可以将数组与array_allclose()函数进行比较。 此函数的参数为绝对公差(atol)和相对公差(rtol)。 对于两个数组a和b,这些参数满足以下方程式:
|a - b| <= (atol + rtol * |b|)
实战时间 – 比较数组
让我们将两个数组与刚才提到的函数进行比较。 我们将重复使用先前“实战”中的数组,并将它们加上 NaN:
-
调用
array_allclose()函数:print("Pass", np.testing.assert_allclose([0, 0.123456789, np.nan], [0, 0.123456780, np.nan], rtol=1e-7, atol=0))The result is as follows:
Pass None -
调用
array_equal()函数:print("Fail", np.testing.assert_array_equal([0, 0.123456789, np.nan], [0, 0.123456780, np.nan]))测试失败,并显示
AssertionError:Fail Traceback (most recent call last): … assert_array_compare raise AssertionError(msg) AssertionError: Arrays are not equal (mismatch 50.0%) x: array([ 0\. , 0.12345679, nan]) y: array([ 0\. , 0.12345678, nan])
刚刚发生了什么?
我们将两个数组与array_allclose()函数和array_equal()函数进行了比较。
排序数组
如果两个数组不具有相同形状的,并且第一个数组的元素严格小于第二个数组的元素,则assert_array_less()函数会引发异常。
实战时间 – 检查数组顺序
让我们检查一个数组是否严格大于另一个数组:
-
用两个严格排序的数组调用
assert_array_less()函数:print("Pass", np.testing.assert_array_less([0, 0.123456789, np.nan], [1, 0.23456780, np.nan]))The result is as follows:
Pass None -
调用
assert_array_less()函数:print("Fail", np.testing.assert_array_less([0, 0.123456789, np.nan], [0, 0.123456780, np.nan]))该测试引发一个异常:
Fail Traceback (most recent call last): ... raise AssertionError(msg) AssertionError: Arrays are not less-ordered (mismatch 100.0%) x: array([ 0\. , 0.12345679, nan]) y: array([ 0\. , 0.12345678, nan])
刚刚发生了什么?
我们使用assert_array_less()函数检查了两个数组的顺序。
对象比较
如果两个对象不相等,则assert_equal()函数将引发异常。 对象不必是 NumPy 数组,它们也可以是列表,元组或字典。
实战时间 – 比较对象
假设您需要比较两个元组。 我们可以使用assert_equal()函数来做到这一点。
调用assert_equal()函数:
print("Equal?", np.testing.assert_equal((1, 2), (1, 3)))
该调用引发错误,因为项目不相等:
Equal?
Traceback (most recent call last):
...
raise AssertionError(msg)
AssertionError:
Items are not equal:
item=1
ACTUAL: 2
DESIRED: 3
刚刚发生了什么?
我们将两个元组与assert_equal()函数进行了比较-由于元组彼此不相等,因此引发了一个例外。
字符串比较
assert_string_equal()函数断言两个字符串相等。 如果测试失败,该函数将引发异常,并显示字符串之间的差异。 字符串字符的大小写很重要。
实战时间 – 比较字符串
让我们比较一下字符串。 这两个字符串都是单词NumPy:
-
调用
assert_string_equal()函数将字符串与自身进行比较。 该测试当然应该通过:print("Pass", np.testing.assert_string_equal("NumPy", "NumPy"))测试通过:
Pass None -
调用
assert_string_equal()函数将一个字符串与另一个字母相同但大小写不同的字符串进行比较。 此测试应引发异常:print("Fail", np.testing.assert_string_equal("NumPy", "Numpy"))测试引发错误:
Fail Traceback (most recent call last): … raise AssertionError(msg) AssertionError: Differences in strings: - NumPy? ^ + Numpy? ^
刚刚发生了什么?
我们将两个字符串与assert_string_equal()函数进行了比较。 当外壳不匹配时,该测试引发了异常。
浮点比较
计算机中浮点数的表示形式不准确。 比较浮点数时,这会导致问题。 assert_array_almost_equal_nulp()和assert_array_max_ulp() NumPy 函数提供一致的浮点比较。浮点数的最低精度的单位(ULP),根据 IEEE 754 规范,是基本算术运算所需的半精度。 您可以将此与标尺进行比较。 公制标尺通常具有毫米的刻度,但超过该刻度则只能估计半毫米。
机器ε是浮点算术中最大的相对舍入误差。 机器ε等于 ULP 相对于 1。NumPy finfo()函数使我们能够确定机器ε。 Python 标准库还可以为您提供机器的ε值。 该值应与 NumPy 给出的值相同。
实战时间 – 使用assert_array_almost_equal_nulp来比较
让我们看到assert_array_almost_equal_nulp()函数的作用:
-
使用
finfo()函数确定机器epsilon:eps = np.finfo(float).eps print("EPS", eps)ε将如下所示:EPS 2.22044604925e-16 -
使用
assert_almost_equal_nulp()函数将1.0与1 + epsilon进行比较。 对1 + 2 * epsilon执行相同的操作:print("1", np.testing.assert_array_almost_equal_nulp(1.0, 1.0 + eps)) print("2", np.testing.assert_array_almost_equal_nulp(1.0, 1.0 + 2 * eps))The result is as follows:
1 None 2 Traceback (most recent call last): … assert_array_almost_equal_nulp raise AssertionError(msg) AssertionError: X and Y are not equal to 1 ULP (max is 2)
刚刚发生了什么?
我们通过finfo()函数确定了机器ε。 然后,我们将1.0与1 + epsilon与assert_almost_equal_nulp()函数进行了比较。 但是,该测试通过了,添加另一个ε导致异常。
更多使用 ULP 的浮点比较
assert_array_max_ulp()函数允许您指定允许的 ULP 数量的上限。 maxulp参数接受整数作为限制。 默认情况下,此参数的值为 1。
实战时间 – 使用最大值 2 的比较
让我们进行与先前“实战”部分相同的事情,但在必要时, 指定maxulp为2:
-
使用
finfo()函数确定机器epsilon:eps = np.finfo(float).eps print("EPS", eps)The epsilon would be as follows:
EPS 2.22044604925e-16 -
按照前面的“实战时间”部分中进行的比较,但是将
assert_array_max_ulp()函数与相应的maxulp值一起使用:print("1", np.testing.assert_array_max_ulp(1.0, 1.0 + eps)) print("2", np.testing.assert_array_max_ulp(1.0, 1 + 2 * eps, maxulp=2))输出为 ,如下所示:
1 1.0 2 2.0
刚刚发生了什么?
我们比较了与之前“实战”部分相同的值,但在第二次比较中指定了2的maxulp。 通过将assert_array_max_ulp()函数与适当的maxulp值一起使用,这些测试通过了 ULP 数量返回值。
单元测试
单元测试是自动化测试,它测试一小段代码,通常是函数或方法。 Python 具有用于单元测试的PyUnit API。 作为 NumPy 的用户,我们可以利用之前在操作中看到的assert函数。
实战时间 – 编写单元测试
我们将为一个简单的阶乘函数编写测试 。 测试将检查所谓的快乐路径和异常状况。
-
首先编写阶乘函数:
import numpy as np import unittest def factorial(n): if n == 0: return 1 if n < 0: raise ValueError, "Unexpected negative value" return np.arange(1, n+1).cumprod()该代码使用
arange()和cumprod()函数创建数组并计算累积乘积,但是我们添加了一些边界条件检查。 -
现在我们将编写单元测试。 让我们写一个包含单元测试的类。 它从标准测试 Pytho 的
unittest模块扩展了TestCase类。 测试具有以下三个属性的阶乘函数的调用:-
正数,正确的方式
-
边界条件 0
-
负数,这将导致错误
class FactorialTest(unittest.TestCase): def test_factorial(self): #Test for the factorial of 3 that should pass. self.assertEqual(6, factorial(3)[-1]) np.testing.assert_equal(np.array([1, 2, 6]), factorial(3)) def test_zero(self): #Test for the factorial of 0 that should pass. self.assertEqual(1, factorial(0)) def test_negative(self): #Test for the factorial of negative numbers that should fail. # It should throw a ValueError, but we expect IndexError self.assertRaises(IndexError, factorial(-10))如以下输出所示,我们将其中一项测试失败了:
$ python unit_test.py .E. ====================================================================== ERROR: test_negative (__main__.FactorialTest) ---------------------------------------------------------------------- Traceback (most recent call last): File "unit_test.py", line 26, in test_negative self.assertRaises(IndexError, factorial(-10)) File "unit_test.py", line 9, in factorial raise ValueError, "Unexpected negative value" ValueError: Unexpected negative value ---------------------------------------------------------------------- Ran 3 tests in 0.003s FAILED (errors=1)
-
刚刚发生了什么?
我们对阶乘函数代码进行了一些满意的路径测试。 我们让边界条件测试故意失败(请参阅unit_test.py):
import numpy as np
import unittest
def factorial(n):
if n == 0:
return 1
if n < 0:
raise ValueError, "Unexpected negative value"
return np.arange(1, n+1).cumprod()
class FactorialTest(unittest.TestCase):
def test_factorial(self):
#Test for the factorial of 3 that should pass.
self.assertEqual(6, factorial(3)[-1])
np.testing.assert_equal(np.array([1, 2, 6]), factorial(3))
def test_zero(self):
#Test for the factorial of 0 that should pass.
self.assertEqual(1, factorial(0))
def test_negative(self):
#Test for the factorial of negative numbers that should fail.
# It should throw a ValueError, but we expect IndexError
self.assertRaises(IndexError, factorial(-10))
if __name__ == '__main__':
unittest.main()
Nose测试装饰器
鼻子是嘴巴上方的器官,人类和动物用来呼吸和闻味。 它也是一个 Python 框架 ,使(单元)测试变得更加容易。 Nose可帮助您组织测试。 根据nose文档:
“将收集与
testMatch正则表达式(默认值:(?:^|[b_.-])[Tt]est)匹配的任何 python 源文件,目录或包作为测试。”
Nose大量使用装饰器。 Python 装饰器是指示有关方法或函数的注释。 numpy.testing模块具有许多装饰器。 下表显示了numpy.testing模块中的不同装饰器:
| 装饰器 | 描述 |
|---|---|
numpy.testing.decorators.deprecated | 运行测试时,此函数过滤弃用警告 |
numpy.testing.decorators.knownfailureif | 此函数基于条件引发KnownFailureTest异常 |
numpy.testing.decorators.setastest | 此装饰器标记测试函数或未测试函数 |
numpy.testing.decorators.skipif | 此函数根据条件引发一个SkipTest异常 |
numpy.testing.decorators.slow | 此函数将测试函数或方法标记为缓慢 |
另外,我们可以调用decorate_methods()函数将修饰符应用于与正则表达式或字符串匹配的类的方法。
实战时间 – 装饰测试函数
我们将直接将@setastest装饰器应用于测试函数。 然后,我们将相同的装饰器应用于方法以将其禁用。 另外,我们将跳过其中一项测试,并通过另一项测试。 首先,安装nose以防万一。
-
用
setuptools安装nose:$ [sudo] easy_install nose或点子:
$ [sudo] pip install nose -
将一个函数当作测试,将另一个函数当作不是测试:
@setastest(False) def test_false(): pass @setastest(True) def test_true(): pass -
使用
@skipif装饰器跳过测试。 让我们使用一个总是导致测试被跳过的条件:@skipif(True) def test_skip(): pass -
添加一个始终通过的测试函数。 然后,使用
@knownfailureif装饰器对其进行装饰,以使测试始终失败:@knownfailureif(True) def test_alwaysfail(): pass -
使用通常应由
nose执行的方法定义一些test类:class TestClass(): def test_true2(self): pass class TestClass2(): def test_false2(self): pass -
让我们从上一步中禁用第二个测试方法:
decorate_methods(TestClass2, setastest(False), 'test_false2') -
使用以下命令运行测试:
$ nosetests -v decorator_setastest.py decorator_setastest.TestClass.test_true2 ... ok decorator_setastest.test_true ... ok decorator_test.test_skip ... SKIP: Skipping test: test_skipTest skipped due to test condition decorator_test.test_alwaysfail ... ERROR ====================================================================== ERROR: decorator_test.test_alwaysfail ---------------------------------------------------------------------- Traceback (most recent call last): File "…/nose/case.py", line 197, in runTest self.test(*self.arg) File …/numpy/testing/decorators.py", line 213, in knownfailer raise KnownFailureTest(msg) KnownFailureTest: Test skipped due to known failure ---------------------------------------------------------------------- Ran 4 tests in 0.001s FAILED (SKIP=1, errors=1)
刚刚发生了什么?
我们将某些函数和方法修饰为非测试形式,以便它们被鼻子忽略。 我们跳过了一项测试,也没有通过另一项测试。 我们通过直接使用装饰器并使用decorate_methods()函数(请参见decorator_test.py)来完成此操作:
from numpy.testing.decorators import setastest
from numpy.testing.decorators import skipif
from numpy.testing.decorators import knownfailureif
from numpy.testing import decorate_methods
@setastest(False)
def test_false():
pass
@setastest(True)
def test_true():
pass
@skipif(True)
def test_skip():
pass
@knownfailureif(True)
def test_alwaysfail():
pass
class TestClass():
def test_true2(self):
pass
class TestClass2():
def test_false2(self):
pass
decorate_methods(TestClass2, setastest(False), 'test_false2')
文档字符串
Doctests 是嵌入在 Python 代码中的字符串,类似于交互式会话。 这些字符串可用于测试某些假设或仅提供示例。 numpy.testing模块具有运行这些测试的函数。
实战时间 – 执行文档测试
让我们写一个简单示例,该示例应该计算众所周知的阶乘,但并不涵盖所有可能的边界条件。 换句话说,某些测试将失败。
-
docstring看起来像您在 Python Shell 中看到的文本(包括提示)。 吊装其中一项测试失败,只是为了看看会发生什么:""" Test for the factorial of 3 that should pass. >>> factorial(3) 6 Test for the factorial of 0 that should fail. >>> factorial(0) 1 """ -
编写以下 NumPy 代码行:
return np.arange(1, n+1).cumprod()[-1]我们希望此代码不时出于演示目的而失败。
-
例如,通过在 Python Shell 中调用
numpy.testing模块的rundocs()函数来运行doctest:>>> from numpy.testing import rundocs >>> rundocs('docstringtest.py') Traceback (most recent call last): File "<stdin>", line 1, in <module> File "…/numpy/testing/utils.py", line 998, in rundocs raise AssertionError("Some doctests failed:\n%s" % "\n".join(msg)) AssertionError: Some doctests failed: ********************************************************************** File "docstringtest.py", line 10, in docstringtest.factorial Failed example: factorial(0) Exception raised: Traceback (most recent call last): File "…/doctest.py", line 1254, in __run compileflags, 1) in test.globs File "<doctest docstringtest.factorial[1]>", line 1, in <module> factorial(0) File "docstringtest.py", line 13, in factorial return np.arange(1, n+1).cumprod()[-1] IndexError: index -1 is out of bounds for axis 0 with size 0
刚刚发生了什么?
我们编写了文档字符串测试,该测试未考虑0和负数。 我们使用numpy.testing模块中的rundocs()函数运行了测试,结果得到了索引错误(请参见docstringtest.py):
import numpy as np
def factorial(n):
"""
Test for the factorial of 3 that should pass.
>>> factorial(3)
6
Test for the factorial of 0 that should fail.
>>> factorial(0)
1
"""
return np.arange(1, n+1).cumprod()[-1]
总结
您在本章中了解了测试和 NumPy 测试工具。 我们介绍了单元测试,文档字符串测试,断言函数和浮点精度。 大多数 NumPy 断言函数都会处理浮点数的复杂性。 我们展示了可以被鼻子使用的 NumPy 装饰器。 装饰器使测试更加容易,并记录了开发人员的意图。
下一章的主题是 matplotlib -- Python 科学的可视化和图形化开源库。
九、matplotlib 绘图
matplotlib是一个非常有用的 Python 绘图库。 它与 NumPy 很好地集成在一起,但是是一个单独的开源项目。 您可以在这个页面上找到漂亮的示例。
matplotlib也具有工具函数,可以从 Yahoo Finance 下载和操纵数据。 我们将看到几个股票图表示例。
本章涵盖以下主题:
- 简单图
- 子图
- 直方图
- 绘图自定义
- 三维图
- 等高线图
- 动画
- 对数图
简单绘图
matplotlib.pyplot包包含用于简单绘图的函数。 重要的是要记住,每个后续函数调用都会更改当前图的状态。 最终,我们想要将图保存在文件中,或使用 show()函数显示。 但是,如果我们在 Qt 或 Wx 后端上运行的 IPython 中,则该图将交互更新,而无需等待show()函数。 这与即时输出文本输出的方式相当。
实战时间 – 绘制多项式函数
为了说明绘图的工作原理,让我们显示一些多项式图。 我们将使用 NumPy 多项式函数poly1d()创建一个多项式。
-
将标准输入值作为多项式系数。 使用 NumPy
poly1d()函数创建多项式:func = np.poly1d(np.array([1, 2, 3, 4]).astype(float)) -
使用 NumPy 和
linspace()函数创建x值。 使用-10到10范围,并创建30等距值:x = np.linspace(-10, 10, 30) -
使用我们在第一步中创建的多项式来计算多项式值:
y = func(x) -
调用
plot()函数; 这样不会立即显示图形:plt.plot(x, y) -
使用
xlabel()函数在x轴上添加标签:plt.xlabel('x') -
使用
ylabel()函数在y轴上添加标签:plt.ylabel('y(x)') -
调用
show()函数显示图形:plt.show()以下是具有多项式系数 1、2、3 和 4 的图:
刚刚发生了什么?
我们在屏幕上显示了多项式的图。 我们在x和y轴上添加了标签(请参见polyplot.py):
import numpy as np
import matplotlib.pyplot as plt
func = np.poly1d(np.array([1, 2, 3, 4]).astype(float))
x = np.linspace(-10, 10, 30)
y = func(x)
plt.plot(x, y)
plt.xlabel('x')
plt.ylabel('y(x)')
plt.show()
小测验 – plot()函数
Q1. plot()函数有什么作用?
- 它在屏幕上显示二维图。
- 它将二维图的图像保存在文件中。
- 它同时执行(1)和(2)。
- 它不执行(1),(2)或(3)。
绘图的格式字符串
plot()函数接受无限数量的参数。 在上一节中,我们给了它两个数组作为参数。 我们也可以通过可选的格式字符串指定线条颜色和样式。 默认情况下,它是蓝色实线,表示为b-,但是您可以指定其他颜色和样式,例如红色破折号。
实战时间 – 绘制多项式及其导数
让我们使用deriv()函数和m作为1绘制多项式及其一阶导数。 我们已经在前面的“实战时间”部分中做了第一部分。 我们希望使用两种不同的线型来识别什么是什么。
-
创建并微分多项式:
func = np.poly1d(np.array([1, 2, 3, 4]).astype(float)) func1 = func.deriv(m=1) x = np.linspace(-10, 10, 30) y = func(x) y1 = func1(x) -
用两种样式绘制多项式及其导数:红色圆圈和绿色虚线。 您无法在本书的印刷版本中看到颜色,因此您将不得不亲自尝试以下代码:
plt.plot(x, y, 'ro', x, y1, 'g--') plt.xlabel('x') plt.ylabel('y') plt.show()具有多项式系数
1,2,3和4的图如下:
刚刚发生了什么?
我们使用两种不同的线型和一次调用plot()函数(请参见polyplot2.py)来绘制多项式及其导数:
import numpy as np
import matplotlib.pyplot as plt
func = np.poly1d(np.array([1, 2, 3, 4]).astype(float))
func1 = func.deriv(m=1)
x = np.linspace(-10, 10, 30)
y = func(x)
y1 = func1(x)
plt.plot(x, y, 'ro', x, y1, 'g--')
plt.xlabel('x')
plt.ylabel('y')
plt.show()
子图
在某一时刻,一个绘图中将有太多的线。 但是,您仍然希望将所有内容组合在一起。 我们可以通过 subplot()函数执行此操作。 此函数在网格中创建多个图。
实战时间 – 绘制多项式及其导数
让我们绘制一个多项式及其一阶和二阶导数。 为了清楚起见,我们将进行三个子图绘制:
-
使用以下代码创建多项式及其导数:
func = np.poly1d(np.array([1, 2, 3, 4]).astype(float)) x = np.linspace(-10, 10, 30) y = func(x) func1 = func.deriv(m=1) y1 = func1(x) func2 = func.deriv(m=2) y2 = func2(x) -
使用
subplot()函数创建多项式的第一个子图。 此函数的第一个参数是行数,第二个参数是列数,第三个参数是以 1 开头的索引号。或者,将这三个参数合并为一个数字,例如311。 子图将组织成三行一列。 为子图命名为Polynomial。 画一条红色实线:plt.subplot(311) plt.plot(x, y, 'r-') plt.title("Polynomial") -
使用
subplot()函数创建一阶导数的第三子图。 为子图命名为First Derivativ。 使用一行蓝色三角形:plt.subplot(312) plt.plot(x, y1, 'b^') plt.title("First Derivative") -
使用
subplot()函数创建第二个导数的第二个子图。 给子图标题为"Second Derivative"。 使用一行绿色圆圈:plt.subplot(313) plt.plot(x, y2, 'go') plt.title("Second Derivative") plt.xlabel('x') plt.ylabel('y') plt.show()多项式系数为 1、2、3 和 4 的三个子图如下:
刚刚发生了什么?
我们在三行一列中使用三种不同的线型和三个子图绘制了多项式及其一阶和二阶导数(请参见polyplot3.py):
import numpy as np
import matplotlib.pyplot as plt
func = np.poly1d(np.array([1, 2, 3, 4]).astype(float))
x = np.linspace(-10, 10, 30)
y = func(x)
func1 = func.deriv(m=1)
y1 = func1(x)
func2 = func.deriv(m=2)
y2 = func2(x)
plt.subplot(311)
plt.plot(x, y, 'r-')
plt.title("Polynomial")
plt.subplot(312)
plt.plot(x, y1, 'b^')
plt.title("First Derivative")
plt.subplot(313)
plt.plot(x, y2, 'go')
plt.title("Second Derivative")
plt.xlabel('x')
plt.ylabel('y')
plt.show()
财务
matplotlib可以帮助监视我们的股票投资。 matplotlib.finance包具有工具,我们可以使用这些工具从 Yahoo Finance 网站下载股票报价。 然后,我们可以将数据绘制为烛台。
实战时间 – 绘制一年的股票报价
我们可以使用matplotlib.finance包绘制一年的股票报价数据。 这需要连接到 Yahoo Finance,这是数据源。
-
通过从今天减去一年来确定开始日期:
from matplotlib.dates import DateFormatter from matplotlib.dates import DayLocator from matplotlib.dates import MonthLocator from matplotlib.finance import quotes_historical_yahoo from matplotlib.finance import candlestick import sys from datetime import date import matplotlib.pyplot as plt today = date.today() start = (today.year - 1, today.month, today.day) -
我们需要创建所谓的定位器。 来自
matplotlib.dates包的这些对象在x轴上定位了几个月和几天:alldays = DayLocator() months = MonthLocator() -
创建一个日期格式化程序以格式化
x轴上的日期。 此格式化程序创建一个字符串,其中包含月份和年份的简称:month_formatter = DateFormatter("%b %Y") -
使用以下代码从 Yahoo Finance 下载股票报价数据:
quotes = quotes_historical_yahoo(symbol, start, today) -
创建一个
matplotlib Figure对象-这是绘图组件的顶级容器:fig = plt.figure() -
在该图中添加子图:
ax = fig.add_subplot(111) -
将
x轴上的主定位器设置为月份定位器。 此定位器负责x轴上的大刻度:ax.xaxis.set_major_locator(months) -
将
x轴上的次要定位器设置为天定位器。 此定位器负责x轴上的小滴答声:ax.xaxis.set_minor_locator(alldays) -
将
x轴上的主要格式器设置为月份格式器。 此格式化程序负责x轴上大刻度的标签:ax.xaxis.set_major_formatter(month_formatter) -
matplotlib.finance包中的函数使我们可以显示烛台。 使用报价数据创建烛台。 可以指定烛台的宽度。 现在,使用默认值:
```py
candlestick(ax, quotes)
```
11. 将x轴上的标签格式化为日期。 这将旋转标签在x轴上,以使其更适合:
```py
fig.autofmt_xdate()
plt.show()
```
`DISH`(**磁盘网络**)的烛台图显示如下:

刚刚发生了什么?
我们从 Yahoo Finance 下载了年的数据。 我们使用烛台绘制了这些数据的图表(请参见candlesticks.py):
from matplotlib.dates import DateFormatter
from matplotlib.dates import DayLocator
from matplotlib.dates import MonthLocator
from matplotlib.finance import quotes_historical_yahoo
from matplotlib.finance import candlestick
import sys
from datetime import date
import matplotlib.pyplot as plt
today = date.today()
start = (today.year - 1, today.month, today.day)
alldays = DayLocator()
months = MonthLocator()
month_formatter = DateFormatter("%b %Y")
symbol = 'DISH'
if len(sys.argv) == 2:
symbol = sys.argv[1]
quotes = quotes_historical_yahoo(symbol, start, today)
fig = plt.figure()
ax = fig.add_subplot(111)
ax.xaxis.set_major_locator(months)
ax.xaxis.set_minor_locator(alldays)
ax.xaxis.set_major_formatter(month_formatter)
candlestick(ax, quotes)
fig.autofmt_xdate()
plt.show()
直方图
直方图可视化数值数据的分布。 matplotlib具有方便的hist()函数 ,可绘制直方图。 hist()函数有两个主要参数-包含数据和条数的数组。
实战时间 – 绘制股价分布图
让我们绘制 Yahoo Finance 的股票价格 , 的分布图。
-
下载一年前的数据:
today = date.today() start = (today.year - 1, today.month, today.day) quotes = quotes_historical_yahoo(symbol, start, today) -
上一步中的报价数据存储在 Python 列表中。 将其转换为 NumPy 数组并提取收盘价:
quotes = np.array(quotes) close = quotes.T[4] -
用合理数量的条形图绘制直方图:
plt.hist(close, np.sqrt(len(close))) plt.show()DISH 的直方图如下所示:
刚刚发生了什么?
我们将 DISH 的股价分布绘制为直方图 (请参见stockhistogram.py):
from matplotlib.finance import quotes_historical_yahoo
import sys
from datetime import date
import matplotlib.pyplot as plt
import numpy as np
today = date.today()
start = (today.year - 1, today.month, today.day)
symbol = 'DISH'
if len(sys.argv) == 2:
symbol = sys.argv[1]
quotes = quotes_historical_yahoo(symbol, start, today)
quotes = np.array(quotes)
close = quotes.T[4]
plt.hist(close, np.sqrt(len(close)))
plt.show()
勇往直前 - 画钟形曲线
使用平均价格和标准差覆盖钟形曲线(与高斯或正态分布有关)。 当然只是练习。
对数图
当数据具有较宽范围的值时,对数图很有用。 matplotlib具有函数semilogx()(对数x轴),semilogy()(对数y轴)和loglog()(x和y轴为对数)。
实战时间 – 绘制股票交易量
股票交易量变化很大,因此让我们以对数标度进行绘制。 首先,我们需要从 Yahoo Finance 下载历史数据,提取日期和交易量,创建定位符和日期格式化程序,然后创建图形并将其添加到子图中。 我们已经在上一个“实战时间”部分中完成了这些步骤,因此我们将在此处跳过 。
使用对数刻度绘制体积:
plt.semilogy(dates, volume)
现在,设置定位器并将x轴格式化为日期。 这些步骤的说明也可以在前面的“实战时间”部分中找到。
使用对数刻度的 DISH 的股票交易量显示如下:
刚刚发生了什么?
我们使用对数比例(参见logy.py)绘制了股票交易量 :
from matplotlib.finance import quotes_historical_yahoo
from matplotlib.dates import DateFormatter
from matplotlib.dates import DayLocator
from matplotlib.dates import MonthLocator
import sys
from datetime import date
import matplotlib.pyplot as plt
import numpy as np
today = date.today()
start = (today.year - 1, today.month, today.day)
symbol = 'DISH'
if len(sys.argv) == 2:
symbol = sys.argv[1]
quotes = quotes_historical_yahoo(symbol, start, today)
quotes = np.array(quotes)
dates = quotes.T[0]
volume = quotes.T[5]
alldays = DayLocator()
months = MonthLocator()
month_formatter = DateFormatter("%b %Y")
fig = plt.figure()
ax = fig.add_subplot(111)
plt.semilogy(dates, volume)
ax.xaxis.set_major_locator(months)
ax.xaxis.set_minor_locator(alldays)
ax.xaxis.set_major_formatter(month_formatter)
fig.autofmt_xdate()
plt.show()
散点图
散点图在同一数据集中显示两个数值变量的值。 matplotlib scatter()函数创建散点图。 (可选)我们可以在图中指定数据点的颜色和大小以及 alpha 透明度。
实战时间 – 用散点图绘制价格和数量回报
我们可以轻松地绘制股票价格和交易量回报的散点图。 同样,从 Yahoo Finance 下载必要的数据。
-
上一步中的报价数据存储在 Python 列表中。 将此转换为 NumPy 数组并提取关闭和体积值:
dates = quotes.T[4] volume = quotes.T[5] -
计算收盘价和批量收益:
ret = np.diff(close)/close[:-1] volchange = np.diff(volume)/volume[:-1] -
创建一个 matplotlib 图形对象:
fig = plt.figure() -
在该图中添加子图:
ax = fig.add_subplot(111) -
创建散点图,将数据点的颜色链接到收盘价,将大小链接到体积变化:
ax.scatter(ret, volchange, c=ret * 100, s=volchange * 100, alpha=0.5) -
设置图的标题并在其上放置网格:
ax.set_title('Close and volume returns') ax.grid(True) plt.show()DISH 的散点图如下所示:
刚刚发生了什么?
我们绘制了 DISH 收盘价和成交量回报的散点图 (请参见scatterprice.py):
from matplotlib.finance import quotes_historical_yahoo
import sys
from datetime import date
import matplotlib.pyplot as plt
import numpy as np
today = date.today()
start = (today.year - 1, today.month, today.day)
symbol = 'DISH'
if len(sys.argv) == 2:
symbol = sys.argv[1]
quotes = quotes_historical_yahoo(symbol, start, today)
quotes = np.array(quotes)
close = quotes.T[4]
volume = quotes.T[5]
ret = np.diff(close)/close[:-1]
volchange = np.diff(volume)/volume[:-1]
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(ret, volchange, c=ret * 100, s=volchange * 100, alpha=0.5)
ax.set_title('Close and volume returns')
ax.grid(True)
plt.show()
填充区域
fill_between()函数用指定的颜色填充绘图区域。 我们可以选择一个可选的 Alpha 通道值。 该函数还具有where参数,以便我们可以根据条件对区域进行着色。
实战时间 – 根据条件遮蔽绘图区域
假设您要在股票图表的某个区域遮蔽,该区域的收盘价低于平均水平,而其颜色高于高于均值的颜色。 fill_between()函数是工作的最佳选择。 我们将再次省略以下步骤:下载一年前的历史数据,提取日期和收盘价以及创建定位器和日期格式化程序。
-
创建一个 matplotlib
Figure对象:fig = plt.figure() -
在该图中添加子图:
ax = fig.add_subplot(111) -
绘制收盘价:
ax.plot(dates, close) -
根据值是低于平均价格还是高于平均价格,使用不同的颜色对低于收盘价的地块区域进行阴影处理:
plt.fill_between(dates, close.min(), close, where=close>close.mean(), facecolor="green", alpha=0.4) plt.fill_between(dates, close.min(), close, where=close<close.mean(), facecolor="red", alpha=0.4)现在,我们可以通过设置定位器并将
x轴值格式化为日期来完成绘制,如图所示。 使用 DISH 的条件阴影的股票价格如下:
刚刚发生了什么?
我们用与高于均值(请参见fillbetween.py)不同的颜色,来着色股票图表中收盘价低于平均水平的区域:
from matplotlib.finance import quotes_historical_yahoo
from matplotlib.dates import DateFormatter
from matplotlib.dates import DayLocator
from matplotlib.dates import MonthLocator
import sys
from datetime import date
import matplotlib.pyplot as plt
import numpy as np
today = date.today()
start = (today.year - 1, today.month, today.day)
symbol = 'DISH'
if len(sys.argv) == 2:
symbol = sys.argv[1]
quotes = quotes_historical_yahoo(symbol, start, today)
quotes = np.array(quotes)
dates = quotes.T[0]
close = quotes.T[4]
alldays = DayLocator()
months = MonthLocator()
month_formatter = DateFormatter("%b %Y")
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(dates, close)
plt.fill_between(dates, close.min(), close, where=close>close.mean(), facecolor="green", alpha=0.4)
plt.fill_between(dates, close.min(), close, where=close<close.mean(), facecolor="red", alpha=0.4)
ax.xaxis.set_major_locator(months)
ax.xaxis.set_minor_locator(alldays)
ax.xaxis.set_major_formatter(month_formatter)
ax.grid(True)
fig.autofmt_xdate()
plt.show()
图例和标注
图例和标注对于良好的绘图至关重要。 我们可以使用legend()函数创建透明的图例,然后让matplotlib找出放置它们的位置。 同样,通过annotate()函数,我们可以准确地在图形上进行标注。 有大量的标注和箭头样式。
实战时间 – 使用图例和标注
在第 3 章,“熟悉常用函数”中,我们学习了如何计算股票价格的 EMA。 我们将绘制股票的收盘价及其三只 EMA 的收盘价。 为了阐明绘图,我们将添加一个图例。 我们还将用标注指示两个平均值的交叉。 为了避免重复,再次省略了某些步骤。
-
返回第 3 章“熟悉常用函数”,如果需要,并查看 EMA 算法。 计算并绘制 9,12 和 15 周期的 EMA:
emas = [] for i in range(9, 18, 3): weights = np.exp(np.linspace(-1., 0., i)) weights /= weights.sum() ema = np.convolve(weights, close)[i-1:-i+1] idx = (i - 6)/3 ax.plot(dates[i-1:], ema, lw=idx, label="EMA(%s)" % (i)) data = np.column_stack((dates[i-1:], ema)) emas.append(np.rec.fromrecords( data, names=["dates", "ema"]))请注意,
plot()函数调用需要图例标签。 我们将移动平均值存储在记录数组中,以进行下一步。 -
让我们找到前两个移动均线的交叉点:
first = emas[0]["ema"].flatten() second = emas[1]["ema"].flatten() bools = np.abs(first[-len(second):] - second)/second < 0.0001 xpoints = np.compress(bools, emas[1]) -
现在我们有了交叉点,用箭头标注它们。 确保标注文本稍微偏离交叉点:
for xpoint in xpoints: ax.annotate('x', xy=xpoint, textcoords='offset points', xytext=(-50, 30), arrowprops=dict(arrowstyle="->")) -
添加图例,然后让
matplotlib决定将其放置在何处:leg = ax.legend(loc='best', fancybox=True)) -
通过设置 Alpha 通道值使图例透明:
leg.get_frame().set_alpha(0.5)带有图例和标注的股票价格和移动均线如下所示:
刚刚发生了什么?
我们绘制了股票的收盘价及其三个 EMA。 我们在剧情中添加了图例。 我们用标注标注了前两个平均值的交叉点(请参见emalegend.py):
from matplotlib.finance import quotes_historical_yahoo
from matplotlib.dates import DateFormatter
from matplotlib.dates import DayLocator
from matplotlib.dates import MonthLocator
import sys
from datetime import date
import matplotlib.pyplot as plt
import numpy as np
today = date.today()
start = (today.year - 1, today.month, today.day)
symbol = 'DISH'
if len(sys.argv) == 2:
symbol = sys.argv[1]
quotes = quotes_historical_yahoo(symbol, start, today)
quotes = np.array(quotes)
dates = quotes.T[0]
close = quotes.T[4]
fig = plt.figure()
ax = fig.add_subplot(111)
emas = []
for i in range(9, 18, 3):
weights = np.exp(np.linspace(-1., 0., i))
weights /= weights.sum()
ema = np.convolve(weights, close)[i-1:-i+1]
idx = (i - 6)/3
ax.plot(dates[i-1:], ema, lw=idx, label="EMA(%s)" % (i))
data = np.column_stack((dates[i-1:], ema))
emas.append(np.rec.fromrecords(data, names=["dates", "ema"]))
first = emas[0]["ema"].flatten()
second = emas[1]["ema"].flatten()
bools = np.abs(first[-len(second):] - second)/second < 0.0001
xpoints = np.compress(bools, emas[1])
for xpoint in xpoints:
ax.annotate('x', xy=xpoint, textcoords='offset points',
xytext=(-50, 30),
arrowprops=dict(arrowstyle="->"))
leg = ax.legend(loc='best', fancybox=True)
leg.get_frame().set_alpha(0.5)
alldays = DayLocator()
months = MonthLocator()
month_formatter = DateFormatter("%b %Y")
ax.plot(dates, close, lw=1.0, label="Close")
ax.xaxis.set_major_locator(months)
ax.xaxis.set_minor_locator(alldays)
ax.xaxis.set_major_formatter(month_formatter)
ax.grid(True)
fig.autofmt_xdate()
plt.show()
三维绘图
三维图非常壮观,因此我们也必须在此处进行介绍。 对于三维图,我们需要一个与3D投影关联的Axes3D对象。
实战时间 – 三维绘图
我们将绘制一个简单的三维函数:
-
使用 3D 关键字为绘图指定三维投影:
ax = fig.add_subplot(111, projection='3d') -
要创建方形二维网格,请使用
meshgrid()函数初始化x和y值:u = np.linspace(-1, 1, 100) x, y = np.meshgrid(u, u) -
我们将为表面图指定行跨度,列跨度和颜色图。 步幅决定了表面砖的尺寸。 颜色图的选择取决于风格:
ax.plot_surface(x, y, z, rstride=4, cstride=4, cmap=cm.YlGnBu_r)结果是以下三维图:
刚刚发生了什么?
我们创建了一个三维函数的绘图(请参见three_d.py):
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
u = np.linspace(-1, 1, 100)
x, y = np.meshgrid(u, u)
z = x ** 2 + y ** 2
ax.plot_surface(x, y, z, rstride=4, cstride=4, cmap=cm.YlGnBu_r)
plt.show()
等高线图
matplotlib等高线三维图有两种样式-填充的和未填充的。 等高线图使用所谓的等高线。 您可能熟悉地理地图上的等高线。 在此类地图中,等高线连接了海拔相同高度的点。 我们可以使用contour()函数创建法线等高线图。 对于填充的等高线图,我们使用contourf()函数。
实战时间 – 绘制填充的等高线图
我们将在前面的“实战时间”部分中绘制三维数学函数的填充等高线图 。 代码也非常相似。 一个主要区别是我们不再需要3D投影参数。 要绘制填充的等高线图,请使用以下代码行:
ax.contourf(x, y, z)
这为我们提供了以下填充等高线图:
刚刚发生了什么?
我们创建了三维数学函数的填充等高线图(请参见contour.py):
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm
fig = plt.figure()
ax = fig.add_subplot(111)
u = np.linspace(-1, 1, 100)
x, y = np.meshgrid(u, u)
z = x ** 2 + y ** 2
ax.contourf(x, y, z)
plt.show()
动画
matplotlib通过特殊的动画模块提供精美的动画函数。 我们需要定义一个用于定期更新屏幕的回调函数。 我们还需要一个函数来生成要绘制的数据。
实战时间 – 动画绘图
我们将绘制三个随机数据集 ,并将它们显示为圆形,点和三角形。 但是,我们将仅使用随机值更新其中两个数据集。
-
以不同的颜色绘制三个随机数据集,如圆形,点和三角形:
circles, triangles, dots = ax.plot(x, 'ro', y, 'g^', z, 'b.') -
调用此函数可以定期更新屏幕。 使用新的
y值更新两个图:def update(data): circles.set_ydata(data[0]) triangles.set_ydata(data[1]) return circles, triangles -
使用 NumPy 生成随机数据:
def generate(): while True: yield np.random.rand(2, N)以下是运行中的动画的快照:
刚刚发生了什么?
我们创建了一个随机数据点的动画 (请参见animation.py):
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
fig = plt.figure()
ax = fig.add_subplot(111)
N = 10
x = np.random.rand(N)
y = np.random.rand(N)
z = np.random.rand(N)
circles, triangles, dots = ax.plot(x, 'ro', y, 'g^', z, 'b.')
ax.set_ylim(0, 1)
plt.axis('off')
def update(data):
circles.set_ydata(data[0])
triangles.set_ydata(data[1])
return circles, triangles
def generate():
while True: yield np.random.rand(2, N)
anim = animation.FuncAnimation(fig, update, generate, interval=150)
plt.show()
总结
本章是关于matplotlib的-Python 绘图库。 我们涵盖了简单图,直方图,图自定义,子图,三维图,等高线图和对数图。 您还看到了一些显示股票走势图的示例。 显然,我们只是刮擦了表面,只是看到了冰山的一角。 matplotlib的功能非常丰富,因此我们没有足够的空间来覆盖 Latex 支持,极坐标支持和其他功能。
matplotlib的作者 John Hunter 于 2012 年 8 月去世。该书的一位技术评论家建议提及John Hunter 纪念基金。 NumFocus 基金会设立的纪念基金为我们(约翰·亨特的工作迷)提供了一个“回馈”的机会。 同样,有关更多详细信息,请查看前面的 NumFocus 网站链接。
下一章将介绍 SciPy,这是一个基于 NumPy 构建的科学 Python 框架。