Machine Learning Mastery Python 教程(一)
Python 代码分析
性能分析是一种确定程序中时间花费的技术。通过这些统计数据,我们可以找到程序的“热点”并考虑改进的方法。有时,意外位置的热点也可能暗示程序中的一个错误。
在本教程中,我们将看到如何使用 Python 中的性能分析功能。具体来说,你将看到:
-
我们如何使用
timeit模块比较小的代码片段 -
我们如何使用
cProfile模块对整个程序进行分析 -
我们如何在现有程序中调用分析器
-
分析器无法做的事情
启动你的项目,请阅读我的新书《机器学习中的 Python》,包括逐步教程和所有示例的Python 源代码文件。
Python 代码分析。照片由Prashant Saini提供。部分权利保留。
教程概述
本教程分为四部分;它们是:
-
分析小片段
-
分析模块
-
在代码中使用分析器
-
注意事项
分析小片段
当你被问到 Python 中完成相同任务的不同方法时,一种观点是检查哪种方法更高效。在 Python 的标准库中,我们有timeit模块,它允许我们进行一些简单的性能分析。
例如,要连接多个短字符串,我们可以使用字符串的join()函数或+运算符。那么,我们如何知道哪个更快呢?请考虑以下 Python 代码:
longstr = ""
for x in range(1000):
longstr += str(x)
这将产生一个长字符串012345....在变量longstr中。另一种写法是:
longstr = "".join([str(x) for x in range(1000)])
为了比较这两者,我们可以在命令行中执行以下操作:
python -m timeit 'longstr=""' 'for x in range(1000): longstr += str(x)'
python -m timeit '"".join([str(x) for x in range(1000)])'
这两个命令将产生以下输出:
1000 loops, best of 5: 265 usec per loop
2000 loops, best of 5: 160 usec per loop
上述命令用于加载timeit模块,并传递一行代码进行测量。在第一种情况下,我们有两行语句,它们作为两个单独的参数传递给timeit模块。按照相同的原理,第一条命令也可以呈现为三行语句(通过将 for 循环拆分成两行),但每行的缩进需要正确地引用:
python -m timeit 'longstr=""' 'for x in range(1000):' ' longstr += str(x)'
timeit的输出是找到多次运行中的最佳性能(默认为 5 次)。每次运行是多次执行提供的语句(次数是动态确定的)。时间以最佳运行中执行一次语句的平均时间来报告。
虽然join函数在字符串连接方面比+运算符更快,但上面的计时结果并不是公平的比较。这是因为我们在循环过程中使用str(x)来即时生成短字符串。更好的做法是如下:
python -m timeit -s 'strings = [str(x) for x in range(1000)]' 'longstr=""' 'for x in strings:' ' longstr += str(x)'
python -m timeit -s 'strings = [str(x) for x in range(1000)]' '"".join(strings)'
产生:
2000 loops, best of 5: 173 usec per loop
50000 loops, best of 5: 6.91 usec per loop
-s选项允许我们提供“设置”代码,该代码在分析之前执行且不计时。在上述代码中,我们在开始循环之前创建了短字符串列表。因此,创建这些字符串的时间不计入“每次循环”的时间。上述结果显示,join()函数比+运算符快两个数量级。-s选项的更常见用法是导入库。例如,我们可以比较 Python 数学模块中的平方根函数与 NumPy,并使用指数运算符**:
python -m timeit '[x**0.5 for x in range(1000)]'
python -m timeit -s 'from math import sqrt' '[sqrt(x) for x in range(1000)]'
python -m timeit -s 'from numpy import sqrt' '[sqrt(x) for x in range(1000)]'
上述结果产生了以下测量,我们可以看到在这个特定的例子中,math.sqrt()是最快的,而numpy.sqrt()是最慢的:
5000 loops, best of 5: 93.2 usec per loop
5000 loops, best of 5: 72.3 usec per loop
200 loops, best of 5: 974 usec per loop
如果你想知道为什么 NumPy 是最慢的,那是因为 NumPy 是为数组优化的。你将在以下替代方案中看到它的卓越速度:
python -m timeit -s 'import numpy as np; x=np.array(range(1000))' 'np.sqrt(x)'
结果如下:
100000 loops, best of 5: 2.08 usec per loop
如果你愿意,你也可以在 Python 代码中运行timeit。例如,下面的代码将类似于上述代码,但会给你每次运行的总原始时间:
import timeit
measurements = timeit.repeat('[x**0.5 for x in range(1000)]', number=10000)
print(measurements)
在上述代码中,每次运行都是执行语句 10,000 次;结果如下。你可以看到在最佳运行中的每次循环约为 98 微秒的结果:
[1.0888952040000106, 0.9799715450000122, 1.0921516899999801, 1.0946189250000202, 1.2792069260000005]
性能分析模块
关注一两个语句的性能是微观的角度。很可能,我们有一个很长的程序,想要查看是什么导致它运行缓慢。这是在考虑替代语句或算法之前发生的情况。
程序运行缓慢通常有两个原因:某一部分运行缓慢,或者某一部分运行次数过多,累计起来耗时过长。我们将这些“性能消耗者”称为热点。我们来看一个例子。考虑以下程序,它使用爬山算法来寻找感知机模型的超参数:
# manually search perceptron hyperparameters for binary classification
from numpy import mean
from numpy.random import randn
from numpy.random import rand
from sklearn.datasets import make_classification
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.linear_model import Perceptron
# objective function
def objective(X, y, cfg):
# unpack config
eta, alpha = cfg
# define model
model = Perceptron(penalty='elasticnet', alpha=alpha, eta0=eta)
# define evaluation procedure
cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=3, random_state=1)
# evaluate model
scores = cross_val_score(model, X, y, scoring='accuracy', cv=cv, n_jobs=-1)
# calculate mean accuracy
result = mean(scores)
return result
# take a step in the search space
def step(cfg, step_size):
# unpack the configuration
eta, alpha = cfg
# step eta
new_eta = eta + randn() * step_size
# check the bounds of eta
if new_eta <= 0.0:
new_eta = 1e-8
if new_eta > 1.0:
new_eta = 1.0
# step alpha
new_alpha = alpha + randn() * step_size
# check the bounds of alpha
if new_alpha < 0.0:
new_alpha = 0.0
# return the new configuration
return [new_eta, new_alpha]
# hill climbing local search algorithm
def hillclimbing(X, y, objective, n_iter, step_size):
# starting point for the search
solution = [rand(), rand()]
# evaluate the initial point
solution_eval = objective(X, y, solution)
# run the hill climb
for i in range(n_iter):
# take a step
candidate = step(solution, step_size)
# evaluate candidate point
candidate_eval = objective(X, y, candidate)
# check if we should keep the new point
if candidate_eval >= solution_eval:
# store the new point
solution, solution_eval = candidate, candidate_eval
# report progress
print('>%d, cfg=%s %.5f' % (i, solution, solution_eval))
return [solution, solution_eval]
# define dataset
X, y = make_classification(n_samples=1000, n_features=5, n_informative=2, n_redundant=1, random_state=1)
# define the total iterations
n_iter = 100
# step size in the search space
step_size = 0.1
# perform the hill climbing search
cfg, score = hillclimbing(X, y, objective, n_iter, step_size)
print('Done!')
print('cfg=%s: Mean Accuracy: %f' % (cfg, score))
假设我们将此程序保存到文件hillclimb.py中,我们可以在命令行中按如下方式运行分析器:
python -m cProfile hillclimb.py
输出将如下所示:
>10, cfg=[0.3792455490265847, 0.21589566352848377] 0.78400
>17, cfg=[0.49105438202347707, 0.1342150084854657] 0.79833
>26, cfg=[0.5737524712834843, 0.016749795596210315] 0.80033
>47, cfg=[0.5067828976025809, 0.05280380038497864] 0.80133
>48, cfg=[0.5427345321546029, 0.0049895870979695875] 0.81167
Done!
cfg=[0.5427345321546029, 0.0049895870979695875]: Mean Accuracy: 0.811667
2686451 function calls (2638255 primitive calls) in 5.500 seconds
Ordered by: standard name
ncalls tottime percall cumtime percall filename:lineno(function)
101 0.001 0.000 4.892 0.048 hillclimb.py:11(objective)
1 0.000 0.000 5.501 5.501 hillclimb.py:2(<module>)
100 0.000 0.000 0.001 0.000 hillclimb.py:25(step)
1 0.001 0.001 4.894 4.894 hillclimb.py:44(hillclimbing)
1 0.000 0.000 0.000 0.000 <__array_function__ internals>:2(<module>)
303 0.000 0.000 0.008 0.000 <__array_function__ internals>:2(all)
303 0.000 0.000 0.005 0.000 <__array_function__ internals>:2(amin)
2 0.000 0.000 0.000 0.000 <__array_function__ internals>:2(any)
4 0.000 0.000 0.000 0.000 <__array_function__ internals>:2(atleast_1d)
3333 0.003 0.000 0.018 0.000 <__array_function__ internals>:2(bincount)
103 0.000 0.000 0.001 0.000 <__array_function__ internals>:2(concatenate)
3 0.000 0.000 0.000 0.000 <__array_function__ internals>:2(copyto)
606 0.001 0.000 0.010 0.000 <__array_function__ internals>:2(cumsum)
6 0.000 0.000 0.000 0.000 <__array_function__ internals>:2(dot)
1 0.000 0.000 0.000 0.000 <__array_function__ internals>:2(empty_like)
1 0.000 0.000 0.000 0.000 <__array_function__ internals>:2(inv)
2 0.000 0.000 0.000 0.000 <__array_function__ internals>:2(linspace)
1 0.000 0.000 0.000 0.000 <__array_function__ internals>:2(lstsq)
101 0.000 0.000 0.005 0.000 <__array_function__ internals>:2(mean)
2 0.000 0.000 0.000 0.000 <__array_function__ internals>:2(ndim)
1 0.000 0.000 0.000 0.000 <__array_function__ internals>:2(outer)
1 0.000 0.000 0.000 0.000 <__array_function__ internals>:2(polyfit)
1 0.000 0.000 0.000 0.000 <__array_function__ internals>:2(polyval)
1 0.000 0.000 0.000 0.000 <__array_function__ internals>:2(prod)
303 0.000 0.000 0.002 0.000 <__array_function__ internals>:2(ravel)
2 0.000 0.000 0.000 0.000 <__array_function__ internals>:2(result_type)
303 0.001 0.000 0.001 0.000 <__array_function__ internals>:2(shape)
303 0.000 0.000 0.035 0.000 <__array_function__ internals>:2(sort)
4 0.000 0.000 0.000 0.000 <__array_function__ internals>:2(trim_zeros)
1617 0.002 0.000 0.112 0.000 <__array_function__ internals>:2(unique)
...
程序的正常输出会首先被打印,然后是分析器的统计信息。从第一行,我们可以看到我们程序中的objective()函数已运行 101 次,耗时 4.89 秒。但这 4.89 秒大部分时间都花在了它调用的函数上,该函数总共只花费了 0.001 秒。依赖模块中的函数也被分析。因此,你会看到很多 NumPy 函数。
上述输出很长,可能对你没有帮助,因为很难判断哪个函数是热点。实际上,我们可以对上述输出进行排序。例如,为了查看哪个函数被调用的次数最多,我们可以按ncalls进行排序:
python -m cProfile -s ncalls hillclimb.py
它的输出如下:它表示 Python 字典中的get()函数是使用最频繁的函数(但它在 5.6 秒完成程序中只消耗了 0.03 秒):
2685349 function calls (2637153 primitive calls) in 5.609 seconds
Ordered by: call count
ncalls tottime percall cumtime percall filename:lineno(function)
247588 0.029 0.000 0.029 0.000 {method 'get' of 'dict' objects}
246196 0.028 0.000 0.028 0.000 inspect.py:2548(name)
168057 0.018 0.000 0.018 0.000 {method 'append' of 'list' objects}
161738 0.018 0.000 0.018 0.000 inspect.py:2560(kind)
144431 0.021 0.000 0.029 0.000 {built-in method builtins.isinstance}
142213 0.030 0.000 0.031 0.000 {built-in method builtins.getattr}
...
其他排序选项如下:
| 排序字符串 | 含义 |
|---|---|
| 调用次数 | 调用计数 |
| cumulative | 累积时间 |
| cumtime | 累积时间 |
| file | 文件名 |
| filename | 文件名 |
| module | 文件名 |
| ncalls | 调用次数 |
| pcalls | 原始调用次数 |
| line | 行号 |
| name | 函数名 |
| nfl | 名称/文件/行 |
| stdname | 标准名称 |
| time | 内部时间 |
| tottime | 内部时间 |
如果程序完成需要一些时间,那么为了找到不同排序方式的分析结果,重复运行程序是不合理的。事实上,我们可以保存分析器的统计数据以便进一步处理,方法如下:
python -m cProfile -o hillclimb.stats hillclimb.py
类似于上述情况,它将运行程序。但这不会将统计数据打印到屏幕上,而是将其保存到一个文件中。之后,我们可以像以下这样使用pstats模块打开统计文件,并提供一个提示以操作数据:
python -m pstats hillclimb.stats
例如,我们可以使用排序命令来更改排序顺序,并使用 stats 打印我们看到的内容:
Welcome to the profile statistics browser.
hillclimb.stat% help
Documented commands (type help <topic>):
========================================
EOF add callees callers help quit read reverse sort stats strip
hillclimb.stat% sort ncall
hillclimb.stat% stats hillclimb
Thu Jan 13 16:44:10 2022 hillclimb.stat
2686227 function calls (2638031 primitive calls) in 5.582 seconds
Ordered by: call count
List reduced from 3456 to 4 due to restriction <'hillclimb'>
ncalls tottime percall cumtime percall filename:lineno(function)
101 0.001 0.000 4.951 0.049 hillclimb.py:11(objective)
100 0.000 0.000 0.001 0.000 hillclimb.py:25(step)
1 0.000 0.000 5.583 5.583 hillclimb.py:2(<module>)
1 0.000 0.000 4.952 4.952 hillclimb.py:44(hillclimbing)
hillclimb.stat%
你会注意到上述stats命令允许我们提供一个额外的参数。该参数可以是一个正则表达式,用于搜索函数,以便仅打印匹配的函数。因此,这是一种提供搜索字符串进行过滤的方法。
想要开始使用 Python 进行机器学习吗?
现在就参加我的 7 天免费电子邮件速成课程(附带示例代码)。
点击注册,并获得课程的免费 PDF 电子书版本。
这个pstats浏览器允许我们查看的不仅仅是上述表格。callers和callees命令显示了哪些函数调用了哪些函数,调用了多少次,以及花费了多少时间。因此,我们可以将其视为函数级别统计数据的细分。如果你有很多相互调用的函数,并且想要了解不同场景下时间的分配情况,这很有用。例如,这显示了objective()函数仅由hillclimbing()函数调用,而hillclimbing()函数调用了其他几个函数:
hillclimb.stat% callers objective
Ordered by: call count
List reduced from 3456 to 1 due to restriction <'objective'>
Function was called by...
ncalls tottime cumtime
hillclimb.py:11(objective) <- 101 0.001 4.951 hillclimb.py:44(hillclimbing)
hillclimb.stat% callees hillclimbing
Ordered by: call count
List reduced from 3456 to 1 due to restriction <'hillclimbing'>
Function called...
ncalls tottime cumtime
hillclimb.py:44(hillclimbing) -> 101 0.001 4.951 hillclimb.py:11(objective)
100 0.000 0.001 hillclimb.py:25(step)
4 0.000 0.000 {built-in method builtins.print}
2 0.000 0.000 {method 'rand' of 'numpy.random.mtrand.RandomState' objects}
hillclimb.stat%
在代码中使用分析器
上述示例假设你已经将完整程序保存到一个文件中,并对整个程序进行了分析。有时,我们只关注程序的一部分。例如,如果我们加载了一个大型模块,它需要时间进行引导,并且我们想要从分析器中移除这个部分。在这种情况下,我们可以仅针对某些行调用分析器。以下是一个示例,来自于上述程序的修改:
# manually search perceptron hyperparameters for binary classification
import cProfile as profile
import pstats
from numpy import mean
from numpy.random import randn
from numpy.random import rand
from sklearn.datasets import make_classification
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.linear_model import Perceptron
# objective function
def objective(X, y, cfg):
# unpack config
eta, alpha = cfg
# define model
model = Perceptron(penalty='elasticnet', alpha=alpha, eta0=eta)
# define evaluation procedure
cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=3, random_state=1)
# evaluate model
scores = cross_val_score(model, X, y, scoring='accuracy', cv=cv, n_jobs=-1)
# calculate mean accuracy
result = mean(scores)
return result
# take a step in the search space
def step(cfg, step_size):
# unpack the configuration
eta, alpha = cfg
# step eta
new_eta = eta + randn() * step_size
# check the bounds of eta
if new_eta <= 0.0:
new_eta = 1e-8
if new_eta > 1.0:
new_eta = 1.0
# step alpha
new_alpha = alpha + randn() * step_size
# check the bounds of alpha
if new_alpha < 0.0:
new_alpha = 0.0
# return the new configuration
return [new_eta, new_alpha]
# hill climbing local search algorithm
def hillclimbing(X, y, objective, n_iter, step_size):
# starting point for the search
solution = [rand(), rand()]
# evaluate the initial point
solution_eval = objective(X, y, solution)
# run the hill climb
for i in range(n_iter):
# take a step
candidate = step(solution, step_size)
# evaluate candidate point
candidate_eval = objective(X, y, candidate)
# check if we should keep the new point
if candidate_eval >= solution_eval:
# store the new point
solution, solution_eval = candidate, candidate_eval
# report progress
print('>%d, cfg=%s %.5f' % (i, solution, solution_eval))
return [solution, solution_eval]
# define dataset
X, y = make_classification(n_samples=1000, n_features=5, n_informative=2, n_redundant=1, random_state=1)
# define the total iterations
n_iter = 100
# step size in the search space
step_size = 0.1
# perform the hill climbing search with profiling
prof = profile.Profile()
prof.enable()
cfg, score = hillclimbing(X, y, objective, n_iter, step_size)
prof.disable()
# print program output
print('Done!')
print('cfg=%s: Mean Accuracy: %f' % (cfg, score))
# print profiling output
stats = pstats.Stats(prof).strip_dirs().sort_stats("cumtime")
stats.print_stats(10) # top 10 rows
它将输出以下内容:
>0, cfg=[0.3776271076534661, 0.2308364063203663] 0.75700
>3, cfg=[0.35803234662466354, 0.03204434939660264] 0.77567
>8, cfg=[0.3001050823005957, 0.0] 0.78633
>10, cfg=[0.39518618870158934, 0.0] 0.78633
>12, cfg=[0.4291267905390187, 0.0] 0.78633
>13, cfg=[0.4403131521968569, 0.0] 0.78633
>16, cfg=[0.38865272555918756, 0.0] 0.78633
>17, cfg=[0.38871654921891885, 0.0] 0.78633
>18, cfg=[0.4542440671724224, 0.0] 0.78633
>19, cfg=[0.44899743344802734, 0.0] 0.78633
>20, cfg=[0.5855375509507891, 0.0] 0.78633
>21, cfg=[0.5935318064858227, 0.0] 0.78633
>23, cfg=[0.7606367310048543, 0.0] 0.78633
>24, cfg=[0.855444293727846, 0.0] 0.78633
>25, cfg=[0.9505501566826242, 0.0] 0.78633
>26, cfg=[1.0, 0.0244821888204496] 0.79800
Done!
cfg=[1.0, 0.0244821888204496]: Mean Accuracy: 0.798000
2179559 function calls (2140124 primitive calls) in 4.941 seconds
Ordered by: cumulative time
List reduced from 581 to 10 due to restriction <10>
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.001 0.001 4.941 4.941 hillclimb.py:46(hillclimbing)
101 0.001 0.000 4.939 0.049 hillclimb.py:13(objective)
101 0.001 0.000 4.931 0.049 _validation.py:375(cross_val_score)
101 0.002 0.000 4.930 0.049 _validation.py:48(cross_validate)
101 0.005 0.000 4.903 0.049 parallel.py:960(__call__)
101 0.235 0.002 3.089 0.031 parallel.py:920(retrieve)
3030 0.004 0.000 2.849 0.001 _parallel_backends.py:537(wrap_future_result)
3030 0.020 0.000 2.845 0.001 _base.py:417(result)
2602 0.016 0.000 2.819 0.001 threading.py:280(wait)
12447 2.796 0.000 2.796 0.000 {method 'acquire' of '_thread.lock' objects}
注意事项
使用 Tensorflow 模型进行分析可能不会产生你预期的结果,特别是如果你为模型编写了自定义层或自定义函数。如果你正确地完成了这项工作,Tensorflow 应该在执行模型之前构建计算图,因此逻辑将发生变化。因此,分析器输出将不会显示你的自定义类。
对于涉及二进制代码的一些高级模块也是如此。分析器可以看到你调用了一些函数,并将它们标记为“内置”方法,但它无法进一步深入编译代码。
下面是用于 MNIST 分类问题的 LeNet5 模型的简短代码。如果你尝试分析它并打印前 15 行,你会看到一个包装器占据了大部分时间,而无法显示更多内容:
import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Dense, AveragePooling2D, Flatten
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import EarlyStopping
# Load and reshape data to shape of (n_sample, height, width, n_channel)
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = np.expand_dims(X_train, axis=3).astype('float32')
X_test = np.expand_dims(X_test, axis=3).astype('float32')
# One-hot encode the output
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
# LeNet5 model
model = Sequential([
Conv2D(6, (5,5), input_shape=(28,28,1), padding="same", activation="tanh"),
AveragePooling2D((2,2), strides=2),
Conv2D(16, (5,5), activation="tanh"),
AveragePooling2D((2,2), strides=2),
Conv2D(120, (5,5), activation="tanh"),
Flatten(),
Dense(84, activation="tanh"),
Dense(10, activation="softmax")
])
model.summary(line_length=100)
# Training
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
earlystopping = EarlyStopping(monitor="val_loss", patience=2, restore_best_weights=True)
model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=20, batch_size=32, callbacks=[earlystopping])
# Evaluate
print(model.evaluate(X_test, y_test, verbose=0))
在下面的结果中,TFE_Py_Execute 被标记为“内置”方法,占用了总运行时间 39.6 秒中的 30.1 秒。注意 tottime 与 cumtime 相同,这意味着从分析器的角度来看,似乎所有时间都花费在这个函数上,并且它没有调用其他函数。这说明了 Python 分析器的局限性。
5962698 function calls (5728324 primitive calls) in 39.674 seconds
Ordered by: cumulative time
List reduced from 12295 to 15 due to restriction <15>
ncalls tottime percall cumtime percall filename:lineno(function)
3212/1 0.013 0.000 39.699 39.699 {built-in method builtins.exec}
1 0.003 0.003 39.699 39.699 mnist.py:4(<module>)
52/4 0.005 0.000 35.470 8.868 /usr/local/lib/python3.9/site-packages/keras/utils/traceback_utils.py:58(error_handler)
1 0.089 0.089 34.334 34.334 /usr/local/lib/python3.9/site-packages/keras/engine/training.py:901(fit)
11075/9531 0.032 0.000 33.406 0.004 /usr/local/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py:138(error_handler)
4689 0.089 0.000 33.017 0.007 /usr/local/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py:882(__call__)
4689 0.023 0.000 32.771 0.007 /usr/local/lib/python3.9/site-packages/tensorflow/python/eager/def_function.py:929(_call)
4688 0.042 0.000 32.134 0.007 /usr/local/lib/python3.9/site-packages/tensorflow/python/eager/function.py:3125(__call__)
4689 0.075 0.000 30.941 0.007 /usr/local/lib/python3.9/site-packages/tensorflow/python/eager/function.py:1888(_call_flat)
4689 0.158 0.000 30.472 0.006 /usr/local/lib/python3.9/site-packages/tensorflow/python/eager/function.py:553(call)
4689 0.034 0.000 30.152 0.006 /usr/local/lib/python3.9/site-packages/tensorflow/python/eager/execute.py:33(quick_execute)
4689 30.105 0.006 30.105 0.006 {built-in method tensorflow.python._pywrap_tfe.TFE_Py_Execute}
3185/24 0.021 0.000 3.902 0.163 <frozen importlib._bootstrap>:1002(_find_and_load)
3169/10 0.014 0.000 3.901 0.390 <frozen importlib._bootstrap>:967(_find_and_load_unlocked)
2885/12 0.009 0.000 3.901 0.325 <frozen importlib._bootstrap_external>:844(exec_module)
最终,Python 的分析器仅提供时间统计信息,而不包括内存使用情况。你可能需要寻找其他库或工具来实现这一目的。
深入阅读
标准库模块 timeit、cProfile 和 pstats 的文档可以在 Python 的文档中找到:
-
timeit模块:docs.python.org/3/library/timeit.html -
cProfile模块和pstats模块:docs.python.org/3/library/profile.html
标准库的分析器非常强大,但不是唯一的。如果你想要更具视觉效果的工具,你可以尝试 Python Call Graph 模块。它可以使用 GraphViz 工具生成函数调用关系图:
- Python Call Graph:
pycallgraph.readthedocs.io/en/master/
无法深入编译代码的限制可以通过不使用 Python 的分析器而是使用针对编译程序的分析器来解决。我最喜欢的是 Valgrind:
- Valgrind:
valgrind.org/
但要使用它,你可能需要重新编译你的 Python 解释器以启用调试支持。
总结
在本教程中,我们了解了什么是分析器以及它能做什么。具体来说,
-
我们知道如何使用
timeit模块比较小代码片段。 -
我们看到 Python 的
cProfile模块可以提供有关时间使用的详细统计数据。 -
我们学会了如何使用
pstats模块对cProfile的输出进行排序或过滤。
部署 Python 项目的第一课
原文:
machinelearningmastery.com/a-first-course-on-deploying-python-projects/
在用 Python 开发项目的艰苦工作之后,我们想与其他人分享我们的项目。可以是你的朋友或同事。也许他们对你的代码不感兴趣,但他们希望运行并实际使用它。例如,你创建了一个回归模型,可以根据输入特征预测一个值。你的朋友希望提供自己的特征,看看你的模型预测了什么值。但随着你的 Python 项目变大,发送一个小脚本给朋友就不那么简单了。可能有许多支持文件、多重脚本,还依赖于一个库列表。正确处理这些问题可能是一个挑战。
完成本教程后,你将学习到:
-
如何通过将代码模块化来增强其部署的简易性
-
如何为你的模块创建一个包,以便我们可以依赖
pip来管理依赖 -
如何使用 venv 模块创建可重复的运行环境
快速启动你的项目,请参考我的新书 Python for Machine Learning,包括 逐步教程 和 Python 源代码 文件,涵盖所有示例。
部署 Python 项目的第一课
图片来源于 Kelly L。版权所有。
概述
本教程分为四个部分,它们是:
-
从开发到部署
-
创建模块
-
从模块到包
-
为你的项目使用 venv
从开发到部署
当我们完成一个 Python 项目时,有时我们不想将其搁置,而是希望将其转变为常规工作。我们可能完成了一个机器学习模型的训练,并积极使用训练好的模型进行预测。我们可能构建了一个时间序列模型,并用它进行下一步预测。然而,新数据每天都在进入,所以我们需要重新训练模型,以适应发展,并保持未来预测的准确性。
无论原因如何,我们需要确保程序按预期运行。然而,这可能比我们想象的要困难得多。一个简单的 Python 脚本可能不是什么大问题,但随着程序变大,依赖增多,许多事情可能会出错。例如,我们使用的库的新版可能会破坏工作流程。或者我们的 Python 脚本可能运行某个外部程序,而在操作系统升级后,该程序可能停止工作。另一种情况是程序依赖于位于特定路径的文件,但我们可能会不小心删除或重命名文件。
我们的程序总是有可能执行失败的。但我们有一些技巧可以使它更稳健,更可靠。
创建模块
在之前的文章中,我们演示了如何使用以下命令检查代码片段的完成时间:
python -m timeit -s 'import numpy as np' 'np.random.random()'
同时,我们还可以将其作为脚本的一部分来使用,并执行以下操作:
import timeit
import numpy as np
time = timeit.timeit("np.random.random()", globals=globals())
print(time)
Python 中的import语句允许你重用定义在另一个文件中的函数,将其视为模块。你可能会想知道我们如何让一个模块不仅提供函数,还能成为一个可执行程序。这是帮助我们部署代码的第一步。如果我们能让模块可执行,用户将无需理解我们的代码结构即可使用它。
如果我们的程序足够大,有多个文件,最好将其打包成一个模块。在 Python 中,模块通常是一个包含 Python 脚本的文件夹,并且有一个明确的入口点。因此,这样更方便传递给其他人,并且更容易理解程序的流程。此外,我们可以为模块添加版本,并让pip跟踪安装的版本。
一个简单的单文件程序可以如下编写:
import random
def main():
n = random.random()
print(n)
if __name__ == "__main__":
main()
如果我们将其保存为randomsample.py在本地目录中,我们可以通过以下方式运行它:
python randomsample.py
或:
python -m randomsample
我们可以通过以下方式在另一个脚本中重用这些函数:
import randomsample
randomsample.main()
这样有效是因为魔法变量__name__ 只有在脚本作为主程序运行时才会是"__main__",而在从另一个脚本导入时不会是。这样,你的机器学习项目可以可能被打包成如下形式:
regressor/
__init__.py
data.json
model.pickle
predict.py
train.py
现在,regressor是一个包含这五个文件的目录。__init__.py是一个空文件,仅用于表示该目录是一个可以import的 Python 模块。脚本train.py如下所示:
import os
import json
import pickle
from sklearn.linear_model import LinearRegression
def load_data():
current_dir = os.path.dirname(os.path.realpath(__file__))
filepath = os.path.join(current_dir, "data.json")
data = json.load(open(filepath))
return data
def train():
reg = LinearRegression()
data = load_data()
reg.fit(data["data"], data["target"])
return reg
predict.py的脚本是:
import os
import pickle
import sys
import numpy as np
def predict(features):
current_dir = os.path.dirname(os.path.realpath(__file__))
filepath = os.path.join(current_dir, "model.pickle")
with open(filepath, "rb") as fp:
reg = pickle.load(fp)
return reg.predict(features)
if __name__ == "__main__":
arr = np.asarray(sys.argv[1:]).astype(float).reshape(1,-1)
y = predict(arr)
print(y[0])
然后,我们可以在regressor/的父目录下运行以下命令来加载数据并训练线性回归模型。然后,我们可以使用 pickle 保存模型:
import pickle
from regressor.train import train
model = train()
with open("model.pickle", "wb") as fp:
pickle.save(model, fp)
如果我们将这个 pickle 文件移动到regressor/目录中,我们还可以在命令行中执行以下操作来运行模型:
python -m regressor.predict 0.186 0 8.3 0 0.62 6.2 58 1.96 6 400 18.1 410 11.5
这里的数值参数是输入特征的向量。如果我们进一步移除if块,即创建一个文件regressor/__main__.py,并使用以下代码:
import sys
import numpy as np
from .predict import predict
if __name__ == "__main__":
arr = np.asarray(sys.argv[1:]).astype(float).reshape(1,-1)
y = predict(arr)
print(y[0])
然后我们可以直接从模块运行模型:
python -m regressor 0.186 0 8.3 0 0.62 6.2 58 1.96 6 400 18.1 410 11.5
注意上例中的form .predict import predict行使用了 Python 的相对导入语法。这应该在模块内部用于从同一模块的其他脚本中导入组件。
想要开始使用 Python 进行机器学习吗?
立即参加我的免费 7 天电子邮件速成课程(附示例代码)。
点击注册,还可以获得课程的免费 PDF 电子书版本。
从模块到包
如果你想将你的 Python 项目作为最终产品进行分发,能够将项目作为包用 pip install 命令安装会很方便。这很容易做到。既然你已经从项目中创建了一个模块,你需要补充一些简单的设置说明。现在你需要创建一个项目目录,并将你的模块放在其中,配上一个 pyproject.toml 文件,一个 setup.cfg 文件和一个 MANIFEST.in 文件。文件结构应如下所示:
project/
pyproject.toml
setup.cfg
MANIFEST.in
regressor/
__init__.py
data.json
model.pickle
predict.py
train.py
我们将使用 setuptools,因为它已成为这项任务的标准。文件 pyproject.toml 用于指定 setuptools:
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"
关键信息在 setup.cfg 中提供。我们需要指定模块的名称、版本、一些可选描述、包含的内容和依赖项,例如以下内容:
[metadata]
name = mlm_demo
version = 0.0.1
description = a simple linear regression model
[options]
packages = regressor
include_package_data = True
python_requires = >=3.6
install_requires =
scikit-learn==1.0.2
numpy>=1.22, <1.23
h5py
MANIFEST.in 只是用来指定我们需要包含哪些额外的文件。在没有包含非 Python 脚本的项目中,这个文件可以省略。但在我们的情况下,我们需要包含训练好的模型和数据文件:
include regressor/data.json
include regressor/model.pickle
然后在项目目录中,我们可以使用以下命令将其作为模块安装到我们的 Python 系统中:
pip install .
随后,以下代码在任何地方都能正常工作,因为 regressor 是我们 Python 安装中的一个可访问模块:
import numpy as np
from regressor.predict import predict
X = np.asarray([[0.186,0,8.3,0,0.62,6.2,58,1.96,6,400,18.1,410,11.5]])
y = predict(X)
print(y[0])
在 setup.cfg 中有一些细节值得解释:metadata 部分是为 pip 系统准备的。因此我们将包命名为 mlm_demo,你可以在 pip list 命令的输出中看到这个名称。然而,Python 的模块系统会将模块名称识别为 regressor,如 options 部分所指定。因此,这是你在 import 语句中应使用的名称。通常,为了用户的方便,这两个名称是相同的,这就是为什么人们会互换使用“包”和“模块”这两个术语。类似地,版本 0.0.1 出现在 pip 中,但代码中并未显示。通常将其放在模块目录中的 __init__.py 中,因此你可以在使用它的其他脚本中检查版本:
__version__ = '0.0.1'
options 部分中的 install_requires 是让我们的项目运行的关键。这意味着在安装此模块时,我们还需要安装那些其他模块(如果指定的话)。这可能会创建一个依赖树,但当你运行 pip install 命令时,pip 会处理它。正如你所预期的,我们使用 Python 的比较运算符 == 来指定特定版本。但如果我们可以接受多个版本,我们使用逗号(, )来分隔条件,例如在 numpy 的情况中。
现在你可以将整个项目目录发送给其他人(例如,打包成 ZIP 文件)。他们可以在项目目录中使用 pip install 安装它,然后使用 python -m regressor 运行你的代码,前提是提供了适当的命令行参数。
最后一点:也许你听说过 Python 项目中的requirements.txt文件。它只是一个文本文件,通常放在一个 Python 模块或一些 Python 脚本所在的目录中。它的格式类似于上述提到的依赖项规范。例如,它可能是这样:
scikit-learn==1.0.2
numpy>=1.22, <1.23
h5py
目的是你不想将你的项目做成一个包,但仍希望给出项目所需库及其版本的提示。这个文件可以被pip理解,我们可以用它来设置系统以准备项目:
pip install -r requirements.txt
但这仅适用于开发中的项目,这就是requirements.txt能够提供的所有便利。
使用 venv 管理你的项目
上述方法可能是发布和部署项目的最有效方式,因为你仅包含最关键的文件。这也是推荐的方法,因为它不依赖于平台。如果我们更改 Python 版本或转移到不同的操作系统,这种方法仍然有效(除非某些特定的依赖项禁止我们这样做)。
但有时我们可能希望为项目运行重现一个精确的环境。例如,我们希望一些不能安装的包,而不是要求安装某些包。另外,还有些情况下,我们用pip安装了一个包后,另一个包的安装会打破版本依赖。我们可以用 Python 的venv模块解决这个问题。
venv模块来自 Python 的标准库,用于创建虚拟环境。它不是像 Docker 提供的虚拟机或虚拟化;相反,它会大量修改 Python 操作的路径位置。例如,我们可以在操作系统中安装多个版本的 Python,但虚拟环境总是假设python命令意味着特定版本。另一个例子是,在一个虚拟环境中,我们可以运行pip install来设置一些包在虚拟环境目录中,这不会干扰系统外部的环境。
要开始使用venv,我们可以简单地找到一个合适的位置并运行以下命令:
$ python -m venv myproject
然后将创建一个名为myproject的目录。虚拟环境应该在 shell 中运行(以便可以操作环境变量)。要激活虚拟环境,我们执行以下命令的激活 shell 脚本(例如,在 Linux 和 macOS 的 bash 或 zsh 下):
$ source myproject/bin/activate
此后,你将处于 Python 虚拟环境中。命令python将是你在虚拟环境中创建的命令(如果你在操作系统中安装了多个 Python 版本)。已安装的包将位于myproject/lib/python3.9/site-packages (假设使用 Python 3.9)。当你运行pip install或pip list时,你只会看到虚拟环境中的包。
要离开虚拟环境,我们在 shell 命令行中运行deactivate:
$ deactivate
这被定义为一个 shell 函数。
如果你有多个项目正在开发,并且它们需要不同版本的包(比如 TensorFlow 的不同版本),使用虚拟环境将特别有用。你可以简单地创建一个虚拟环境,激活它,使用 pip install 命令安装所有需要的库的正确版本,然后将你的项目代码放入虚拟环境中。你的虚拟环境目录可能会很大(例如,仅安装 TensorFlow 及其依赖项就会占用接近 1GB 的磁盘空间)。但是,随后将整个虚拟环境目录发送给其他人可以保证执行你的代码的确切环境。如果你不想运行 Docker 服务器,这可以作为 Docker 容器的一种替代方案。
进一步阅读
确实,还有其他工具可以帮助我们整洁地部署项目。前面提到的 Docker 可以是其中之一。Python 标准库中的 zipapp 包也是一个有趣的工具。如果你想深入了解,下面是关于这个主题的资源。
文章
-
Python 教程,第六章,模块
-
关于各种与 venv 相关的包 在 StackOverflow 上的问题
APIs 和软件
-
venv 来自 Python 标准库
总结
在本教程中,你已经看到如何确信地完成我们的项目并交付给另一个用户来运行。具体来说,你学到了:
-
将一组 Python 脚本变成模块的最小改动
-
如何将一个模块转换成用于
pip的包 -
Python 中虚拟环境的概念及其使用方法
Python 中装饰器的温和介绍
原文:
machinelearningmastery.com/a-gentle-introduction-to-decorators-in-python/
在编写代码时,无论我们是否意识到,我们常常会遇到装饰器设计模式。这是一种编程技术,可以在不修改类或函数的情况下扩展它们的功能。装饰器设计模式允许我们轻松混合和匹配扩展。Python 具有根植于装饰器设计模式的装饰器语法。了解如何制作和使用装饰器可以帮助你编写更强大的代码。
在这篇文章中,你将发现装饰器模式和 Python 的函数装饰器。
完成本教程后,你将学到:
-
什么是装饰器模式,为什么它有用
-
Python 的函数装饰器及其使用方法
通过我的新书 《Python 机器学习》,逐步教程 和所有示例的 Python 源代码 文件来快速启动你的项目。
Python 中装饰器的温和介绍
图片由 Olya Kobruseva 提供。保留部分权利。
概述
本教程分为四部分:
-
什么是装饰器模式,为什么它有用?
-
Python 中的函数装饰器
-
装饰器的使用案例
-
一些实用的装饰器示例
什么是装饰器模式,为什么它有用?
装饰器模式是一种软件设计模式,允许我们动态地向类添加功能,而无需创建子类并影响同一类的其他对象的行为。通过使用装饰器模式,我们可以轻松生成我们可能需要的不同功能排列,而无需创建指数增长数量的子类,从而使我们的代码变得越来越复杂和臃肿。
装饰器通常作为我们想要实现的主要接口的子接口来实现,并存储主要接口类型的对象。然后,它将通过覆盖原始接口中的方法并调用存储对象的方法来修改它希望添加某些功能的方法。
装饰器模式的 UML 类图
上图是装饰器设计模式的 UML 类图。装饰器抽象类包含一个OriginalInterface类型的对象;这是装饰器将修改其功能的对象。要实例化我们的具体DecoratorClass,我们需要传入一个实现了OriginalInterface的具体类,然后当我们调用DecoratorClass.method1()方法时,我们的DecoratorClass应修改该对象的method1()的输出。
然而,通过 Python,我们能够简化许多这些设计模式,因为动态类型以及函数和类是头等对象。虽然在不改变实现的情况下修改类或函数仍然是装饰器的关键思想,但我们将在下面探讨 Python 的装饰器语法。
Python 中的函数装饰器
函数装饰器是 Python 中一个极其有用的功能。它建立在函数和类在 Python 中是头等对象的概念之上。
让我们考虑一个简单的例子,即调用一个函数两次。由于 Python 函数是对象,并且我们可以将函数作为参数传递给另一个函数,因此这个任务可以如下完成:
def repeat(fn):
fn()
fn()
def hello_world():
print("Hello world!")
repeat(hello_world)
同样,由于 Python 函数是对象,我们可以创建一个函数来返回另一个函数,即执行另一个函数两次。这可以如下完成:
def repeat_decorator(fn):
def decorated_fn():
fn()
fn()
# returns a function
return decorated_fn
def hello_world():
print ("Hello world!")
hello_world_twice = repeat_decorator(hello_world)
# call the function
hello_world_twice()
上述repeat_decorator()返回的函数是在调用时创建的,因为它依赖于提供的参数。在上述代码中,我们将hello_world函数作为参数传递给repeat_decorator()函数,它返回decorated_fn函数,该函数被分配给hello_world_twice。之后,我们可以调用hello_world_twice(),因为它现在是一个函数。
装饰器模式的理念在这里适用。但我们不需要显式地定义接口和子类。事实上,hello_world是在上述示例中定义为一个函数的名称。没有什么阻止我们将这个名称重新定义为其他名称。因此我们也可以这样做:
def repeat_decorator(fn):
def decorated_fn():
fn()
fn()
# returns a function
return decorated_fn
def hello_world():
print ("Hello world!")
hello_world = repeat_decorator(hello_world)
# call the function
hello_world()
也就是说,我们不是将新创建的函数分配给hello_world_twice,而是覆盖了hello_world。虽然hello_world的名称被重新分配给另一个函数,但之前的函数仍然存在,只是不对我们公开。
实际上,上述代码在功能上等同于以下代码:
# function decorator that calls the function twice
def repeat_decorator(fn):
def decorated_fn():
fn()
fn()
# returns a function
return decorated_fn
# using the decorator on hello_world function
@repeat_decorator
def hello_world():
print ("Hello world!")
# call the function
hello_world()
在上述代码中,@repeat_decorator在函数定义之前意味着将函数传递给repeat_decorator()并将其名称重新分配给输出。也就是说,相当于hello_world = repeat_decorator(hello_world)。@行是 Python 中的装饰器语法。
注意: @ 语法在 Java 中也被使用,但含义不同,它是注解,基本上是元数据而不是装饰器。
我们还可以实现接受参数的装饰器,但这会稍微复杂一些,因为我们需要再多一层嵌套。如果我们扩展上面的例子以定义重复函数调用的次数:
def repeat_decorator(num_repeats = 2):
# repeat_decorator should return a function that's a decorator
def inner_decorator(fn):
def decorated_fn():
for i in range(num_repeats):
fn()
# return the new function
return decorated_fn
# return the decorator that actually takes the function in as the input
return inner_decorator
# use the decorator with num_repeats argument set as 5 to repeat the function call 5 times
@repeat_decorator(5)
def hello_world():
print("Hello world!")
# call the function
hello_world()
repeat_decorator() 接受一个参数并返回一个函数,这个函数是 hello_world 函数的实际装饰器(即,调用 repeat_decorator(5) 返回的是 inner_decorator,其中本地变量 num_repeats = 5 被设置)。上述代码将打印如下内容:
Hello world!
Hello world!
Hello world!
Hello world!
Hello world!
在我们结束本节之前,我们应该记住,装饰器不仅可以应用于函数,也可以应用于类。由于 Python 中的类也是一个对象,我们可以用类似的方式重新定义一个类。
想开始学习 Python 机器学习吗?
现在就来获取我的免费 7 天电子邮件速成课程(附有示例代码)。
点击注册,并免费获得课程的 PDF 电子书版本。
装饰器的使用案例
Python 中的装饰器语法使得装饰器的使用变得更简单。我们使用装饰器的原因有很多,其中一个最常见的用例是隐式地转换数据。例如,我们可以定义一个假设所有操作都基于 numpy 数组的函数,然后创建一个装饰器来确保这一点,通过修改输入:
# function decorator to ensure numpy input
def ensure_numpy(fn):
def decorated_function(data):
# converts input to numpy array
array = np.asarray(data)
# calls fn on input numpy array
return fn(array)
return decorated_function
我们可以进一步修改装饰器,通过调整函数的输出,例如对浮点值进行四舍五入:
# function decorator to ensure numpy input
# and round off output to 4 decimal places
def ensure_numpy(fn):
def decorated_function(data):
array = np.asarray(data)
output = fn(array)
return np.around(output, 4)
return decorated_function
让我们考虑一个求数组和的例子。一个 numpy 数组有内置的 sum() 方法,pandas DataFrame 也是如此。但是,后者是对列求和,而不是对所有元素求和。因此,一个 numpy 数组会得到一个浮点值的和,而 DataFrame 则会得到一个值的向量。但通过上述装饰器,我们可以编写一个函数,使得在这两种情况下都能得到一致的输出:
import numpy as np
import pandas as pd
# function decorator to ensure numpy input
# and round off output to 4 decimal places
def ensure_numpy(fn):
def decorated_function(data):
array = np.asarray(data)
output = fn(array)
return np.around(output, 4)
return decorated_function
@ensure_numpy
def numpysum(array):
return array.sum()
x = np.random.randn(10,3)
y = pd.DataFrame(x, columns=["A", "B", "C"])
# output of numpy .sum() function
print("x.sum():", x.sum())
print()
# output of pandas .sum() funuction
print("y.sum():", y.sum())
print(y.sum())
print()
# calling decorated numpysum function
print("numpysum(x):", numpysum(x))
print("numpysum(y):", numpysum(y))
运行上述代码会得到如下输出:
x.sum(): 0.3948331694737762
y.sum(): A -1.175484
B 2.496056
C -0.925739
dtype: float64
A -1.175484
B 2.496056
C -0.925739
dtype: float64
numpysum(x): 0.3948
numpysum(y): 0.3948
这是一个简单的例子。但是想象一下,如果我们定义一个新函数来计算数组中元素的标准差。我们可以简单地使用相同的装饰器,这样函数也会接受 pandas DataFrame。因此,所有的输入处理代码都被移到了装饰器中。这就是我们如何高效重用代码的方法。
一些实际的装饰器示例
既然我们学习了 Python 中的装饰器语法,那我们来看看可以用它做些什么吧!
备忘录化
有些函数调用我们会重复进行,但这些值很少甚至几乎不变。这可能是对数据相对静态的服务器的调用,或者作为动态编程算法或计算密集型数学函数的一部分。我们可能想要备忘录化这些函数调用,即将它们的输出值存储在虚拟备忘录中以便后续重用。
装饰器是实现备忘录化函数的最佳方式。我们只需要记住函数的输入和输出,但保持函数的行为不变。下面是一个例子:
import pickle
import hashlib
MEMO = {} # To remember the function input and output
def memoize(fn):
def _deco(*args, **kwargs):
# pickle the function arguments and obtain hash as the store keys
key = (fn.__name__, hashlib.md5(pickle.dumps((args, kwargs), 4)).hexdigest())
# check if the key exists
if key in MEMO:
ret = pickle.loads(MEMO[key])
else:
ret = fn(*args, **kwargs)
MEMO[key] = pickle.dumps(ret)
return ret
return _deco
@memoize
def fibonacci(n):
if n in [0, 1]:
return n
else:
return fibonacci(n-1) + fibonacci(n-2)
print(fibonacci(40))
print(MEMO)
在这个示例中,我们实现了memoize()函数以便与全局字典MEMO一起工作,使得函数名与参数组成键,函数的返回值成为值。当调用函数时,装饰器会检查对应的键是否存在于MEMO中,如果存在,则返回存储的值。否则,将调用实际的函数,并将其返回值添加到字典中。
我们使用pickle来序列化输入和输出,并使用hashlib来创建输入的哈希,因为并不是所有东西都可以作为 Python 字典的键(例如,list是不可哈希的类型,因此不能作为键)。将任何任意结构序列化为字符串可以克服这个问题,并确保返回数据是不可变的。此外,对函数参数进行哈希处理可以避免在字典中存储异常长的键(例如,当我们将一个巨大的 numpy 数组传递给函数时)。
上述示例使用fibonacci()演示了记忆化的强大功能。调用fibonacci(n)将生成第 n 个斐波那契数。运行上述示例将产生以下输出,其中我们可以看到第 40 个斐波那契数是 102334155,以及字典MEMO是如何用于存储对函数的不同调用的。
102334155
{('fibonacci', '635f1664f168e2a15b8e43f20d45154b'): b'\x80\x04K\x01.',
('fibonacci', 'd238998870ae18a399d03477dad0c0a8'): b'\x80\x04K\x00.',
('fibonacci', 'dbed6abf8fcf4beec7fc97f3170de3cc'): b'\x80\x04K\x01.',
...
('fibonacci', 'b9954ff996a4cd0e36fffb09f982b08e'): b'\x80\x04\x95\x06\x00\x00\x00\x00\x00\x00\x00J)pT\x02.',
('fibonacci', '8c7aba62def8063cf5afe85f42372f0d'): b'\x80\x04\x95\x06\x00\x00\x00\x00\x00\x00\x00J\xa2\x0e\xc5\x03.',
('fibonacci', '6de8535f23d756de26959b4d6e1f66f6'): b'\x80\x04\x95\x06\x00\x00\x00\x00\x00\x00\x00J\xcb~\x19\x06.'}
你可以尝试去掉上述代码中的@memoize行。你会发现程序运行时间显著增加(因为每次函数调用都会调用两个额外的函数调用,因此它的运行复杂度是 O(2^n),而记忆化情况下为 O(n)),或者你可能会遇到内存不足的问题。
记忆化对那些输出不经常变化的昂贵函数非常有帮助,例如,下面的函数从互联网读取一些股市数据:
...
import pandas_datareader as pdr
@memoize
def get_stock_data(ticker):
# pull data from stooq
df = pdr.stooq.StooqDailyReader(symbols=ticker, start="1/1/00", end="31/12/21").read()
return df
#testing call to function
import cProfile as profile
import pstats
for i in range(1, 3):
print(f"Run {i}")
run_profile = profile.Profile()
run_profile.enable()
get_stock_data("^DJI")
run_profile.disable()
pstats.Stats(run_profile).print_stats(0)
如果实现正确,第一次调用get_stock_data()应该会更昂贵,而后续调用则会便宜得多。上述代码片段的输出结果是:
Run 1
17492 function calls (17051 primitive calls) in 1.452 seconds
Run 2
221 function calls (218 primitive calls) in 0.001 seconds
如果你正在使用 Jupyter notebook,这特别有用。如果需要下载一些数据,将其包装在 memoize 装饰器中。由于开发机器学习项目意味着多次更改代码以查看结果是否有所改善,使用记忆化下载函数可以节省大量不必要的等待时间。
你可以通过将数据存储在数据库中(例如,像 GNU dbm 这样的键值存储或像 memcached 或 Redis 这样的内存数据库)来创建一个更强大的记忆化装饰器。但如果你只需要上述功能,Python 3.2 或更高版本的内置库functools中已经提供了装饰器lru_cache,因此你不需要自己编写:
import functools
import pandas_datareader as pdr
# memoize using lru_cache
@functools.lru_cache
def get_stock_data(ticker):
# pull data from stooq
df = pdr.stooq.StooqDailyReader(symbols=ticker, start="1/1/00", end="31/12/21").read()
return df
# testing call to function
import cProfile as profile
import pstats
for i in range(1, 3):
print(f"Run {i}")
run_profile = profile.Profile()
run_profile.enable()
get_stock_data("^DJI")
run_profile.disable()
pstats.Stats(run_profile).print_stats(0)
注意: lru_cache实现了 LRU 缓存,它将其大小限制为对函数的最新调用(默认 128)。在 Python 3.9 中,还有一个@functools.cache,其大小无限制,不进行 LRU 清除。
函数目录
另一个我们可能希望考虑使用函数装饰器的例子是用于在目录中注册函数。它允许我们将函数与字符串关联,并将这些字符串作为其他函数的参数传递。这是构建一个允许用户提供插件的系统的开始。让我们用一个例子来说明。以下是一个装饰器和我们稍后将使用的函数activate()。假设以下代码保存于文件activation.py中:
# activation.py
ACTIVATION = {}
def register(name):
def decorator(fn):
# assign fn to "name" key in ACTIVATION
ACTIVATION[name] = fn
# return fn unmodified
return fn
return decorator
def activate(x, kind):
try:
fn = ACTIVATION[kind]
return fn(x)
except KeyError:
print("Activation function %s undefined" % kind)
在上面的代码中定义了register装饰器之后,我们现在可以使用它来注册函数并将字符串与之关联。让我们来看一下funcs.py文件:
# funcs.py
from activation import register
import numpy as np
@register("relu")
def relu(x):
return np.where(x>0, x, 0)
@register("sigmoid")
def sigm(x):
return 1/(1+np.exp(-x))
@register("tanh")
def tanh(x):
return np.tanh(x)
我们通过在ACTIVATION字典中建立这种关联,将“relu”,“sigmoid”和“tanh”函数注册到各自的字符串。
现在,让我们看看如何使用我们新注册的函数。
import numpy as np
from activation import activate
# create a random matrix
x = np.random.randn(5,3)
print(x)
# try ReLU activation on the matrix
relu_x = activate(x, "relu")
print(relu_x)
# load the functions, and call ReLU activation again
import funcs
relu_x = activate(x, "relu")
print(relu_x)
这将给我们输出:
[[-0.81549502 -0.81352867 1.41539545]
[-0.28782853 -1.59323543 -0.19824959]
[ 0.06724466 -0.26622761 -0.41893662]
[ 0.47927331 -1.84055276 -0.23147207]
[-0.18005588 -1.20837815 -1.34768876]]
Activation function relu undefined
None
[[0\. 0\. 1.41539545]
[0\. 0\. 0\. ]
[0.06724466 0\. 0\. ]
[0.47927331 0\. 0\. ]
[0\. 0\. 0\. ]]
请注意,在我们到达import func这一行之前,ReLU 激活函数并不存在。因此调用该函数会打印错误信息,结果为None。然后在我们运行那一行import之后,我们就像加载插件模块一样加载了那些定义的函数。之后同样的函数调用给出了我们预期的结果。
请注意,我们从未显式调用模块func中的任何内容,也没有修改activate()的调用。仅仅导入func就使得那些新函数注册并扩展了activate()的功能。使用这种技术允许我们在开发非常大的系统时,只关注一小部分,而不必担心其他部分的互操作性。如果没有注册装饰器和函数目录,添加新的激活函数将需要修改每一个使用激活的函数。
如果你对 Keras 很熟悉,你应该能将上述内容与以下语法产生共鸣:
layer = keras.layers.Dense(128, activation="relu")
model.compile(loss="sparse_categorical_crossentropy",
optimizer="adam",
metrics=["sparse_categorical_accuracy"])
Keras 几乎使用类似性质的装饰器定义了所有组件。因此我们可以通过名称引用构建块。如果没有这种机制,我们必须一直使用以下语法,这让我们需要记住很多组件的位置:
layer = keras.layers.Dense(128, activation=keras.activations.relu)
model.compile(loss=keras.losses.SparseCategoricalCrossentropy(),
optimizer=keras.optimizers.Adam(),
metrics=[keras.metrics.SparseCategoricalAccuracy()])
进一步阅读
本节提供了更多关于该主题的资源,如果你希望深入了解。
文章
-
Python 语言参考,第 8.7 节,函数定义
书籍
- Fluent Python,第二版,作者 Luciano Ramalho
API
- Python 标准库中的functools 模块
总结
在这篇文章中,你了解了装饰器设计模式和 Python 的装饰器语法。你还看到了一些装饰器的具体使用场景,这些可以帮助你的 Python 程序运行得更快或更易扩展。
具体来说,你学习了:
-
装饰器模式的概念以及 Python 中的装饰器语法
-
如何在 Python 中实现装饰器,以便使用装饰器语法
-
使用装饰器来适配函数输入输出、实现记忆化以及在目录中注册函数
Python 序列化的温和介绍
原文:
machinelearningmastery.com/a-gentle-introduction-to-serialization-for-python/
序列化是指将数据对象(例如,Python 对象、Tensorflow 模型)转换为一种格式,使我们可以存储或传输数据,然后在需要时使用反序列化的逆过程重新创建该对象。
数据的序列化有不同的格式,如 JSON、XML、HDF5 和 Python 的 pickle,用于不同的目的。例如,JSON 返回人类可读的字符串形式,而 Python 的 pickle 库则可以返回字节数组。
在这篇文章中,你将学习如何使用 Python 中的两个常见序列化库来序列化数据对象(即 pickle 和 HDF5),例如字典和 Tensorflow 模型,以便于存储和传输。
完成本教程后,你将了解:
-
Python 中的序列化库,如 pickle 和 h5py
-
在 Python 中序列化诸如字典和 Tensorflow 模型的对象
-
如何使用序列化进行记忆化以减少函数调用
快速启动你的项目,通过我的新书 Python for Machine Learning,包括 逐步教程 和所有示例的 Python 源代码 文件。
Python 序列化的温和介绍。图片来源 little plant。版权所有
概述
本教程分为四个部分;它们是:
-
什么是序列化,为什么我们要进行序列化?
-
使用 Python 的 pickle 库
-
在 Python 中使用 HDF5
-
不同序列化方法的比较
什么是序列化,我们为什么要关心它?
想一想如何存储一个整数;你会如何将其存储在文件中或传输?这很简单!我们可以直接将整数写入文件中,然后存储或传输这个文件。
但是现在,如果我们考虑存储一个 Python 对象(例如,一个 Python 字典或一个 Pandas DataFrame),它有一个复杂的结构和许多属性(例如,DataFrame 的列和索引,以及每列的数据类型)呢?你会如何将它存储为一个文件或传输到另一台计算机上?
这就是序列化发挥作用的地方!
序列化是将对象转换为可以存储或传输的格式的过程。在传输或存储序列化数据后,我们能够稍后重建对象,并获得完全相同的结构/对象,这使得我们可以在之后继续使用存储的对象,而不必从头开始重建对象。
在 Python 中,有许多不同的序列化格式可供选择。一个跨多种语言的常见示例是 JSON 文件格式,它是可读的并允许我们存储字典并以相同的结构重新创建它。但 JSON 只能存储基本结构,如列表和字典,并且只能保留字符串和数字。我们不能要求 JSON 记住数据类型(例如,numpy float32 与 float64)。它也无法区分 Python 元组和列表。
更强大的序列化格式存在。接下来,我们将探讨两个常见的 Python 序列化库,即 pickle 和 h5py。
使用 Python 的 Pickle 库
pickle 模块是 Python 标准库的一部分,实现了序列化(pickling)和反序列化(unpickling)Python 对象的方法。
要开始使用 pickle,请在 Python 中导入它:
import pickle
之后,为了序列化一个 Python 对象(如字典)并将字节流存储为文件,我们可以使用 pickle 的 dump() 方法。
test_dict = {"Hello": "World!"}
with open("test.pickle", "wb") as outfile:
# "wb" argument opens the file in binary mode
pickle.dump(test_dict, outfile)
代表test_dict的字节流现在存储在文件“test.pickle”中!
要恢复原始对象,我们使用 pickle 的 load() 方法从文件中读取序列化的字节流。
with open("test.pickle", "rb") as infile:
test_dict_reconstructed = pickle.load(infile)
警告: 仅从您信任的来源反序列化数据,因为在反序列化过程中可能会执行任意恶意代码。
将它们结合起来,以下代码帮助您验证 pickle 可以恢复相同的对象:
import pickle
# A test object
test_dict = {"Hello": "World!"}
# Serialization
with open("test.pickle", "wb") as outfile:
pickle.dump(test_dict, outfile)
print("Written object", test_dict)
# Deserialization
with open("test.pickle", "rb") as infile:
test_dict_reconstructed = pickle.load(infile)
print("Reconstructed object", test_dict_reconstructed)
if test_dict == test_dict_reconstructed:
print("Reconstruction success")
除了将序列化的对象写入 pickle 文件外,我们还可以使用 pickle 的 dumps() 函数在 Python 中获取序列化为字节数组类型的对象:
test_dict_ba = pickle.dumps(test_dict) # b'\x80\x04\x95\x15…
同样,我们可以使用 pickle 的 load 方法将字节数组类型转换回原始对象:
test_dict_reconstructed_ba = pickle.loads(test_dict_ba)
pickle 的一个有用功能是它可以序列化几乎任何 Python 对象,包括用户定义的对象,如下所示:
import pickle
class NewClass:
def __init__(self, data):
print(data)
self.data = data
# Create an object of NewClass
new_class = NewClass(1)
# Serialize and deserialize
pickled_data = pickle.dumps(new_class)
reconstructed = pickle.loads(pickled_data)
# Verify
print("Data from reconstructed object:", reconstructed.data)
上述代码将打印以下内容:
1
Data from reconstructed object: 1
注意,在调用 pickle.loads() 时,类构造函数中的 print 语句没有执行。这是因为它重建了对象,而不是重新创建它。
pickle 甚至可以序列化 Python 函数,因为函数在 Python 中是一级对象:
import pickle
def test():
return "Hello world!"
# Serialize and deserialize
pickled_function = pickle.dumps(test)
reconstructed_function = pickle.loads(pickled_function)
# Verify
print (reconstructed_function()) #prints “Hello, world!”
因此,我们可以利用 pickle 来保存我们的工作。例如,从 Keras 或 scikit-learn 训练的模型可以通过 pickle 序列化并在之后加载,而不是每次使用时都重新训练模型。以下示例展示了我们如何使用 Keras 构建一个 LeNet5 模型来识别 MNIST 手写数字,然后使用 pickle 序列化训练好的模型。之后,我们可以在不重新训练的情况下重建模型,它应该会产生与原始模型完全相同的结果:
import pickle
import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Dense, AveragePooling2D, Dropout, Flatten
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import EarlyStopping
# Load MNIST digits
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# Reshape data to (n_samples, height, wiedth, n_channel)
X_train = np.expand_dims(X_train, axis=3).astype("float32")
X_test = np.expand_dims(X_test, axis=3).astype("float32")
# One-hot encode the output
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
# LeNet5 model
model = Sequential([
Conv2D(6, (5,5), input_shape=(28,28,1), padding="same", activation="tanh"),
AveragePooling2D((2,2), strides=2),
Conv2D(16, (5,5), activation="tanh"),
AveragePooling2D((2,2), strides=2),
Conv2D(120, (5,5), activation="tanh"),
Flatten(),
Dense(84, activation="tanh"),
Dense(10, activation="softmax")
])
# Train the model
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
earlystopping = EarlyStopping(monitor="val_loss", patience=4, restore_best_weights=True)
model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=100, batch_size=32, callbacks=[earlystopping])
# Evaluate the model
print(model.evaluate(X_test, y_test, verbose=0))
# Pickle to serialize and deserialize
pickled_model = pickle.dumps(model)
reconstructed = pickle.loads(pickled_model)
# Evaluate again
print(reconstructed.evaluate(X_test, y_test, verbose=0))
上述代码将生成如下输出。请注意,原始模型和重建模型的评估分数在最后两行中完全一致:
Epoch 1/100
1875/1875 [==============================] - 15s 7ms/step - loss: 0.1517 - accuracy: 0.9541 - val_loss: 0.0958 - val_accuracy: 0.9661
Epoch 2/100
1875/1875 [==============================] - 15s 8ms/step - loss: 0.0616 - accuracy: 0.9814 - val_loss: 0.0597 - val_accuracy: 0.9822
Epoch 3/100
1875/1875 [==============================] - 16s 8ms/step - loss: 0.0493 - accuracy: 0.9846 - val_loss: 0.0449 - val_accuracy: 0.9853
Epoch 4/100
1875/1875 [==============================] - 17s 9ms/step - loss: 0.0394 - accuracy: 0.9876 - val_loss: 0.0496 - val_accuracy: 0.9838
Epoch 5/100
1875/1875 [==============================] - 17s 9ms/step - loss: 0.0320 - accuracy: 0.9898 - val_loss: 0.0394 - val_accuracy: 0.9870
Epoch 6/100
1875/1875 [==============================] - 16s 9ms/step - loss: 0.0294 - accuracy: 0.9908 - val_loss: 0.0373 - val_accuracy: 0.9872
Epoch 7/100
1875/1875 [==============================] - 21s 11ms/step - loss: 0.0252 - accuracy: 0.9921 - val_loss: 0.0370 - val_accuracy: 0.9879
Epoch 8/100
1875/1875 [==============================] - 18s 10ms/step - loss: 0.0223 - accuracy: 0.9931 - val_loss: 0.0386 - val_accuracy: 0.9880
Epoch 9/100
1875/1875 [==============================] - 15s 8ms/step - loss: 0.0219 - accuracy: 0.9930 - val_loss: 0.0418 - val_accuracy: 0.9871
Epoch 10/100
1875/1875 [==============================] - 15s 8ms/step - loss: 0.0162 - accuracy: 0.9950 - val_loss: 0.0531 - val_accuracy: 0.9853
Epoch 11/100
1875/1875 [==============================] - 15s 8ms/step - loss: 0.0169 - accuracy: 0.9941 - val_loss: 0.0340 - val_accuracy: 0.9895
Epoch 12/100
1875/1875 [==============================] - 15s 8ms/step - loss: 0.0165 - accuracy: 0.9944 - val_loss: 0.0457 - val_accuracy: 0.9874
Epoch 13/100
1875/1875 [==============================] - 15s 8ms/step - loss: 0.0137 - accuracy: 0.9955 - val_loss: 0.0407 - val_accuracy: 0.9879
Epoch 14/100
1875/1875 [==============================] - 16s 8ms/step - loss: 0.0159 - accuracy: 0.9945 - val_loss: 0.0442 - val_accuracy: 0.9871
Epoch 15/100
1875/1875 [==============================] - 16s 8ms/step - loss: 0.0125 - accuracy: 0.9956 - val_loss: 0.0434 - val_accuracy: 0.9882
[0.0340442918241024, 0.9894999861717224]
[0.0340442918241024, 0.9894999861717224]
尽管 pickle 是一个强大的库,但它仍然有其自身的限制。例如,无法 pickle 包括数据库连接和已打开的文件句柄在内的活动连接。这个问题的根源在于重建这些对象需要 pickle 重新建立与数据库/文件的连接,这是 pickle 无法为你做的事情(因为它需要适当的凭证,超出了 pickle 的预期范围)。
想要开始使用 Python 进行机器学习吗?
现在就来参加我的免费 7 天电子邮件速成课程吧(附带示例代码)。
点击注册并获取课程的免费 PDF 电子书版本。
在 Python 中使用 HDF5
层次数据格式 5(HDF5)是一种二进制数据格式。h5py 包是一个 Python 库,提供了对 HDF5 格式的接口。根据 h5py 文档,HDF5 “允许你存储大量数值数据,并且可以轻松地使用 Numpy 对该数据进行操作。”
HDF5 能比其他序列化格式做得更好的是以文件系统的层次结构存储数据。你可以在 HDF5 中存储多个对象或数据集,就像在文件系统中保存多个文件一样。你也可以从 HDF5 中读取特定的数据集,就像从文件系统中读取一个文件而不需要考虑其他文件一样。如果你用 pickle 做这件事,每次加载或创建 pickle 文件时都需要读取和写入所有内容。因此,对于无法完全放入内存的大量数据,HDF5 是一个有利的选择。
要开始使用 h5py,你首先需要安装 h5py 库,可以使用以下命令进行安装:
pip install h5py
或者,如果你正在使用 conda 环境:
conda install h5py
接下来,我们可以开始创建我们的第一个数据集!
import h5py
with h5py.File("test.hdf5", "w") as file:
dataset = file.create_dataset("test_dataset", (100,), type="i4")
这将在文件 test.hdf5 中创建一个名为 “test_dataset” 的新数据集,形状为 (100, ),类型为 int32。h5py 的数据集遵循 Numpy 的语法,因此你可以进行切片、检索、获取形状等操作,类似于 Numpy 数组。
要检索特定索引:
dataset[0] #retrieves element at index 0 of dataset
要从索引 0 到索引 10 获取数据集的片段:
dataset[:10]
如果你在 with 语句之外初始化了 h5py 文件对象,请记得关闭文件!
要从以前创建的 HDF5 文件中读取数据,你可以以 “r” 的方式打开文件进行读取,或者以 “r+” 的方式进行读写:
with h5py.File("test.hdf5", "r") as file:
print (file.keys()) #gets names of datasets that are in the file
dataset = file["test_dataset"]
要组织你的 HDF5 文件,你可以使用组:
with h5py.File("test.hdf5", "w") as file:
# creates new group_1 in file
file.create_group("group_1")
group1 = file["group_1"]
# creates dataset inside group1
group1.create_dataset("dataset1", shape=(10,))
# to access the dataset
dataset = file["group_1"]["dataset1"]
另一种创建组和文件的方式是通过指定要创建的数据集的路径,h5py 也会在该路径上创建组(如果它们不存在):
with h5py.File("test.hdf5", "w") as file:
# creates dataset inside group1
file.create_dataset("group1/dataset1", shape=(10,))
这两段代码片段都会在未创建 group1 的情况下创建它,然后在 group1 中创建 dataset1。
在 Tensorflow 中的 HDF5
要在 Tensorflow Keras 中保存模型为 HDF5 格式,我们可以使用模型的 save() 函数,并将文件名指定为 .h5 扩展名,如下所示:
from tensorflow import keras
# Create model
model = keras.models.Sequential([
keras.layers.Input(shape=(10,)),
keras.layers.Dense(1)
])
model.compile(optimizer="adam", loss="mse")
# using the .h5 extension in the file name specifies that the model
# should be saved in HDF5 format
model.save("my_model.h5")
要加载存储的 HDF5 模型,我们也可以直接使用 Keras 中的函数:
...
model = keras.models.load_model("my_model.h5")
# to check that the model has been successfully reconstructed
print(model.summary)
我们不希望为 Keras 模型使用 pickle 的一个原因是,我们需要一种更灵活的格式,不受特定版本 Keras 的限制。如果我们升级了 Tensorflow 版本,模型对象可能会改变,而 pickle 可能无法给我们一个可工作的模型。另一个原因是保留模型的必要数据。例如,如果我们检查上面创建的 HDF5 文件 my_model.h5,我们可以看到其中存储了以下内容:
/
/model_weights
/model_weights/dense
/model_weights/dense/dense
/model_weights/dense/dense/bias:0
/model_weights/dense/dense/kernel:0
/model_weights/top_level_model_weights
因此,Keras 仅选择对重建模型至关重要的数据。训练好的模型将包含更多数据集,即 /optimizer_weights/ 除了 /model_weights/。Keras 将恢复模型并适当地恢复权重,以给我们一个功能相同的模型。
以上面的例子为例。我们的模型保存在 my_model.h5 中。我们的模型是一个单层的全连接层,我们可以通过以下方式找出该层的内核:
import h5py
with h5py.File("my_model.h5", "r") as infile:
print(infile["/model_weights/dense/dense/kernel:0"][:])
因为我们没有为任何事情训练我们的网络,所以它会给我们初始化层的随机矩阵:
[[ 0.6872471 ]
[-0.51016176]
[-0.5604881 ]
[ 0.3387223 ]
[ 0.52146655]
[-0.6960067 ]
[ 0.38258582]
[-0.05564564]
[ 0.1450575 ]
[-0.3391946 ]]
并且在 HDF5 中,元数据存储在数据旁边。Keras 以 JSON 格式在元数据中存储了网络的架构。因此,我们可以按以下方式复现我们的网络架构:
import json
import h5py
with h5py.File("my_model.h5", "r") as infile:
for key in infile.attrs.keys():
formatted = infile.attrs[key]
if key.endswith("_config"):
formatted = json.dumps(json.loads(formatted), indent=4)
print(f"{key}: {formatted}")
这会产生:
backend: tensorflow
keras_version: 2.7.0
model_config: {
"class_name": "Sequential",
"config": {
"name": "sequential",
"layers": [
{
"class_name": "InputLayer",
"config": {
"batch_input_shape": [
null,
10
],
"dtype": "float32",
"sparse": false,
"ragged": false,
"name": "input_1"
}
},
{
"class_name": "Dense",
"config": {
"name": "dense",
"trainable": true,
"dtype": "float32",
"units": 1,
"activation": "linear",
"use_bias": true,
"kernel_initializer": {
"class_name": "GlorotUniform",
"config": {
"seed": null
}
},
"bias_initializer": {
"class_name": "Zeros",
"config": {}
},
"kernel_regularizer": null,
"bias_regularizer": null,
"activity_regularizer": null,
"kernel_constraint": null,
"bias_constraint": null
}
}
]
}
}
training_config: {
"loss": "mse",
"metrics": null,
"weighted_metrics": null,
"loss_weights": null,
"optimizer_config": {
"class_name": "Adam",
"config": {
"name": "Adam",
"learning_rate": 0.001,
"decay": 0.0,
"beta_1": 0.9,
"beta_2": 0.999,
"epsilon": 1e-07,
"amsgrad": false
}
}
}
模型配置(即我们神经网络的架构)和训练配置(即我们传递给 compile() 函数的参数)存储为一个 JSON 字符串。在上面的代码中,我们使用 json 模块重新格式化它,以便更容易阅读。建议将您的模型保存为 HDF5,而不仅仅是您的 Python 代码,因为正如我们在上面看到的,它包含比代码更详细的网络构建信息。
比较不同序列化方法之间的差异
在上文中,我们看到 pickle 和 h5py 如何帮助序列化我们的 Python 数据。
我们可以使用 pickle 序列化几乎任何 Python 对象,包括用户定义的对象和函数。但 pickle 不是语言通用的。您不能在 Python 之外反序列化它。到目前为止,甚至有 6 个版本的 pickle,旧版 Python 可能无法消费新版本的 pickle 数据。
相反,HDF5 是跨平台的,并且与其他语言如 Java 和 C++ 兼容良好。在 Python 中,h5py 库实现了 Numpy 接口,以便更轻松地操作数据。数据可以在不同语言中访问,因为 HDF5 格式仅支持 Numpy 的数据类型,如浮点数和字符串。我们不能将任意对象(如 Python 函数)存储到 HDF5 中。
进一步阅读
本节提供了更多关于此主题的资源,如果您希望深入了解。
文章
-
C# 编程指南中的序列化,
docs.microsoft.com/en-us/dotnet/csharp/programming-guide/concepts/serialization/ -
保存和加载 Keras 模型,
www.tensorflow.org/guide/keras/save_and_serialize
库
API
-
Tensorflow tf.keras.layers.serialize,
www.tensorflow.org/api_docs/python/tf/keras/layers/serialize -
Tensorflow tf.keras.models.load_model,
www.tensorflow.org/api_docs/python/tf/keras/models/load_model -
Tensorflow tf.keras.models.save_model,
www.tensorflow.org/api_docs/python/tf/keras/models/save_model
总结
在本篇文章中,你将了解什么是序列化以及如何在 Python 中使用库来序列化 Python 对象,例如字典和 Tensorflow Keras 模型。你还学到了两个 Python 序列化库(pickle、h5py)的优缺点。
具体来说,你学到了:
-
什么是序列化,以及它的用途
-
如何在 Python 中开始使用 pickle 和 h5py 序列化库
-
不同序列化方法的优缺点
Python 单元测试的温和介绍
原文:
machinelearningmastery.com/a-gentle-introduction-to-unit-testing-in-python/
单元测试是一种测试软件的方法,关注代码中最小的可测试单元,并测试其正确性。通过单元测试,我们可以验证代码的每个部分,包括可能不对用户公开的辅助函数,是否正常工作并按预期执行。
这个理念是,我们独立检查程序中的每一个小部分,以确保它正常工作。这与回归测试和集成测试形成对比,后者测试程序的不同部分是否协同工作并按预期执行。
在这篇文章中,你将发现如何使用两个流行的单元测试框架:内置的 PyUnit 框架和 PyTest 框架来实现 Python 中的单元测试。
完成本教程后,你将知道:
-
Python 中的单元测试库,如 PyUnit 和 PyTest
-
通过使用单元测试检查预期的函数行为
通过我的新书 《Python 机器学习》,启动你的项目,包括逐步教程和所有示例的Python 源代码文件。
Python 单元测试的温和介绍
图片由Bee Naturalles提供。版权所有。
概述
本教程分为五个部分;它们是:
-
什么是单元测试,它们为什么重要?
-
什么是测试驱动开发(TDD)?
-
使用 Python 内置的 PyUnit 框架
-
使用 PyTest 库
-
单元测试的实际操作
什么是单元测试,它们为什么重要?
记得在学校做数学时,完成不同的算术步骤,然后将它们组合以获得正确答案吗?想象一下,你如何检查每一步的计算是否正确,确保没有粗心的错误或写错任何东西。
现在,把这个理念扩展到代码上!我们不希望不断检查代码以静态验证其正确性,那么你会如何创建测试以确保以下代码片段实际返回矩形的面积?
def calculate_area_rectangle(width, height):
return width * height
我们可以用一些测试示例运行代码,看看它是否返回预期的输出。
这就是单元测试的理念!单元测试是检查单一代码组件的测试,通常模块化为函数,并确保其按预期执行。
单元测试是回归测试的重要组成部分,以确保在对代码进行更改后,代码仍然按预期功能运行,并帮助确保代码的稳定性。在对代码进行更改后,我们可以运行之前创建的单元测试,以确保我们对代码库其他部分的更改没有影响到现有功能。
单元测试的另一个关键好处是它们有助于轻松隔离错误。想象一下运行整个项目并收到一连串的错误。我们该如何调试代码呢?
这就是单元测试的作用。我们可以分析单元测试的输出,查看代码中的任何组件是否出现错误,并从那里开始调试。这并不是说单元测试总能帮助我们找到错误,但它在我们开始查看集成测试中的组件集成之前提供了一个更方便的起点。
在接下来的文章中,我们将展示如何通过测试 Rectangle 类中的函数来进行单元测试:
class Rectangle:
def __init__(self, width, height):
self.width = width
self.height = height
def get_area(self):
return self.width * self.height
def set_width(self, width):
self.width = width
def set_height(self, height):
self.height = height
现在我们已经了解了单元测试的意义,让我们探索一下如何将单元测试作为开发流程的一部分,以及如何在 Python 中实现它们!
测试驱动开发
测试在良好的软件开发中如此重要,以至于甚至存在一个基于测试的软件开发过程——测试驱动开发(TDD)。Robert C. Martin 提出的 TDD 三条规则是:
-
除非是为了让一个失败的单元测试通过,否则你不允许编写任何生产代码。
-
除非是为了让单元测试失败,否则你不允许编写超过必要的单元测试代码,编译失败也是失败。
-
除非是为了让一个失败的单元测试通过,否则你不允许编写比通过一个失败单元测试所需的更多生产代码。
TDD 的关键理念是,我们围绕一组我们创建的单元测试来进行软件开发,这使得单元测试成为 TDD 软件开发过程的核心。这样,你可以确保你为开发的每个组件都制定了测试。
TDD 还偏向于进行更小的测试,这意味着测试更具体,并且每次测试的组件更少。这有助于追踪错误,而且小的测试更易于阅读和理解,因为每次运行中涉及的组件更少。
这并不意味着你必须在你的项目中使用 TDD。但你可以考虑将其作为同时开发代码和测试的方法。
想要开始使用 Python 进行机器学习吗?
立即参加我的免费 7 天邮件速成课程(附示例代码)。
点击注册,还可以获得课程的免费 PDF 电子书版本。
使用 Python 内置的 PyUnit 框架
你可能会想,既然 Python 和其他语言提供了 assert 关键字,我们为什么还需要单元测试框架?单元测试框架有助于自动化测试过程,并允许我们对同一函数运行多个测试,使用不同的参数,检查预期的异常等等。
PyUnit 是 Python 内置的单元测试框架,也是 Python 版的 JUnit 测试框架。要开始编写测试文件,我们需要导入 unittest 库以使用 PyUnit:
import unittest
然后,我们可以开始编写第一个单元测试。PyUnit 中的单元测试结构为 unittest.TestCase 类的子类,我们可以重写 runTest() 方法来执行自己的单元测试,使用 unittest.TestCase 中的不同断言函数检查条件:
class TestGetAreaRectangle(unittest.TestCase):
def runTest(self):
rectangle = Rectangle(2, 3)
self.assertEqual(rectangle.get_area(), 6, "incorrect area")
这就是我们的第一个单元测试!它检查 rectangle.get_area() 方法是否返回宽度 = 2 和长度 = 3 的矩形的正确面积。我们使用 self.assertEqual 而不是简单使用 assert,以便 unittest 库允许测试运行器累积所有测试用例并生成报告。
使用 unittest.TestCase 中的不同断言函数还可以更好地测试不同的行为,例如 self.assertRaises(exception)。这允许我们检查某段代码是否产生了预期的异常。
要运行单元测试,我们在程序中调用 unittest.main(),
...
unittest.main()
由于代码返回了预期的输出,它显示测试成功运行,输出为:
.
----------------------------------------------------------------------
Ran 1 test in 0.003s
OK
完整的代码如下:
import unittest
# Our code to be tested
class Rectangle:
def __init__(self, width, height):
self.width = width
self.height = height
def get_area(self):
return self.width * self.height
def set_width(self, width):
self.width = width
def set_height(self, height):
self.height = height
# The test based on unittest module
class TestGetAreaRectangle(unittest.TestCase):
def runTest(self):
rectangle = Rectangle(2, 3)
self.assertEqual(rectangle.get_area(), 6, "incorrect area")
# run the test
unittest.main()
注意: 在上面,我们的业务逻辑 Rectangle 类和我们的测试代码 TestGetAreaRectangle 被放在一起。实际上,你可以将它们放在不同的文件中,并将业务逻辑 import 到你的测试代码中。这可以帮助你更好地管理代码。
我们还可以在 unittest.TestCase 的一个子类中嵌套多个单元测试,通过在新子类中的方法名前加上 “test” 前缀,例如:
class TestGetAreaRectangle(unittest.TestCase):
def test_normal_case(self):
rectangle = Rectangle(2, 3)
self.assertEqual(rectangle.get_area(), 6, "incorrect area")
def test_negative_case(self):
"""expect -1 as output to denote error when looking at negative area"""
rectangle = Rectangle(-1, 2)
self.assertEqual(rectangle.get_area(), -1, "incorrect negative output")
运行这段代码将给我们第一个错误:
F.
======================================================================
FAIL: test_negative_case (__main__.TestGetAreaRectangle)
expect -1 as output to denote error when looking at negative area
----------------------------------------------------------------------
Traceback (most recent call last):
File "<ipython-input-96-59b1047bb08a>", line 9, in test_negative_case
self.assertEqual(rectangle.get_area(), -1, "incorrect negative output")
AssertionError: -2 != -1 : incorrect negative output
----------------------------------------------------------------------
Ran 2 tests in 0.003s
FAILED (failures=1)
我们可以看到失败的单元测试,即 test_negative_case,如输出中突出显示的内容和 stderr 消息,因为 get_area() 没有返回我们在测试中预期的 -1。
unittest 中定义了许多不同种类的断言函数。例如,我们可以使用 TestCase 类:
def test_geq(self):
"""tests if value is greater than or equal to a particular target"""
self.assertGreaterEqual(self.rectangle.get_area(), -1)
我们甚至可以检查在执行过程中是否抛出了特定的异常:
def test_assert_raises(self):
"""using assertRaises to detect if an expected error is raised when running a particular block of code"""
with self.assertRaises(ZeroDivisionError):
a = 1 / 0
现在,我们来看看如何建立我们的测试。如果我们有一些代码需要在每个测试运行之前执行呢?我们可以重写 unittest.TestCase 中的 setUp 方法。
class TestGetAreaRectangleWithSetUp(unittest.TestCase):
def setUp(self):
self.rectangle = Rectangle(0, 0)
def test_normal_case(self):
self.rectangle.set_width(2)
self.rectangle.set_height(3)
self.assertEqual(self.rectangle.get_area(), 6, "incorrect area")
def test_negative_case(self):
"""expect -1 as output to denote error when looking at negative area"""
self.rectangle.set_width(-1)
self.rectangle.set_height(2)
self.assertEqual(self.rectangle.get_area(), -1, "incorrect negative output")
在上述代码示例中,我们重写了来自 unittest.TestCase 的 setUp() 方法,使用了我们自己的 setUp() 方法来初始化一个 Rectangle 对象。这个 setUp() 方法在每个单元测试之前运行,有助于避免在多个测试依赖相同代码来设置测试时的代码重复。这类似于 JUnit 中的 @Before 装饰器。
同样,我们还可以重写 tearDown() 方法,用于在每个测试之后执行代码。
为了在每个 TestCase 类中只运行一次该方法,我们也可以使用 setUpClass 方法,如下所示:
class TestGetAreaRectangleWithSetUp(unittest.TestCase):
@classmethod
def setUpClass(self):
self.rectangle = Rectangle(0, 0)
上述代码在每个 TestCase 中仅运行一次,而不是像 setUp 那样在每次测试运行时运行一次。
为了帮助我们组织测试并选择要运行的测试集,我们可以将测试用例汇总到测试套件中,这有助于将应一起执行的测试分组到一个对象中:
...
# loads all unit tests from TestGetAreaRectangle into a test suite
calculate_area_suite = unittest.TestLoader() \
.loadTestsFromTestCase(TestGetAreaRectangleWithSetUp)
在这里,我们还介绍了另一种通过使用 unittest.TextTestRunner 类在 PyUnit 中运行测试的方法,该类允许我们运行特定的测试套件。
runner = unittest.TextTestRunner()
runner.run(calculate_area_suite)
这与从命令行运行文件并调用 unittest.main() 的输出相同。
综合所有内容,这就是单元测试的完整脚本:
class TestGetAreaRectangleWithSetUp(unittest.TestCase):
@classmethod
def setUpClass(self):
#this method is only run once for the entire class rather than being run for each test which is done for setUp()
self.rectangle = Rectangle(0, 0)
def test_normal_case(self):
self.rectangle.set_width(2)
self.rectangle.set_height(3)
self.assertEqual(self.rectangle.get_area(), 6, "incorrect area")
def test_geq(self):
"""tests if value is greater than or equal to a particular target"""
self.assertGreaterEqual(self.rectangle.get_area(), -1)
def test_assert_raises(self):
"""using assertRaises to detect if an expected error is raised when running a particular block of code"""
with self.assertRaises(ZeroDivisionError):
a = 1 / 0
这只是 PyUnit 能做的一部分。我们还可以编写测试,查找与正则表达式匹配的异常消息或仅运行一次的 setUp/tearDown 方法(例如 setUpClass)。
使用 PyTest
PyTest 是内置 unittest 模块的替代品。要开始使用 PyTest,您首先需要安装它,可以通过以下方式进行安装:
Shell
pip install pytest
要编写测试,您只需编写以“test”为前缀的函数名,PyTest 的测试发现程序将能够找到您的测试,例如,
def test_normal_case(self):
rectangle = Rectangle(2, 3)
assert rectangle.get_area() == 6, "incorrect area"
您会注意到 PyTest 使用 Python 内置的 assert 关键字,而不是像 PyUnit 那样的一组断言函数,这可能会更方便,因为我们可以避免查找不同的断言函数。
完整的代码如下:
# Our code to be tested
class Rectangle:
def __init__(self, width, height):
self.width = width
self.height = height
def get_area(self):
return self.width * self.height
def set_width(self, width):
self.width = width
def set_height(self, height):
self.height = height
# The test function to be executed by PyTest
def test_normal_case():
rectangle = Rectangle(2, 3)
assert rectangle.get_area() == 6, "incorrect area"
将其保存到文件 test_file.py 中后,我们可以通过以下方式运行 PyTest 单元测试:
Shell
python -m pytest test_file.py
这将给我们以下输出:
=================== test session starts ====================
platform darwin -- Python 3.9.9, pytest-7.0.1, pluggy-1.0.0
rootdir: /Users/MLM
plugins: anyio-3.4.0, typeguard-2.13.2
collected 1 item
test_file.py . [100%]
==================== 1 passed in 0.01s =====================
您可能会注意到,在 PyUnit 中,我们需要通过运行程序或调用 unittest.main() 来触发测试例程。但在 PyTest 中,我们只需将文件传递给模块。PyTest 模块将收集所有以 test 为前缀定义的函数并逐一调用它们。然后,它将验证 assert 语句是否引发了任何异常。这样可以更方便地让测试保持与业务逻辑一起。
PyTest 还支持将函数归组到类中,但类名应该以“Test”作为前缀(大写的 T),例如,
class TestGetAreaRectangle:
def test_normal_case(self):
rectangle = Rectangle(2, 3)
assert rectangle.get_area() == 6, "incorrect area"
def test_negative_case(self):
"""expect -1 as output to denote error when looking at negative area"""
rectangle = Rectangle(-1, 2)
assert rectangle.get_area() == -1, "incorrect negative output"
使用 PyTest 运行此测试将生成以下输出:
=================== test session starts ====================
platform darwin -- Python 3.9.9, pytest-7.0.1, pluggy-1.0.0
rootdir: /Users/MLM
plugins: anyio-3.4.0, typeguard-2.13.2
collected 2 items
test_code.py .F [100%]
========================= FAILURES =========================
_________ TestGetAreaRectangle.test_negative_case __________
self = <test_code.TestGetAreaRectangle object at 0x10f5b3fd0>
def test_negative_case(self):
"""expect -1 as output to denote error when looking at negative area"""
rectangle = Rectangle(-1, 2)
> assert rectangle.get_area() == -1, "incorrect negative output"
E AssertionError: incorrect negative output
E assert -2 == -1
E + where -2 = <bound method Rectangle.get_area of <test_code.Rectangle object at 0x10f5b3df0>>()
E + where <bound method Rectangle.get_area of <test_code.Rectangle object at 0x10f5b3df0>> = <test_code.Rectangle object at 0x10f5b3df0>.get_area
unittest5.py:24: AssertionError
================= short test summary info ==================
FAILED test_code.py::TestGetAreaRectangle::test_negative_case
=============== 1 failed, 1 passed in 0.12s ================
完整的代码如下:
# Our code to be tested
class Rectangle:
def __init__(self, width, height):
self.width = width
self.height = height
def get_area(self):
return self.width * self.height
def set_width(self, width):
self.width = width
def set_height(self, height):
self.height = height
# The test functions to be executed by PyTest
class TestGetAreaRectangle:
def test_normal_case(self):
rectangle = Rectangle(2, 3)
assert rectangle.get_area() == 6, "incorrect area"
def test_negative_case(self):
"""expect -1 as output to denote error when looking at negative area"""
rectangle = Rectangle(-1, 2)
assert rectangle.get_area() == -1, "incorrect negative output"
为我们的测试实现设置和拆解代码,PyTest 具有极其灵活的 fixture 系统,其中 fixture 是具有返回值的函数。PyTest 的 fixture 系统允许在类、模块、包或会话之间共享 fixture,并且 fixture 可以将其他 fixture 作为参数调用。
在这里,我们简单介绍了 PyTest 的 fixture 系统:
@pytest.fixture
def rectangle():
return Rectangle(0, 0)
def test_negative_case(rectangle):
print (rectangle.width)
rectangle.set_width(-1)
rectangle.set_height(2)
assert rectangle.get_area() == -1, "incorrect negative output"
上述代码将 Rectangle 引入为 fixture,PyTest 在 test_negative_case 的参数列表中匹配这个矩形 fixture,并为 test_negative_case 提供来自矩形函数的输出集合。这对每个其他测试也会如此。然而,请注意,fixtures 可以在每个测试中请求多次,每个测试中 fixture 只运行一次,并且结果会被缓存。这意味着在单个测试运行期间对该 fixture 的所有引用都引用了相同的返回值(如果返回值是引用类型,这一点很重要)。
完整代码如下:
import pytest
# Our code to be tested
class Rectangle:
def __init__(self, width, height):
self.width = width
self.height = height
def get_area(self):
return self.width * self.height
def set_width(self, width):
self.width = width
def set_height(self, height):
self.height = height
@pytest.fixture
def rectangle():
return Rectangle(0, 0)
def test_negative_case(rectangle):
print (rectangle.width)
rectangle.set_width(-1)
rectangle.set_height(2)
assert rectangle.get_area() == -1, "incorrect negative output"
与 PyUnit 类似,PyTest 具有许多其他功能,使你能够构建更全面和高级的单元测试。
单元测试的实际应用
现在,我们将探讨单元测试的实际应用。在我们的示例中,我们将测试一个从 Yahoo Finance 获取股票数据的函数,使用 pandas_datareader,并在 PyUnit 中进行测试:
import pandas_datareader.data as web
def get_stock_data(ticker):
"""pull data from stooq"""
df = web.DataReader(ticker, "yahoo")
return df
这个函数通过从 Yahoo Finance 网站爬取数据来获取特定股票代码的股票数据,并返回 pandas DataFrame。这可能会以多种方式失败。例如,数据读取器可能无法返回任何内容(如果 Yahoo Finance 出现故障)或返回一个缺少列或列中缺少数据的 DataFrame(如果来源重构了其网站)。因此,我们应该提供多个测试函数以检查多种失败模式:
import datetime
import unittest
import pandas as pd
import pandas_datareader.data as web
def get_stock_data(ticker):
"""pull data from stooq"""
df = web.DataReader(ticker, 'yahoo')
return df
class TestGetStockData(unittest.TestCase):
@classmethod
def setUpClass(self):
"""We only want to pull this data once for each TestCase since it is an expensive operation"""
self.df = get_stock_data('^DJI')
def test_columns_present(self):
"""ensures that the expected columns are all present"""
self.assertIn("Open", self.df.columns)
self.assertIn("High", self.df.columns)
self.assertIn("Low", self.df.columns)
self.assertIn("Close", self.df.columns)
self.assertIn("Volume", self.df.columns)
def test_non_empty(self):
"""ensures that there is more than one row of data"""
self.assertNotEqual(len(self.df.index), 0)
def test_high_low(self):
"""ensure high and low are the highest and lowest in the same row"""
ohlc = self.df[["Open","High","Low","Close"]]
highest = ohlc.max(axis=1)
lowest = ohlc.min(axis=1)
self.assertTrue(ohlc.le(highest, axis=0).all(axis=None))
self.assertTrue(ohlc.ge(lowest, axis=0).all(axis=None))
def test_most_recent_within_week(self):
"""most recent data was collected within the last week"""
most_recent_date = pd.to_datetime(self.df.index[-1])
self.assertLessEqual((datetime.datetime.today() - most_recent_date).days, 7)
unittest.main()
我们上面的单元测试系列检查了某些列是否存在(test_columns_present)、数据框是否为空(test_non_empty)、"high" 和 "low" 列是否确实是同一行的最高和最低值(test_high_low),以及数据框中最最近的数据是否在过去 7 天内(test_most_recent_within_week)。
想象一下,你正在进行一个消耗股市数据的机器学习项目。拥有一个单元测试框架可以帮助你识别数据预处理是否按预期工作。
使用这些单元测试,我们能够识别函数输出是否发生了重大变化,这可以成为持续集成(CI)过程的一部分。我们还可以根据需要附加其他单元测试,具体取决于我们对该函数的功能的依赖。
为了完整性,这里提供一个 PyTest 的等效版本:
import pytest
# scope="class" tears down the fixture only at the end of the last test in the class, so we avoid rerunning this step.
@pytest.fixture(scope="class")
def stock_df():
# We only want to pull this data once for each TestCase since it is an expensive operation
df = get_stock_data('^DJI')
return df
class TestGetStockData:
def test_columns_present(self, stock_df):
# ensures that the expected columns are all present
assert "Open" in stock_df.columns
assert "High" in stock_df.columns
assert "Low" in stock_df.columns
assert "Close" in stock_df.columns
assert "Volume" in stock_df.columns
def test_non_empty(self, stock_df):
# ensures that there is more than one row of data
assert len(stock_df.index) != 0
def test_most_recent_within_week(self, stock_df):
# most recent data was collected within the last week
most_recent_date = pd.to_datetime(stock_df.index[0])
assert (datetime.datetime.today() - most_recent_date).days <= 7
构建单元测试可能看起来费时且繁琐,但它们可能是任何 CI 流水线的关键部分,并且是捕获早期 bug 的宝贵工具,以避免它们进一步传递到流水线中并变得更难处理。
如果你喜欢它,那么你应该对其进行测试。
— 谷歌的软件工程
进一步阅读
本节提供了更多资源,以便你可以深入了解这个主题。
库
-
unittest 模块(以及 assert 方法列表),
docs.python.org/3/library/unittest.html -
PyTest,
docs.pytest.org/en/7.0.x/
文章
-
测试驱动开发 (TDD),
www.ibm.com/garage/method/practices/code/practice_test_driven_development/ -
Python 单元测试框架,
pyunit.sourceforge.net/pyunit.html
书籍
-
谷歌的软件工程,作者:Titus Winters, Tom Manshreck 和 Hyrum Wright
www.amazon.com/dp/1492082791 -
编程实践,作者:Brian Kernighan 和 Rob Pike(第五章和第六章),
www.amazon.com/dp/020161586X
摘要
在这篇文章中,你了解了什么是单元测试,以及如何使用两个流行的 Python 库(PyUnit 和 PyTest)进行单元测试。你还学会了如何配置单元测试,并看到了数据科学流程中单元测试的一个用例示例。
具体来说,你学到了:
-
什么是单元测试,及其为何有用
-
单元测试如何融入测试驱动开发流程
-
如何使用 PyUnit 和 PyTest 在 Python 中进行单元测试