不幸的是,PyTorch 并没有提供直接的查看 module 所在 device 的方法。原因是一个 module 内的 tensor 所在设备不一定是统一的。
但可以换一个思路,通过模型权重 tensor 所在 device 获知整个 module。最简单的实践:
next(model.parameters()).device
参考来源:python - How to get the device type of a pytorch module conveniently? - Stack Overflow