持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第1天,点击查看活动详情
python作为一个脚本语言,使用非常便捷,语法的兼容性也很高,学习成本相对较低,但是其功能却非常完善。但是python也有一个致命的缺点,python在运行过程中如果存在循环,那么python的整个效率就会变低。
有时候程序段较长时对于整个代码的分析并不是特别的容易,因此,我们可以利用一个现成的工具进行程序运行过程中的耗时分析。
首先,写一个比较耗时的程序做测试,本例就用蒙特卡洛模拟作为分析对象,其含义是掷一个骰子m次,获得6点至少n次的概率是多少?由于当前的计算设备相比于之前已经快了很多倍,因此为了在执行过程中体现耗时的主要函数部分,我们在实验过程中将次数进行重复,扩大到10的几次方倍数,统计实验的概率。
函数的代码如下 `
import cProfile, pstats
import random
import time
def diceX_py(N, ndice, nX):
'''
蒙特卡洛模拟,扔一个骰子m次,获得6点至少n次的概率
:param N: 实验重复N次
:param ndice: 对应m
:param nX: 对应n
:return: 重复N次执行蒙特卡洛得到的概率
'''
M = 0
for i in range(N):
nflag = 0
for j in range(ndice):
num = random.randint(1, 6)
if num == 6:
nflag += 1
if nflag >= nX:
M += 1
p = float(M) / N
return p
def call_diceX_py():
'''
调用概率模拟函数
:return:
'''
t0 = time.time()
p = diceX_py(10000, 5, 3)
t1 = time.time()
print(p)
print("time used %f"%(t1 - t0))
`
这个函数包括其调用,主要就是今天的测试对象,测试过程中我们可以将测试的结果保存在文件中进行输出,分析代码相对较为简单,主要执行过程被封装在cProfile库函数中,因此我们进需要在上述代码的最后加入下面四行代码即可完成对于函数的耗时分析。代码如下
statement = "diceX_py(10000, 5, 3)"
cProfile.runctx(statement, globals(), locals(), 'calTime.dat')
s = pstats.Stats('calTime.dat')
s.strip_dirs().sort_stats('time').print_stats(30)
上面的程序执行结果如下
通过结果分析,程序执行一共耗费了0.086秒,调用过程一共发生了266693次,虽然代码仅有几行,其调用量还是非常庞大的。假设我们将实验重复的次数增加一个数量级,再看一下耗时情况。实验重复次数由10000增加为100000时,统计结果如下,由原来的0.086变为0.832.近似增加了10倍
继续扩大一个数量级,时间由0.832变为8.291,增加了近似10倍。
通过上述测试发现,在其它不变,仅增加单一变量情况下,计算时间近似线性增加,因此对于算法的优化非常重要。后面的博客中将陆续介绍如何将普通的代码通过cuda平台进行加速计算。 最后放上项目的完整代码
import cProfile, pstats
import random
import time
def diceX_py(N, ndice, nX):
'''
蒙特卡洛模拟,扔一个骰子m次,获得6点至少n次的概率
:param N: 实验重复N次
:param ndice: 对应m
:param nX: 对应n
:return: 重复N次执行蒙特卡洛得到的概率
'''
M = 0
for i in range(N):
nflag = 0
for j in range(ndice):
num = random.randint(1, 6)
if num == 6:
nflag += 1
if nflag >= nX:
M += 1
p = float(M) / N
return p
def call_diceX_py():
'''
调用概率模拟函数
:return:
'''
t0 = time.time()
p = diceX_py(1000000, 5, 3)
t1 = time.time()
print(p)
print("time used %f"%(t1 - t0))
statement = "diceX_py(1000000, 5, 3)"
cProfile.runctx(statement, globals(), locals(), 'calTime.dat')
s = pstats.Stats('calTime.dat')
s.strip_dirs().sort_stats('time').print_stats(30)