Python-数值计算-三-

73 阅读59分钟

Python 数值计算(三)

原文:annas-archive.org/md5/9e81efca12eeaa9a42c1e05702b8c0c0

译者:飞龙

协议:CC BY-NC-SA 4.0

第六章:你好,绘图世界!

学习编程时,我们通常从打印“Hello world!”消息开始。那么,对于包含数据、坐标轴、标签、线条和刻度的图形,我们应该如何开始呢?

本章概述了 Matplotlib 的功能和最新特性。我们将引导你完成 Matplotlib 绘图环境的设置。你将学习如何创建一个简单的折线图、查看并保存你的图形。在本章结束时,你将足够自信开始创建自己的图表,并准备好在接下来的章节中学习自定义和更高级的技巧。

来,向图表世界说“你好!”

本章涵盖的主题包括:

  • 什么是 Matplotlib?

  • 绘制第一个简单的折线图

  • 将数据加载到 Matplotlib 中

  • 导出图形

你好,Matplotlib!

欢迎来到 Matplotlib 2.0 的世界!按照本章中的简单示例,绘制你第一个“Hello world”图表。

什么是 Matplotlib?

Matplotlib 是一个多功能的 Python 库,用于生成数据可视化图形。凭借丰富的图表类型和精致的样式选项,它非常适合创建专业的图表,用于演示和科学出版物。Matplotlib 提供了一种简单的方法来生成适合不同目的的图形,从幻灯片、高清海报打印、动画到基于 Web 的交互式图表。除了典型的二维图表,Matplotlib 还支持基本的三维绘图。

在开发方面,Matplotlib 的层次化类结构和面向对象的绘图接口使绘图过程直观且系统化。虽然 Matplotlib 提供了一个用于实时交互的本地图形用户界面,但它也可以轻松集成到流行的基于 IPython 的交互式开发环境中,如 Jupyter Notebook 和 PyCharm。

Matplotlib 2.0 的新特性

Matplotlib 2.0 引入了许多改进,包括默认样式的外观、图像支持和文本渲染速度。我们选取了一些重要的变化稍后会详细介绍。所有新变化的详细信息可以在文档网站上找到:matplotlib.org/devdocs/users/whats_new.html

如果你已经在使用 Matplotlib 的旧版本,可能需要更加关注这一部分来更新你的编码习惯。如果你完全是 Matplotlib 甚至 Python 的新手,可以直接跳过这部分,先开始使用 Matplotlib,稍后再回来查看。

默认样式的更改

Matplotlib 2.0 版本最显著的变化是默认样式的改变。你可以在这里查看更改的详细列表:matplotlib.org/devdocs/users/dflt_style_changes.html

颜色循环

为了快速绘图而不必为每个数据系列单独设置颜色,Matplotlib 使用一个称为默认属性循环的颜色列表,每个系列将被分配循环中的一个默认颜色。在 Matplotlib 2.0 中,该列表已经从原来的红色、绿色、蓝色、青色、品红色、黄色和黑色(记作['b', 'g', 'r', 'c', 'm', 'y', 'k'])变更为 Tableau 软件引入的当前 category10 色板。顾名思义,新的色板有 10 种不同的颜色,适用于分类显示。通过导入 Matplotlib 并在 Python 中调用matplotlib.rcParams['axes.prop_cycle'],可以访问该列表。

色图

色图对于展示渐变非常有用。黄色到蓝色的“viridis”色图现在是 Matplotlib 2.0 中的默认色图。这个感知统一的色图比经典的“jet”色图更好地展示了数值的视觉过渡。下面是两种色图的对比:

除了默认的感知连续色图外,现在还可以使用定性色图将值分组为不同类别:

散点图

散点图中的点默认尺寸更大,且不再有黑色边缘,从而提供更清晰的视觉效果。如果没有指定颜色,则每个数据系列将使用默认颜色循环中的不同颜色:

图例

虽然早期版本将图例设置在右上角,但 Matplotlib 2.0 默认将图例位置设置为“最佳”。它自动避免图例与数据重叠。图例框还具有圆角、较浅的边缘和部分透明的背景,以便将读者的焦点保持在数据上。经典和当前默认样式中的平方数曲线展示了这一情况:

线条样式

线条样式中的虚线模式现在可以根据线条宽度进行缩放,以显示更粗的虚线以增加清晰度:

来自文档(matplotlib.org/users/dflt_…

填充边缘和颜色

就像前面散点图中的点一样,大多数填充元素默认不再有黑色边缘,使图形看起来更简洁:

字体

默认字体已经从“Bitstream Vera Sans”更改为“DejaVu Sans”。当前的字体支持更多的国际字符、数学符号和符号字符,包括表情符号。

改进的功能或性能

Matplotlib 2.0 展示了改善用户体验的新特性,包括速度、输出质量和资源使用的优化。

改进的颜色转换 API 和 RGBA 支持

Alpha 通道,现在完全支持 Matplotlib 2.0,用于指定透明度。

改进的图像支持

Matplotlib 2.0 现在使用更少的内存和更少的数据类型转换重新采样图像。

更快的文本渲染

据称,Agg 后端的文本渲染速度提高了 20%。我们将在第九章中进一步讨论后端内容,添加交互性和动画图表

默认动画编解码器的更改

为了生成动画图表的视频输出,现在默认使用更高效的编解码器 H.264 代替 MPEG-4。由于 H.264 具有更高的压缩率,较小的输出文件大小允许更长的视频记录时间,并减少加载它们所需的时间和网络数据。H.264 视频的实时播放通常比 MPEG-4 编码的视频更加流畅且质量更高。

设置的更改

为了方便性、一致性或避免意外结果,Matplotlib v2.0 中更改了一些设置。

新的配置参数(rcParams)

新增了一些参数,如 date.autoformatter.year 用于日期时间字符串格式化。

样式参数黑名单

不再允许样式文件配置与样式无关的设置,以防止意外后果。这些参数包括以下内容:

'interactive', 'backend', 'backend.qt4', 'webagg.port', 'webagg.port_retries', 'webagg.open_in_browser', 'backend_fallback', 'toolbar', 'timezone', 'datapath', 'figure.max_open_warning', 'savefig.directory', tk.window_focus', 'docstring.hardcopy'

Axes 属性关键字的更改

Axes 属性 axisbgaxis_bgcolorfacecolor 替代,以保持关键字一致性。

绘制我们的第一个图表

我们将从一个简单的平方曲线的线性图开始,即 y = x²

加载数据以进行绘图

为了可视化数据,我们应该从“拥有”一些数据开始。虽然我们假设你手头有一些不错的数据可以展示,但我们将简要展示如何在 Python 中加载数据以进行绘图。

数据结构

有几个常见的数据结构我们将不断遇到。

列表

列表是 Python 中用于存储一组值的基本数据类型。列表是通过将元素值放入方括号中创建的。为了重用我们的列表,我们可以给它起个名字并像这样存储它:

evens = [2,4,6,8,10]

当我们希望获取更大范围的系列时,例如,为了让平方曲线更平滑,获取更多的数据点,我们可以使用 Python 的 range() 函数:

evens = range(2,102,2)

此命令将给出从 2 到 100(包含)的所有偶数,并将其存储在名为 evens 的列表中。

Numpy 数组

很多时候,我们处理的是更复杂的数据。如果你需要一个包含多个列的矩阵,或者想对集合中的所有元素进行数学操作,那么 numpy 就是你的选择:

import numpy as np

我们根据约定将numpy缩写为np,以保持代码简洁。

np.array() 将支持的数据类型(在此例中是列表)转换为 Numpy 数组。为了从我们的 evens 列表中生成一个 numpy 数组,我们这样做:

np.array(evens)

pandas 数据框

当我们在矩阵中有一些非数值标签或值时,pandas 数据框非常有用。它不像 Numpy 那样要求数据类型统一。列可以命名。还有一些函数,比如 melt()pivot_table(),它们在重塑表格以便分析和绘图时提供了便利。

要将一个列表转换为 pandas 数据框,我们可以做如下操作:

import pandas as pd
pd.DataFrame(evens)

你也可以将一个 numpy 数组转换为 pandas 数据框。

从文件加载数据

虽然这一切能让你复习我们将要处理的数据结构,但在现实生活中,我们不是发明数据,而是从数据源读取它。制表符分隔的纯文本文件是最简单且最常见的数据输入类型。假设我们有一个名为 evens.txt 的文件,里面包含了前面提到的偶数。该文件有两列。第一列只记录不必要的信息,我们想要加载第二列的数据。

这就是假文本文件的样子:

基本的 Python 方式

我们可以初始化一个空列表,逐行读取文件,拆分每一行,并将第二个元素附加到我们的列表中:

evens = []
with open as f:
    for line in f.readlines():
        evens.append(line.split()[1])

当然,你也可以使用一行代码来实现:

evens = [int(x.split()[1]) for x in open('evens.txt').readlines()]

我们只是尝试一步步走,遵循 Python 的 Zen(禅哲学):简单优于复杂。

Numpy 方式

当我们有一个只有两列的文件,并且只需要读取一列时,这非常简单,但当我们拥有一个包含成千上万列和行的扩展表格,并且想要将其转换为 Numpy 矩阵时,它可能会变得更加繁琐。

Numpy 提供了一个标准的一行代码解决方案:

import numpy as np
np.loadtxt(‘evens.txt’,delimiter=’\t’,usecols=1,dtype=np.int32)

第一个参数是数据文件的路径。delimiter 参数指定用于分隔值的字符串,这里是一个制表符。因为 numpy.loadtxt() 默认情况下将任何空白符分隔的值拆分成列,所以这个参数在这里可以省略。我们为演示设置了它。

对于 usecolsdtype,它们分别指定要读取哪些列以及每列对应的数据类型,你可以传递单个值,也可以传递一个序列(例如列表)来读取多列。

Numpy 默认情况下还会跳过以 # 开头的行,这通常表示注释或标题行。你可以通过设置 comment 参数来更改这种行为。

pandas 方式

类似于 Numpy,pandas 提供了一种简单的方法将文本文件加载到 pandas 数据框中:

import pandas as pd
pd.read_csv(usecols=1)

这里的分隔符可以用 sepdelimiter 来表示,默认情况下它是逗号 ,CSV 代表 逗号分隔值)。

关于如何处理不同的数据格式、数据类型和错误,有一长串不太常用的选项可供选择。你可以参考文档:pandas.pydata.org/pandas-docs/stable/generated/pandas.read_csv.html。除了平面 CSV 文件外,Pandas 还提供了读取其他常见数据格式的内置函数,例如 Excel、JSON、HTML、HDF5、SQL 和 Google BigQuery。

为了保持对数据可视化的关注,本书将不深入探讨数据清洗的方法,但这对于数据科学来说是一个非常有用的生存技能。如果你感兴趣,可以查阅关于使用 Python 进行数据处理的相关资源。

导入 Matplotlib 的 pyplot 模块

Matplotlib 包包含许多模块,其中包括控制美学的 artist 模块和用于设置默认值的 rcParams 模块。Pyplot 模块是我们主要处理的绘图接口,它以面向对象的方式创建数据图表。

按惯例,我们在导入时使用plt这个缩写:

import matplotlib.pylot as plt

别忘了运行 Jupyter Notebook 的单元格魔法 %matplotlib inline,以便将图形嵌入输出中。

不要使用 pylab 模块!

现在不推荐使用 pylab 模块,通常被 面向对象OO)接口所替代。虽然 pylab 通过在一个命名空间下导入matplotlib.pyplotnumpy提供了一些便利,但现在许多在线的 pylab 示例仍然存在,但最好分别调用 Matplotlib.pyplotnumpy 模块。

绘制曲线

绘制列表的折线图可以简单到:

plt.plot(evens)

当只指定一个参数时,Pyplot 假定我们输入的数据位于 y 轴,并自动选择 x 轴的刻度。

要绘制图表,调用plt.plot(x,y),其中xy是数据点的 x 坐标和 y 坐标:

plt.plot(evens,evens**2)

要为曲线添加图例标签,我们在 plot 函数中添加标签信息:

plt.plot(evens,evens**2,label = 'x²')
plt.legend()

查看图形

现在,别忘了调用plt.show()来显示图形!

保存图形

现在我们已经绘制了第一个图形。让我们保存我们的工作!当然,我们不想依赖截图。这里有一个简单的方法,通过调用 pyplot.savefig() 来完成。

如果你既想在屏幕上查看图像,又想将其保存在文件中,记得在调用 pyplot.show() 之前先调用 pyplot.savefig(),以确保你不会保存一个空白画布。

设置输出格式

pyplot.savefig() 函数接受输出文件的路径,并自动以指定的扩展名输出。例如,pyplot.savefig('output.png') 会生成一个 PNG 图像。如果没有指定扩展名,默认会生成 SVG 图像。如果指定的格式不受支持,比如.doc,会抛出一个 ValueError Python 异常:

PNG(便携式网络图形)

与另一种常见的图像文件格式 JPEG 相比,PNG 的优势在于允许透明背景。PNG 被大多数图像查看器和处理程序广泛支持。

PDF(便携式文档格式)

PDF 是一种标准的文档格式,你不必担心阅读器的可用性。然而,大多数办公软件不支持将 PDF 作为图像导入。

SVG(可伸缩矢量图形)

SVG 是一种矢量图形格式,可以在不失去细节的情况下进行缩放。因此,可以在较小的文件大小下获得更好的质量。它与 HTML5 兼容,适合用于网页。但某些基础图像查看器可能不支持它。

Post(Postscript)

Postscript 是一种用于电子出版的页面描述语言。它对于批量处理图像以进行出版非常有用。

Gimp 绘图工具包GDK)的光栅图形渲染在 2.0 版本中已被弃用,这意味着像 JPG 和 TIFF 这样的图像格式不再由默认后端支持。我们将在后面更详细地讨论后端。

调整分辨率

分辨率衡量图像记录的细节。它决定了你可以在不失去细节的情况下放大图像的程度。具有较高分辨率的图像在较大尺寸下保持较高质量,但文件大小也会更大。

根据用途,你可能希望以不同的分辨率输出图形。分辨率是通过每英寸的颜色像素**点数(dpi)**来衡量的。你可以通过在pyplot.savefig()函数中指定dpi参数来调整输出图形的分辨率,例如:

plt.savefig('output.png',dpi=300)

虽然较高的分辨率能提供更好的图像质量,但它也意味着更大的文件大小,并且需要更多的计算机资源。以下是一些关于你应设置图像分辨率多高的参考:

  • 幻灯片演示:96 dpi+

以下是微软针对不同屏幕大小的 PowerPoint 演示文稿图形分辨率建议:support.microsoft.com/en-us/help/827745/how-to-change-the-export-resolution-of-a-powerpoint-slide

屏幕高度(像素)分辨率(dpi)
72096(默认)
750100
1125150
1500200
1875250
2250300
  • 海报展示:300 dpi+

  • 网络:72 dpi+(推荐使用可以响应式缩放的 SVG)

摘要

在本章中,你学习了如何使用 Matplotlib 绘制一个简单的折线图。我们设置了环境,导入了数据,并将图形输出为不同格式的图像。在下一章,你将学习如何可视化在线数据。

第七章:可视化在线数据

到目前为止,我们已经介绍了使用 Matplotlib 创建和定制图表的基础知识。在本章中,我们将通过在专门主题中的示例,开始了解更高级的 Matplotlib 使用方法。

在考虑可视化某个概念时,需要仔细考虑以下重要因素:

  • 数据来源

  • 数据过滤和处理

  • 选择适合数据的图表类型:

    • 可视化数据趋势:

      • 折线图、区域图和堆叠区域图
    • 可视化单变量分布:

      • 条形图、直方图和核密度估计图
    • 可视化双变量分布:

      • 散点图、KDE 密度图和六边形图
    • 可视化类别数据:

      • 类别散点图、箱线图、蜂群图、小提琴图
  • 调整图形美学以有效讲述故事

我们将通过使用人口统计和财务数据来讨论这些主题。首先,我们将讨论从 应用程序编程接口API)获取数据时的典型数据格式。接下来,我们将探索如何将 Matplotlib 2.0 与 Pandas、Scipy 和 Seaborn 等其他 Python 包结合使用,以实现不同数据类型的可视化。

常见的 API 数据格式

许多网站通过 API 提供数据,API 是通过标准化架构连接应用程序的桥梁。虽然我们这里不打算详细讨论如何使用 API,因为网站特定的文档通常可以在线找到;但我们将展示在许多 API 中使用的三种最常见的数据格式。

CSV

CSV逗号分隔值)是最古老的文件格式之一,它在互联网存在之前就已经被引入。然而,随着其他高级格式如 JSON 和 XML 的流行,CSV 格式现在逐渐被淘汰。顾名思义,数据值由逗号分隔。预安装的 csv 包和 pandas 包包含读取和写入 CSV 格式数据的类。这个 CSV 示例定义了一个包含两个国家的总人口表:

Country,Time,Sex,Age,Value
United Kingdom,1950,Male,0-4,2238.735
United States of America,1950,Male,0-4,8812.309

JSON

JSONJavaScript 对象表示法)因其高效性和简洁性,近年来越来越受欢迎。JSON 允许指定数字、字符串、布尔值、数组和对象。Python 提供了默认的 json 包来解析 JSON。另外,pandas.read_json 类可以用于将 JSON 导入为 Pandas 数据框。前面的总人口表可以通过以下 JSON 示例表示:

{
 "population": [
 {
 "Country": "United Kingdom",
 "Time": 1950,
 "Sex", "Male",
 "Age", "0-4",
 "Value",2238.735
 },{
 "Country": "United States of America",
 "Time": 1950,
 "Sex", "Male",
 "Age", "0-4",
 "Value",8812.309
 },
 ]
}

XML

XML可扩展标记语言)是数据格式中的瑞士军刀,已成为 Microsoft Office、Apple iWork、XHTML、SVG 等的默认容器。XML 的多功能性有其代价,因为它使得 XML 变得冗长且较慢。Python 中有多种解析 XML 的方法,但建议使用 xml.etree.ElementTree,因为它提供了 Python 风格的接口,并且有高效的 C 后端支持。本书不打算介绍 XML 解析,但其他地方有很好的教程(例如 eli.thegreenplace.net/2012/03/15/processing-xml-in-python-with-elementtree)。

例如,相同的人口表可以转换为 XML 格式:

<?xml version='1.0' encoding='utf-8'?>
<populations>
 <population>
 <Country>United Kingdom</Country> 
 <Time>1950</Time>
 <Sex>Male</Sex>
 <Age>0-4</Age>
 <Value>2238.735</Value>
 </population>
 <population>
 <Country>United States of America</Country>
 <Time>1950</Time>
 <Sex>Male</Sex>
 <Age>0-4</Age>
 <Value>8812.309</Value>
 </population>
</populations>

介绍 pandas

除了 NumPy 和 SciPy,pandas 是 Python 中最常见的科学计算库之一。其作者旨在使 pandas 成为任何语言中最强大、最灵活的开源数据分析和处理工具,实际上,他们几乎实现了这一目标。其强大且高效的库与数据科学家的需求完美契合。像其他 Python 包一样,Pandas 可以通过 PyPI 轻松安装:

pip install pandas

Matplotlib 在 1.5 版本中首次引入,支持将 pandas DataFrame 作为输入应用于各种绘图类。Pandas DataFrame 是一种强大的二维标签数据结构,支持索引、查询、分组、合并以及其他一些常见的关系数据库操作。DataFrame 类似于电子表格,因为 DataFrame 的每一行包含一个实例的不同变量,而每一列则包含一个特定变量在所有实例中的向量。

pandas DataFrame 支持异构数据类型,如字符串、整数和浮点数。默认情况下,行按顺序索引,列由 pandas Series 组成。可以通过 index 和 columns 属性指定可选的行标签或列标签。

导入在线人口数据(CSV 格式)

让我们首先来看一下将在线 CSV 文件导入 pandas DataFrame 的步骤。在这个例子中,我们将使用联合国经济和社会事务部在 2015 年发布的年度人口总结数据集。该数据集还包含了面向 2100 年的人口预测数据:

import numpy as np # Python scientific computing package
import pandas as pd # Python data analysis package

# URL for Annual Population by Age and Sex - Department of Economic
# and Social Affairs, United Nations
source = "https://github.com/PacktPublishing/Matplotlib-2.x-By-Example/blob/master/WPP2015_DB04_Population_Annual.zip"

# Pandas support both local or online files 
data = pd.read_csv(source, header=0, compression='zip', encoding='latin_1') 

# Show the first five rows of the DataFrame
data.head() 

代码的预期输出如下所示:

LocIDLocationVarIDVariantTimeMidPeriodSexIDSexAgeGrpAgeGrpStartAgeGrpSpanValue
04阿富汗2中等19501950.51男性0-405630.044
14阿富汗2中等19501950.51男性5-955516.205
24阿富汗2中等19501950.51男性10-14105461.378
34阿富汗2中等19501950.51男性15-19155414.368
44阿富汗2中等19501950.51男性20-24205374.110

pandas.read_csv 类极为多功能,支持列标题、自定义分隔符、各种压缩格式(例如,.gzip.bz2.zip.xz)、不同的文本编码等。读者可以参考文档页面(pandas.pydata.org/pandas-docs/stable/generated/pandas.read_csv.html)获取更多信息。

通过调用 Pandas DataFrame 对象的 .head() 函数,我们可以快速查看数据的前五行。

在本章中,我们将把这个人口数据集与 Quandl 中的其他数据集合并。不过,Quandl 使用三字母国家代码(ISO 3166 alpha-3)来表示地理位置;因此我们需要相应地重新格式化地点名称。

pycountry 包是根据 ISO 3166 标准转换国家名称的优秀选择。同样,pycountry 可以通过 PyPI 安装:

pip install pycountry 

继续之前的代码示例,我们将为数据框添加一个新的 country 列:

from pycountry import countries

def get_alpha_3(location):
    """Convert full country name to three letter code (ISO 3166 alpha-3)

    Args:
        location: Full location name
    Returns:
        three letter code or None if not found"""

    try:
        return countries.get(name=location).alpha_3
    except:
        return None

# Add a new country column to the dataframe
population_df['country'] = population_df['Location'].apply(lambda x: get_alpha_3(x))
population_df.head()

代码的预期输出如下所示:

-LocIDLocationVarIDVariantTimeMidPeriodSexIDSexAgeGrpAgeGrpStartAgeGrpSpanValuecountry
04阿富汗2中等19501950.51男性0-405630.044AFG
14阿富汗2中等19501950.51男性5-955516.205AFG
24阿富汗2中等19501950.51男性10-14105461.378AFG
34阿富汗2中等19501950.51男性15-19155414.368AFG
44阿富汗2中等19501950.51男性20-24205374.110AFG

导入在线财务数据(JSON 格式)

在本章中,我们还将利用 Quandl 的 API 提取财务数据,并创建有洞察力的可视化图表。如果你不熟悉 Quandl,它是一个财务和经济数据仓库,存储了来自数百家出版商的数百万个数据集。Quandl 最棒的地方在于,这些数据集通过统一的 API 进行交付,无需担心如何正确解析数据。匿名用户每天可以进行最多 50 次 API 调用,注册用户可获得最多 500 次免费的 API 调用。读者可以在www.quandl.com/?modal=register注册免费 API 密钥。

在 Quandl 中,每个数据集都有一个唯一的 ID,这个 ID 在每个搜索结果网页上由 Quandl Code 定义。例如,Quandl 代码 GOOG/NASDAQ_SWTX 定义了 Google 财务发布的历史 NASDAQ 指数数据。每个数据集有三种格式可用——CSV、JSON 和 XML。

尽管 Quandl 提供了官方的 Python 客户端库,但为了演示导入 JSON 数据的一般过程,我们将不使用它。根据 Quandl 的文档,我们可以通过以下 API 调用获取 JSON 格式的数据表:

GET https://www.quandl.com/api/v3/datasets/{Quandl code}/data.json

让我们尝试从 Quandl 获取巨无霸指数数据。

from urllib.request import urlopen
import json
import time
import pandas as pd

def get_bigmac_codes():
    """Get a Pandas DataFrame of all codes in the Big Mac index dataset

    The first column contains the code, while the second header
    contains the description of the code.

    for example, 
    ECONOMIST/BIGMAC_ARG,Big Mac Index - Argentina
    ECONOMIST/BIGMAC_AUS,Big Mac Index - Australia
    ECONOMIST/BIGMAC_BRA,Big Mac Index - Brazil

    Returns:
        codes: Pandas DataFrame of Quandl dataset codes"""

    codes_url = "https://www.quandl.com/api/v3/databases/ECONOMIST/codes"
    codes = pd.read_csv(codes_url, header=None, names=['Code', 'Description'], 
                        compression='zip', encoding='latin_1')

    return codes

def get_quandl_dataset(api_key, code):
    """Obtain and parse a quandl dataset in Pandas DataFrame format

    Quandl returns dataset in JSON format, where data is stored as a 
    list of lists in response['dataset']['data'], and column headers
    stored in response['dataset']['column_names'].

    for example, {'dataset': {...,
             'column_names': ['Date',
                              'local_price',
                              'dollar_ex',
                              'dollar_price',
                              'dollar_ppp',
                              'dollar_valuation',
                              'dollar_adj_valuation',
                              'euro_adj_valuation',
                              'sterling_adj_valuation',
                              'yen_adj_valuation',
                              'yuan_adj_valuation'],
             'data': [['2017-01-31',
                       55.0,
                       15.8575,
                       3.4683903515687,
                       10.869565217391,
                       -31.454736135007,
                       6.2671477203176,
                       8.2697553162259,
                       29.626894343348,
                       32.714616745128,
                       13.625825886047],
                      ['2016-07-31',
                       50.0,
                       14.935,
                       3.3478406427854,
                       9.9206349206349,
                       -33.574590420925,
                       2.0726096168216,
                       0.40224795003514,
                       17.56448458418,
                       19.76377270142,
                       11.643103380531]
                      ],
             'database_code': 'ECONOMIST',
             'dataset_code': 'BIGMAC_ARG',
             ... }}

    A custom column--country is added to denote the 3-letter country code.

    Args:
        api_key: Quandl API key
        code: Quandl dataset code

    Returns:
        df: Pandas DataFrame of a Quandl dataset

    """
    base_url = "https://www.quandl.com/api/v3/datasets/"
    url_suffix = ".json?api_key="

    # Fetch the JSON response 
    u = urlopen(base_url + code + url_suffix + api_key)
    response = json.loads(u.read().decode('utf-8'))

    # Format the response as Pandas Dataframe
    df = pd.DataFrame(response['dataset']['data'], columns=response['dataset']['column_names'])

    # Label the country code
    df['country'] = code[-3:]

    return df

quandl_dfs = []
codes = get_bigmac_codes()

# Replace this with your own API key
api_key = "INSERT YOUR KEY HERE" 

for code in codes.Code:
    # Get the DataFrame of a Quandl dataset
    df = get_quandl_dataset(api_key, code)

    # Store in a list
    quandl_dfs.append(df)

    # Prevents exceeding the API speed limit
    time.sleep(2)

# Concatenate the list of dataframes into a single one    
bigmac_df = pd.concat(quandl_dfs)
bigmac_df.head()

预期的输出如下:

-日期本地价格美元汇率美元价格美元 PPP美元估值美元调整估值欧元调整估值英镑调整估值日元调整估值人民币调整估值国家
02017-01-3155.015.857503.46839010.869565-31.4547366.267158.2697629.626932.714613.6258阿根廷
12016-07-3150.014.935003.3478419.920635-33.5745902.072610.40224817.564519.763811.6431阿根廷
22016-01-3133.013.809252.3897036.693712-51.527332-24.8619-18.714-18.72090.40859-17.029阿根廷
32015-07-3128.09.135003.0651345.845511-36.009727-4.7585-0.357918-6.0109130.86095.02868阿根廷
42015-01-3128.08.610003.2520335.845511-32.1078810.540242-0.804495-2.4946834.39056.01183阿根廷

巨无霸指数是《经济学人》于 1986 年发明的,用来轻松检查货币是否在正确的水平。它基于购买力平价PPP)理论,并被认为是货币汇率在 PPP 下的非正式衡量标准。它通过与类似商品和服务的价格对比来衡量货币的价值,在这种情况下,是巨无霸的价格。市场汇率下不同的价格意味着某种货币被低估或高估。

从 Quandl API 解析 JSON 的代码稍微复杂一些,因此额外的解释可能有助于你理解它。第一个函数get_bigmac_codes()解析 Quandl Economist 数据库中所有可用数据集代码的列表,并将其作为 pandas DataFrame 返回。与此同时,第二个函数get_quandl_dataset(api_key, code)将 Quandl 数据集 API 查询的 JSON 响应转换为 pandas DataFrame。所有获取的数据集都通过pandas.concat()进行拼接。

可视化数据趋势

一旦我们导入了这两个数据集,就可以开始进一步的可视化之旅。让我们从绘制 1950 年到 2017 年的全球人口趋势开始。为了根据某一列的值选择行,我们可以使用以下语法:df[df.variable_name == "target"]df[df['variable_name'] == "target"],其中df是数据框对象。其他条件运算符,如大于 > 或小于 <,也支持。可以使用“与”运算符&或“或”运算符|将多个条件语句链在一起。

为了聚合某一年内所有年龄组的人口数据,我们将依赖 df.groupby().sum(),如以下示例所示:

import matplotlib.pyplot as plt

# Select the aggregated population data from the world for both genders,
# during 1950 to 2017.
selected_data = data[(data.Location == 'WORLD') & (data.Sex == 'Both') & (data.Time <= 2017) ]

# Calculate aggregated population data across all age groups for each year 
# Set as_index=False to avoid the Time variable to be used as index
grouped_data = selected_data.groupby('Time', as_index=False).sum()

# Generate a simple line plot of population vs time
fig = plt.figure()
plt.plot(grouped_data.Time, grouped_data.Value)

# Label the axis
plt.xlabel('Year')
plt.ylabel('Population (thousands)')

plt.show()

区域图和堆叠区域图

有时,我们可能希望通过为线图下方的区域填充颜色来增加视觉冲击力。可以通过 fill_between 类来实现这一点:

fill_between(x, y1, y2=0, where=None, interpolate=False, step=None)

默认情况下,当未指定 y2 时,fill_between 会为 y=0 和曲线之间的区域着色。可以通过使用 whereinterpolatestep 等关键字参数来指定更复杂的着色行为。读者可以通过以下链接获取更多信息:matplotlib.org/examples/pylab_examples/fill_between_demo.html

让我们尝试通过区分男女来绘制一个更详细的图表。我们将探讨男性和女性对人口增长的相对贡献。为此,我们可以使用 stackplot 类绘制堆叠区域图:

# Select the aggregated population data from the world for each gender,
# during 1950 to 2017.
male_data = data[(data.Location == 'WORLD') & (data.Sex == 'Male') & (data.Time <= 2017) ]
female_data = data[(data.Location == 'WORLD') & (data.Sex == 'Female') & (data.Time <= 2017) ]

# Calculate aggregated population data across all age groups for each year 
# Set as_index=False to avoid the Time variable to be used as index
grouped_male_data = male_data.groupby('Time', as_index=False).sum()
grouped_female_data = female_data.groupby('Time', as_index=False).sum()

# Create two subplots with shared y-axis (sharey=True)
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(12,4), sharey=True)

# Generate a simple line plot of population vs time,
# then shade the area under the line in sky blue.
ax1.plot(grouped_data.Time, grouped_data.Value)
ax1.fill_between(grouped_data.Time, grouped_data.Value, color='skyblue')

# Use set_xlabel() or set_ylabel() instead to set the axis label of an
# axes object
ax1.set_xlabel('Year')
ax1.set_ylabel('Population (thousands)')

# Generate a stacked area plot of population vs time
ax2.stackplot(grouped_male_data.Time, grouped_male_data.Value, grouped_female_data.Value)

# Add a figure legend
ax2.legend(['Male', 'Female'], loc='upper left')

# Set the x-axis label only this time
ax2.set_xlabel('Year')
plt.show()

介绍 Seaborn

Seaborn 是由 Michael Waskom 开发的一个统计可视化库,建立在 Matplotlib 之上。它提供了用于可视化类别变量、单变量分布和双变量分布的便捷函数。对于更复杂的图表,提供了多种统计方法,如线性回归模型和聚类算法。像 Matplotlib 一样,Seaborn 也支持 Pandas 数据框作为输入,并自动进行必要的切片、分组、聚合以及统计模型拟合,从而生成有用的图形。

这些 Seaborn 函数旨在通过最小化的参数集,通过 API 创建出版级质量的图形,同时保持 Matplotlib 完整的自定义功能。事实上,Seaborn 中的许多函数在调用时会返回一个 Matplotlib 轴或网格对象。因此,Seaborn 是 Matplotlib 的得力伙伴。要通过 PyPI 安装 Seaborn,可以在终端中运行以下命令:

pip install pandas

Seaborn 将在本书中以 sns 导入。本节不会是 Seaborn 的文档,而是从 Matplotlib 用户的角度,概述 Seaborn 的功能。读者可以访问 Seaborn 的官方网站 (seaborn.pydata.org/index.html) 获取更多信息。

可视化单变量分布

Seaborn 使得可视化数据集分布的任务变得更加容易。以之前讨论的人口数据为例,让我们通过绘制条形图来查看 2017 年不同国家的人口分布情况:

import seaborn as sns
import matplotlib.pyplot as plt

# Extract USA population data in 2017
current_population = population_df[(population_df.Location 
                                    == 'United States of America') & 
                                   (population_df.Time == 2017) &
                                   (population_df.Sex != 'Both')]

# Population Bar chart 
sns.barplot(x="AgeGrp",y="Value", hue="Sex", data = current_population)

# Use Matplotlib functions to label axes rotate tick labels
ax = plt.gca()
ax.set(xlabel="Age Group", ylabel="Population (thousands)")
ax.set_xticklabels(ax.xaxis.get_majorticklabels(), rotation=45)
plt.title("Population Barchart (USA)")

# Show the figure
plt.show()

Seaborn 中的条形图

seaborn.barplot() 函数显示一系列数据点作为矩形条。如果每组有多个点,则在条形顶部显示置信区间,以指示点估计的不确定性。与大多数其他 Seaborn 函数一样,支持各种输入数据格式,如 Python 列表、Numpy 数组、pandas Series 和 pandas DataFrame。

展示人口结构的更传统方式是通过人口金字塔。

那么什么是人口金字塔?顾名思义,它是显示人口年龄分布的金字塔形绘图。它可以粗略地分为三类,即压缩型、稳定型和扩张型,分别用于经历负增长、稳定增长和快速增长的人口。例如,压缩型人口的年轻人比例较低,因此金字塔底部看起来受限。稳定型人口的年轻人和中年组相对较多。而扩张型人口则有大量年轻人,从而导致金字塔底部扩大。

我们可以通过在两个共享 y 轴的子图上绘制两个条形图来构建人口金字塔:

import seaborn as sns
import matplotlib.pyplot as plt

# Extract USA population data in 2017
current_population = population_df[(population_df.Location 
                                    == 'United States of America') & 
                                   (population_df.Time == 2017) &
                                   (population_df.Sex != 'Both')]

# Change the age group to descending order
current_population = current_population.iloc[::-1]

# Create two subplots with shared y-axis
fig, axes = plt.subplots(ncols=2, sharey=True)

# Bar chart for male
sns.barplot(x="Value",y="AgeGrp", color="darkblue", ax=axes[0],
            data = current_population[(current_population.Sex == 'Male')])
# Bar chart for female
sns.barplot(x="Value",y="AgeGrp", color="darkred", ax=axes[1],
            data = current_population[(current_population.Sex == 'Female')])

# Use Matplotlib function to invert the first chart
axes[0].invert_xaxis()

# Use Matplotlib function to show tick labels in the middle
axes[0].yaxis.tick_right()

# Use Matplotlib functions to label the axes and titles
axes[0].set_title("Male")
axes[1].set_title("Female")
axes[0].set(xlabel="Population (thousands)", ylabel="Age Group")
axes[1].set(xlabel="Population (thousands)", ylabel="")
fig.suptitle("Population Pyramid (USA)")

# Show the figure
plt.show()

由于 Seaborn 建立在 Matplotlib 的坚实基础之上,我们可以使用 Matplotlib 的内置函数轻松定制绘图。在前面的例子中,我们使用 matplotlib.axes.Axes.invert_xaxis() 将男性人口图水平翻转,然后使用 matplotlib.axis.YAxis.tick_right() 将刻度标签位置改为右侧。我们进一步使用 matplotlib.axes.Axes.set_title()matplotlib.axes.Axes.set()matplotlib.figure.Figure.suptitle() 组合定制了绘图的标题和轴标签。

我们尝试通过将行 population_df.Location == 'United States of America' 更改为 population_df.Location == 'Cambodia'population_df.Location == 'Japan' 来绘制柬埔寨和日本的人口金字塔。你能把金字塔分类到三类人口金字塔中的一类吗?

为了看到 Seaborn 如何简化相对复杂绘图的代码,让我们看看如何使用原始 Matplotlib 实现类似的绘图。

首先,像之前基于 Seaborn 的示例一样,我们创建具有共享 y 轴的两个子图:

fig, axes = plt.subplots(ncols=2, sharey=True)

接下来,我们使用 matplotlib.pyplot.barh() 绘制水平条形图,并设置刻度的位置和标签,然后调整子图间距:

# Get a list of tick positions according to the data bins
y_pos = range(len(current_population.AgeGrp.unique()))

# Horizontal barchart for male
axes[0].barh(y_pos, current_population[(current_population.Sex ==
             'Male')].Value, color="darkblue")

# Horizontal barchart for female
axes[1].barh(y_pos, current_population[(current_population.Sex == 
             'Female')].Value, color="darkred")

# Show tick for each data point, and label with the age group
axes[0].set_yticks(y_pos)
axes[0].set_yticklabels(current_population.AgeGrp.unique())

# Increase spacing between subplots to avoid clipping of ytick labels
plt.subplots_adjust(wspace=0.3)

最后,我们使用相同的代码进一步定制图形的外观和感觉:

# Invert the first chart
axes[0].invert_xaxis()

# Show tick labels in the middle
axes[0].yaxis.tick_right()

# Label the axes and titles
axes[0].set_title("Male")
axes[1].set_title("Female")
axes[0].set(xlabel="Population (thousands)", ylabel="Age Group")
axes[1].set(xlabel="Population (thousands)", ylabel="")
fig.suptitle("Population Pyramid (USA)")

# Show the figure
plt.show()

与基于 Seaborn 的代码相比,纯 Matplotlib 实现需要额外的代码行来定义刻度位置、刻度标签和子图间距。对于一些其他包含额外统计计算(如线性回归、皮尔逊相关)的 Seaborn 图表类型,代码的简化更加明显。因此,Seaborn 是一个“开箱即用”的统计可视化包,使用户可以写出更简洁的代码。

Seaborn 中的直方图和分布拟合

在人口示例中,原始数据已经分为不同的年龄组。如果数据没有被分组(例如大麦指数数据),该怎么办呢?事实证明,seaborn.distplot可以帮助我们将数据分组,并显示相应的直方图。让我们看一下这个例子:

import seaborn as sns
import matplotlib.pyplot as plt

# Get the BigMac index in 2017
current_bigmac = bigmac_df[(bigmac_df.Date == "2017-01-31")]

# Plot the histogram
ax = sns.distplot(current_bigmac.dollar_price)
plt.show()

seaborn.distplot函数期望输入的是 pandas Series、单维度的 numpy.array 或者 Python 列表。然后,它根据 Freedman-Diaconis 规则确定箱子的大小,最后在直方图上拟合核密度估计KDE)。

KDE 是一种非参数方法,用于估计变量的分布。我们还可以提供一个参数分布,例如贝塔分布、伽马分布或正态分布,作为fit参数。

在这个例子中,我们将拟合来自scipy.stats包的正态分布到大麦指数数据集:

from scipy import stats

ax = sns.distplot(current_bigmac.dollar_price, kde=False, fit=stats.norm)
plt.show()

可视化双变量分布

我们应当记住,大麦指数在不同国家之间并不能直接比较。通常,我们会预期贫穷国家的商品比富裕国家的便宜。为了更公平地呈现该指数,最好显示大麦价格与国内生产总值GDP)人均的关系。

我们将从 Quandl 的世界银行世界发展指标WWDI)数据集中获取人均 GDP 数据。基于之前获取 Quandl JSON 数据的代码示例,你能尝试将其修改为下载人均 GDP 数据集吗?

对于不耐烦的人,这里是完整的代码:

import urllib
import json
import pandas as pd
import time
from urllib.request import urlopen

def get_gdp_dataset(api_key, country_code):
    """Obtain and parse a quandl GDP dataset in Pandas DataFrame format
    Quandl returns dataset in JSON format, where data is stored as a 
    list of lists in response['dataset']['data'], and column headers
    stored in response['dataset']['column_names'].

    Args:
        api_key: Quandl API key
        country_code: Three letter code to represent country

    Returns:
        df: Pandas DataFrame of a Quandl dataset
    """
    base_url = "https://www.quandl.com/api/v3/datasets/"
    url_suffix = ".json?api_key="

    # Compose the Quandl API dataset code to get GDP per capita
    # (constant 2000 US$) dataset
    gdp_code = "WWDI/" + country_code + "_NY_GDP_PCAP_KD"

    # Parse the JSON response from Quandl API
    # Some countries might be missing, so we need error handling code
    try:
        u = urlopen(base_url + gdp_code + url_suffix + api_key)
    except urllib.error.URLError as e:
        print(gdp_code,e)
        return None

    response = json.loads(u.read().decode('utf-8'))

    # Format the response as Pandas Dataframe
    df = pd.DataFrame(response['dataset']['data'], columns=response['dataset']['column_names'])

    # Add a new country code column
    df['country'] = country_code

    return df

api_key = "INSERT YOUR KEY HERE"
quandl_dfs = []

# Loop through all unique country code values in the BigMac index DataFrame
for country_code in bigmac_df.country.unique():
    # Fetch the GDP dataset for the corresponding country 
    df = get_gdp_dataset(api_key, country_code)

    # Skip if the response is empty
    if df is None:
        continue

    # Store in a list DataFrames
    quandl_dfs.append(df)

    # Prevents exceeding the API speed limit
    time.sleep(2)

# Concatenate the list of DataFrames into a single one 
gdp_df = pd.concat(quandl_dfs)
gdp_df.head()

预期输出:

WWDI/EUR_NY_GDP_PCAP_KD HTTP Error 404: Not Found
WWDI/SIN_NY_GDP_PCAP_KD HTTP Error 404: Not Found
WWDI/ROC_NY_GDP_PCAP_KD HTTP Error 404: Not Found
WWDI/UAE_NY_GDP_PCAP_KD HTTP Error 404: Not Found
日期国家
02015-12-3110501.660269ARG
12014-12-3110334.780146ARG
22013-12-3110711.229530ARG
32012-12-3110558.265365ARG
42011-12-3110780.342508ARG

我们可以看到,人均 GDP 数据集在四个地理位置上不可用,但我们现在可以忽略这一点。

接下来,我们将使用pandas.merge()合并包含大麦指数和人均 GDP 的两个 DataFrame。WWDI 人均 GDP 数据集的最新记录是在 2015 年底收集的,所以我们将其与同年对应的大麦指数数据集配对。

对于熟悉 SQL 语言的用户,pandas.merge()支持四种模式,即左连接、右连接、内连接和外连接。由于我们只关心两个 DataFrame 中都有匹配国家的行,所以我们将选择内连接:

merged_df = pd.merge(bigmac_df[(bigmac_df.Date == "2015-01-31")], gdp_df[(gdp_df.Date == "2015-12-31")], how='inner', on='country')
merged_df.head()
Date_xlocal_pricedollar_exdollar_pricedollar_pppdollar_valuationdollar_adj_valuationeuro_adj_valuationsterling_adj_valuationyen_adj_valuationyuan_adj_valuationcountryDate_yValue
02015-01-3128.008.6100003.2520335.845511-32.1078810.540242-0.804495-2.4946834.39056.01183ARG2015-12-3110501.660269
12015-01-315.301.2272204.3187051.106472-9.839144-17.8995-18.9976-20.37789.74234-13.4315AUS2015-12-3154688.445933
22015-01-3113.502.5927505.2068272.8183728.70201968.455566.202463.3705125.17277.6231BRA2015-12-3111211.891104
32015-01-312.890.6615944.3682350.603340-8.8051153.112571.73343037.82898.72415GBR2015-12-3141182.619517
42015-01-315.701.2285504.6396161.189979-3.139545-2.34134-3.64753-5.2892830.53872.97343CAN2015-12-3150108.065004

Seaborn 中的散点图

散点图是科学和商业世界中最常见的图表之一。它尤其适用于显示两个变量之间的关系。虽然我们可以简单地使用matplotlib.pyplot.scatter来绘制散点图,但我们也可以使用 Seaborn 来构建具有更多高级功能的类似图表。

seaborn.regplot()seaborn.lmplot()这两个函数以散点图的形式显示线性关系、回归线,以及回归线周围的 95% 置信区间。两者的主要区别在于,lmplot()结合了regplot()FacetGrid,使我们能够创建带有颜色编码或分面散点图,显示三个或更多变量对之间的交互作用。我们将在本章和下一章展示lmplot()的使用。

seaborn.regplot()的最简单形式支持 numpy 数组、pandas Series 或 pandas DataFrame 作为输入。可以通过指定fit_reg=False来移除回归线和置信区间。

我们将研究一个假设,即在较贫穷的国家巨无霸更便宜,反之亦然,并检查巨无霸指数与人均 GDP 之间是否存在相关性:

import seaborn as sns
import matplotlib.pyplot as plt

# seaborn.regplot() returns matplotlib.Axes object
ax = sns.regplot(x="Value", y="dollar_price", data=merged_df, fit_reg=False)
ax.set_xlabel("GDP per capita (constant 2000 US$)")
ax.set_ylabel("BigMac index (US$)")

plt.show()

预期的输出:

到目前为止一切顺利!看起来巨无霸指数与人均 GDP 正相关。让我们重新启用回归线,并标注一些巨无霸指数值极端的国家:

ax = sns.regplot(x="Value", y="dollar_price", data=merged_df)
ax.set_xlabel("GDP per capita (constant 2000 US$)")
ax.set_ylabel("BigMac index (US$)")

# Label the country code for those who demonstrate extreme BigMac index
for row in merged_df.itertuples():
    if row.dollar_price >= 5 or row.dollar_price <= 2:
        ax.text(row.Value,row.dollar_price+0.1,row.country)

plt.show()

这是预期的输出:

我们可以看到,许多国家的数据都落在回归线的置信区间内。根据每个国家的人均 GDP 水平,线性回归模型预测了相应的巨无霸指数。如果实际指数偏离回归模型,则货币价值可能出现低估或高估的迹象。

通过标注那些显示极高或极低值的国家,我们可以清晰地看到,即使考虑到 GDP 差异,巴西和瑞士的巨无霸平均价格被高估,而印度、俄罗斯和乌克兰的价格则被低估。

由于 Seaborn 并不是一个用于统计分析的包,我们需要依赖其他包,如scipy.statsstatsmodels,来获得回归模型的参数。在下一个示例中,我们将从回归模型中获取slopeintercept参数,并为高于或低于回归线的点应用不同的颜色:

from scipy.stats import linregress

ax = sns.regplot(x="Value", y="dollar_price", data=merged_df)
ax.set_xlabel("GDP per capita (constant 2000 US$)")
ax.set_ylabel("BigMac index (US$)")

# Calculate linear regression parameters
slope, intercept, r_value, p_value, std_err = linregress(merged_df.Value, merged_df.dollar_price)

colors = []
for row in merged_df.itertuples():
    if row.dollar_price > row.Value * slope + intercept:
        # Color markers as darkred if they are above the regression line
        color = "darkred"
    else:
        # Color markers as darkblue if they are below the regression line
        color = "darkblue"

    # Label the country code for those who demonstrate extreme BigMac index
    if row.dollar_price >= 5 or row.dollar_price <= 2:
        ax.text(row.Value,row.dollar_price+0.1,row.country)

    # Highlight the marker that corresponds to China
    if row.country == "CHN":
        t = ax.text(row.Value,row.dollar_price+0.1,row.country)
        color = "yellow"

    colors.append(color)

# Overlay another scatter plot on top with marker-specific color
ax.scatter(merged_df.Value, merged_df.dollar_price, c=colors)

# Label the r squared value and p value of the linear regression model.
# transform=ax.transAxes indicates that the coordinates are given relative
# to the axes bounding box, with 0,0 being the lower left of the axes
# and 1,1 the upper right.
ax.text(0.1, 0.9, "$r²={0:.3f}, p={1:.3e}$".format(r_value ** 2, p_value), transform=ax.transAxes)

plt.show()

与普遍观点相反,2015 年中国的货币似乎并没有显著低估,因为其标记完全落在回归线的 95%置信区间内。

为了更好地展示数值的分布,我们可以通过seaborn.jointplot()xy值的直方图与散点图结合起来:

# seaborn.jointplot() returns a seaborn.JointGrid object
g = sns.jointplot(x="Value", y="dollar_price", data=merged_df)

# Provide custom axes labels through accessing the underlying axes object
# We can get matplotlib.axes.Axes of the scatter plot by calling g.ax_joint
g.ax_joint.set_xlabel("GDP per capita (constant 2000 US$)")
g.ax_joint.set_ylabel("BigMac index (US$)")

# Set the title and adjust the margin
g.fig.suptitle("Relationship between GDP per capita and BigMac Index")
g.fig.subplots_adjust(top=0.9)

plt.show()

通过在jointplot中额外指定kind参数为regresidhexkde,我们可以迅速将图表类型分别更改为回归图、残差图、六边形图或 KDE 轮廓图。

在此给出一个重要声明:根据我们手头的数据,现在下结论关于货币估值仍然为时过早!劳动力成本、租金、原材料成本和税收等不同的商业因素都会对巨无霸的定价模型产生影响,但这些内容超出了本书的范围。

可视化分类数据

在本章的最后,我们来整合一下到目前为止我们处理过的所有数据集。还记得在本章开头我们简要介绍过三种人口结构类别(即收缩型、稳定型和扩展型)吗?

在本节中,我们将实现一个简单的算法,将人口分类为三种类别之一。之后,我们将探索不同的可视化分类数据的技术。

在线上,大多数参考文献只讨论了人口金字塔的可视化分类(例如,www.populationeducation.org/content/what-are-different-types-population-pyramids)。确实存在基于聚类的方法(例如,Korenjak-Cˇ erne, Kejžar, Batagelj (2008)。人口金字塔的聚类。Informatica. 32.),但是迄今为止,人口类别的数学定义很少被讨论。我们将在下一个示例中构建一个基于“0-4”和“50-54”年龄组之间人口比例的简单分类器:

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Select total population for each country in 2015
current_population = population_df[(population_df.Time == 2015) &
                                   (population_df.Sex == 'Both')]

# A list for storing the population type for each country
pop_type_list = []

# Look through each country in the BigMac index dataset
for country in merged_df.country.unique():
    # Make sure the country also exist in the GDP per capita dataset
    if not country in current_population.country.values:
       continue

    # Calculate the ratio of population between "0-4" and "50-54"
    # age groups
    young = current_population[(current_population.country == country) &
                               (current_population.AgeGrp == "0-4")].Value

    midage = current_population[(current_population.country == country) &
                                (current_population.AgeGrp == "50-54")].Value

    ratio = float(young) / float(midage)

    # Classify the populations based on arbitrary ratio thresholds
    if ratio < 0.8:
        pop_type = "constrictive"
    elif ratio < 1.2 and ratio >= 0.8:
        pop_type = "stable"
    else:
        pop_type = "expansive"

    pop_type_list.append([country, ratio, pop_type])

# Convert the list to Pandas DataFrame
pop_type_df = pd.DataFrame(pop_type_list, columns=['country','ratio','population type'])

# Merge the BigMac index DataFrame with population type DataFrame
merged_df2 = pd.merge(merged_df, pop_type_df, how='inner', on='country')
merged_df2.head()

期望的输出如下:

Date_xlocal_pricedollar_exdollar_pricedollar_pppdollar_valuationdollar_adj_valuationeuro_adj_valuationsterling_adj_valuationyen_adj_valuationyuan_adj_valuationcountryDate_yValueratiopopulation type
02015-01-3128.008.6100003.2520335.845511-32.1078810.540242-0.804495-2.4946834.39056.01183ARG2015-12-3110501.6602691.695835扩张
12015-01-315.301.2272204.3187051.106472-9.839144-17.8995-18.9976-20.37789.74234-13.4315AUS2015-12-3154688.4459330.961301稳定
22015-01-3113.502.5927505.2068272.8183728.70201968.455566.202463.3705125.17277.6231BRA2015-12-3111211.8911041.217728扩张
32015-01-312.890.6615944.3682350.603340-8.8051153.112571.73343037.82898.72415GBR2015-12-3141182.6195170.872431稳定
42015-01-315.701.2285504.6396161.189979-3.139545-2.34134-3.64753-5.2892830.53872.97343CAN2015-12-3150108.0650040.690253收缩

分类散点图

通过将数据分类,我们可以检查不同人口类型是否展示出不同的Big Mac 指数分布。

我们可以使用seaborn.lmplot来解析数据并创建一个分类散点图。回顾一下,lmplot()regplot()FacetGrid结合,用于在分面网格或颜色编码的散点图中可视化三对或更多的变量。在接下来的示例中,我们将把人口类型变量分配给lmplot()colrowhue参数。让我们来看一下结果:

# Horizontal faceted grids (col="population type")
g = sns.lmplot(x="Value", y="dollar_price", col="population type", data=merged_df2)
g.set_xlabels("GDP per capita (constant 2000 US$)")
g.set_ylabels("BigMac index (US$)")

plt.show()

上述代码片段生成了:

另外,如果我们在代码片段中将row="population type"替换为col="population type",将会生成以下图表:

最后,通过将col="population type"更改为hue="population type",将生成一个颜色编码的分类散点图:

实际上,colrowhue 可以结合使用,创建丰富的分面网格。当数据中存在多个维度时,这特别有用。关于分面网格的更多讨论将在下一章中介绍。

条形图和蜂群图

条形图本质上是一个散点图,其中 x- 轴表示一个分类变量。条形图的典型用法是在每个数据点上应用一个小的随机抖动值,使得数据点之间的间隔更加清晰:

# Strip plot with jitter value
ax = sns.stripplot(x="population type", y="dollar_price", data=merged_df2, jitter=True)
ax.set_xlabel("Population type")
ax.set_ylabel("BigMac index (US$)")

plt.show()

蜂群图与条形图非常相似,然而点的位置会自动调整以避免重叠,即使没有应用抖动值。这些图像像蜜蜂围绕某个位置飞舞,因此也被称为蜂群图。

如果我们将前面代码片段中的 Seaborn 函数调用从sns.stripplot更改为sns.swarmplot,结果将会变成这样:

箱型图和小提琴图

条形图和蜂群图的数据显示方式使得比较变得困难。假设你想找出稳定型或收缩型人口类型的中位数 BigMac 指数值哪个更高。你能基于前面两个示例图进行判断吗?

你可能会认为收缩型组的中位数值较高,因为它有更高的最大数据点,但实际上,稳定型组的中位数值更高。

是否有更好的图表类型来比较分类数据的分布?来看看这个!我们来尝试一下箱型图:

# Box plot
ax = sns.boxplot(x="population type", y="dollar_price", data=merged_df2)
ax.set_xlabel("Population type")
ax.set_ylabel("BigMac index (US$)")

plt.show()

预期输出:

箱型图的框表示数据的四分位数,中心线表示中位数值,胡须表示数据的完整范围。那些偏离上四分位数或下四分位数超过 1.5 倍四分位距的数据点被视为异常值,并以飞点形式显示。

小提琴图将我们数据的核密度估计与箱型图结合在一起。箱型图和小提琴图都显示了中位数和四分位数范围,但小提琴图更进一步,通过显示适合数据的完整估计概率分布来展示更多信息。因此,我们可以判断数据中是否存在峰值,并且还可以比较它们的相对幅度。

如果我们将代码片段中的 Seaborn 函数调用从sns.boxplot更改为sns.violinplot,结果将会像这样:

我们还可以将条形图或蜂群图叠加在箱型图或蜂群图之上,从而兼得两者的优点。这里是一个示例代码:

# Prepare a box plot
ax = sns.boxplot(x="population type", y="dollar_price", data=merged_df2)

# Overlay a swarm plot on top of the same axes
sns.swarmplot(x="population type", y="dollar_price", data=merged_df2, color="w", ax=ax)
ax.set_xlabel("Population type")
ax.set_ylabel("BigMac index (US$)")

plt.show()

预期输出:

控制 Seaborn 图形美学

虽然我们可以使用 Matplotlib 自定义图形的美学,但 Seaborn 提供了几个方便的函数来简化定制。如果您使用的是 Seaborn 0.8 或更高版本,必须在导入后显式调用 seaborn.set(),以启用 Seaborn 默认的美观主题。在较早版本中,seaborn.set() 在导入时隐式调用。

预设主题

Seaborn 中的五种默认主题,即 darkgrid、whitegrid、dark、white 和 ticks,可以通过调用 seaborn.set_style() 函数来选择。

必须在发出任何绘图命令之前调用 seaborn.set_style(),以便正确显示主题。

从图形中移除脊柱

要删除或调整脊柱的位置,可以使用 seaborn.despine 函数。默认情况下,图形的顶部和右侧脊柱被移除,可以通过设置 left=Truebottom=True 来移除其他脊柱。通过使用偏移和修剪参数,还可以调整脊柱的位置。

seaborn.despine 必须在调用 Seaborn 绘图函数之后调用。

这里是 seaborn.despine 函数中不同参数组合的结果:

改变图形的大小

要控制图形的高度和宽度,我们也可以依赖 matplotlib.pyplot.figure(figsize=(WIDTH,HEIGHT))

在此示例中,我们将把之前的直方图示例的大小更改为宽 8 英寸,高 4 英寸:

import seaborn as sns
import matplotlib.pyplot as plt
from scipy import stats

# Note: Codes related to data preparation are skipped for brevity
# Reset all previous theme settings to defaults
sns.set()

# Change the size to 8 inches wide and 4 inches tall
fig = plt.figure(figsize=(8,4))

# We are going to reuse current_bigmac that was generated earlier
# Plot the histogram
ax = sns.distplot(current_bigmac.dollar_price)
plt.show()

以下是前述代码的预期输出:

Seaborn 还提供了 seaborn.set_context() 函数来控制图表元素的比例。有四种预设的上下文,分别是 paper、notebook、talk 和 poster,它们按大小递增排列。默认情况下,选择的是 Notebook 风格。以下是将上下文设置为 poster 的示例:

# Reset all previous theme settings to defaults
sns.set()

# Set Seaborn context to poster
sns.set_context("poster")

# We are going to reuse current_bigmac that was generated earlier
# Plot the histogram
ax = sns.distplot(current_bigmac.dollar_price)
plt.show()

以下是前述代码的预期输出:

微调图形的样式

Seaborn 图形中的几乎每个元素都可以通过 seaborn.set 进一步自定义。以下是支持的参数列表:

  • context:预设的上下文之一——{paper, notebook, talk, poster}。

  • style:轴样式之一——{darkgrid, whitegrid, dark, white, ticks}。

  • palette:在 seaborn.pydata.org/generated/seaborn.color_palette.html#seaborn.color_palette 中定义的色板之一。

  • font:支持的字体或字体家族名称,如 serif、sans-serif、cursive、fantasy 或 monospace。欲了解更多信息,请访问 matplotlib.org/api/font_manager_api.html

  • font_scale:字体元素的独立缩放因子。

  • rc:额外 rc 参数映射的字典。要获取所有 rc 参数的完整列表,可以运行 seaborn.axes_style()

当前使用的预设上下文或轴样式中未定义的 RC 参数无法被覆盖。有关 seaborn.set() 的更多信息,请访问 seaborn.pydata.org/generated/seaborn.set.html#seaborn.set

让我们尝试增加字体比例、增加 KDE 图的线宽,并改变几个图表元素的颜色:

# Get a dictionary of all parameters that can be changed
sns.axes_style()

"""
Returns
{'axes.axisbelow': True,
 'axes.edgecolor': '.8',
 'axes.facecolor': 'white',
 'axes.grid': True,
 'axes.labelcolor': '.15',
 'axes.linewidth': 1.0,
 'figure.facecolor': 'white',
 'font.family': [u'sans-serif'],
 'font.sans-serif': [u'Arial',
 u'DejaVu Sans',
 u'Liberation Sans',
 u'Bitstream Vera Sans',
 u'sans-serif'],
 'grid.color': '.8',
 'grid.linestyle': u'-',
 'image.cmap': u'rocket',
 'legend.frameon': False,
 'legend.numpoints': 1,
 'legend.scatterpoints': 1,
 'lines.solid_capstyle': u'round',
 'text.color': '.15',
 'xtick.color': '.15',
 'xtick.direction': u'out',
 'xtick.major.size': 0.0,
 'xtick.minor.size': 0.0,
 'ytick.color': '.15',
 'ytick.direction': u'out',
 'ytick.major.size': 0.0,
 'ytick.minor.size': 0.0}
 """

# Increase the font scale to 2, change the grid color to light grey, 
# and axes label color to dark blue
sns.set(context="notebook", 
 style="darkgrid",
 font_scale=2, 
 rc={'grid.color': '0.6', 
 'axes.labelcolor':'darkblue',
 "lines.linewidth": 2.5})

# Plot the histogram
ax = sns.distplot(current_bigmac.dollar_price)
plt.show()

该代码生成以下直方图:

到目前为止,我们只介绍了控制全局美学的函数。如果我们只想改变某个特定图表的样式呢?

幸运的是,大多数 Seaborn 绘图函数都提供了专门的参数来定制样式。这也意味着并没有一个适用于所有 Seaborn 绘图函数的通用样式教程。然而,我们可以仔细查看这段 seaborn.distplot() 的代码示例,以了解大概:

# Note: Codes related to data preparation and imports are skipped for
# brevity
# Reset the style
sns.set(context="notebook", style="darkgrid")

# Plot the histogram with custom style
ax = sns.distplot(current_bigmac.dollar_price,
                 kde_kws={"color": "g", 
                          "linewidth": 3, 
                          "label": "KDE"},
                 hist_kws={"histtype": "step", 
                           "alpha": 1, 
                           "color": "k",
                           "label": "histogram"})

plt.show()

预期结果:

一些 Seaborn 函数支持更加直接的美学定制方法。例如,seaborn.barplot 可以通过关键字参数,如 facecoloredgecolorecolorlinewidth,传递给底层的 matplotlib.pyplot.bar 函数:

# Note: Codes related to data preparation and imports are skipped
# for brevity
# Population Bar chart 
sns.barplot(x="AgeGrp",y="Value", hue="Sex",
            linewidth=2, edgecolor="w",
            data = current_population)

# Use Matplotlib functions to label axes rotate tick labels
ax = plt.gca()
ax.set(xlabel="Age Group", ylabel="Population (thousands)")
ax.set_xticklabels(ax.xaxis.get_majorticklabels(), rotation=45)
plt.title("Population Barchart (USA)")

# Show the figure
plt.show()

更多关于颜色的内容

颜色可能是图表风格中最重要的方面,因此它值得单独设立一个小节。有许多优秀的资源讨论了选择颜色在可视化中的原则(例如,betterfigures.org/2015/06/23/picking-a-colour-scale-for-scientific-graphics/earthobservatory.nasa.gov/blogs/elegantfigures/2013/08/05/subtleties-of-color-part-1-of-6/)。官方的 Matplotlib 文档也包含了关于颜色映射的良好概述(matplotlib.org/users/colormaps.html)。

有效使用颜色可以增加足够的对比度,使某些内容突出并吸引观众的注意力。颜色还可以唤起情感;例如,红色通常与重要或激情相关,而绿色通常与自然或稳定相关。如果你想通过图表传递一个故事,务必尝试使用合适的配色方案。据估计,8%的男性和 0.5%的女性患有红绿色盲,因此我们在选择颜色时也需要考虑到这些人群。

配色方案和调色板

Seaborn 提供了三种常见的颜色调色板——定性、发散性和连续性:

  • 定性调色板最适用于具有离散级别或名义/分类数据的数据。可以通过向seaborn.color_palette提供 Matplotlib 颜色列表来创建自定义定性调色板。

  • 分歧调色板用于突出图形中的低值和高值,具有中性色的中点。可以通过将两个色调值以及可选的亮度和饱和度值传递给seaborn.diverging_palette函数来创建自定义分歧调色板。

  • 顺序调色板通常用于量化数据,这些数据在低到高之间连续变化。

    可以通过向seaborn.light_paletteseaborn.dark_palette提供单一的 Matplotlib 颜色,创建自定义顺序调色板,这将生成一个从浅色或深色的去饱和值逐渐变化到种子颜色的调色板。

在下一个示例中,我们将绘制最常用的定性、分歧和顺序调色板,以及一些自定义调色板:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

def palplot(pal, ax):
    """Plot the values in a color palette as a horizontal array.
    Adapted from seaborn.palplot

    Args:
        p : seaborn color palette
        ax : axes to plot the color palette
    """
    n = len(pal) 
    ax.imshow(np.arange(n).reshape(1, n),
              cmap=ListedColormap(list(pal)),
              interpolation="nearest", aspect="auto")
    ax.set_xticks(np.arange(n) - .5)
    ax.set_yticks([-.5, .5])
    ax.set_xticklabels([])
    ax.set_yticklabels([])

palettes = {"qualitative": ["deep", "pastel", "bright", "dark", 
                            "colorblind", "Accent", "Paired", 
                            "Set1", "Set2", "Set3", "Pastel1", 
                            "Pastel2", "Dark2"],
            "diverging": ["BrBG", "PiYG", "PRGn", "PuOr", "RdBu", 
                          "RdBu_r", "RdGy", "RdGy_r", "RdYlGn", 
                          "coolwarm"],
            "sequential": ["husl", "Greys", "Blues", "BuGn_r", 
                           "GnBu_d", "plasma", "viridis","cubehelix"]}

#Reset to default Seaborn style
sns.set()

# Create one subplot per palette, the x-axis is shared
fig, axarr = plt.subplots(13, 3, sharex=True, figsize=(12,11))

# Plot 9 color blocks for each palette
for i, palette_type in enumerate(palettes.keys()):
    for j, palette in enumerate(palettes[palette_type]):
        pal = sns.color_palette(palettes[palette_type][j], 9)
        palplot(pal, axarr[j,i])
        axarr[j,i].set_xlabel(palettes[palette_type][j])

# Plot a few more custom diverging palette
custom_diverging_palette = [
 sns.diverging_palette(220, 20, n=9),
 sns.diverging_palette(10, 220, sep=80, n=9),
 sns.diverging_palette(145, 280, s=85, l=25, n=9)
]

for i, palette in enumerate(custom_diverging_palette):
    palplot(palette, axarr[len(palettes["diverging"])+i,1])
    axarr[len(palettes["diverging"])+i,1].set_xlabel("custom diverging 
    {}".format(i+1))

# Plot a few more custom sequential palette
other_custom_palette = [
 sns.light_palette("green", 9),
 sns.light_palette("green", 9, reverse=True),
 sns.dark_palette("navy", 9),
 sns.dark_palette("navy", 9, reverse=True),
 sns.color_palette(["#49a17a","#4aae82","#4eb98a","#55c091","#c99b5f",
 "#cbb761","#c5cc62","#accd64","#94ce65"])
]

for i, palette in enumerate(other_custom_palette):
    palplot(palette, axarr[len(palettes["sequential"])+i,2])
    axarr[len(palettes["sequential"])+i,2].set_xlabel("custom sequential
    {}".format(i+1))

# Reduce unnecessary margin space
plt.tight_layout()

# Show the plot
plt.show()

预期的输出如下:

要更改 Seaborn 图形的配色方案,我们可以使用大多数 Seaborn 函数中提供的colorpalette参数。color参数支持应用于所有元素的单一颜色;而palette参数支持一系列颜色,用于区分hue变量的不同水平。

一些 Seaborn 函数仅支持color参数(例如,分布图),而其他函数可以同时支持colorpalette(例如,条形图和箱型图)。读者可以参考官方文档查看哪些参数是受支持的。

以下三个代码片段演示了如何在分布图(dist plot)、条形图(bar plot)和箱型图(box plot)中使用colorpalette参数:

# Note: Codes related to data preparation and imports are skipped
# for brevity
# Change the color of histogram and KDE line to darkred
ax = sns.distplot(current_bigmac.dollar_price, color="darkred")
plt.show()

current_population = population_df[(population_df.Location == 'United States of America') & 
                                   (population_df.Time == 2017) &
                                   (population_df.Sex != 'Both')]
# Change the color palette of the bar chart to Paired 
sns.barplot(x="AgeGrp",y="Value", hue="Sex", palette="Paired", data = current_population)
# Rotate tick labels by 30 degree
plt.setp(plt.gca().get_xticklabels(), rotation=30, horizontalalignment='right') 
plt.show()

# Note: Codes related to data preparation and imports are skipped
# for brevity
# Change the color palette of the bar chart to Set2 from color
# brewer library
ax = sns.boxplot(x="population type", y="dollar_price", palette="Set2", data=merged_df2)
plt.show()

总结

你刚刚学会了如何使用多功能的 Pandas 包解析 CSV 或 JSON 格式的在线数据。你进一步学习了如何筛选、子集化、合并和处理数据,以获取见解。现在,你已经掌握了可视化时间序列、单变量、双变量和分类数据的知识。本章最后介绍了若干有用的技巧,以自定义图形美学,从而有效地讲述故事。

呼!我们刚刚完成了一个长篇章节,赶紧去吃个汉堡,休息一下,放松一下吧。

第八章:可视化多变量数据

当我们拥有包含许多变量的大数据时,第七章中 可视化在线数据的图表类型可能不再是有效的数据可视化方式。我们可能会尝试在单一图表中尽可能多地压缩变量,但过度拥挤或杂乱的细节很快就会超出人类的视觉感知能力。

本章旨在介绍多变量数据可视化技术;这些技术使我们能够更好地理解数据的分布以及变量之间的关系。以下是本章的概述:

  • 从 Quandl 获取日终(EOD)股票数据

  • 二维分面图:

    • Seaborn 中的因子图

    • Seaborn 中的分面网格

    • Seaborn 中的配对图

  • 其他二维多变量图:

    • Seaborn 中的热力图

    • matplotlib.finance 中的蜡烛图:

      • 可视化各种股市指标
    • 构建综合股票图表

  • 三维图表:

    • 散点图

    • 条形图

    • 使用 Matplotlib 3D 的注意事项

首先,我们将讨论分面图,这是一种用于可视化多变量数据的分而治之的方法。这种方法的要义是将输入数据切分成不同的分面,每个可视化面板中只展示少数几个属性。通过在减少的子集上查看变量,这样可以减少视觉上的杂乱。有时,在二维图表中找到合适的方式来表示多变量数据是困难的。因此,我们还将介绍 Matplotlib 中的三维绘图函数。

本章使用的数据来自 Quandl 的日终(EOD)股票数据库。首先让我们从 Quandl 获取数据。

从 Quandl 获取日终(EOD)股票数据

由于我们将广泛讨论股票数据,请注意,我们不保证所呈现内容的准确性、完整性或有效性;也不对可能发生的任何错误或遗漏负责。数据、可视化和分析仅以“原样”方式提供,仅用于教育目的,不附带任何形式的声明、保证或条件。因此,出版商和作者不对您使用内容承担任何责任。需要注意的是,过去的股票表现不能预测未来的表现。读者还应意识到股票投资的风险,并且不应根据本章内容做出任何投资决策。此外,建议读者在做出投资决策之前,对个别股票进行独立研究。

我们将调整第七章《可视化在线数据》中的 Quandl JSON API 代码,以便从 Quandl 获取 EOD 股票数据。我们将获取 2017 年 1 月 1 日至 2017 年 6 月 30 日之间六只股票代码的历史股市数据:苹果公司(EOD/AAPL)、宝洁公司(EOD/PG)、强生公司(EOD/JNJ)、埃克森美孚公司(EOD/XOM)、国际商业机器公司(EOD/IBM)和微软公司(EOD/MSFT)。同样,我们将使用默认的urllibjson模块来处理 Quandl API 调用,接着将数据转换为 Pandas DataFrame:

from urllib.request import urlopen
import json
import pandas as pd

def get_quandl_dataset(api_key, code, start_date, end_date):
    """Obtain and parse a quandl dataset in Pandas DataFrame format

    Quandl returns dataset in JSON format, where data is stored as a 
    list of lists in response['dataset']['data'], and column headers
    stored in response['dataset']['column_names'].

    Args:
        api_key: Quandl API key
        code: Quandl dataset code

    Returns:
        df: Pandas DataFrame of a Quandl dataset

    """
    base_url = "https://www.quandl.com/api/v3/datasets/"
    url_suffix = ".json?api_key="
    date = "&start_date={}&end_date={}".format(start_date, end_date)

    # Fetch the JSON response 
    u = urlopen(base_url + code + url_suffix + api_key + date)
    response = json.loads(u.read().decode('utf-8'))

    # Format the response as Pandas Dataframe
    df = pd.DataFrame(response['dataset']['data'], columns=response['dataset']
    ['column_names'])

    return df

# Input your own API key here
api_key = "INSERT YOUR KEY HERE"

# Quandl code for six US companies
codes = ["EOD/AAPL", "EOD/PG", "EOD/JNJ", "EOD/XOM", "EOD/IBM", "EOD/MSFT"]
start_date = "2017-01-01"
end_date = "2017-06-30"

dfs = []
# Get the DataFrame that contains the EOD data for each company
for code in codes:
    df = get_quandl_dataset(api_key, code, start_date, end_date)
    df["Company"] = code[4:]
    dfs.append(df)

# Concatenate all dataframes into a single one
stock_df = pd.concat(dfs)

# Sort by ascending order of Company then Date
stock_df = stock_df.sort_values(["Company","Date"])
stock_df.head()
-日期开盘最高最低收盘成交量分红拆股调整后开盘调整后最高调整后最低调整后收盘调整后成交量公司
1242017-01-03115.80116.3300114.76116.1528781865.00.01.0114.833750115.359328113.802428115.18083028781865.0AAPL
1232017-01-04115.85116.5100115.75116.0221118116.00.01.0114.883333115.537826114.784167115.05191421118116.0AAPL
1222017-01-05115.92116.8642115.81116.6122193587.00.01.0114.952749115.889070114.843667115.63699122193587.0AAPL
1212017-01-06116.78118.1600116.47117.9131751900.00.01.0115.805573117.174058115.498159116.92614431751900.0AAPL
1202017-01-09117.95119.4300117.94118.9933561948.00.01.0116.965810118.433461116.955894117.99713233561948.0AAPL

数据框包含每只股票的开盘价、最高价、最低价和收盘价OHLC)。此外,还提供了额外信息;例如,分红列反映了当天的现金分红值。拆股列显示当天如果发生了拆股事件,新的股票与旧股票的比例。调整后的价格考虑了分配或公司行为引起的价格波动,假设所有这些行动已被再投资到当前股票中。有关这些列的更多信息,请查阅 Quandl 文档页面。

按行业分组公司

正如你可能注意到的,三家公司(AAPL、IBM 和 MSFT)是科技公司,而剩余三家公司则不是。股市分析师通常根据行业将公司分组,以便深入了解。让我们尝试按行业对公司进行标记:

# Classify companies by industry
tech_companies = set(["AAPL","IBM","MSFT"])
stock_df['Industry'] = ["Tech" if c in tech_companies else "Others" for c in stock_df['Company']]

转换日期为支持的格式

stock_df中的Date列以一系列 Python 字符串的形式记录。尽管 Seaborn 可以在某些函数中使用字符串格式的日期,但 Matplotlib 则不能。为了使日期更适合数据处理和可视化,我们需要将这些值转换为 Matplotlib 支持的浮动数字:

from matplotlib.dates import date2num

# Convert Date column from string to Python datetime object,
# then to float number that is supported by Matplotlib.
stock_df["Datetime"] = date2num(pd.to_datetime(stock_df["Date"], format="%Y-%m-%d").tolist())

获取收盘价的百分比变化

接下来,我们想要计算相对于前一天收盘价的收盘价变化。Pandas 中的pct_change()函数使得这个任务变得非常简单:

import numpy as np

# Calculate percentage change versus the previous close
stock_df["Close_change"] = stock_df["Close"].pct_change()
# Since the DataFrame contain multiple companies' stock data, 
# the first record in the "Close_change" should be changed to
# NaN in order to prevent referencing the price of incorrect company.
stock_df.loc[stock_df["Date"]=="2017-01-03", "Close_change"] = np.NaN
stock_df.head()

二维分面图

我们将介绍三种创建分面图的主要方法:seaborn.factorplot()seaborn.FacetGrid()seaborn.pairplot()。在上一章当我们讨论seaborn.lmplot()时,你可能已经见过一些分面图。实际上,seaborn.lmplot()函数将seaborn.regplot()seaborn.FacetGrid()结合在一起,并且数据子集的定义可以通过huecolrow参数进行调整。

我们将介绍三种创建分面图的主要方法:seaborn.factorplot()seaborn.FacetGrid()seaborn.pairplot()。这些函数在定义分面时与seaborn.lmplot()的工作方式非常相似。

Seaborn 中的因子图

seaborn.factorplot()的帮助下,我们可以通过调节kind参数,将类别点图、箱线图、小提琴图、条形图或条纹图绘制到seaborn.FacetGrid()上。factorplot的默认绘图类型是点图。与 Seaborn 中的其他绘图函数不同,后者支持多种输入数据格式,factorplot仅支持 pandas DataFrame 作为输入,而变量/列名可以作为字符串传递给xyhuecolrow

import seaborn as sns
import matplotlib.pyplot as plt

sns.set(style="ticks")

# Plot EOD stock closing price vs Date for each company.
# Color of plot elements is determined by company name (hue="Company"),
# plot panels are also arranged in columns accordingly (col="Company").
# The col_wrap parameter determines the number of panels per row (col_wrap=3).
g = sns.factorplot(x="Date", y="Close", 
                   hue="Company", col="Company", 
                   data=stock_df, col_wrap=3)

plt.show()

上面的图存在几个问题。

首先,纵横比(长度与高度之比)对于时间序列图来说稍显不理想。较宽的图形将使我们能够观察到在这一时间段内的微小变化。我们将通过调整aspect参数来解决这个问题。

其次,线条和点的粗细过大,从而遮盖了一些图中的细节。我们可以通过调整scale参数来减小这些视觉元素的大小。

最后,刻度线之间太近,且刻度标签重叠。绘图完成后,sns.factorplot()返回一个 FacetGrid,在代码中表示为g。我们可以通过调用FacetGrid对象中的相关函数进一步调整图形的美学,比如刻度位置和标签:

# Increase the aspect ratio and size of each panel
g = sns.factorplot(x="Date", y="Close", 
                   hue="Company", col="Company", 
                   data=stock_df,
                   col_wrap=3, size=3,
                   scale=0.5, aspect=1.5)

# Thinning of ticks (select 1 in 10)
locs, labels = plt.xticks()
g.set(xticks=locs[0::10], xticklabels=labels[0::10])

# Rotate the tick labels to prevent overlap
g.set_xticklabels(rotation=30)

# Reduce the white space between plots
g.fig.subplots_adjust(wspace=.1, hspace=.2)
plt.show()

# Create faceted plot separated by industry
g = sns.factorplot(x="Date", y="Close", 
                   hue="Company", col="Industry", 
                   data=stock_df, size=4, 
                   aspect=1.5, scale=0.5)

locs, labels = plt.xticks()
g.set(xticks=locs[0::10], xticklabels=labels[0::10])
g.set_xticklabels(rotation=30)
plt.show()

Seaborn 中的分面网格

到目前为止,我们已经提到过FacetGrid几次,但它到底是什么呢?

正如您所知,FacetGrid是一个用于对数据进行子集化和绘制绘图面板的引擎,由将变量分配给hue参数的行和列来确定。虽然我们可以使用lmplotfactorplot等包装函数轻松地在FacetGrid上搭建绘图,但更灵活的方法是从头开始构建 FacetGrid。为此,我们首先向FacetGrid对象提供一个 pandas DataFrame,并通过colrowhue参数指定布局网格的方式。然后,我们可以通过调用FacetGrid对象的map()函数为每个面板分配一个 Seaborn 或 Matplotlib 绘图函数:

# Create a FacetGrid
g = sns.FacetGrid(stock_df, col="Company", hue="Company",
                  size=3, aspect=2, col_wrap=2)

# Map the seaborn.distplot function to the panels,
# which shows a histogram of closing prices.
g.map(sns.distplot, "Close")

# Label the axes
g.set_axis_labels("Closing price (US Dollars)", "Density")

plt.show()

我们还可以向绘图函数提供关键字参数:

g = sns.FacetGrid(stock_df, col="Company", hue="Company",
                  size=3, aspect=2.2, col_wrap=2)

# We can supply extra kwargs to the plotting function.
# Let's turn off KDE line (kde=False), and plot raw 
# frequency of bins only (norm_hist=False).
# By setting rug=True, tick marks that denotes the
# density of data points will be shown in the bottom.
g.map(sns.distplot, "Close", kde=False, norm_hist=False, rug=True)

g.set_axis_labels("Closing price (US Dollars)", "Density")

plt.show()

FacetGrid不仅限于使用 Seaborn 绘图函数;让我们尝试将老式的Matplotlib.pyplot.plot()函数映射到FacetGrid上:

from matplotlib.dates import DateFormatter

g = sns.FacetGrid(stock_df, hue="Company", col="Industry",
                  size=4, aspect=1.5, col_wrap=2)

# plt.plot doesn't support string-formatted Date,
# so we need to use the Datetime column that we
# prepared earlier instead.
g.map(plt.plot, "Datetime", "Close", marker="o", markersize=3, linewidth=1)
g.add_legend()

# We can access individual axes through g.axes[column]
# or g.axes[row,column] if multiple rows are present.
# Let's adjust the tick formatter and rotate the tick labels
# in each axes.
for col in range(2):
    g.axes[col].xaxis.set_major_formatter(DateFormatter('%Y-%m-%d'))
    plt.setp(g.axes[col].get_xticklabels(), rotation=30)

g.set_axis_labels("", "Closing price (US Dollars)")
plt.show()

Seaborn 中的 pair plot

对角线轴上将显示一系列直方图,以显示该列中变量的分布:

# Show a pairplot of three selected variables (vars=["Open", "Volume", "Close"])
g = sns.pairplot(stock_df, hue="Company", 
                 vars=["Open", "Volume", "Close"])

plt.show()

我们可以调整绘图的许多方面。在下一个示例中,我们将增加纵横比,将对角线上的绘图类型更改为 KDE 绘图,并使用关键字参数调整绘图的美学效果:

# Adjust the aesthetics of the plot
g = sns.pairplot(stock_df, hue="Company", 
                 aspect=1.5, diag_kind="kde", 
                 diag_kws=dict(shade=True),
                 plot_kws=dict(s=15, marker="+"),
                 vars=["Open", "Volume", "Close"])

plt.show()

与基于FacetGrid的其他绘图类似,我们可以定义要在每个面板中显示的变量。我们还可以手动定义对我们重要的比较,而不是通过设置x_varsy_vars参数进行全对全比较。如果需要更高的灵活性来定义比较组,也可以直接使用seaborn.PairGrid()

# Manually defining the comparisons that we are interested.
g = sns.pairplot(stock_df, hue="Company", aspect=1.5,
                 x_vars=["Open", "Volume"],
                 y_vars=["Close", "Close_change"])

plt.show()

其他二维多变量图

当我们需要可视化更多变量或样本时,FacetGrid、factor plot 和 pair plot 可能会占用大量空间。如果您希望最大化空间效率,则有两种特殊的绘图类型非常方便 - 热力图和蜡烛图。

Seaborn 中的热力图

热力图是显示大量数据的极其紧凑的方式。在金融世界中,色块编码可以让投资者快速了解哪些股票上涨或下跌。在科学世界中,热力图允许研究人员可视化成千上万基因的表达水平。

seaborn.heatmap()函数期望以 2D 列表、2D Numpy 数组或 pandas DataFrame 作为输入。如果提供了列表或数组,我们可以通过xticklabelsyticklabels分别提供列和行标签。另一方面,如果提供了 DataFrame,则将使用列标签和索引值分别标记列和行。

为了开始,我们将使用热图绘制六只股票的表现概览。我们将股票表现定义为与前一个收盘价相比的收盘价变化。这些信息在本章前面已经计算过(即 Close_change 列)。不幸的是,我们不能直接将整个 DataFrame 提供给 seaborn.heatmap(),因为它需要公司名称作为列,日期作为索引,收盘价变化作为数值。

如果你熟悉 Microsoft Excel,你可能有使用透视表的经验,这是总结特定变量水平或数值的强大技巧。pandas 也包含了类似的功能。以下代码片段使用了 Pandas.DataFrame.pivot() 函数来创建透视表:

stock_change = stock_df.pivot(index='Date', columns='Company', values='Close_change')
stock_change = stock_change.loc["2017-06-01":"2017-06-30"]
stock_change.head()
公司日期AAPLIBMJNJMSFTPGXOM
2017-06-010.0027490.0002620.0041330.0037230.0004540.002484
2017-06-020.014819-0.0040610.0100950.0236800.005220-0.014870
2017-06-05-0.0097780.0023680.0021530.0072460.0016930.007799
2017-06-060.003378-0.0002620.0036050.0033200.0006760.013605
2017-06-070.005957-0.009123-0.000611-0.001793-0.000338-0.003694

透视表创建完成后,我们可以继续绘制第一个热图:

ax = sns.heatmap(stock_change)
plt.show()

默认的热图实现并不够紧凑。当然,我们可以通过plt.figure(figsize=(width, height))来调整图形大小;我们还可以切换方形参数来创建方形的块。为了方便视觉识别,我们可以在块周围添加一条细边框。

根据美国股市的惯例,绿色表示价格上涨,红色表示价格下跌。因此,我们可以调整cmap参数来调整颜色图。然而,Matplotlib 和 Seaborn 都没有包含红绿颜色图,所以我们需要自己创建一个:

在第七章《可视化在线数据》末尾,我们简要介绍了创建自定义颜色图的函数。这里我们将使用seaborn.diverging_palette()来创建红绿颜色图,它要求我们为颜色图的负值和正值指定色调、饱和度和亮度(husl)。你还可以使用以下代码在 Jupyter Notebook 中启动交互式小部件,帮助选择颜色:

%matplotlib notebook

import seaborn as sns

sns.choose_diverging_palette(as_cmap=True)

# Create a new red-green color map using the husl color system
# h_neg and h_pos determines the hue of the extents of the color map.
# s determines the color saturation
# l determines the lightness
# sep determines the width of center point
# In addition, we need to set as_cmap=True as the cmap parameter of 
# sns.heatmap expects matplotlib colormap object.
rdgn = sns.diverging_palette(h_neg=10, h_pos=140, s=80, l=50,
                             sep=10, as_cmap=True)

# Change to square blocks (square=True), add a thin
# border (linewidths=.5), and change the color map
# to follow US stocks market convention (cmap="RdGn").
ax = sns.heatmap(stock_change, cmap=rdgn,
                 linewidths=.5, square=True)

# Prevent x axes label from being cropped
plt.tight_layout()
plt.show()

当颜色是唯一的区分因素时,可能很难分辨数值间的小差异。为每个颜色块添加文本注释可能有助于读者理解差异的大小:

fig = plt.figure(figsize=(6,8))

# Set annot=True to overlay the values.
# We can also assign python format string to fmt. 
# For example ".2%" refers to percentage values with
# two decimal points.
ax = sns.heatmap(stock_change, cmap=rdgn,
                 annot=True, fmt=".2%",
                 linewidths=.5, cbar=False)
plt.show()

matplotlib.finance 中的蜡烛图

正如您在本章的第一部分所看到的,我们的数据集包含每个交易日的开盘价、收盘价以及最高和最低价格。到目前为止,我们描述的任何图表都无法在单个图表中描述所有这些变量的趋势。

在金融界,蜡烛图几乎是描述股票、货币和商品在一段时间内价格变动的默认选择。每个蜡烛图由实体组成,描述开盘和收盘价,以及展示特定交易日最高和最低价格的延伸影线。如果收盘价高于开盘价,则蜡烛图通常为黑色。相反,如果收盘价低于开盘价,则为红色。交易员可以根据颜色的组合和蜡烛图实体的边界推断开盘和收盘价。

在以下示例中,我们将准备一个苹果公司在我们的 DataFrame 最近 50 个交易日的蜡烛图。我们还将应用刻度格式化程序来标记日期的刻度:

import matplotlib.pyplot as plt
from matplotlib.dates import date2num, WeekdayLocator, DayLocator, DateFormatter, MONDAY
from matplotlib.finance import candlestick_ohlc

# Extract stocks data for AAPL.
# candlestick_ohlc expects Date (in floating point number), Open, High, Low,
# Close columns only
# So we need to select the useful columns first using DataFrame.loc[]. Extra 
# columns can exist, 
# but they are ignored. Next we get the data for the last 50 trading only for 
# simplicity of plots.
candlestick_data = stock_df[stock_df["Company"]=="AAPL"]\
                       .loc[:, ["Datetime", "Open", "High", "Low", "Close",
                       "Volume"]]\
                       .iloc[-50:]

# Create a new Matplotlib figure
fig, ax = plt.subplots()

# Prepare a candlestick plot
candlestick_ohlc(ax, candlestick_data.values, width=0.6)

ax.xaxis.set_major_locator(WeekdayLocator(MONDAY)) # major ticks on the mondays
ax.xaxis.set_minor_locator(DayLocator()) # minor ticks on the days
ax.xaxis.set_major_formatter(DateFormatter('%Y-%m-%d'))
ax.xaxis_date() # treat the x data as dates
# rotate all ticks to vertical
plt.setp(ax.get_xticklabels(), rotation=90, horizontalalignment='right')

ax.set_ylabel('Price (US $)') # Set y-axis label
plt.show()

从 Matplotlib 2.0 开始,matplotlib.finance 已被弃用。读者应该将来使用mpl_financegithub.com/matplotlib/mpl_finance)。然而,截至本章撰写时,mpl_finance 尚未在 PyPI 上提供,因此我们暂时还是使用matplotlib.finance

可视化各种股市指标

当前形式的蜡烛图有些单调。交易员通常会叠加股票指标,如平均真实范围ATR)、布林带、商品通道指数CCI)、指数移动平均EMA)、移动平均收敛背离MACD)、相对强弱指数RSI)以及各种其他技术分析的统计数据。

Stockstats(github.com/jealous/stockstats)是一个用于计算这些指标/统计数据以及更多内容的优秀包。它封装了 pandas 的数据框架,并在访问时动态生成这些统计数据。要使用stockstats,我们只需通过 PyPI 安装它:pip install stockstats

接下来,我们可以通过stockstats.StockDataFrame.retype()将 pandas DataFrame 转换为 stockstats DataFrame。然后,可以按照StockDataFrame["variable_timeWindow_indicator"]的模式访问大量股票指标。例如,StockDataFrame['open_2_sma']将给出开盘价的 2 天简单移动平均线。一些指标可能有快捷方式,请查阅官方文档获取更多信息:

from stockstats import StockDataFrame

# Convert to StockDataFrame
# Need to pass a copy of candlestick_data to StockDataFrame.retype
# Otherwise the original candlestick_data will be modified
stockstats = StockDataFrame.retype(candlestick_data.copy())

# 5-day exponential moving average on closing price
ema_5 = stockstats["close_5_ema"]
# 20-day exponential moving average on closing price
ema_20 = stockstats["close_20_ema"]
# 50-day exponential moving average on closing price
ema_50 = stockstats["close_50_ema"]
# Upper Bollinger band
boll_ub = stockstats["boll_ub"]
# Lower Bollinger band
boll_lb = stockstats["boll_lb"]
# 7-day Relative Strength Index
rsi_7 = stockstats['rsi_7']
# 14-day Relative Strength Index
rsi_14 = stockstats['rsi_14']

准备好股票指标后,我们可以将它们叠加在同一个蜡烛图上:

import datetime
import matplotlib.pyplot as plt
from matplotlib.dates import date2num, WeekdayLocator, DayLocator, DateFormatter, MONDAY
from matplotlib.finance import candlestick_ohlc

# Create a new Matplotlib figure
fig, ax = plt.subplots()

# Prepare a candlestick plot
candlestick_ohlc(ax, candlestick_data.values, width=0.6)

# Plot stock indicators in the same plot
ax.plot(candlestick_data["Datetime"], ema_5, lw=1, label='EMA (5)')
ax.plot(candlestick_data["Datetime"], ema_20, lw=1, label='EMA (20)')
ax.plot(candlestick_data["Datetime"], ema_50, lw=1, label='EMA (50)')
ax.plot(candlestick_data["Datetime"], boll_ub, lw=2, linestyle="--", label='Bollinger upper')
ax.plot(candlestick_data["Datetime"], boll_lb, lw=2, linestyle="--", label='Bollinger lower')

ax.xaxis.set_major_locator(WeekdayLocator(MONDAY)) # major ticks on 
# the mondays
ax.xaxis.set_minor_locator(DayLocator()) # minor ticks on the days
ax.xaxis.set_major_formatter(DateFormatter('%Y-%m-%d'))
ax.xaxis_date() # treat the x data as dates
# rotate all ticks to vertical
plt.setp(ax.get_xticklabels(), rotation=90, horizontalalignment='right')

ax.set_ylabel('Price (US $)') # Set y-axis label

# Limit the x-axis range from 2017-4-23 to 2017-7-1
datemin = datetime.date(2017, 4, 23)
datemax = datetime.date(2017, 7, 1)
ax.set_xlim(datemin, datemax)

plt.legend() # Show figure legend
plt.tight_layout()
plt.show()

创建全面的股票图表

在以下详细示例中,我们将应用到目前为止讲解的多种技巧,创建一个更全面的股票图表。除了前面的图表外,我们还将添加一条线图来显示相对强弱指数RSI)以及一条柱状图来显示交易量。一个特殊的市场事件(markets.businessinsider.com/news/stocks/apple-stock-price-falling-new-iphone-speed-2017-6-1002082799)也将在图表中做注释:

如果你仔细观察图表,你可能会注意到一些缺失的日期。这些日期通常是非交易日或公共假期,它们在我们的数据框中没有出现。

import datetime
import matplotlib.pyplot as plt
from matplotlib.dates import date2num, WeekdayLocator, DayLocator, DateFormatter, MONDAY
from matplotlib.finance import candlestick_ohlc
from matplotlib.ticker import FuncFormatter

# FuncFormatter to convert tick values to Millions
def millions(x, pos):
    return '%dM' % (x/1e6)

# Create 3 subplots spread acrosee three rows, with shared x-axis. 
# The height ratio is specified via gridspec_kw
fig, axarr = plt.subplots(nrows=3, ncols=1, sharex=True, figsize=(8,8),
                          gridspec_kw={'height_ratios':[3,1,1]})

# Prepare a candlestick plot in the first axes
candlestick_ohlc(axarr[0], candlestick_data.values, width=0.6)

# Overlay stock indicators in the first axes
axarr[0].plot(candlestick_data["Datetime"], ema_5, lw=1, label='EMA (5)')
axarr[0].plot(candlestick_data["Datetime"], ema_20, lw=1, label='EMA (20)')
axarr[0].plot(candlestick_data["Datetime"], ema_50, lw=1, label='EMA (50)')
axarr[0].plot(candlestick_data["Datetime"], boll_ub, lw=2, linestyle="--", label='Bollinger upper')
axarr[0].plot(candlestick_data["Datetime"], boll_lb, lw=2, linestyle="--", label='Bollinger lower')

# Display RSI in the second axes
axarr[1].axhline(y=30, lw=2, color = '0.7') # Line for oversold threshold
axarr[1].axhline(y=50, lw=2, linestyle="--", color = '0.8') # Neutral RSI
axarr[1].axhline(y=70, lw=2, color = '0.7') # Line for overbought threshold
axarr[1].plot(candlestick_data["Datetime"], rsi_7, lw=2, label='RSI (7)')
axarr[1].plot(candlestick_data["Datetime"], rsi_14, lw=2, label='RSI (14)')

# Display trade volume in the third axes
axarr[2].bar(candlestick_data["Datetime"], candlestick_data['Volume'])

# Mark the market reaction to the Bloomberg news
# https://www.bloomberg.com/news/articles/2017-06-09/apple-s-new
# -iphones-said-to-miss-out-on-higher-speed-data-links
# http://markets.businessinsider.com/news/stocks/apple-stock-price
# -falling-new-iphone-speed-2017-6-1002082799
axarr[0].annotate("Bloomberg News",
                  xy=(datetime.date(2017, 6, 9), 155), xycoords='data',
                  xytext=(25, 10), textcoords='offset points', size=12,
                  arrowprops=dict(arrowstyle="simple",
                  fc="green", ec="none"))

# Label the axes
axarr[0].set_ylabel('Price (US $)')
axarr[1].set_ylabel('RSI')
axarr[2].set_ylabel('Volume (US $)')

axarr[2].xaxis.set_major_locator(WeekdayLocator(MONDAY)) # major ticks on the mondays
axarr[2].xaxis.set_minor_locator(DayLocator()) # minor ticks on the days
axarr[2].xaxis.set_major_formatter(DateFormatter('%Y-%m-%d'))
axarr[2].xaxis_date() # treat the x data as dates
axarr[2].yaxis.set_major_formatter(FuncFormatter(millions)) # Change the y-axis ticks to millions
plt.setp(axarr[2].get_xticklabels(), rotation=90, horizontalalignment='right') # Rotate x-tick labels by 90 degree

# Limit the x-axis range from 2017-4-23 to 2017-7-1
datemin = datetime.date(2017, 4, 23)
datemax = datetime.date(2017, 7, 1)
axarr[2].set_xlim(datemin, datemax)

# Show figure legend
axarr[0].legend()
axarr[1].legend()

# Show figure title
axarr[0].set_title("AAPL (Apple Inc.) NASDAQ", loc='left')

# Reduce unneccesary white space
plt.tight_layout()
plt.show()

三维(3D)图表

通过过渡到三维空间,在创建可视化时,你可能会享有更大的创作自由度。额外的维度还可以在单一图表中容纳更多信息。然而,有些人可能会认为,当三维图形被投影到二维表面(如纸张)时,三维不过是一个视觉噱头,因为它会模糊数据点的解读。

在 Matplotlib 版本 2 中,尽管三维 API 有了显著的进展,但依然存在一些令人烦恼的错误或问题。我们将在本章的最后讨论一些解决方法。确实有更强大的 Python 3D 可视化包(如 MayaVi2、Plotly 和 VisPy),但如果你希望使用同一个包同时绘制 2D 和 3D 图,或者希望保持其 2D 图的美学,使用 Matplotlib 的三维绘图功能是很好的选择。

大多数情况下,Matplotlib 中的三维图与二维图有相似的结构。因此,在本节中我们不会讨论每种三维图类型。我们将重点介绍三维散点图和柱状图。

三维散点图

在第六章,《你好,绘图世界!》中,我们已经探索了二维散点图。在这一节中,让我们尝试创建一个三维散点图。在此之前,我们需要一些三维数据点(xyz):

import pandas as pd

source = "https://raw.githubusercontent.com/PointCloudLibrary/data/master/tutorials/ism_train_cat.pcd"
cat_df = pd.read_csv(source, skiprows=11, delimiter=" ", names=["x","y","z"], encoding='latin_1') 
cat_df.head()
xyz
0-17.03417818.97228240.482403
1-16.88148121.81545144.156799
2-16.74958218.15491134.131474
3-16.87691920.59828636.271809
4-16.84934017.40371142.993984

要声明一个三维图,我们首先需要从mpl_toolkits中的mplot3d扩展导入Axes3D对象,它负责在二维平面中渲染三维图表。然后,在创建子图时,我们需要指定projection='3d'

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(cat_df.x, cat_df.y, cat_df.z)

plt.show()

瞧,强大的 3D 散点图。猫目前正在占领互联网。根据《纽约时报》的报道,猫是“互联网的基本构建单元”(www.nytimes.com/2014/07/23/upshot/what-the-internet-can-see-from-your-cat-pictures.html)。毫无疑问,它们也应该在本章中占有一席之地。

与 2D 版本的 scatter() 相反,当创建 3D 散点图时,我们需要提供 X、Y 和 Z 坐标。然而,2D scatter() 支持的参数也可以应用于 3D scatter()

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# Change the size, shape and color of markers
ax.scatter(cat_df.x, cat_df.y, cat_df.z, s=4, c="g", marker="o")

plt.show()

要更改 3D 图的视角和仰角,我们可以使用 view_init()azim 参数指定 X-Y 平面上的方位角,而 elev 指定仰角。当方位角为 0 时,X-Y 平面将从你的北侧看起来。同时,方位角为 180 时,你将看到 X-Y 平面的南侧:

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(cat_df.x, cat_df.y, cat_df.z,s=4, c="g", marker="o")

# elev stores the elevation angle in the z plane azim stores the 
# azimuth angle in the x,y plane
ax.view_init(azim=180, elev=10)

plt.show()

3D 条形图

我们引入了烛台图来展示开盘-最高-最低-收盘OHLC)金融数据。此外,可以使用 3D 条形图来展示随时间变化的 OHLC。下图展示了绘制 5 天 OHLC 条形图的典型示例:

import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

# Get 1 and every fifth row for the 5-day AAPL OHLC data
ohlc_5d = stock_df[stock_df["Company"]=="AAPL"].iloc[1::5, :]

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# Create one color-coded bar chart for Open, High, Low and Close prices.
for color, col, z in zip(['r', 'g', 'b', 'y'], ["Open", "High", "Low", 
                          "Close"], [30, 20, 10, 0]):
    xs = np.arange(ohlc_5d.shape[0])
    ys = ohlc_5d[col]
    # Assign color to the bars
    colors = [color] * len(xs)
    ax.bar(xs, ys, zs=z, zdir='y', color=colors, alpha=0.8, width=5)

plt.show()

设置刻度和标签的方法与其他 Matplotlib 绘图函数类似:

fig = plt.figure(figsize=(9,7))
ax = fig.add_subplot(111, projection='3d')

# Create one color-coded bar chart for Open, High, Low and Close prices.
for color, col, z in zip(['r', 'g', 'b', 'y'], ["Open", "High", "Low", 
                          "Close"], [30, 20, 10, 0]):
    xs = np.arange(ohlc_5d.shape[0])
    ys = ohlc_5d[col]
    # Assign color to the bars 
    colors = [color] * len(xs)
    ax.bar(xs, ys, zs=z, zdir='y', color=colors, alpha=0.8)

# Manually assign the ticks and tick labels
ax.set_xticks(np.arange(ohlc_5d.shape[0]))
ax.set_xticklabels(ohlc_5d["Date"], rotation=20,
                   verticalalignment='baseline',
                   horizontalalignment='right',
                   fontsize='8')
ax.set_yticks([30, 20, 10, 0])
ax.set_yticklabels(["Open", "High", "Low", "Close"])

# Set the z-axis label
ax.set_zlabel('Price (US $)')

# Rotate the viewport
ax.view_init(azim=-42, elev=31)
plt.tight_layout()
plt.show()

Matplotlib 3D 的注意事项

由于缺乏真正的 3D 图形渲染后端(如 OpenGL)和适当的算法来检测 3D 对象的交叉点,Matplotlib 的 3D 绘图能力并不强大,但对于典型应用来说仅仅够用。在官方 Matplotlib FAQ 中(matplotlib.org/mpl_toolkits/mplot3d/faq.html),作者指出 3D 图可能在某些角度看起来不正确。此外,我们还报告了如果设置了 zlim,mplot3d 会无法裁剪条形图的问题(github.com/matplotlib/matplotlib/issues/8902;另见 github.com/matplotlib/matplotlib/issues/209)。在没有改进 3D 渲染后端的情况下,这些问题很难解决。

为了更好地说明后一个问题,让我们尝试在之前的 3D 条形图中的 plt.tight_layout() 上方添加 ax.set_zlim3d(bottom=110, top=150)

显然,柱状图超出了坐标轴的下边界。我们将尝试通过以下解决方法解决后一个问题:

# FuncFormatter to add 110 to the tick labels
def major_formatter(x, pos):
    return "{}".format(x+110)

fig = plt.figure(figsize=(9,7))
ax = fig.add_subplot(111, projection='3d')

# Create one color-coded bar chart for Open, High, Low and Close prices.
for color, col, z in zip(['r', 'g', 'b', 'y'], ["Open", "High", "Low", 
                          "Close"], [30, 20, 10, 0]):
    xs = np.arange(ohlc_5d.shape[0])
    ys = ohlc_5d[col]

    # Assign color to the bars 
    colors = [color] * len(xs)

    # Truncate the y-values by 110
    ax.bar(xs, ys-110, zs=z, zdir='y', color=colors, alpha=0.8)

# Manually assign the ticks and tick labels
ax.set_xticks(np.arange(ohlc_5d.shape[0]))
ax.set_xticklabels(ohlc_5d["Date"], rotation=20,
                   verticalalignment='baseline',
                   horizontalalignment='right',
                   fontsize='8')

# Set the z-axis label
ax.set_yticks([30, 20, 10, 0])
ax.set_yticklabels(["Open", "High", "Low", "Close"])
ax.zaxis.set_major_formatter(FuncFormatter(major_formatter))
ax.set_zlabel('Price (US $)')

# Rotate the viewport
ax.view_init(azim=-42, elev=31)

plt.tight_layout()
plt.show()

基本上,我们将 y 值截断了 110,然后使用刻度格式化器(major_formatter)将刻度值恢复到原始值。对于三维散点图,我们可以简单地移除超过 set_zlim3d() 边界的数据点,以生成正确的图形。然而,这些解决方法可能并不适用于所有类型的三维图形。

总结

你已经成功掌握了将多变量数据以二维和三维形式可视化的技术。尽管本章中的大部分示例围绕股票交易这一主题展开,但数据处理和可视化方法也可以轻松应用于其他领域。特别是,用于在多个面上可视化多变量数据的分治法在科学领域中非常有用。

我们没有过多探讨 Matplotlib 的三维绘图功能,因为它尚未完善。对于简单的三维图形,Matplotlib 已经足够了。如果我们使用同一个库来绘制二维和三维图形,可以减少学习曲线。如果你需要更强大的三维绘图功能,建议你查看 MayaVi2、Plotly 和 VisPy。