开启掘金成长之旅!这是我参与「掘金日新计划 · 2 月更文挑战」的第 7 天,点击查看活动详情
本文主要关注如何将各种Python对象储存为本地文件,并反之从本地文件加载Python对象。 (注意:一般情况下Python读写的工具都需要统一,如果可以跨工具使用的话,我会在对应内容的位置说明)
最近更新时间:2023.2.5 最早更新时间:2022.8.18
@[toc]
1. 使用Python3原生函数读写文件流
Python3使用原生函数open()
可以直接打开本地文件,返回值是文件流。
参数:
- 文件路径
- 打开模式,默认为
r
只读。其他可选项:w
写入,a
添加,rb
/wb
后面的b
指对二进制文件的处理1 encoding
:编码格式,常用选项为utf-8
或gbk
有两种常见写法,一种是将open()
作为命令,对返回的文件流进行处理,最后要记得close()
;一种是将open()
作为上下文管理器,如with open('file.txt') as f:
语句下包裹的代码运行之间自动打开文件流,运行完毕后自动关闭。
(如果对with语句之外的f进行I/O操作,将会报:ValueError: I/O operation on closed file.
这个bug)
对文件流的操作:
readlines()
对于文本文件,就是返回全部内容,列表格式,每行文字是一个元素read()
对于文本文件,就是返回全部内容,字符串格式write(str)
写入一个字符串对象writelines(obj)
写入一个可迭代对象的所有元素,obj
需要元素是字符串。注意:1. 不会自动换行。2. 集合对象也可以写入,但顺序随机;以字符串为键的字典对象也可以写入,但将只写入键值,具体的顺序我不确定。close()
关闭文件流(如果使用with open()
就不用显式关闭文件流)
2. 使用json包
加载本地文件到内存中:json.load(文件流)
将Python对象储存到本地:json.dump(Python对象,文件流)
(文件流是通过open()
函数打开的)
将字符串对象转换为dict对象:json.loads(str)
将dict对象转换为字符串:json.dumps(obj)
dump()
和dumps()
的共有入参:
ensure_ascii
:默认置True, 这会导致转换得到的字符串无法用肉眼直接阅读。所以一般都会显式置False
使用JSON来储存数据的优势在于跨平台、跨语言。
3. 使用pickle包
pickle包官方文档:docs.python.org/3/library/p…
常用的导入包代码:import pickle as pk
将Python对象储存为本地文件:pk.dump(Python对象,文件流)
加载本地文件到内存中:pk.load(文件流)
(文件流是通过open()
函数打开的)
4. 使用csv包
csv包官方文档:csv — CSV File Reading and Writing — Python 3.11.0 documentation
import csv
with open(CSV文件名,newline='') as csvfile:
spamwreader=csv.reader(csvfile)
for row in spamreader:
#一行数据,列表对象,每个元素是该行的一个cell
5. 使用numpy包
5.1 一次性序列化多个对象
习惯以.npz
后缀存储
官方文档:numpy.org/devdocs/ref… numpy.org/devdocs/ref…
6. 使用scipy包
6.1 scipy.sparse
习惯以.npz
后缀存储
储存对象:save_npz()
(官方文档:docs.scipy.org/doc/scipy/r…)
import scipy.sparse
sparse_matrix = scipy.sparse.csc_matrix(np.array([[0, 0, 3], [4, 0, 0]]))
scipy.sparse.save_npz('/tmp/sparse_matrix.npz', sparse_matrix)
加载本地对象:load_npz()
(官方文档:docs.scipy.org/doc/scipy/r…)
import scipy.sparse
sparse_matrix = scipy.sparse.load_npz('/tmp/sparse_matrix.npz')
7. 使用pandas包
8. 使用sklearn包
9. 使用PyTorch包
习惯以.pt
或.pth
后缀存储
PyTorch储存与加载模型的官方教程:Saving and Loading Models — PyTorch Tutorials 1.12.1+cu102 documentation 其他参考资料:python - How do I save a trained model in PyTorch? - Stack Overflow
将对象储存到磁盘:torch.save(obj,path)
将磁盘对象加载到内存:torch.load(path)
(path可以是路径字符串或文件流)
load()
入参:
map_location
:可以是函数、torch.device、字符串或字典,指定对象存储的设备位置。
获取模型参数(返回state_dict,匹配模型层到参数张量的字典文件,只包括可学习的那些。优化器对象也有这个):model.state_dict()
optimizer.state_dict()
将模型参数加载回模型:model.load(state_dict)
所以直接储存模型参数就是:torch.save(model.state_dict(), path)
直接加载模型参数就是:model.load_state_dict(torch.load(path))
需要注意的一个情况是:如果在每个epoch后,都根据当前指标,保存最好指标下的epoch的checkpoint,因为state_dict
是mutable对象OrderedDict
,所以直接引用(best_state = model.state_dict()
)的话会跟着模型的当前指标变化,因此需要深拷贝(best_state = copy.deepcopy(model.state_dict())
)2