从零开始的数据科学第二版-三-

91 阅读54分钟

从零开始的数据科学第二版(三)

原文:zh.annas-archive.org/md5/48ab308fc34189a6d7d26b91b72a6df9

译者:飞龙

协议:CC BY-NC-SA 4.0

第九章:获取数据

要写它,用了三个月;要构思它,用了三分钟;要收集其中的数据,用了一生。

F. 斯科特·菲茨杰拉德

要成为一名数据科学家,你需要数据。事实上,作为数据科学家,你将花费大量时间来获取、清理和转换数据。如果必要,你可以自己键入数据(或者如果有下属,让他们来做),但通常这不是你时间的好用法。在本章中,我们将探讨将数据引入 Python 及其转换为正确格式的不同方法。

stdin 和 stdout

如果在命令行中运行 Python 脚本,你可以使用sys.stdinsys.stdout将数据管道通过它们。例如,这是一个读取文本行并返回匹配正则表达式的脚本:

# egrep.py
import sys, re

# sys.argv is the list of command-line arguments
# sys.argv[0] is the name of the program itself
# sys.argv[1] will be the regex specified at the command line
regex = sys.argv[1]

# for every line passed into the script
for line in sys.stdin:
    # if it matches the regex, write it to stdout
    if re.search(regex, line):
        sys.stdout.write(line)

这里有一个示例,它会计算接收到的行数并将其写出:

# line_count.py
import sys

count = 0
for line in sys.stdin:
    count += 1

# print goes to sys.stdout
print(count)

然后你可以使用它们来计算文件中包含数字的行数。在 Windows 中,你会使用:

type SomeFile.txt | python egrep.py "[0-9]" | python line_count.py

在 Unix 系统中,你会使用:

cat SomeFile.txt | python egrep.py "[0-9]" | python line_count.py

管道符号|表示管道字符,意味着“使用左侧命令的输出作为右侧命令的输入”。你可以通过这种方式构建非常复杂的数据处理管道。

注意

如果你使用 Windows,你可能可以在该命令中省略python部分:

type SomeFile.txt | egrep.py "[0-9]" | line_count.py

如果你在 Unix 系统上,这样做需要几个额外步骤。首先在你的脚本的第一行添加一个“shebang” #!/usr/bin/env python。然后,在命令行中使用chmod x egrep.py++将文件设为可执行。

同样地,这是一个计算其输入中单词数量并写出最常见单词的脚本:

# most_common_words.py
import sys
from collections import Counter

# pass in number of words as first argument
try:
    num_words = int(sys.argv[1])
except:
    print("usage: most_common_words.py num_words")
    sys.exit(1)   # nonzero exit code indicates error

counter = Counter(word.lower()                      # lowercase words
                  for line in sys.stdin
                  for word in line.strip().split()  # split on spaces
                  if word)                          # skip empty 'words'

for word, count in counter.most_common(num_words):
    sys.stdout.write(str(count))
    sys.stdout.write("\t")
    sys.stdout.write(word)
    sys.stdout.write("\n")

然后你可以像这样做一些事情:

$ cat the_bible.txt | python most_common_words.py 10
36397	the
30031	and
20163	of
7154	to
6484	in
5856	that
5421	he
5226	his
5060	unto
4297	shall

(如果你使用 Windows,则使用type而不是cat。)

注意

如果你是一名经验丰富的 Unix 程序员,可能已经熟悉各种命令行工具(例如,egrep),这些工具已经内建到你的操作系统中,比从头开始构建更可取。不过,了解自己可以这样做也是很好的。

读取文件

你也可以在代码中直接显式地读取和写入文件。Python 使得处理文件变得非常简单。

文本文件的基础知识

处理文本文件的第一步是使用open获取一个文件对象

# 'r' means read-only, it's assumed if you leave it out
file_for_reading = open('reading_file.txt', 'r')
file_for_reading2 = open('reading_file.txt')

# 'w' is write -- will destroy the file if it already exists!
file_for_writing = open('writing_file.txt', 'w')

# 'a' is append -- for adding to the end of the file
file_for_appending = open('appending_file.txt', 'a')

# don't forget to close your files when you're done
file_for_writing.close()

因为很容易忘记关闭文件,所以你应该总是在with块中使用它们,在块结束时它们将自动关闭:

with open(filename) as f:
    data = function_that_gets_data_from(f)

# at this point f has already been closed, so don't try to use it
process(data)

如果你需要读取整个文本文件,可以使用for循环迭代文件的每一行:

starts_with_hash = 0

with open('input.txt') as f:
    for line in f:                  # look at each line in the file
        if re.match("^#",line):     # use a regex to see if it starts with '#'
            starts_with_hash += 1   # if it does, add 1 to the count

通过这种方式获取的每一行都以换行符结尾,所以在处理之前通常会将其strip掉。

例如,假设你有一个文件,其中包含一个邮箱地址一行,你需要生成一个域名的直方图。正确提取域名的规则有些微妙,可以参考公共后缀列表,但一个很好的初步方法是仅仅取邮箱地址中“@”后面的部分(对于像joel@mail.datasciencester.com这样的邮箱地址,这个方法会给出错误的答案,但在这个例子中我们可以接受这种方法):

def get_domain(email_address: str) -> str:
    """Split on '@' and return the last piece"""
    return email_address.lower().split("@")[-1]

# a couple of tests
assert get_domain('joelgrus@gmail.com') == 'gmail.com'
assert get_domain('joel@m.datasciencester.com') == 'm.datasciencester.com'

from collections import Counter

with open('email_addresses.txt', 'r') as f:
    domain_counts = Counter(get_domain(line.strip())
                            for line in f
                            if "@" in line)

分隔文件

我们刚刚处理的假设的邮箱地址文件每行一个地址。更频繁地,你将使用每行有大量数据的文件。这些文件往往是逗号分隔或制表符分隔的:每行有多个字段,逗号或制表符表示一个字段的结束和下一个字段的开始。

当你的字段中有逗号、制表符和换行符时(这是不可避免的)。因此,你不应该尝试自己解析它们。相反,你应该使用 Python 的csv模块(或 pandas 库,或设计用于读取逗号分隔或制表符分隔文件的其他库)。

警告

永远不要自己解析逗号分隔的文件。你会搞砸一些边缘情况!

如果你的文件没有表头(这意味着你可能希望每行作为一个list,并且需要你知道每一列中包含什么),你可以使用csv.reader来迭代行,每行都会是一个适当拆分的列表。

例如,如果我们有一个制表符分隔的股票价格文件:

6/20/2014   AAPL    90.91
6/20/2014   MSFT    41.68
6/20/2014   FB  64.5
6/19/2014   AAPL    91.86
6/19/2014   MSFT    41.51
6/19/2014   FB  64.34

我们可以用以下方式处理它们:

import csv

with open('tab_delimited_stock_prices.txt') as f:
    tab_reader = csv.reader(f, delimiter='\t')
    for row in tab_reader:
        date = row[0]
        symbol = row[1]
        closing_price = float(row[2])
        process(date, symbol, closing_price)

如果你的文件有表头:

date:symbol:closing_price
6/20/2014:AAPL:90.91
6/20/2014:MSFT:41.68
6/20/2014:FB:64.5

你可以通过初始调用reader.next跳过表头行,或者通过使用csv.DictReader将每一行作为dict(表头作为键)来获取:

with open('colon_delimited_stock_prices.txt') as f:
    colon_reader = csv.DictReader(f, delimiter=':')
    for dict_row in colon_reader:
        date = dict_row["date"]
        symbol = dict_row["symbol"]
        closing_price = float(dict_row["closing_price"])
        process(date, symbol, closing_price)

即使你的文件没有表头,你仍然可以通过将键作为fieldnames参数传递给DictReader来使用它。

你也可以使用csv.writer类似地写出分隔数据:

todays_prices = {'AAPL': 90.91, 'MSFT': 41.68, 'FB': 64.5 }

with open('comma_delimited_stock_prices.txt', 'w') as f:
    csv_writer = csv.writer(f, delimiter=',')
    for stock, price in todays_prices.items():
        csv_writer.writerow([stock, price])

如果你的字段本身包含逗号,csv.writer会处理得很好。但是,如果你自己手动编写的写入器可能不会。例如,如果你尝试:

results = [["test1", "success", "Monday"],
           ["test2", "success, kind of", "Tuesday"],
           ["test3", "failure, kind of", "Wednesday"],
           ["test4", "failure, utter", "Thursday"]]

# don't do this!
with open('bad_csv.txt', 'w') as f:
    for row in results:
        f.write(",".join(map(str, row))) # might have too many commas in it!
        f.write("\n")                    # row might have newlines as well!

你将会得到一个如下的*.csv*文件:

test1,success,Monday
test2,success, kind of,Tuesday
test3,failure, kind of,Wednesday
test4,failure, utter,Thursday

而且没有人能够理解。

网页抓取

另一种获取数据的方式是从网页中抓取数据。事实证明,获取网页很容易;但从中获取有意义的结构化信息却不那么容易。

HTML 及其解析

网页是用 HTML 编写的,文本(理想情况下)被标记为元素及其属性:

<html>
  <head>
    <title>A web page</title>
  </head>
  <body>
    <p id="author">Joel Grus</p>
    <p id="subject">Data Science</p>
  </body>
</html>

在一个完美的世界中,所有网页都会被语义化地标记,为了我们的利益。我们将能够使用诸如“查找idsubject<p>元素并返回其包含的文本”之类的规则来提取数据。但实际上,HTML 通常并不规范,更不用说注释了。这意味着我们需要帮助来理解它。

要从 HTML 中获取数据,我们将使用Beautiful Soup 库,它会构建一个网页上各种元素的树,并提供一个简单的接口来访问它们。在我写这篇文章时,最新版本是 Beautiful Soup 4.6.0,这也是我们将使用的版本。我们还将使用Requests 库,这是一种比 Python 内置的任何东西都更好的方式来进行 HTTP 请求。

Python 内置的 HTML 解析器并不那么宽容,这意味着它不能很好地处理不完全形式的 HTML。因此,我们还将安装html5lib解析器。

确保您处于正确的虚拟环境中,安装库:

python -m pip install beautifulsoup4 requests html5lib

要使用 Beautiful Soup,我们将一个包含 HTML 的字符串传递给BeautifulSoup函数。在我们的示例中,这将是对requests.get调用的结果:

from bs4 import BeautifulSoup
import requests

# I put the relevant HTML file on GitHub. In order to fit
# the URL in the book I had to split it across two lines.
# Recall that whitespace-separated strings get concatenated.
url = ("https://raw.githubusercontent.com/"
       "joelgrus/data/master/getting-data.html")
html = requests.get(url).text
soup = BeautifulSoup(html, 'html5lib')

然后我们可以使用几种简单的方法走得相当远。

我们通常会使用Tag对象,它对应于表示 HTML 页面结构的标签。

例如,要找到第一个<p>标签(及其内容),您可以使用:

first_paragraph = soup.find('p')        # or just soup.p

您可以使用其text属性获取Tag的文本内容:

first_paragraph_text = soup.p.text
first_paragraph_words = soup.p.text.split()

您可以通过将其视为dict来提取标签的属性:

first_paragraph_id = soup.p['id']       # raises KeyError if no 'id'
first_paragraph_id2 = soup.p.get('id')  # returns None if no 'id'

您可以按以下方式一次获取多个标签:

all_paragraphs = soup.find_all('p')  # or just soup('p')
paragraphs_with_ids = [p for p in soup('p') if p.get('id')]

经常,您会想要找到具有特定class的标签:

important_paragraphs = soup('p', {'class' : 'important'})
important_paragraphs2 = soup('p', 'important')
important_paragraphs3 = [p for p in soup('p')
                         if 'important' in p.get('class', [])]

你可以结合这些方法来实现更复杂的逻辑。例如,如果你想找到每个包含在<div>元素内的<span>元素,你可以这样做:

# Warning: will return the same <span> multiple times
# if it sits inside multiple <div>s.
# Be more clever if that's the case.
spans_inside_divs = [span
                     for div in soup('div')     # for each <div> on the page
                     for span in div('span')]   # find each <span> inside it

这些功能的几个特点就足以让我们做很多事情。如果你最终需要做更复杂的事情(或者你只是好奇),请查阅文档

当然,重要的数据通常不会标记为class="important"。您需要仔细检查源 HTML,通过选择逻辑推理,并担心边缘情况,以确保数据正确。让我们看一个例子。

例如:监控国会

DataSciencester 的政策副总裁担心数据科学行业可能会受到监管,并要求您量化国会在该主题上的言论。特别是,他希望您找出所有发表关于“数据”内容的代表。

在发布时,有一个页面链接到所有代表的网站,网址为https://www.house.gov/representatives

如果“查看源代码”,所有指向网站的链接看起来像:

<td>
  <a href="https://jayapal.house.gov">Jayapal, Pramila</a>
</td>

让我们开始收集从该页面链接到的所有 URL:

from bs4 import BeautifulSoup
import requests

url = "https://www.house.gov/representatives"
text = requests.get(url).text
soup = BeautifulSoup(text, "html5lib")

all_urls = [a['href']
            for a in soup('a')
            if a.has_attr('href')]

print(len(all_urls))  # 965 for me, way too many

这返回了太多的 URL。如果你查看它们,我们想要的 URL 以http://https://开头,有一些名称,并且以.house.gov或*.house.gov/*结尾。

这是使用正则表达式的好地方:

import re

# Must start with http:// or https://
# Must end with .house.gov or .house.gov/
regex = r"^https?://.*\.house\.gov/?$"

# Let's write some tests!
assert re.match(regex, "http://joel.house.gov")
assert re.match(regex, "https://joel.house.gov")
assert re.match(regex, "http://joel.house.gov/")
assert re.match(regex, "https://joel.house.gov/")
assert not re.match(regex, "joel.house.gov")
assert not re.match(regex, "http://joel.house.com")
assert not re.match(regex, "https://joel.house.gov/biography")

# And now apply
good_urls = [url for url in all_urls if re.match(regex, url)]

print(len(good_urls))  # still 862 for me

这仍然太多了,因为只有 435 位代表。如果你看一下列表,会发现很多重复。让我们使用set来去重:

good_urls = list(set(good_urls))

print(len(good_urls))  # only 431 for me

总会有几个众议院席位是空缺的,或者可能有一个没有网站的代表。无论如何,这已经足够了。当我们查看这些站点时,大多数都有一个指向新闻稿的链接。例如:

html = requests.get('https://jayapal.house.gov').text
soup = BeautifulSoup(html, 'html5lib')

# Use a set because the links might appear multiple times.
links = {a['href'] for a in soup('a') if 'press releases' in a.text.lower()}

print(links) # {'/media/press-releases'}

注意这是一个相对链接,这意味着我们需要记住原始站点。让我们来做一些抓取:

from typing import Dict, Set

press_releases: Dict[str, Set[str]] = {}

for house_url in good_urls:
    html = requests.get(house_url).text
    soup = BeautifulSoup(html, 'html5lib')
    pr_links = {a['href'] for a in soup('a') if 'press releases'
                                             in a.text.lower()}
    print(f"{house_url}: {pr_links}")
    press_releases[house_url] = pr_links
注意

通常情况下,自由地抓取一个网站是不礼貌的。大多数网站会有一个robots.txt文件,指示您可以多频繁地抓取该站点(以及您不应该抓取的路径),但由于涉及到国会,我们不需要特别礼貌。

如果你看这些内容滚动显示,你会看到很多/media/press-releasesmedia-center/press-releases,以及各种其他地址。其中一个 URL 是https://jayapal.house.gov/media/press-releases

请记住,我们的目标是找出哪些国会议员在其新闻稿中提到了“数据”。我们将编写一个稍微更通用的函数,检查新闻稿页面是否提到了任何给定的术语。

如果你访问该网站并查看源代码,似乎每篇新闻稿都有一个在<p>标签中的片段,所以我们将用它作为我们的第一个尝试:

def paragraph_mentions(text: str, keyword: str) -> bool:
    """
 Returns True if a <p> inside the text mentions {keyword}
 """
    soup = BeautifulSoup(text, 'html5lib')
    paragraphs = [p.get_text() for p in soup('p')]

    return any(keyword.lower() in paragraph.lower()
               for paragraph in paragraphs)

让我们为此写一个快速的测试:

text = """<body><h1>Facebook</h1><p>Twitter</p>"""
assert paragraph_mentions(text, "twitter")       # is inside a <p>
assert not paragraph_mentions(text, "facebook")  # not inside a <p>

最后,我们准备好找到相关的国会议员,并把他们的名字交给副总裁:

for house_url, pr_links in press_releases.items():
    for pr_link in pr_links:
        url = f"{house_url}/{pr_link}"
        text = requests.get(url).text

        if paragraph_mentions(text, 'data'):
            print(f"{house_url}")
            break  # done with this house_url

当我运行这个时,我得到了大约 20 位代表的列表。你的结果可能会有所不同。

注意

如果你看各种“新闻稿”页面,大多数页面都是分页的,每页只有 5 或 10 篇新闻稿。这意味着我们只检索了每位国会议员最近的几篇新闻稿。更彻底的解决方案将迭代每一页,并检索每篇新闻稿的全文。

使用 API

许多网站和 Web 服务提供应用程序编程接口(API),允许您以结构化格式显式请求数据。这样可以避免您必须进行抓取的麻烦!

JSON 和 XML

因为 HTTP 是一个用于传输文本的协议,通过 Web API 请求的数据需要被序列化为字符串格式。通常这种序列化使用JavaScript 对象表示法(JSON)。JavaScript 对象看起来非常类似于 Python 的dict,这使得它们的字符串表示易于解释:

{ "title" : "Data Science Book",
  "author" : "Joel Grus",
  "publicationYear" : 2019,
  "topics" : [ "data", "science", "data science"] }

我们可以使用 Python 的json模块解析 JSON。特别地,我们将使用它的loads函数,将表示 JSON 对象的字符串反序列化为 Python 对象:

import json
serialized = """{ "title" : "Data Science Book",
 "author" : "Joel Grus",
 "publicationYear" : 2019,
 "topics" : [ "data", "science", "data science"] }"""

# parse the JSON to create a Python dict
deserialized = json.loads(serialized)
assert deserialized["publicationYear"] == 2019
assert "data science" in deserialized["topics"]

有时 API 提供者会讨厌你,并且只提供 XML 格式的响应:

<Book>
  <Title>Data Science Book</Title>
  <Author>Joel Grus</Author>
  <PublicationYear>2014</PublicationYear>
  <Topics>
    <Topic>data</Topic>
    <Topic>science</Topic>
    <Topic>data science</Topic>
  </Topics>
</Book>

你可以像从 HTML 中获取数据那样,使用 Beautiful Soup 从 XML 中获取数据;请查看其文档以获取详细信息。

使用未经身份验证的 API

大多数 API 现在要求你先进行身份验证,然后才能使用它们。虽然我们不反对这种策略,但这会产生很多额外的样板代码,使我们的解释变得混乱。因此,我们将首先看一下GitHub 的 API,它可以让你无需身份验证就能进行一些简单的操作:

import requests, json

github_user = "joelgrus"
endpoint = f"https://api.github.com/users/{github_user}/repos"

repos = json.loads(requests.get(endpoint).text)

此时repos是我 GitHub 账户中的公共仓库的 Python dict列表。(随意替换你的用户名并获取你的 GitHub 仓库数据。你有 GitHub 账户,对吧?)

我们可以用这个来找出我最有可能创建仓库的月份和星期几。唯一的问题是响应中的日期是字符串:

"created_at": "2013-07-05T02:02:28Z"

Python 自带的日期解析器不是很好用,所以我们需要安装一个:

python -m pip install python-dateutil

其中你可能只会需要dateutil.parser.parse函数:

from collections import Counter
from dateutil.parser import parse

dates = [parse(repo["created_at"]) for repo in repos]
month_counts = Counter(date.month for date in dates)
weekday_counts = Counter(date.weekday() for date in dates)

同样地,你可以获取我最近五个仓库的语言:

last_5_repositories = sorted(repos,
                             key=lambda r: r["pushed_at"],
                             reverse=True)[:5]

last_5_languages = [repo["language"]
                    for repo in last_5_repositories]

通常情况下,我们不会在低层次(“自己发起请求并解析响应”)处理 API。使用 Python 的好处之一是,几乎任何你有兴趣访问的 API,都已经有人建立了一个库。如果做得好,这些库可以节省你很多访问 API 的复杂细节的麻烦。(如果做得不好,或者当它们基于已失效的 API 版本时,可能会带来巨大的麻烦。)

尽管如此,偶尔你会需要自己编写 API 访问库(或者更有可能,调试为什么别人的库不起作用),因此了解一些细节是很有用的。

寻找 API

如果你需要从特定网站获取数据,请查找该网站的“开发者”或“API”部分以获取详细信息,并尝试在网上搜索“python api”来找到相应的库。

有关 Yelp API、Instagram API、Spotify API 等等,都有相应的库。

如果你在寻找 Python 封装的 API 列表,Real Python 在 GitHub 上有一个很好的列表。

如果找不到你需要的内容,总有一种方法,那就是网页抓取,数据科学家的最后避风港。

示例:使用 Twitter 的 API

Twitter 是一个非常好的数据来源。你可以用它来获取实时新闻,也可以用它来衡量对当前事件的反应。你还可以用它来查找与特定主题相关的链接。你可以用它来做几乎任何你能想到的事情,只要你能访问到它的数据。通过它的 API,你可以获取到它的数据。

要与 Twitter 的 API 交互,我们将使用Twython 库python -m pip install twython)。目前有许多 Python Twitter 库,但这是我使用最成功的一个。当然,也鼓励你探索其他库!

获取凭证

为了使用 Twitter 的 API,你需要获取一些凭据(你需要一个 Twitter 帐户,这样你就可以成为活跃且友好的 Twitter #datascience 社区的一部分)。

警告

像所有与我无法控制的网站相关的说明一样,这些说明可能在某个时候过时,但希望能够一段时间内工作。(尽管自我最初开始写这本书以来,它们已经多次发生变化,所以祝你好运!)

以下是步骤:

  1. 前往 https://developer.twitter.com/

  2. 如果你没有登录,点击“登录”并输入你的 Twitter 用户名和密码。

  3. 点击申请以申请开发者帐户。

  4. 为你自己的个人使用请求访问。

  5. 填写申请。 它需要 300 字(真的)解释为什么你需要访问,所以为了超过限制,你可以告诉他们关于这本书以及你有多么喜欢它。

  6. 等待一段不确定的时间。

  7. 如果你认识在 Twitter 工作的人,请给他们发电子邮件,询问他们是否可以加快你的申请。 否则,继续等待。

  8. 一旦你获得批准,返回到 developer.twitter.com,找到“Apps”部分,然后点击“创建应用程序”。

  9. 填写所有必填字段(同样,如果你需要描述的额外字符,你可以谈论这本书以及你发现它多么有启发性)。

  10. 点击创建。

现在你的应用程序应该有一个“Keys and tokens”选项卡,其中包含一个“Consumer API keys”部分,列出了一个“API key”和一个“API secret key”。 记下这些密钥; 你会需要它们。(另外,保持它们保密! 它们就像密码。)

小心

不要分享密钥,不要在书中发布它们,也不要将它们检入你的公共 GitHub 存储库。 一个简单的解决方案是将它们存储在一个不会被检入的 credentials.json 文件中,并让你的代码使用 json.loads 来检索它们。 另一个解决方案是将它们存储在环境变量中,并使用 os.environ 来检索它们。

使用 Twython

使用 Twitter API 的最棘手的部分是验证身份。(事实上,这是使用许多 API 中最棘手的部分之一。) API 提供商希望确保你被授权访问他们的数据,并且你不会超出他们的使用限制。 他们还想知道谁在访问他们的数据。

认证有点痛苦。 有一种简单的方法,OAuth 2,在你只想做简单搜索时足够使用。 还有一种复杂的方法,OAuth 1,在你想执行操作(例如,发推文)或(特别是对我们来说)连接到 Twitter 流时需要使用。

所以我们被迫使用更复杂的方式,我们会尽可能自动化它。

首先,你需要你的 API 密钥和 API 密钥(有时也称为消费者密钥和消费者密钥)。 我将从环境变量中获取我的,但请随意以任何你希望的方式替换你的:

import os

# Feel free to plug your key and secret in directly
CONSUMER_KEY = os.environ.get("TWITTER_CONSUMER_KEY")
CONSUMER_SECRET = os.environ.get("TWITTER_CONSUMER_SECRET")

现在我们可以实例化客户端:

import webbrowser
from twython import Twython

# Get a temporary client to retrieve an authentication URL
temp_client = Twython(CONSUMER_KEY, CONSUMER_SECRET)
temp_creds = temp_client.get_authentication_tokens()
url = temp_creds['auth_url']

# Now visit that URL to authorize the application and get a PIN
print(f"go visit {url} and get the PIN code and paste it below")
webbrowser.open(url)
PIN_CODE = input("please enter the PIN code: ")

# Now we use that PIN_CODE to get the actual tokens
auth_client = Twython(CONSUMER_KEY,
                      CONSUMER_SECRET,
                      temp_creds['oauth_token'],
                      temp_creds['oauth_token_secret'])
final_step = auth_client.get_authorized_tokens(PIN_CODE)
ACCESS_TOKEN = final_step['oauth_token']
ACCESS_TOKEN_SECRET = final_step['oauth_token_secret']

# And get a new Twython instance using them.
twitter = Twython(CONSUMER_KEY,
                  CONSUMER_SECRET,
                  ACCESS_TOKEN,
                  ACCESS_TOKEN_SECRET)
提示

此时,你可能希望考虑将ACCESS_TOKENACCESS_TOKEN_SECRET保存在安全的地方,这样下次你就不必再经历这个烦琐的过程了。

一旦我们有了一个经过身份验证的Twython实例,我们就可以开始执行搜索:

# Search for tweets containing the phrase "data science"
for status in twitter.search(q='"data science"')["statuses"]:
    user = status["user"]["screen_name"]
    text = status["text"]
    print(f"{user}: {text}\n")

如果你运行这个程序,你应该会得到一些推文,比如:

haithemnyc: Data scientists with the technical savvy &amp; analytical chops to
derive meaning from big data are in demand. http://t.co/HsF9Q0dShP

RPubsRecent: Data Science http://t.co/6hcHUz2PHM

spleonard1: Using #dplyr in #R to work through a procrastinated assignment for
@rdpeng in @coursera data science specialization. So easy and Awesome.

这并不那么有趣,主要是因为 Twitter 搜索 API 只会显示出它觉得最近的结果。在进行数据科学时,更多时候你会想要大量的推文。这就是Streaming API有用的地方。它允许你连接到(部分)巨大的 Twitter firehose。要使用它,你需要使用你的访问令牌进行身份验证。

为了使用 Twython 访问 Streaming API,我们需要定义一个类,该类继承自TwythonStreamer并重写其on_success方法,可能还有其on_error方法:

from twython import TwythonStreamer

# Appending data to a global variable is pretty poor form
# but it makes the example much simpler
tweets = []

class MyStreamer(TwythonStreamer):
    def on_success(self, data):
        """
 What do we do when Twitter sends us data?
 Here data will be a Python dict representing a tweet.
 """
        # We only want to collect English-language tweets
        if data.get('lang') == 'en':
            tweets.append(data)
            print(f"received tweet #{len(tweets)}")

        # Stop when we've collected enough
        if len(tweets) >= 100:
            self.disconnect()

    def on_error(self, status_code, data):
        print(status_code, data)
        self.disconnect()

MyStreamer将连接到 Twitter 流并等待 Twitter 提供数据。每次接收到一些数据(在这里是表示为 Python 对象的推文)时,它都会将其传递给on_success方法,如果推文的语言是英语,则将其追加到我们的tweets列表中,然后在收集到 1,000 条推文后断开流。

唯一剩下的就是初始化它并开始运行:

stream = MyStreamer(CONSUMER_KEY, CONSUMER_SECRET,
                    ACCESS_TOKEN, ACCESS_TOKEN_SECRET)

# starts consuming public statuses that contain the keyword 'data'
stream.statuses.filter(track='data')

# if instead we wanted to start consuming a sample of *all* public statuses
# stream.statuses.sample()

这将持续运行,直到收集到 100 条推文(或遇到错误为止),然后停止,此时你可以开始分析这些推文。例如,你可以找出最常见的标签:

top_hashtags = Counter(hashtag['text'].lower()
                       for tweet in tweets
                       for hashtag in tweet["entities"]["hashtags"])

print(top_hashtags.most_common(5))

每条推文都包含大量的数据。你可以自己探索,或者查看Twitter API 文档

注意

在一个非玩具项目中,你可能不想依赖于内存中的list来存储推文。相反,你可能想把它们保存到文件或数据库中,这样你就能永久地拥有它们。

进一步探索

  • pandas是数据科学家们用来处理数据,特别是导入数据的主要库。

  • Scrapy是一个用于构建复杂网络爬虫的全功能库,可以执行诸如跟踪未知链接等操作。

  • Kaggle拥有大量的数据集。

第十章:处理数据

专家往往拥有比判断更多的数据。

科林·鲍威尔

处理数据既是一门艺术,也是一门科学。我们主要讨论科学部分,但在本章中我们将探讨一些艺术方面。

探索您的数据

在确定您要回答的问题并获取到数据之后,您可能会有冲动立即开始构建模型和获取答案。但您应该抑制这种冲动。您的第一步应该是 探索 您的数据。

探索一维数据

最简单的情况是您有一个一维数据集,它只是一组数字。例如,这些可能是每个用户每天在您的网站上花费的平均分钟数,一系列数据科学教程视频的观看次数,或者您的数据科学图书馆中每本数据科学书籍的页数。

显而易见的第一步是计算一些摘要统计信息。您想知道有多少数据点,最小值,最大值,均值和标准差。

但即使这些也不一定能给您带来很好的理解。一个很好的下一步是创建直方图,将数据分组为离散的 并计算落入每个桶中的点数:

from typing import List, Dict
from collections import Counter
import math

import matplotlib.pyplot as plt

def bucketize(point: float, bucket_size: float) -> float:
    """Floor the point to the next lower multiple of bucket_size"""
    return bucket_size * math.floor(point / bucket_size)

def make_histogram(points: List[float], bucket_size: float) -> Dict[float, int]:
    """Buckets the points and counts how many in each bucket"""
    return Counter(bucketize(point, bucket_size) for point in points)

def plot_histogram(points: List[float], bucket_size: float, title: str = ""):
    histogram = make_histogram(points, bucket_size)
    plt.bar(histogram.keys(), histogram.values(), width=bucket_size)
    plt.title(title)

例如,考虑以下两组数据:

import random
from scratch.probability import inverse_normal_cdf

random.seed(0)

# uniform between -100 and 100
uniform = [200 * random.random() - 100 for _ in range(10000)]

# normal distribution with mean 0, standard deviation 57
normal = [57 * inverse_normal_cdf(random.random())
          for _ in range(10000)]

两者的均值接近 0,标准差接近 58。然而,它们的分布却非常不同。 第 10-1 图 显示了 uniform 的分布:

plot_histogram(uniform, 10, "Uniform Histogram")

而 第 10-2 图 显示了 normal 的分布:

plot_histogram(normal, 10, "Normal Histogram")

均匀分布直方图

第 10-1 图。均匀分布直方图

在这种情况下,两个分布的 maxmin 相差很大,但即使知道这一点也不足以理解它们的 差异

两个维度

现在想象一下,您有一个具有两个维度的数据集。除了每天的分钟数,您可能还有数据科学经验的年数。当然,您希望单独了解每个维度。但您可能还想散点数据。

例如,考虑另一个虚构数据集:

def random_normal() -> float:
    """Returns a random draw from a standard normal distribution"""
    return inverse_normal_cdf(random.random())

xs = [random_normal() for _ in range(1000)]
ys1 = [ x + random_normal() / 2 for x in xs]
ys2 = [-x + random_normal() / 2 for x in xs]

如果你对 ys1ys2 运行 plot_histogram,你会得到类似的图表(实际上,两者均为具有相同均值和标准差的正态分布)。

正态分布直方图

第 10-2 图。正态分布直方图

但每个维度与 xs 的联合分布非常不同,如 第 10-3 图 所示:

plt.scatter(xs, ys1, marker='.', color='black', label='ys1')
plt.scatter(xs, ys2, marker='.', color='gray',  label='ys2')
plt.xlabel('xs')
plt.ylabel('ys')
plt.legend(loc=9)
plt.title("Very Different Joint Distributions")
plt.show()

散射两个不同 ys 的图

第 10-3 图。散射两个不同的 ys

如果你查看相关性,这种差异也会显现:

from scratch.statistics import correlation

print(correlation(xs, ys1))      # about 0.9
print(correlation(xs, ys2))      # about -0.9

多维度

对于许多维度,您可能想了解所有维度之间的关系。一个简单的方法是查看 相关矩阵,其中第 i 行和第 j 列的条目是数据的第 i 维和第 j 维之间的相关性:

from scratch.linear_algebra import Matrix, Vector, make_matrix

def correlation_matrix(data: List[Vector]) -> Matrix:
    """
 Returns the len(data) x len(data) matrix whose (i, j)-th entry
 is the correlation between data[i] and data[j]
 """
    def correlation_ij(i: int, j: int) -> float:
        return correlation(data[i], data[j])

    return make_matrix(len(data), len(data), correlation_ij)

一种更直观的方法(如果维度不多)是制作一个散点图矩阵(参见图 10-4),显示所有成对的散点图。为此,我们将使用plt.subplots,它允许我们创建图表的子图。我们给它行数和列数,它返回一个figure对象(我们将不使用它)和一个二维数组的axes对象(我们将每个都绘制):

# corr_data is a list of four 100-d vectors
num_vectors = len(corr_data)
fig, ax = plt.subplots(num_vectors, num_vectors)

for i in range(num_vectors):
    for j in range(num_vectors):

        # Scatter column_j on the x-axis vs. column_i on the y-axis
        if i != j: ax[i][j].scatter(corr_data[j], corr_data[i])

        # unless i == j, in which case show the series name
        else: ax[i][j].annotate("series " + str(i), (0.5, 0.5),
                                xycoords='axes fraction',
                                ha="center", va="center")

        # Then hide axis labels except left and bottom charts
        if i < num_vectors - 1: ax[i][j].xaxis.set_visible(False)
        if j > 0: ax[i][j].yaxis.set_visible(False)

# Fix the bottom-right and top-left axis labels, which are wrong because
# their charts only have text in them
ax[-1][-1].set_xlim(ax[0][-1].get_xlim())
ax[0][0].set_ylim(ax[0][1].get_ylim())

plt.show()

散点图矩阵

图 10-4. 散点图矩阵

从散点图中可以看出,系列 1 与系列 0 之间存在非常负相关的关系,系列 2 与系列 1 之间存在正相关的关系,而系列 3 只取值 0 和 6,其中 0 对应于系列 2 的小值,6 对应于系列 2 的大值。

这是快速了解哪些变量相关的一种粗略方法(除非你花费数小时调整 matplotlib 以完全按照你想要的方式显示,否则这不是一个快速方法)。

使用 NamedTuples

表示数据的一种常见方式是使用dicts:

import datetime

stock_price = {'closing_price': 102.06,
               'date': datetime.date(2014, 8, 29),
               'symbol': 'AAPL'}

然而,这不太理想的几个原因。这是一种略微低效的表示形式(一个dict涉及一些开销),因此如果你有很多股价,它们将占用比它们应该占用的更多内存。在大多数情况下,这只是一个小考虑。

一个更大的问题是通过dict键访问事物容易出错。以下代码将不会报错,但会执行错误的操作:

# oops, typo
stock_price['cosing_price'] = 103.06

最后,虽然我们可以为统一的字典进行类型注释:

prices: Dict[datetime.date, float] = {}

没有一种有用的方法来注释具有许多不同值类型的字典数据。因此,我们也失去了类型提示的力量。

作为一种替代方案,Python 包含一个namedtuple类,它类似于一个tuple,但具有命名的槽位:

from collections import namedtuple

StockPrice = namedtuple('StockPrice', ['symbol', 'date', 'closing_price'])
price = StockPrice('MSFT', datetime.date(2018, 12, 14), 106.03)

assert price.symbol == 'MSFT'
assert price.closing_price == 106.03

像常规的tuples一样,namedtuples 是不可变的,这意味着一旦创建就无法修改它们的值。偶尔这会成为我们的障碍,但大多数情况下这是件好事。

你会注意到我们还没有解决类型注解的问题。我们可以通过使用类型化的变体NamedTuple来解决:

from typing import NamedTuple

class StockPrice(NamedTuple):
    symbol: str
    date: datetime.date
    closing_price: float

    def is_high_tech(self) -> bool:
        """It's a class, so we can add methods too"""
        return self.symbol in ['MSFT', 'GOOG', 'FB', 'AMZN', 'AAPL']

price = StockPrice('MSFT', datetime.date(2018, 12, 14), 106.03)

assert price.symbol == 'MSFT'
assert price.closing_price == 106.03
assert price.is_high_tech()

现在你的编辑器可以帮助你,就像在图 10-5 中显示的那样。

有用的编辑器

图 10-5. 有用的编辑器
注意

很少有人以这种方式使用NamedTuple。但他们应该!

Dataclasses

Dataclasses 是NamedTuple的一种(某种程度上)可变版本。(我说“某种程度上”,因为NamedTuples 将它们的数据紧凑地表示为元组,而 dataclasses 是常规的 Python 类,只是为您自动生成了一些方法。)

注意

Dataclasses 在 Python 3.7 中是新功能。如果你使用的是旧版本,则本节对你无效。

语法与NamedTuple非常相似。但是,我们使用装饰器而不是从基类继承:

from dataclasses import dataclass

@dataclass
class StockPrice2:
    symbol: str
    date: datetime.date
    closing_price: float

    def is_high_tech(self) -> bool:
        """It's a class, so we can add methods too"""
        return self.symbol in ['MSFT', 'GOOG', 'FB', 'AMZN', 'AAPL']

price2 = StockPrice2('MSFT', datetime.date(2018, 12, 14), 106.03)

assert price2.symbol == 'MSFT'
assert price2.closing_price == 106.03
assert price2.is_high_tech()

正如前面提到的,最大的区别在于我们可以修改 dataclass 实例的值:

# stock split
price2.closing_price /= 2
assert price2.closing_price == 51.03

如果我们尝试修改NamedTuple版本的字段,我们会得到一个AttributeError

这也使我们容易受到我们希望通过不使用dict来避免的错误的影响:

# It's a regular class, so add new fields however you like!
price2.cosing_price = 75  # oops

我们不会使用 dataclasses,但你可能会在野外遇到它们。

清洁和操纵

现实世界的数据。通常你需要在使用之前对其进行一些处理。我们在第九章中看到了这方面的例子。在使用之前,我们必须将字符串转换为floatint。我们必须检查缺失值、异常值和错误数据。

以前,我们在使用数据之前就这样做了:

closing_price = float(row[2])

但是在一个我们可以测试的函数中进行解析可能更少出错:

from dateutil.parser import parse

def parse_row(row: List[str]) -> StockPrice:
    symbol, date, closing_price = row
    return StockPrice(symbol=symbol,
                      date=parse(date).date(),
                      closing_price=float(closing_price))

# Now test our function
stock = parse_row(["MSFT", "2018-12-14", "106.03"])

assert stock.symbol == "MSFT"
assert stock.date == datetime.date(2018, 12, 14)
assert stock.closing_price == 106.03

如果有错误数据怎么办?一个不实际代表数字的“浮点”值?也许你宁愿得到一个None而不是使程序崩溃?

from typing import Optional
import re

def try_parse_row(row: List[str]) -> Optional[StockPrice]:
    symbol, date_, closing_price_ = row

    # Stock symbol should be all capital letters
    if not re.match(r"^[A-Z]+$", symbol):
        return None

    try:
        date = parse(date_).date()
    except ValueError:
        return None

    try:
        closing_price = float(closing_price_)
    except ValueError:
        return None

    return StockPrice(symbol, date, closing_price)

# Should return None for errors
assert try_parse_row(["MSFT0", "2018-12-14", "106.03"]) is None
assert try_parse_row(["MSFT", "2018-12--14", "106.03"]) is None
assert try_parse_row(["MSFT", "2018-12-14", "x"]) is None

# But should return same as before if data is good
assert try_parse_row(["MSFT", "2018-12-14", "106.03"]) == stock

举个例子,如果我们有用逗号分隔的股票价格数据有错误:

AAPL,6/20/2014,90.91
MSFT,6/20/2014,41.68
FB,6/20/3014,64.5
AAPL,6/19/2014,91.86
MSFT,6/19/2014,n/a
FB,6/19/2014,64.34

现在我们可以只读取并返回有效的行了:

import csv

data: List[StockPrice] = []

with open("comma_delimited_stock_prices.csv") as f:
    reader = csv.reader(f)
    for row in reader:
        maybe_stock = try_parse_row(row)
        if maybe_stock is None:
            print(f"skipping invalid row: {row}")
        else:
            data.append(maybe_stock)

并决定我们想要如何处理无效数据。一般来说,三个选择是摆脱它们,返回到源头并尝试修复错误/丢失的数据,或者什么也不做,只是交叉手指。如果数百万行中有一行错误的数据,那么忽略它可能没问题。但是如果一半的行都有错误数据,那就是你需要解决的问题。

下一个好的步骤是检查异常值,使用“探索您的数据”中的技术或通过临时调查来进行。例如,你是否注意到股票文件中的一个日期的年份是 3014 年?这不会(必然)给你一个错误,但很明显是错误的,如果你不注意到它,你会得到混乱的结果。现实世界的数据集有缺失的小数点、额外的零、排版错误以及无数其他问题,你需要解决。(也许官方上不是你的工作,但还有谁会做呢?)

数据操作

数据科学家最重要的技能之一是数据操作。这更像是一种通用方法而不是特定的技术,所以我们只需通过几个示例来让你了解一下。

想象我们有一堆股票价格数据,看起来像这样:

data = [
    StockPrice(symbol='MSFT',
               date=datetime.date(2018, 12, 24),
               closing_price=106.03),
    # ...
]

让我们开始对这些数据提出问题。在此过程中,我们将尝试注意到我们正在做的事情,并抽象出一些工具,使操纵更容易。

例如,假设我们想知道 AAPL 的最高收盘价。让我们将这个问题分解成具体的步骤:

  1. 限制自己只看 AAPL 的行。

  2. 从每行中获取closing_price

  3. 获取那些价格的最大值。

我们可以一次完成所有三个任务使用推导:

max_aapl_price = max(stock_price.closing_price
                     for stock_price in data
                     if stock_price.symbol == "AAPL")

更一般地,我们可能想知道数据集中每支股票的最高收盘价。做到这一点的一种方法是:

  1. 创建一个dict来跟踪最高价格(我们将使用一个defaultdict,对于缺失值返回负无穷大,因为任何价格都将大于它)。

  2. 迭代我们的数据,更新它。

这是代码:

from collections import defaultdict

max_prices: Dict[str, float] = defaultdict(lambda: float('-inf'))

for sp in data:
    symbol, closing_price = sp.symbol, sp.closing_price
    if closing_price > max_prices[symbol]:
        max_prices[symbol] = closing_price

现在我们可以开始询问更复杂的问题,比如数据集中最大和最小的单日百分比变化是多少。百分比变化是price_today / price_yesterday - 1,这意味着我们需要一种将今天价格和昨天价格关联起来的方法。一种方法是按符号分组价格,然后在每个组内:

  1. 按日期排序价格。

  2. 使用zip获取(前一个,当前)对。

  3. 将这些对转换为新的“百分比变化”行。

让我们从按符号分组的价格开始:

from typing import List
from collections import defaultdict

# Collect the prices by symbol
prices: Dict[str, List[StockPrice]] = defaultdict(list)

for sp in data:
    prices[sp.symbol].append(sp)

由于价格是元组,它们将按字段顺序排序:首先按符号,然后按日期,最后按价格。这意味着如果我们有一些价格具有相同的符号,sort将按日期排序(然后按价格排序,但由于每个日期只有一个价格,所以这没有什么效果),这正是我们想要的。

# Order the prices by date
prices = {symbol: sorted(symbol_prices)
          for symbol, symbol_prices in prices.items()}

我们可以用它来计算一系列日对日的变化:

def pct_change(yesterday: StockPrice, today: StockPrice) -> float:
    return today.closing_price / yesterday.closing_price - 1

class DailyChange(NamedTuple):
    symbol: str
    date: datetime.date
    pct_change: float

def day_over_day_changes(prices: List[StockPrice]) -> List[DailyChange]:
    """
 Assumes prices are for one stock and are in order
 """
    return [DailyChange(symbol=today.symbol,
                        date=today.date,
                        pct_change=pct_change(yesterday, today))
            for yesterday, today in zip(prices, prices[1:])]

然后收集它们全部:

all_changes = [change
               for symbol_prices in prices.values()
               for change in day_over_day_changes(symbol_prices)]

在这一点上,找到最大值和最小值很容易:

max_change = max(all_changes, key=lambda change: change.pct_change)
# see e.g. http://news.cnet.com/2100-1001-202143.html
assert max_change.symbol == 'AAPL'
assert max_change.date == datetime.date(1997, 8, 6)
assert 0.33 < max_change.pct_change < 0.34

min_change = min(all_changes, key=lambda change: change.pct_change)
# see e.g. http://money.cnn.com/2000/09/29/markets/techwrap/
assert min_change.symbol == 'AAPL'
assert min_change.date == datetime.date(2000, 9, 29)
assert -0.52 < min_change.pct_change < -0.51

现在我们可以使用这个新的all_changes数据集来找出哪个月份最适合投资科技股。我们只需查看每月的平均每日变化:

changes_by_month: List[DailyChange] = {month: [] for month in range(1, 13)}

for change in all_changes:
    changes_by_month[change.date.month].append(change)

avg_daily_change = {
    month: sum(change.pct_change for change in changes) / len(changes)
    for month, changes in changes_by_month.items()
}

# October is the best month
assert avg_daily_change[10] == max(avg_daily_change.values())

在整本书中,我们将会进行这些操作,通常不会过多显式地提及它们。

重新缩放

许多技术对您数据的尺度很敏感。例如,想象一下,您有一个由数百名数据科学家的身高和体重组成的数据集,您试图识别体型的聚类

直觉上,我们希望聚类表示彼此附近的点,这意味着我们需要某种点之间距离的概念。我们已经有了欧氏distance函数,因此一个自然的方法可能是将(身高,体重)对视为二维空间中的点。考虑表 10-1 中列出的人员。

表 10-1. 身高和体重

人员身高(英寸)身高(厘米)体重(磅)
A63160150
B67170.2160
C70177.8171

如果我们用英寸测量身高,那么 B 的最近邻是 A:

from scratch.linear_algebra import distance

a_to_b = distance([63, 150], [67, 160])        # 10.77
a_to_c = distance([63, 150], [70, 171])        # 22.14
b_to_c = distance([67, 160], [70, 171])        # 11.40

然而,如果我们用厘米测量身高,那么 B 的最近邻将变为 C:

a_to_b = distance([160, 150], [170.2, 160])    # 14.28
a_to_c = distance([160, 150], [177.8, 171])    # 27.53
b_to_c = distance([170.2, 160], [177.8, 171])  # 13.37

显然,如果改变单位会导致结果发生变化,这是一个问题。因此,当维度不可比较时,我们有时会重新缩放我们的数据,使得每个维度的均值为 0,标准差为 1。这实际上消除了单位,将每个维度转换为“均值的标准偏差数”。

首先,我们需要计算每个位置的meanstandard_deviation

from typing import Tuple

from scratch.linear_algebra import vector_mean
from scratch.statistics import standard_deviation

def scale(data: List[Vector]) -> Tuple[Vector, Vector]:
    """returns the mean and standard deviation for each position"""
    dim = len(data[0])

    means = vector_mean(data)
    stdevs = [standard_deviation([vector[i] for vector in data])
              for i in range(dim)]

    return means, stdevs

vectors = [[-3, -1, 1], [-1, 0, 1], [1, 1, 1]]
means, stdevs = scale(vectors)
assert means == [-1, 0, 1]
assert stdevs == [2, 1, 0]

然后我们可以用它们创建一个新的数据集:

def rescale(data: List[Vector]) -> List[Vector]:
    """
 Rescales the input data so that each position has
 mean 0 and standard deviation 1\. (Leaves a position
 as is if its standard deviation is 0.)
 """
    dim = len(data[0])
    means, stdevs = scale(data)

    # Make a copy of each vector
    rescaled = [v[:] for v in data]

    for v in rescaled:
        for i in range(dim):
            if stdevs[i] > 0:
                v[i] = (v[i] - means[i]) / stdevs[i]

    return rescaled

当然,让我们写一个测试来确认rescale是否按我们想的那样工作:

means, stdevs = scale(rescale(vectors))
assert means == [0, 0, 1]
assert stdevs == [1, 1, 0]

如常,您需要根据自己的判断。例如,如果您将一个大量的身高和体重数据集筛选为只有身高在 69.5 英寸和 70.5 英寸之间的人,剩下的变化很可能(取决于您试图回答的问题)只是噪声,您可能不希望将其标准差与其他维度的偏差平等看待。

旁注:tqdm

经常我们会进行需要很长时间的计算。当您进行这样的工作时,您希望知道自己在取得进展并且预计需要等待多长时间。

一种方法是使用 tqdm 库,它生成自定义进度条。我们将在本书的其他部分中多次使用它,所以现在让我们学习一下它是如何工作的。

要开始使用,您需要安装它:

python -m pip install tqdm

你只需要知道几个特性。首先是,在 tqdm.tqdm 中包装的可迭代对象会生成一个进度条:

import tqdm

for i in tqdm.tqdm(range(100)):
    # do something slow
    _ = [random.random() for _ in range(1000000)]

这会生成一个类似于以下输出的结果:

 56%|████████████████████              | 56/100 [00:08<00:06,  6.49it/s]

特别地,它会显示循环的完成部分百分比(尽管如果您使用生成器,它无法这样做),已运行时间以及预计的剩余运行时间。

在这种情况下(我们只是包装了对 range 的调用),您可以直接使用 tqdm.trange

在其运行时,您还可以设置进度条的描述。要做到这一点,您需要在 with 语句中捕获 tqdm 迭代器:

from typing import List

def primes_up_to(n: int) -> List[int]:
    primes = [2]

    with tqdm.trange(3, n) as t:
        for i in t:
            # i is prime if no smaller prime divides it
            i_is_prime = not any(i % p == 0 for p in primes)
            if i_is_prime:
                primes.append(i)

            t.set_description(f"{len(primes)} primes")

    return primes

my_primes = primes_up_to(100_000)

这会添加一个如下描述,其中计数器会随着新的质数被发现而更新:

5116 primes:  50%|████████        | 49529/99997 [00:03<00:03, 15905.90it/s]

使用 tqdm 有时会使您的代码变得不稳定——有时屏幕重绘不良,有时循环会简单地挂起。如果您意外地将 tqdm 循环包装在另一个 tqdm 循环中,可能会发生奇怪的事情。尽管如此,通常它的好处超过这些缺点,因此在我们有运行缓慢的计算时,我们将尝试使用它。

降维

有时数据的“实际”(或有用)维度可能与我们拥有的维度不对应。例如,请考虑图示的数据集 Figure 10-6。

带有“错误”轴的数据

图 10-6. 带有“错误”轴的数据

数据中的大部分变化似乎沿着一个不对应于 x 轴或 y 轴的单一维度发生。

当情况如此时,我们可以使用一种称为主成分分析(PCA)的技术来提取尽可能多地捕获数据变化的一个或多个维度。

注意

在实践中,您不会在这样低维度的数据集上使用此技术。当您的数据集具有大量维度并且您希望找到捕获大部分变化的小子集时,降维大多数时候非常有用。不幸的是,在二维书籍格式中很难说明这种情况。

作为第一步,我们需要转换数据,使得每个维度的均值为 0:

from scratch.linear_algebra import subtract

def de_mean(data: List[Vector]) -> List[Vector]:
    """Recenters the data to have mean 0 in every dimension"""
    mean = vector_mean(data)
    return [subtract(vector, mean) for vector in data]

(如果我们不这样做,我们的技术可能会识别出均值本身,而不是数据中的变化。)

图 10-7 显示了去均值后的示例数据。

去均值后的 PCA 数据。

图 10-7. 去均值后的数据

现在,给定一个去均值的矩阵 X,我们可以问哪个方向捕捉了数据中的最大方差。

具体来说,给定一个方向 d(一个大小为 1 的向量),矩阵中的每一行 xd 方向上延伸 dot(x, d)。并且每个非零向量 w 确定一个方向,如果我们重新缩放它使其大小为 1:

from scratch.linear_algebra import magnitude

def direction(w: Vector) -> Vector:
    mag = magnitude(w)
    return [w_i / mag for w_i in w]

因此,给定一个非零向量 w,我们可以计算由 w 确定的数据集在方向上的方差:

from scratch.linear_algebra import dot

def directional_variance(data: List[Vector], w: Vector) -> float:
    """
 Returns the variance of x in the direction of w
 """
    w_dir = direction(w)
    return sum(dot(v, w_dir) ** 2 for v in data)

我们希望找到最大化这种方差的方向。我们可以使用梯度下降来实现这一点,只要我们有梯度函数:

def directional_variance_gradient(data: List[Vector], w: Vector) -> Vector:
    """
 The gradient of directional variance with respect to w
 """
    w_dir = direction(w)
    return [sum(2 * dot(v, w_dir) * v[i] for v in data)
            for i in range(len(w))]

现在,我们拥有的第一个主成分就是最大化directional_variance函数的方向:

from scratch.gradient_descent import gradient_step

def first_principal_component(data: List[Vector],
                              n: int = 100,
                              step_size: float = 0.1) -> Vector:
    # Start with a random guess
    guess = [1.0 for _ in data[0]]

    with tqdm.trange(n) as t:
        for _ in t:
            dv = directional_variance(data, guess)
            gradient = directional_variance_gradient(data, guess)
            guess = gradient_step(guess, gradient, step_size)
            t.set_description(f"dv: {dv:.3f}")

    return direction(guess)

在去均值的数据集上,这将返回方向 [0.924, 0.383],看起来捕捉了数据变化的主轴(图 10-8)。

带有第一个成分的 PCA 数据。

图 10-8. 第一个主成分

一旦找到了第一个主成分的方向,我们可以将数据投影到这个方向上,以找到该成分的值:

from scratch.linear_algebra import scalar_multiply

def project(v: Vector, w: Vector) -> Vector:
    """return the projection of v onto the direction w"""
    projection_length = dot(v, w)
    return scalar_multiply(projection_length, w)

如果我们想找到更多的成分,我们首先要从数据中移除投影:

from scratch.linear_algebra import subtract

def remove_projection_from_vector(v: Vector, w: Vector) -> Vector:
    """projects v onto w and subtracts the result from v"""
    return subtract(v, project(v, w))

def remove_projection(data: List[Vector], w: Vector) -> List[Vector]:
    return [remove_projection_from_vector(v, w) for v in data]

因为这个示例数据集仅有二维,在移除第一个成分后,剩下的有效是一维的(图 10-9)。

移除第一个主成分后的数据

图 10-9. 移除第一个主成分后的数据

在那一点上,通过在 remove_projection 的结果上重复这个过程,我们可以找到下一个主成分(图 10-10)。

在一个高维数据集上,我们可以迭代地找到我们想要的许多成分:

def pca(data: List[Vector], num_components: int) -> List[Vector]:
    components: List[Vector] = []
    for _ in range(num_components):
        component = first_principal_component(data)
        components.append(component)
        data = remove_projection(data, component)

    return components

然后我们可以将我们的数据转换到由这些成分张成的低维空间中:

def transform_vector(v: Vector, components: List[Vector]) -> Vector:
    return [dot(v, w) for w in components]

def transform(data: List[Vector], components: List[Vector]) -> List[Vector]:
    return [transform_vector(v, components) for v in data]

这种技术有几个原因很有价值。首先,它可以通过消除噪声维度和整合高度相关的维度来帮助我们清理数据。

前两个主成分。

图 10-10. 前两个主成分

第二,当我们提取出数据的低维表示后,我们可以使用多种在高维数据上效果不佳的技术。本书中将展示此类技术的示例。

同时,尽管这种技术可以帮助你建立更好的模型,但也可能使这些模型更难以解释。理解“每增加一年经验,平均增加 1 万美元的薪水”这样的结论很容易。但“第三主成分每增加 0.1,平均薪水增加 1 万美元”则更难理解。

进一步探索

  • 正如在 第九章 结尾提到的,pandas 可能是清洗、处理和操作数据的主要 Python 工具。本章我们手动完成的所有示例,使用 pandas 都可以更简单地实现。《Python 数据分析》(O’Reilly) 由 Wes McKinney 编写,可能是学习 pandas 最好的方式。

  • scikit-learn 提供了各种矩阵分解函数,包括 PCA。

第十一章:机器学习

我总是乐意学习,尽管我并不总是喜欢被教导。

温斯顿·丘吉尔

许多人想象数据科学主要是机器学习,认为数据科学家整天都在构建、训练和调整机器学习模型。(不过,很多这样想的人其实并不知道机器学习是什么。)事实上,数据科学主要是将业务问题转化为数据问题,收集数据、理解数据、清理数据和格式化数据,而机器学习几乎成了事后的事情。尽管如此,它是一个有趣且必不可少的事后步骤,你基本上必须了解它才能从事数据科学工作。

建模

在我们讨论机器学习之前,我们需要谈谈模型

什么是模型?简单来说,它是描述不同变量之间数学(或概率)关系的规范。

例如,如果你正在为你的社交网络站点筹集资金,你可能会建立一个商业模型(通常在电子表格中),该模型接受“用户数量”、“每用户广告收入”和“员工数量”等输入,并输出未来几年的年度利润。烹饪食谱涉及一个模型,将“用餐者数量”和“饥饿程度”等输入与所需的食材量联系起来。如果你曾经在电视上观看扑克比赛,你会知道每位玩家的“获胜概率”是根据模型实时估算的,该模型考虑了到目前为止已经公开的牌和牌堆中牌的分布。

商业模型可能基于简单的数学关系:利润等于收入减去支出,收入等于销售单位数乘以平均价格,等等。食谱模型可能基于试错法——有人在厨房尝试不同的配料组合,直到找到自己喜欢的那一种。而扑克模型则基于概率论、扑克规则以及关于发牌随机过程的一些合理假设。

什么是机器学习?

每个人都有自己的确切定义,但我们将使用机器学习来指代从数据中创建和使用模型的过程。在其他情境下,这可能被称为预测建模数据挖掘,但我们将坚持使用机器学习。通常,我们的目标是利用现有数据开发模型,用于预测新数据的各种结果,比如:

  • 是否是垃圾邮件

  • 信用卡交易是否属于欺诈

  • 哪个广告最有可能被购物者点击

  • 哪支橄榄球队会赢得超级碗

我们将讨论监督模型(其中有一组带有正确答案标签的数据可供学习)和无监督模型(其中没有这些标签)两种模型。还有其他各种类型,比如半监督(其中只有部分数据被标记)、在线(模型需要持续调整以适应新到达的数据)和强化(在做出一系列预测后,模型会收到一个指示其表现如何的信号),这些我们在本书中不会涉及。

现在,即使在最简单的情况下,也有可能有整个宇宙的模型可以描述我们感兴趣的关系。在大多数情况下,我们会自己选择一个参数化的模型家族,然后使用数据来学习某种方式上最优的参数。

例如,我们可能假设一个人的身高(大致上)是他的体重的线性函数,然后使用数据来学习这个线性函数是什么。或者我们可能认为决策树是诊断我们的患者患有哪些疾病的好方法,然后使用数据来学习“最优”这样的树。在本书的其余部分,我们将研究我们可以学习的不同模型家族。

但在此之前,我们需要更好地理解机器学习的基本原理。在本章的其余部分,我们将讨论一些基本概念,然后再讨论模型本身。

过拟合和欠拟合

在机器学习中一个常见的危险是过拟合——生成一个在您训练它的数据上表现良好但泛化性能差的模型。这可能涉及学习数据中的噪音。或者可能涉及学习识别特定输入,而不是实际上对所需输出有预测能力的因素。

这种情况的另一面是拟合不足——产生一个即使在训练数据上表现也不好的模型,尽管通常在这种情况下,您会认为您的模型还不够好,继续寻找更好的模型。

在图 11-1 中,我拟合了三个多项式到一组数据样本中。(不用担心具体方法;我们会在后面的章节中介绍。)

过拟合和欠拟合。

图 11-1. 过拟合和欠拟合

水平线显示了最佳拟合度为 0(即常数)的多项式。它严重拟合不足训练数据。最佳拟合度为 9(即 10 参数)的多项式恰好通过每个训练数据点,但它非常严重过拟合;如果我们再选几个数据点,它很可能会严重偏离。而一次拟合度的线条达到了很好的平衡;它非常接近每个点,如果这些数据是代表性的,那么这条线也很可能接近新数据点。

显然,过于复杂的模型会导致过拟合,并且在训练数据之外不能很好地泛化。那么,我们如何确保我们的模型不会太复杂呢?最基本的方法涉及使用不同的数据来训练模型和测试模型。

这样做的最简单方法是将数据集分割,例如,将其的三分之二用于训练模型,之后我们可以在剩余的三分之一上测量模型的性能:

import random
from typing import TypeVar, List, Tuple
X = TypeVar('X')  # generic type to represent a data point

def split_data(data: List[X], prob: float) -> Tuple[List[X], List[X]]:
    """Split data into fractions [prob, 1 - prob]"""
    data = data[:]                    # Make a shallow copy
    random.shuffle(data)              # because shuffle modifies the list.
    cut = int(len(data) * prob)       # Use prob to find a cutoff
    return data[:cut], data[cut:]     # and split the shuffled list there.

data = [n for n in range(1000)]
train, test = split_data(data, 0.75)

# The proportions should be correct
assert len(train) == 750
assert len(test) == 250

# And the original data should be preserved (in some order)
assert sorted(train + test) == data

通常情况下,我们会有成对的输入变量和输出变量。在这种情况下,我们需要确保将对应的值放在训练数据或测试数据中:

Y = TypeVar('Y')  # generic type to represent output variables

def train_test_split(xs: List[X],
                     ys: List[Y],
                     test_pct: float) -> Tuple[List[X], List[X], List[Y],
                                                                 List[Y]]:
    # Generate the indices and split them
    idxs = [i for i in range(len(xs))]
    train_idxs, test_idxs = split_data(idxs, 1 - test_pct)

    return ([xs[i] for i in train_idxs],  # x_train
            [xs[i] for i in test_idxs],   # x_test
            [ys[i] for i in train_idxs],  # y_train
            [ys[i] for i in test_idxs])   # y_test

如常,我们要确保我们的代码能够正常工作:

xs = [x for x in range(1000)]  # xs are 1 ... 1000
ys = [2 * x for x in xs]       # each y_i is twice x_i
x_train, x_test, y_train, y_test = train_test_split(xs, ys, 0.25)

# Check that the proportions are correct
assert len(x_train) == len(y_train) == 750
assert len(x_test) == len(y_test) == 250

# Check that the corresponding data points are paired correctly
assert all(y == 2 * x for x, y in zip(x_train, y_train))
assert all(y == 2 * x for x, y in zip(x_test, y_test))

之后,您可以做一些像这样的事情:

model = SomeKindOfModel()
x_train, x_test, y_train, y_test = train_test_split(xs, ys, 0.33)
model.train(x_train, y_train)
performance = model.test(x_test, y_test)

如果模型对训练数据过拟合,那么它在(完全分开的)测试数据上的表现希望会非常差。换句话说,如果它在测试数据上表现良好,那么您可以更有信心它是在适应而不是过拟合

然而,有几种情况可能会出错。

第一种情况是测试数据和训练数据中存在的常见模式不会推广到更大的数据集中。

例如,想象一下,您的数据集包含用户活动,每个用户每周一行。在这种情况下,大多数用户会出现在训练数据和测试数据中,并且某些模型可能会学习识别用户而不是发现涉及属性的关系。这并不是一个很大的担忧,尽管我曾经遇到过一次。

更大的问题是,如果您不仅用于评估模型而且用于选择多个模型。在这种情况下,尽管每个单独的模型可能不会过拟合,“选择在测试集上表现最佳的模型”是一个元训练,使得测试集充当第二个训练集。(当然,在测试集上表现最佳的模型在测试集上表现良好。)

在这种情况下,您应该将数据分为三部分:用于构建模型的训练集,用于在训练后的模型中进行选择的验证集,以及用于评估最终模型的测试集。

正确性

当我不从事数据科学时,我涉足医学。在业余时间里,我想出了一种廉价的、无创的测试方法,可以给新生儿做,预测——准确率超过 98%——新生儿是否会患白血病。我的律师说服我这个测试方法无法申请专利,所以我会在这里和大家分享详细信息:只有当宝宝被命名为卢克(听起来有点像“白血病”)时,预测白血病。

如我们所见,这个测试确实有超过 98%的准确率。然而,这是一个非常愚蠢的测试,很好地说明了为什么我们通常不使用“准确性”来衡量(二元分类)模型的好坏。

想象构建一个用于进行二进制判断的模型。这封邮件是垃圾邮件吗?我们应该雇佣这位候选人吗?这位空中旅客是不是秘密的恐怖分子?

针对一组标记数据和这样一个预测模型,每个数据点都属于四个类别之一:

真阳性

“此消息是垃圾邮件,我们正确预测了垃圾邮件。”

假阳性(第一类错误)

“此消息不是垃圾邮件,但我们预测了垃圾邮件。”

假阴性(第二类错误)

“此消息是垃圾邮件,但我们预测了非垃圾邮件。”

真阴性

“此消息不是垃圾邮件,我们正确预测了非垃圾邮件。”

我们通常将这些表示为混淆矩阵中的计数:

垃圾邮件非垃圾邮件
预测“垃圾邮件”真阳性假阳性
预测“非垃圾邮件”假阴性真阴性

让我们看看我的白血病测试如何符合这个框架。 近年来,大约每 1,000 名婴儿中有 5 名被命名为卢克。 白血病的终身患病率约为 1.4%,或每 1,000 人中有 14 人

如果我们相信这两个因素是独立的,并将我的“卢克是用于白血病检测”的测试应用于 1 百万人,我们预计会看到一个混淆矩阵,如下所示:

白血病无白血病总计
“卢克”704,9305,000
非“卢克”13,930981,070995,000
总计14,000986,0001,000,000

我们可以使用这些来计算有关模型性能的各种统计信息。 例如,准确度 定义为正确预测的分数的比例:

def accuracy(tp: int, fp: int, fn: int, tn: int) -> float:
    correct = tp + tn
    total = tp + fp + fn + tn
    return correct / total

assert accuracy(70, 4930, 13930, 981070) == 0.98114

这似乎是一个相当令人印象深刻的数字。 但显然这不是一个好的测试,这意味着我们可能不应该对原始准确性赋予很高的信任。

通常会查看精确度召回率的组合。 精确度衡量我们的阳性预测的准确性:

def precision(tp: int, fp: int, fn: int, tn: int) -> float:
    return tp / (tp + fp)

assert precision(70, 4930, 13930, 981070) == 0.014

召回率衡量了我们的模型识别出的阳性的分数:

def recall(tp: int, fp: int, fn: int, tn: int) -> float:
    return tp / (tp + fn)

assert recall(70, 4930, 13930, 981070) == 0.005

这两个数字都很糟糕,反映出这是一个糟糕的模型。

有时精确度和召回率会结合成F1 分数,其定义为:

def f1_score(tp: int, fp: int, fn: int, tn: int) -> float:
    p = precision(tp, fp, fn, tn)
    r = recall(tp, fp, fn, tn)

    return 2 * p * r / (p + r)

这是调和平均 精度和召回率,必然位于它们之间。

通常,模型的选择涉及精确度和召回率之间的权衡。 当模型在稍微有信心时预测“是”可能会具有很高的召回率但较低的精确度; 仅当模型极度自信时才预测“是”可能会具有较低的召回率和较高的精确度。

或者,您可以将其视为假阳性和假阴性之间的权衡。 说“是”的次数太多会产生大量的假阳性; 说“不”太多会产生大量的假阴性。

想象一下,白血病有 10 个风险因素,而且你拥有的风险因素越多,患白血病的可能性就越大。在这种情况下,你可以想象一系列测试:“如果至少有一个风险因素则预测患白血病”,“如果至少有两个风险因素则预测患白血病”,依此类推。随着阈值的提高,测试的准确性增加(因为拥有更多风险因素的人更有可能患病),而召回率降低(因为越来越少最终患病者将满足阈值)。在这种情况下,选择正确的阈值是找到正确权衡的问题。

偏差-方差权衡

另一种思考过拟合问题的方式是将其视为偏差和方差之间的权衡。

这两者都是在假设你会在来自同一较大总体的不同训练数据集上多次重新训练模型时会发生的情况的度量。

例如,“过拟合和欠拟合”中的零阶模型在几乎任何训练集上都会犯很多错误(从同一总体中抽取),这意味着它有很高的偏差。然而,任意选择的两个训练集应该产生相似的模型(因为任意选择的两个训练集应该具有相似的平均值)。所以我们说它的方差很低。高偏差和低方差通常对应欠拟合。

另一方面,九阶模型完美地适应了训练集。它的偏差非常低,但方差非常高(因为任意两个训练集可能会产生非常不同的模型)。这对应于过拟合。

以这种方式思考模型问题可以帮助你弄清楚当你的模型效果不佳时该怎么做。

如果你的模型存在高偏差(即使在训练数据上表现也很差),可以尝试的一种方法是添加更多特征。从“过拟合和欠拟合”中的零阶模型转换为一阶模型是一个很大的改进。

如果你的模型方差很高,你可以类似地删除特征。但另一个解决方案是获取更多数据(如果可能的话)。

在图 11-2 中,我们将一个九阶多项式拟合到不同大小的样本上。基于 10 个数据点进行的模型拟合到处都是,正如我们之前看到的。如果我们改为在 100 个数据点上训练,过拟合就会减少很多。而在 1,000 个数据点上训练的模型看起来与一阶模型非常相似。保持模型复杂性恒定,拥有的数据越多,过拟合就越困难。另一方面,更多的数据对偏差没有帮助。如果你的模型没有使用足够的特征来捕获数据的规律,那么扔更多数据进去是没有帮助的。

通过增加数据减少方差。

图 11-2. 通过增加数据减少方差

特征提取和选择

正如前面提到的,当你的数据没有足够的特征时,你的模型很可能会欠拟合。而当你的数据有太多特征时,很容易过拟合。但特征是什么,它们从哪里来呢?

特征就是我们向模型提供的任何输入。

在最简单的情况下,特征只是简单地给出。如果你想根据某人的工作经验预测她的薪水,那么工作经验就是你唯一拥有的特征。(尽管正如我们在“过拟合和欠拟合”中看到的那样,如果这有助于构建更好的模型,你可能还会考虑添加工作经验的平方、立方等。)

随着数据变得更加复杂,事情变得更加有趣。想象一下试图构建一个垃圾邮件过滤器来预测邮件是否是垃圾的情况。大多数模型不知道如何处理原始邮件,因为它只是一堆文本。你需要提取特征。例如:

  • 邮件中是否包含Viagra一词?

  • 字母d出现了多少次?

  • 发件人的域名是什么?

对于像这里第一个问题的答案,答案很简单,是一个是或否的问题,我们通常将其编码为 1 或 0。第二个问题是一个数字。第三个问题是从一组离散选项中选择的一个选项。

几乎总是,我们会从数据中提取属于这三类之一的特征。此外,我们拥有的特征类型限制了我们可以使用的模型类型。

  • 我们将在第十三章中构建的朴素贝叶斯分类器适用于像前面列表中的第一个这样的是或否特征。

  • 我们将在第 14 和第十六章中学习的回归模型需要数值特征(可能包括虚拟变量,即 0 和 1)。

  • 我们将在第十七章中探讨的决策树可以处理数值或分类数据。

虽然在垃圾邮件过滤器示例中我们寻找创建特征的方法,但有时我们会寻找删除特征的方法。

例如,你的输入可能是几百个数字的向量。根据情况,将这些特征简化为几个重要的维度可能是合适的(正如在“降维”中所示),然后仅使用这少量的特征。或者可能适合使用一种技术(如我们将在“正则化”中看到的那样),该技术惩罚使用更多特征的模型。

我们如何选择特征?这就是经验和领域专业知识结合起来发挥作用的地方。如果你收到了大量的邮件,那么你可能会意识到某些词语的出现可能是垃圾邮件的良好指标。而你可能还会觉得字母d的数量可能不是衡量邮件是否是垃圾的好指标。但总的来说,你必须尝试不同的方法,这也是乐趣的一部分。

进一步探索

  • 继续阅读!接下来的几章讲述不同类型的机器学习模型。

  • Coursera 的机器学习课程是最早的大规模在线开放课程(MOOC),是深入了解机器学习基础知识的好地方。

  • 统计学习的要素,作者是 Jerome H. Friedman、Robert Tibshirani 和 Trevor Hastie(Springer),是一本可以免费在线下载的经典教材。但请注意:它非常数学化。

第十二章:k-最近邻

如果你想要惹恼你的邻居,就告诉他们关于他们的真相。

皮耶特罗·阿雷蒂诺

想象一下,你试图预测我在下次总统选举中的投票方式。如果你对我一无所知(并且如果你有数据的话),一个明智的方法是看看我的邻居打算如何投票。像我一样住在西雅图,我的邻居们无一例外地计划投票给民主党候选人,这表明“民主党候选人”对我来说也是一个不错的猜测。

现在想象一下,你不仅了解我的地理位置,还知道我的年龄、收入、有几个孩子等等。在我行为受这些因素影响(或者说特征化)的程度上,只看那些在所有这些维度中与我接近的邻居,似乎比看所有邻居更有可能是一个更好的预测器。这就是最近邻分类背后的思想。

模型

最近邻是最简单的预测模型之一。它不做任何数学假设,也不需要任何重型设备。它唯一需要的是:

  • 某种距离的概念

  • 一个假设是相互接近的点是相似的。

在本书中,我们将看到的大多数技术都是整体看待数据集以便学习数据中的模式。而最近邻则故意忽略了很多信息,因为对于每个新点的预测仅依赖于最接近它的少数几个点。

此外,最近邻可能不会帮助你理解你正在研究的现象的驱动因素。根据我的邻居的投票预测我的投票并不能告诉你关于我为什么投票的原因,而一些基于(比如)我的收入和婚姻状况预测我的投票的替代模型可能会很好地做到这一点。

在一般情况下,我们有一些数据点和相应的标签集合。标签可以是TrueFalse,表示每个输入是否满足某些条件,比如“是垃圾邮件?”或“有毒?”或“看起来有趣吗?”或者它们可以是类别,比如电影评级(G,PG,PG-13,R,NC-17)。或者它们可以是总统候选人的名字。或者它们可以是喜欢的编程语言。

在我们的情况下,数据点将是向量,这意味着我们可以使用第四章中的distance函数。

假设我们选择了一个像 3 或 5 这样的数字k。那么,当我们想要对一些新的数据点进行分类时,我们找到k个最近的带标签点,并让它们对新的输出进行投票。

为此,我们需要一个计票的函数。一种可能性是:

from typing import List
from collections import Counter

def raw_majority_vote(labels: List[str]) -> str:
    votes = Counter(labels)
    winner, _ = votes.most_common(1)[0]
    return winner

assert raw_majority_vote(['a', 'b', 'c', 'b']) == 'b'

但这并不会处理带有智能的平局情况。例如,想象我们正在评分电影,而最近的五部电影分别被评为 G、G、PG、PG 和 R。那么 G 有两票,PG 也有两票。在这种情况下,我们有几个选项:

  • 随机挑选一个赢家。

  • 通过距离加权投票并选择加权赢家。

  • 减少k直到找到唯一的赢家。

我们将实现第三种方法:

def majority_vote(labels: List[str]) -> str:
    """Assumes that labels are ordered from nearest to farthest."""
    vote_counts = Counter(labels)
    winner, winner_count = vote_counts.most_common(1)[0]
    num_winners = len([count
                       for count in vote_counts.values()
                       if count == winner_count])

    if num_winners == 1:
        return winner                     # unique winner, so return it
    else:
        return majority_vote(labels[:-1]) # try again without the farthest

# Tie, so look at first 4, then 'b'
assert majority_vote(['a', 'b', 'c', 'b', 'a']) == 'b'

这种方法肯定最终会奏效,因为在最坏的情况下,我们最终只需一个标签,此时那个标签会获胜。

使用这个函数很容易创建一个分类器:

from typing import NamedTuple
from scratch.linear_algebra import Vector, distance

class LabeledPoint(NamedTuple):
    point: Vector
    label: str

def knn_classify(k: int,
                 labeled_points: List[LabeledPoint],
                 new_point: Vector) -> str:

    # Order the labeled points from nearest to farthest.
    by_distance = sorted(labeled_points,
                         key=lambda lp: distance(lp.point, new_point))

    # Find the labels for the k closest
    k_nearest_labels = [lp.label for lp in by_distance[:k]]

    # and let them vote.
    return majority_vote(k_nearest_labels)

让我们看看这是如何工作的。

示例:鸢尾花数据集

Iris 数据集是机器学习的重要数据集。它包含了 150 朵花的测量数据,代表三种鸢尾花物种。对于每朵花,我们有它的花瓣长度、花瓣宽度、萼片长度和萼片宽度,以及它的物种。你可以从https://archive.ics.uci.edu/ml/datasets/iris下载:

import requests

data = requests.get(
  "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
)

with open('iris.dat', 'w') as f:
    f.write(data.text)

数据是逗号分隔的,包含字段:

sepal_length, sepal_width, petal_length, petal_width, class

例如,第一行看起来像:

5.1,3.5,1.4,0.2,Iris-setosa

在这一节中,我们将尝试构建一个模型,可以从前四个测量值预测类别(即物种)。

首先,让我们加载并探索数据。我们的最近邻函数期望一个LabeledPoint,所以让我们用这种方式表示我们的数据:

from typing import Dict
import csv
from collections import defaultdict

def parse_iris_row(row: List[str]) -> LabeledPoint:
    """
 sepal_length, sepal_width, petal_length, petal_width, class
 """
    measurements = [float(value) for value in row[:-1]]
    # class is e.g. "Iris-virginica"; we just want "virginica"
    label = row[-1].split("-")[-1]

    return LabeledPoint(measurements, label)

with open('iris.data') as f:
    reader = csv.reader(f)
    iris_data = [parse_iris_row(row) for row in reader]

# We'll also group just the points by species/label so we can plot them
points_by_species: Dict[str, List[Vector]] = defaultdict(list)
for iris in iris_data:
    points_by_species[iris.label].append(iris.point)

我们希望绘制测量结果,以便查看它们按物种的变化。不幸的是,它们是四维的,这使得绘图变得棘手。我们可以做的一件事是查看每一对测量的散点图(图 12-1)。我不会解释所有的细节,但这是对 matplotlib 更复杂用法的很好示例,所以值得学习:

from matplotlib import pyplot as plt
metrics = ['sepal length', 'sepal width', 'petal length', 'petal width']
pairs = [(i, j) for i in range(4) for j in range(4) if i < j]
marks = ['+', '.', 'x']  # we have 3 classes, so 3 markers

fig, ax = plt.subplots(2, 3)

for row in range(2):
    for col in range(3):
        i, j = pairs[3 * row + col]
        ax[row][col].set_title(f"{metrics[i]} vs {metrics[j]}", fontsize=8)
        ax[row][col].set_xticks([])
        ax[row][col].set_yticks([])

        for mark, (species, points) in zip(marks, points_by_species.items()):
            xs = [point[i] for point in points]
            ys = [point[j] for point in points]
            ax[row][col].scatter(xs, ys, marker=mark, label=species)

ax[-1][-1].legend(loc='lower right', prop={'size': 6})
plt.show()

鸢尾花散点图

图 12-1. 鸢尾花散点图

如果你看这些图,看起来测量结果确实按物种聚类。例如,仅看萼片长度和萼片宽度,你可能无法区分鸢尾花维吉尼亚。但一旦加入花瓣长度和宽度,似乎你应该能够根据最近邻来预测物种。

首先,让我们将数据分成测试集和训练集:

import random
from scratch.machine_learning import split_data

random.seed(12)
iris_train, iris_test = split_data(iris_data, 0.70)
assert len(iris_train) == 0.7 * 150
assert len(iris_test) == 0.3 * 150

训练集将是我们用来分类测试集中点的“邻居”。我们只需选择一个k值,即获得投票权的邻居数。如果太小(考虑k = 1),我们让离群值影响过大;如果太大(考虑k = 105),我们只是预测数据集中最常见的类别。

在真实的应用中(和更多数据),我们可能会创建一个单独的验证集,并用它来选择k。在这里我们只使用k = 5:

from typing import Tuple

# track how many times we see (predicted, actual)
confusion_matrix: Dict[Tuple[str, str], int] = defaultdict(int)
num_correct = 0

for iris in iris_test:
    predicted = knn_classify(5, iris_train, iris.point)
    actual = iris.label

    if predicted == actual:
        num_correct += 1

    confusion_matrix[(predicted, actual)] += 1

pct_correct = num_correct / len(iris_test)
print(pct_correct, confusion_matrix)

在这个简单的数据集上,模型几乎完美地预测了。有一个鸢尾花,它预测为维吉尼亚,但除此之外其他都是完全正确的。

维度灾难

在高维空间中,“k”最近邻算法在处理高维数据时遇到麻烦,这要归因于“维度的诅咒”,其核心问题在于高维空间是广阔的。高维空间中的点往往彼此之间并不接近。通过在各种维度中随机生成“d”维“单位立方体”中的点对,并计算它们之间的距离,可以看出这一点。

生成随机点现在应该是驾轻就熟了:

def random_point(dim: int) -> Vector:
    return [random.random() for _ in range(dim)]

编写一个生成距离的函数也是一样的:

def random_distances(dim: int, num_pairs: int) -> List[float]:
    return [distance(random_point(dim), random_point(dim))
            for _ in range(num_pairs)]

对于从 1 到 100 的每个维度,我们将计算 10,000 个距离,并使用这些距离计算点之间的平均距离以及每个维度中点之间的最小距离(参见图 12-2):

import tqdm
dimensions = range(1, 101)

avg_distances = []
min_distances = []

random.seed(0)
for dim in tqdm.tqdm(dimensions, desc="Curse of Dimensionality"):
    distances = random_distances(dim, 10000)      # 10,000 random pairs
    avg_distances.append(sum(distances) / 10000)  # track the average
    min_distances.append(min(distances))          # track the minimum

维度的诅咒。

图 12-2。维度的诅咒

随着维度的增加,点之间的平均距离也增加。但更为问题的是最近距离与平均距离之间的比率(参见图 12-3):

min_avg_ratio = [min_dist / avg_dist
                 for min_dist, avg_dist in zip(min_distances, avg_distances)]

再谈维度的诅咒。

图 12-3。再谈维度的诅咒

在低维数据集中,最近的点往往比平均距离要接近得多。但是只有当两个点在每个维度上都接近时,这两个点才是接近的,而每增加一个维度——即使只是噪音——都是使每个点与其他每个点的距离更远的机会。当你有很多维度时,最接近的点可能并不比平均距离要接近,所以两个点接近并不意味着太多(除非你的数据具有使其表现得像低维度的大量结构)。

对问题的另一种思考方式涉及到更高维度空间的稀疏性。

如果你在 0 和 1 之间随机选择 50 个数字,你可能会得到单位区间的一个相当好的样本(参见图 12-4)。

一维空间中的 50 个随机点。

图 12-4。一维空间中的 50 个随机点

如果你在单位正方形中随机选择 50 个点,你将得到更少的覆盖(参见图 12-5)。

二维空间中的 50 个随机点。

图 12-5。二维空间中的 50 个随机点

在三维空间中,覆盖更少(参见图 12-6)。

matplotlib 对于四维图表的呈现并不好,所以这是我们所能达到的最远的地方,但你已经可以看到开始出现大量空白区域,没有点靠近它们。在更多维度中——除非你得到指数级更多的数据——这些大量空白区域代表了远离所有你想要用于预测的点的区域。

因此,如果你尝试在更高维度中使用最近邻方法,最好先进行某种降维处理。

三维空间中的 50 个随机点。

图 12-6。三维空间中的 50 个随机点

进一步探索

scikit-learn 有许多最近邻模型。