Machine-Learning-Mastery-Python-教程-七-

74 阅读24分钟

Machine Learning Mastery Python 教程(七)

原文:Machine Learning Mastery

协议:CC BY-NC-SA 4.0

理解 Python 中的 Traceback

原文:machinelearningmastery.com/understanding-traceback-in-python/

当 Python 程序中发生异常时,通常会打印 traceback。知道如何阅读 traceback 可以帮助你轻松识别错误并进行修复。在本教程中,我们将看到 traceback 可以告诉你什么。

完成本教程后,你将了解:

  • 如何读取 traceback

  • 如何在没有异常的情况下打印调用栈

  • Traceback 中未显示的内容

使用我的新书 Python for Machine Learning 启动你的项目,包括逐步教程和所有示例的Python 源代码文件。

让我们开始吧!

理解 Python 中的 Traceback

图片由 Marten Bjork 提供,部分权利保留

教程概述

本教程分为四部分;它们是:

  1. 简单程序的调用层次结构

  2. 异常时的 Traceback

  3. 手动触发 traceback

  4. 模型训练中的一个示例

简单程序的调用层次结构

让我们考虑一个简单的程序:

def indentprint(x, indent=0, prefix="", suffix=""):
    if isinstance(x, dict):
        printdict(x, indent, prefix, suffix)
    elif isinstance(x, list):
        printlist(x, indent, prefix, suffix)
    elif isinstance(x, str):
        printstring(x, indent, prefix, suffix)
    else:
        printnumber(x, indent, prefix, suffix)

def printdict(x, indent, prefix, suffix):
    spaces = " " * indent
    print(spaces + prefix + "{")
    for n, key in enumerate(x):
        comma = "," if n!=len(x)-1 else ""
        indentprint(x[key], indent+2, str(key)+": ", comma)
    print(spaces + "}" + suffix)

def printlist(x, indent, prefix, suffix):
    spaces = " " * indent
    print(spaces + prefix + "[")
    for n, item in enumerate(x):
        comma = "," if n!=len(x)-1 else ""
        indentprint(item, indent+2, "", comma)
    print(spaces + "]" + suffix)

def printstring(x, indent, prefix, suffix):
    spaces = " " * indent
    print(spaces + prefix + '"' + str(x) + '"' + suffix)

def printnumber(x, indent, prefix, suffix):
    spaces = " " * indent
    print(spaces + prefix + str(x) + suffix)

data = {
    "a": [{
        "p": 3, "q": 4,
        "r": [3,4,5],
    },{
        "f": "foo", "g": 2.71
    },{
        "u": None, "v": "bar"
    }],
    "c": {
        "s": ["fizz", 2, 1.1],
        "t": []
    },
}

indentprint(data)

这个程序将带有缩进的 Python 字典 data 打印出来。它的输出如下:

{
  a: [
    {
      p: 3,
      q: 4,
      r: [
        3,
        4,
        5
      ]
    },
    {
      f: "foo",
      g: 2.71
    },
    {
      u: None,
      v: "bar"
    }
  ],
  c: {
    s: [
      "fizz",
      2,
      1.1
    ],
    t: [
    ]
  }
}

这是一个短程序,但函数之间相互调用。如果我们在每个函数的开头添加一行,我们可以揭示输出是如何随着控制流产生的:

def indentprint(x, indent=0, prefix="", suffix=""):
    print(f'indentprint(x, {indent}, "{prefix}", "{suffix}")')
    if isinstance(x, dict):
        printdict(x, indent, prefix, suffix)
    elif isinstance(x, list):
        printlist(x, indent, prefix, suffix)
    elif isinstance(x, str):
        printstring(x, indent, prefix, suffix)
    else:
        printnumber(x, indent, prefix, suffix)

def printdict(x, indent, prefix, suffix):
    print(f'printdict(x, {indent}, "{prefix}", "{suffix}")')
    spaces = " " * indent
    print(spaces + prefix + "{")
    for n, key in enumerate(x):
        comma = "," if n!=len(x)-1 else ""
        indentprint(x[key], indent+2, str(key)+": ", comma)
    print(spaces + "}" + suffix)

def printlist(x, indent, prefix, suffix):
    print(f'printlist(x, {indent}, "{prefix}", "{suffix}")')
    spaces = " " * indent
    print(spaces + prefix + "[")
    for n, item in enumerate(x):
        comma = "," if n!=len(x)-1 else ""
        indentprint(item, indent+2, "", comma)
    print(spaces + "]" + suffix)

def printstring(x, indent, prefix, suffix):
    print(f'printstring(x, {indent}, "{prefix}", "{suffix}")')
    spaces = " " * indent
    print(spaces + prefix + '"' + str(x) + '"' + suffix)

def printnumber(x, indent, prefix, suffix):
    print(f'printnumber(x, {indent}, "{prefix}", "{suffix}")')
    spaces = " " * indent
    print(spaces + prefix + str(x) + suffix)

输出将被更多信息搞乱:

indentprint(x, 0, "", "")
printdict(x, 0, "", "")
{
indentprint(x, 2, "a: ", ",")
printlist(x, 2, "a: ", ",")
  a: [
indentprint(x, 4, "", ",")
printdict(x, 4, "", ",")
    {
indentprint(x, 6, "p: ", ",")
printnumber(x, 6, "p: ", ",")
      p: 3,
indentprint(x, 6, "q: ", ",")
printnumber(x, 6, "q: ", ",")
      q: 4,
indentprint(x, 6, "r: ", "")
printlist(x, 6, "r: ", "")
      r: [
indentprint(x, 8, "", ",")
printnumber(x, 8, "", ",")
        3,
indentprint(x, 8, "", ",")
printnumber(x, 8, "", ",")
        4,
indentprint(x, 8, "", "")
printnumber(x, 8, "", "")
        5
      ]
    },
indentprint(x, 4, "", ",")
printdict(x, 4, "", ",")
    {
indentprint(x, 6, "f: ", ",")
printstring(x, 6, "f: ", ",")
      f: "foo",
indentprint(x, 6, "g: ", "")
printnumber(x, 6, "g: ", "")
      g: 2.71
    },
indentprint(x, 4, "", "")
printdict(x, 4, "", "")
    {
indentprint(x, 6, "u: ", ",")
printnumber(x, 6, "u: ", ",")
      u: None,
indentprint(x, 6, "v: ", "")
printstring(x, 6, "v: ", "")
      v: "bar"
    }
  ],
indentprint(x, 2, "c: ", "")
printdict(x, 2, "c: ", "")
  c: {
indentprint(x, 4, "s: ", ",")
printlist(x, 4, "s: ", ",")
    s: [
indentprint(x, 6, "", ",")
printstring(x, 6, "", ",")
      "fizz",
indentprint(x, 6, "", ",")
printnumber(x, 6, "", ",")
      2,
indentprint(x, 6, "", "")
printnumber(x, 6, "", "")
      1.1
    ],
indentprint(x, 4, "t: ", "")
printlist(x, 4, "t: ", "")
    t: [
    ]
  }
}

现在我们知道了每个函数调用的顺序。这就是调用栈的概念。在我们运行函数中的一行代码时,我们想知道是什么调用了这个函数。

异常时的 Traceback

如果我们在代码中犯了一个错别字,例如:

def printdict(x, indent, prefix, suffix):
    spaces = " " * indent
    print(spaces + prefix + "{")
    for n, key in enumerate(x):
        comma = "," if n!=len(x)-1 else ""
        indentprint(x[key], indent+2, str(key)+": ", comma)
    print(spaces + "}") + suffix

错误在最后一行,其中闭合括号应该在行末,而不是在任何 + 之前。print() 函数的返回值是 Python 的 None 对象。将内容添加到 None 会触发异常。

如果你使用 Python 解释器运行这个程序,你将看到:

{
  a: [
    {
      p: 3,
      q: 4,
      r: [
        3,
        4,
        5
      ]
    }
Traceback (most recent call last):
  File "tb.py", line 52, in 
    indentprint(data)
  File "tb.py", line 3, in indentprint
    printdict(x, indent, prefix, suffix)
  File "tb.py", line 16, in printdict
    indentprint(x[key], indent+2, str(key)+": ", comma)
  File "tb.py", line 5, in indentprint
    printlist(x, indent, prefix, suffix)
  File "tb.py", line 24, in printlist
    indentprint(item, indent+2, "", comma)
  File "tb.py", line 3, in indentprint
    printdict(x, indent, prefix, suffix)
  File "tb.py", line 17, in printdict
    print(spaces + "}") + suffix
TypeError: unsupported operand type(s) for +: 'NoneType' and 'str'

以“Traceback (most recent call last):”开头的行是 traceback。它是你的程序在遇到异常时的。在上述示例中,traceback 以“最近的调用最后”顺序显示。因此你的主函数在顶部,而触发异常的函数在底部。所以我们知道问题出在函数 printdict() 内部。

通常,你会在 traceback 的末尾看到错误消息。在这个例子中,是由于将 None 和字符串相加触发的 TypeError。但 traceback 的帮助到此为止。你需要弄清楚哪个是 None,哪个是字符串。通过阅读 traceback,我们也知道触发异常的函数 printdict() 是由 indentprint() 调用的,indentprint() 又由 printlist() 调用,依此类推。

如果你在 Jupyter notebook 中运行这段代码,输出如下:

{
  a: [
    {
      p: 3,
      q: 4,
      r: [
        3,
        4,
        5
      ]
    }
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/var/folders/6z/w0ltb1ss08l593y5xt9jyl1w0000gn/T/ipykernel_37031/2508041071.py in 
----> 1 indentprint(x)

/var/folders/6z/w0ltb1ss08l593y5xt9jyl1w0000gn/T/ipykernel_37031/2327707064.py in indentprint(x, indent, prefix, suffix)
      1 def indentprint(x, indent=0, prefix="", suffix=""):
      2     if isinstance(x, dict):
----> 3         printdict(x, indent, prefix, suffix)
      4     elif isinstance(x, list):
      5         printlist(x, indent, prefix, suffix)

/var/folders/6z/w0ltb1ss08l593y5xt9jyl1w0000gn/T/ipykernel_37031/2327707064.py in printdict(x, indent, prefix, suffix)
     14     for n, key in enumerate(x):
     15         comma = "," if n!=len(x)-1 else ""
---> 16         indentprint(x[key], indent+2, str(key)+": ", comma)
     17     print(spaces + "}") + suffix
     18 

/var/folders/6z/w0ltb1ss08l593y5xt9jyl1w0000gn/T/ipykernel_37031/2327707064.py in indentprint(x, indent, prefix, suffix)
      3         printdict(x, indent, prefix, suffix)
      4     elif isinstance(x, list):
----> 5         printlist(x, indent, prefix, suffix)
      6     elif isinstance(x, str):
      7         printstring(x, indent, prefix, suffix)

/var/folders/6z/w0ltb1ss08l593y5xt9jyl1w0000gn/T/ipykernel_37031/2327707064.py in printlist(x, indent, prefix, suffix)
     22     for n, item in enumerate(x):
     23         comma = "," if n!=len(x)-1 else ""
---> 24         indentprint(item, indent+2, "", comma)
     25     print(spaces + "]" + suffix)
     26 

/var/folders/6z/w0ltb1ss08l593y5xt9jyl1w0000gn/T/ipykernel_37031/2327707064.py in indentprint(x, indent, prefix, suffix)
      1 def indentprint(x, indent=0, prefix="", suffix=""):
      2     if isinstance(x, dict):
----> 3         printdict(x, indent, prefix, suffix)
      4     elif isinstance(x, list):
      5         printlist(x, indent, prefix, suffix)

/var/folders/6z/w0ltb1ss08l593y5xt9jyl1w0000gn/T/ipykernel_37031/2327707064.py in printdict(x, indent, prefix, suffix)
     15         comma = "," if n!=len(x)-1 else ""
     16         indentprint(x[key], indent+2, str(key)+": ", comma)
---> 17     print(spaces + "}") + suffix
     18 
     19 def printlist(x, indent, prefix, suffix):

TypeError: unsupported operand type(s) for +: 'NoneType' and 'str'

信息本质上是相同的,但它提供了每个函数调用前后的行。

想要开始学习 Python 进行机器学习吗?

现在立即报名我的免费 7 天电子邮件速成课程(包含示例代码)。

点击注册并获得免费 PDF 电子书版课程。

手动触发 traceback

打印 traceback 最简单的方法是添加 raise 语句来手动创建异常。但这也会终止你的程序。如果我们希望在任何时间打印栈,即使没有任何异常,我们可以使用以下方法:

import traceback

def printdict(x, indent, prefix, suffix):
    spaces = " " * indent
    print(spaces + prefix + "{")
    for n, key in enumerate(x):
        comma = "," if n!=len(x)-1 else ""
        indentprint(x[key], indent+2, str(key)+": ", comma)
    traceback.print_stack()    # print the current call stack
    print(spaces + "}" + suffix)

traceback.print_stack() 将打印当前调用栈。

但确实,我们通常只在出现错误时才打印栈(以便了解为什么会这样)。更常见的用例如下:

import traceback
import random

def compute():
    n = random.randint(0, 10)
    m = random.randint(0, 10)
    return n/m

def compute_many(n_times):
    try:
        for _ in range(n_times):
            x = compute()
        print(f"Completed {n_times} times")
    except:
        print("Something wrong")
        traceback.print_exc()

compute_many(100)

这是重复计算函数的典型模式,例如蒙特卡洛模拟。但如果我们不够小心,可能会遇到一些错误,如上例中的除零错误。问题是,在更复杂的计算情况下,你不能轻易发现缺陷。例如上面的情况,问题隐藏在 compute() 的调用中。因此,理解错误的产生方式是有帮助的。但同时,我们希望处理错误的情况,而不是让整个程序终止。如果我们使用 try-catch 构造,traceback 默认不会打印。因此,我们需要使用 traceback.print_exc() 语句手动打印。

实际上,我们可以使 traceback 更加详细。由于 traceback 是调用栈,我们可以检查调用栈中的每个函数,并检查每一层中的变量。在这种复杂的情况下,这是我通常用来做更详细跟踪的函数:

def print_tb_with_local():
    """Print stack trace with local variables. This does not need to be in
    exception. Print is using the system's print() function to stderr.
    """
    import traceback, sys
    tb = sys.exc_info()[2]
    stack = []
    while tb:
        stack.append(tb.tb_frame)
        tb = tb.tb_next()
    traceback.print_exc()
    print("Locals by frame, most recent call first", file=sys.stderr)
    for frame in stack:
        print("Frame {0} in {1} at line {2}".format(
            frame.f_code.co_name,
            frame.f_code.co_filename,
            frame.f_lineno), file=sys.stderr)
        for key, value in frame.f_locals.items():
            print("\t%20s = " % key, file=sys.stderr)
            try:
                if '__repr__' in dir(value):
                    print(value.__repr__(), file=sys.stderr)
                elif '__str__' in dir(value):
                    print(value.__str__(), file=sys.stderr)
                else:
                    print(value, file=sys.stderr)
            except:
                print("", file=sys.stderr)

模型训练的一个示例

traceback 中报告的调用栈有一个限制:你只能看到 Python 函数。这对于你编写的程序应该没问题,但许多大型 Python 库的一部分是用其他语言编写并编译成二进制的。例如 Tensorflow。所有底层操作都是以二进制形式存在以提升性能。因此,如果你运行以下代码,你会看到不同的内容:

import numpy as np

sequence = np.arange(0.1, 1.0, 0.1)  # 0.1 to 0.9
n_in = len(sequence)
sequence = sequence.reshape((1, n_in, 1))

# define model
import tensorflow as tf
from tensorflow.keras.layers import LSTM, RepeatVector, Dense, TimeDistributed, Input
from tensorflow.keras import Sequential, Model

model = Sequential([
    LSTM(100, activation="relu", input_shape=(n_in+1, 1)),
    RepeatVector(n_in),
    LSTM(100, activation="relu", return_sequences=True),
    TimeDistributed(Dense(1))
])
model.compile(optimizer="adam", loss="mse")

model.fit(sequence, sequence, epochs=300, verbose=0)

模型中第一个 LSTM 层的 input_shape 参数应该是 (n_in, 1) 以匹配输入数据,而不是 (n_in+1, 1)。这段代码在你调用最后一行时将打印以下错误:

Traceback (most recent call last):
  File "trback3.py", line 20, in 
    model.fit(sequence, sequence, epochs=300, verbose=0)
  File "/usr/local/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/usr/local/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 1129, in autograph_handler
    raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:

    File "/usr/local/lib/python3.9/site-packages/keras/engine/training.py", line 878, in train_function  *
        return step_function(self, iterator)
    File "/usr/local/lib/python3.9/site-packages/keras/engine/training.py", line 867, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/usr/local/lib/python3.9/site-packages/keras/engine/training.py", line 860, in run_step  **
        outputs = model.train_step(data)
    File "/usr/local/lib/python3.9/site-packages/keras/engine/training.py", line 808, in train_step
        y_pred = self(x, training=True)
    File "/usr/local/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
        raise e.with_traceback(filtered_tb) from None
    File "/usr/local/lib/python3.9/site-packages/keras/engine/input_spec.py", line 263, in assert_input_compatibility
        raise ValueError(f'Input {input_index} of layer "{layer_name}" is '

    ValueError: Input 0 of layer "sequential" is incompatible with the layer: expected shape=(None, 10, 1), found shape=(None, 9, 1)

如果你查看追溯信息,你无法真正看到完整的调用栈。例如,你知道你调用了 model.fit(),但第二个帧来自一个名为 error_handler() 的函数。在这里,你无法看到 fit() 函数如何触发了这个函数。这是因为 Tensorflow 被高度优化了。许多内容隐藏在编译代码中,Python 解释器无法看到。

在这种情况下,耐心阅读追溯信息并找出原因的线索是至关重要的。当然,错误信息通常也会给你一些有用的提示。

进一步阅读

如果你想更深入了解该主题,本节提供了更多资源。

书籍

Python 官方文档

总结

在本教程中,你学习了如何读取和打印 Python 程序的追溯信息。

具体来说,你学习了:

  • 追溯信息告诉你什么

  • 如何在程序的任何点打印追溯信息而不引发异常

在下一篇文章中,我们将学习如何在 Python 调试器中导航调用栈。

在机器学习项目中使用 Kaggle

原文:machinelearningmastery.com/using-kaggle-in-machine-learning-projects/

你可能听说过 Kaggle 数据科学竞赛,但你知道 Kaggle 还有许多其他功能可以帮助你完成下一个机器学习项目吗?对于寻找数据集以进行下一个机器学习项目的人,Kaggle 允许你访问他人公开的数据集并分享你自己的数据集。对于希望构建和训练自己机器学习模型的人,Kaggle 还提供了一个浏览器内的笔记本环境和一些免费的 GPU 小时。你还可以查看其他人的公开笔记本!

除了网站之外,Kaggle 还提供了一个命令行界面(CLI),你可以在命令行中使用它来访问和下载数据集。

让我们立即深入探索 Kaggle 所提供的内容!

完成本教程后,你将学到:

  • 什么是 Kaggle?

  • 如何将 Kaggle 作为你机器学习工作流的一部分

  • 使用 Kaggle API 的命令行界面(CLI)

通过我的新书 《Python 机器学习》快速启动你的项目,包括逐步教程所有示例的 Python 源代码文件。

让我们开始吧!!

在机器学习项目中使用 Kaggle

图片由Stefan Widua提供。保留部分权利。

概述

本教程分为五个部分;它们是:

  • 什么是 Kaggle?

  • 设置 Kaggle 笔记本

  • 使用带有 GPU/TPU 的 Kaggle 笔记本

  • 在 Kaggle 笔记本中使用 Kaggle 数据集

  • 使用 Kaggle CLI 工具中的 Kaggle 数据集

什么是 Kaggle?

Kaggle 可能以其举办的数据科学竞赛而最为人知,其中一些竞赛提供了五位数的奖池,并有数百支队伍参赛。除了这些竞赛,Kaggle 还允许用户发布和搜索数据集,这些数据集可以用于他们的机器学习项目。要使用这些数据集,你可以在浏览器中使用 Kaggle 笔记本或 Kaggle 的公共 API 来下载数据集,然后在你的机器学习项目中使用这些数据集。

Kaggle 竞赛

此外,Kaggle 还提供一些课程和讨论页面,供你学习更多关于机器学习的知识,并与其他机器学习从业者交流!

在本文的其余部分,我们将重点介绍如何利用 Kaggle 的数据集和笔记本来帮助我们在自己的机器学习项目中工作或寻找新的项目。

设置 Kaggle 笔记本

要开始使用 Kaggle 笔记本,你需要创建一个 Kaggle 账户,可以使用现有的 Google 账户或使用你的电子邮件创建一个。

然后,前往“代码”页面。

Kaggle 首页的左侧边栏,代码标签

然后你将能够看到你自己的笔记本以及其他人发布的公共笔记本。要创建自己的笔记本,点击“新建笔记本”。

Kaggle 代码页面

这将创建你的新笔记本,它看起来像一个 Jupyter 笔记本,具有许多类似的命令和快捷键。

Kaggle 笔记本

你还可以通过前往“文件 -> 编辑器类型”在笔记本编辑器和脚本编辑器之间切换。

更改 Kaggle 笔记本中的编辑器类型

将编辑器类型更改为脚本会显示如下内容:

Kaggle 笔记本脚本编辑器类型

想要开始学习 Python 用于机器学习吗?

现在就获取我的免费 7 天电子邮件速成课程(包含示例代码)。

点击注册并获得课程的免费 PDF 电子书版本。

使用 GPUs/TPUs 的 Kaggle

谁不喜欢用于机器学习项目的免费 GPU 时间呢? GPUs 可以大幅加速机器学习模型的训练和推断,尤其是深度学习模型。

Kaggle 提供了一些免费的 GPUs 和 TPUs 配额,你可以用来进行你的项目。在撰写本文时,验证手机号码后每周 GPU 的可用时间为 30 小时,TPU 的可用时间为 20 小时。

要将加速器附加到你的笔记本,请前往“设置 ▷ 环境 ▷ 偏好设置”。

更改 Kaggle 笔记本环境偏好设置

你将被要求通过手机号码验证你的账户。

验证手机号码

然后会出现一个页面,列出你剩余的使用量,并提到开启 GPUs 会减少可用的 CPUs 数量,因此在进行神经网络训练/推断时才可能是一个好主意。

向 Kaggle 笔记本添加 GPU 加速器

使用 Kaggle 数据集与 Kaggle 笔记本

机器学习项目是数据饥饿的怪物,找到当前项目的数据集或寻找新的项目数据集总是一项繁琐的工作。幸运的是,Kaggle 拥有由用户和比赛提供的丰富数据集。这些数据集对寻找当前机器学习项目的数据或寻找新项目创意的人来说都是宝贵的财富。

让我们探索如何将这些数据集添加到我们的 Kaggle 笔记本中。

首先,点击右侧边栏中的“添加数据”。

将数据集添加到 Kaggle 笔记本环境中

应该会出现一个窗口,显示一些公开可用的数据集,并提供将自己的数据集上传以供 Kaggle 笔记本使用的选项。

在 Kaggle 数据集中进行搜索

我将使用经典的泰坦尼克数据集作为本教程的示例,您可以通过在窗口右上角的搜索框中输入搜索词来找到它。

使用“Titanic”关键词过滤的 Kaggle 数据集

之后,数据集可以在笔记本中使用。要访问文件,请查看文件的路径并在其前面加上../input/{path}。例如,泰坦尼克数据集的文件路径是:

../input/titanic/train_and_test2.csv

在笔记本中,我们可以使用以下命令读取数据:

import pandas

pandas.read_csv("../input/titanic/train_and_test2.csv")

这将从文件中获取数据:

在 Kaggle 笔记本中使用泰坦尼克数据集

使用 Kaggle CLI 工具操作 Kaggle 数据集

Kaggle 还拥有一个公共 API 及 CLI 工具,我们可以用来下载数据集、参与比赛等。我们将探讨如何使用 CLI 工具设置和下载 Kaggle 数据集。

要开始,请使用以下命令安装 CLI 工具:

pip install kaggle

对于 Mac/Linux 用户,您可能需要:

pip install --user kaggle

然后,您需要创建一个 API 令牌进行身份验证。请访问 Kaggle 网页,点击右上角的个人资料图标,然后进入帐户。

进入 Kaggle 帐户设置

从那里,向下滚动到创建新的 API 令牌:

为 Kaggle 公共 API 生成新的 API 令牌

这将下载一个 kaggle.json 文件,您将用它来通过 Kaggle CLI 工具进行身份验证。您必须将其放置在正确的位置以使其正常工作。对于 Linux/Mac/Unix 系统,应放置在 ~/.kaggle/kaggle.json,对于 Windows 用户,应放置在 C:\Users\<Windows 用户名>\.kaggle\kaggle.json。如果放错位置并在命令行中调用 kaggle,将会出现错误:

OSError: Could not find kaggle.json. Make sure it’s location in … Or use the environment method

现在,让我们开始下载这些数据集吧!

要使用搜索词(如 titanic)搜索数据集,我们可以使用:

kaggle datasets list -s titanic

搜索 titanic,我们得到:

$ kaggle datasets list -s titanic
ref                                                          title                                           size  lastUpdated          downloadCount  voteCount  usabilityRating
-----------------------------------------------------------  ---------------------------------------------  -----  -------------------  -------------  ---------  ---------------
datasets/heptapod/titanic                                    Titanic                                         11KB  2017-05-16 08:14:22          37681        739  0.7058824
datasets/azeembootwala/titanic                               Titanic                                         12KB  2017-06-05 12:14:37          13104        145  0.8235294
datasets/brendan45774/test-file                              Titanic dataset                                 11KB  2021-12-02 16:11:42          19348        251  1.0
datasets/rahulsah06/titanic                                  Titanic                                         34KB  2019-09-16 14:43:23           3619         43  0.6764706
datasets/prkukunoor/TitanicDataset                           Titanic                                        135KB  2017-01-03 22:01:13           4719         24  0.5882353
datasets/hesh97/titanicdataset-traincsv                      Titanic-Dataset (train.csv)                     22KB  2018-02-02 04:51:06          54111        377  0.4117647
datasets/fossouodonald/titaniccsv                            Titanic csv                                      1KB  2016-11-07 09:44:58           8615         50  0.5882353
datasets/broaniki/titanic                                    titanic                                        717KB  2018-01-30 04:08:45           8004        128  0.1764706
datasets/pavlofesenko/titanic-extended                       Titanic extended dataset (Kaggle + Wikipedia)  134KB  2019-03-06 09:53:24           8779        130  0.9411765
datasets/jamesleslie/titanic-cleaned-data                    Titanic: cleaned data                           36KB  2018-11-21 11:50:18           4846         53  0.7647059
datasets/kittisaks/testtitanic                               test titanic                                    22KB  2017-03-13 15:13:12           1658         32  0.64705884
datasets/yasserh/titanic-dataset                             Titanic Dataset                                 22KB  2021-12-24 14:53:06           1011         25  1.0
datasets/abhinavralhan/titanic                               titanic                                         22KB  2017-07-30 11:07:55            628         11  0.8235294
datasets/cities/titanic123                                   Titanic Dataset Analysis                        22KB  2017-02-07 23:15:54           1585         29  0.5294118
datasets/brendan45774/gender-submisson                       Titanic: all ones csv file                      942B  2021-02-12 19:18:32            459         34  0.9411765
datasets/harunshimanto/titanic-solution-for-beginners-guide  Titanic Solution for Beginner's Guide           34KB  2018-03-12 17:47:06           1444         21  0.7058824
datasets/ibrahimelsayed182/titanic-dataset                   Titanic dataset                                  6KB  2022-01-27 07:41:54            334          8  1.0
datasets/sureshbhusare/titanic-dataset-from-kaggle           Titanic DataSet from Kaggle                     33KB  2017-10-12 04:49:39           2688         27  0.4117647
datasets/shuofxz/titanic-machine-learning-from-disaster      Titanic: Machine Learning from Disaster         33KB  2017-10-15 10:05:34           3867         55  0.29411766
datasets/vinicius150987/titanic3                             The Complete Titanic Dataset                   277KB  2020-01-04 18:24:11           1459         23  0.64705884

要下载列表中的第一个数据集,我们可以使用:

kaggle datasets download -d heptapod/titanic --unzip

使用 Jupyter 笔记本来读取文件,类似于 Kaggle 笔记本示例,给我们提供了:

在 Jupyter 笔记本中使用 Titanic 数据集

当然,某些数据集的大小非常大,您可能不希望将它们保留在自己的磁盘上。尽管如此,这是 Kaggle 提供的免费资源之一,供您的机器学习项目使用!

进一步阅读

此部分提供了更多资源,如果您对深入研究此主题感兴趣。

摘要

在本教程中,您学到了 Kaggle 是什么,我们如何使用 Kaggle 获取数据集,甚至在 Kaggle 笔记本中使用一些免费的 GPU/TPU 实例。您还看到了我们如何使用 Kaggle API 的 CLI 工具下载数据集,以便在本地环境中使用。

具体来说,您学到了:

  • 什么是 Kaggle

  • 如何在 Kaggle 笔记本中使用 GPU/TPU 加速器

  • 如何在 Kaggle 笔记本中使用 Kaggle 数据集或使用 Kaggle 的 CLI 工具下载它们

Python 中的网页爬取

原文:machinelearningmastery.com/web-crawling-in-python/

以前,收集数据是一项繁琐的工作,有时非常昂贵。机器学习项目离不开数据。幸运的是,如今我们可以利用大量的网络数据来创建数据集。我们可以从网络上复制数据来构建数据集。我们可以手动下载文件并保存到磁盘。但通过自动化数据采集,我们可以更高效地完成这项工作。Python 中有几种工具可以帮助实现自动化。

完成本教程后,您将学习到:

  • 如何使用 requests 库通过 HTTP 读取在线数据

  • 如何使用 pandas 读取网页上的表格

  • 如何使用 Selenium 模拟浏览器操作

用我的新书 《Python 机器学习》 来启动您的项目,其中包括逐步教程和所有示例的Python 源代码文件。

让我们开始吧!!

Python 中的网页爬取

图片由 Ray Bilcliff 提供。保留所有权利。

概述

本教程分为三个部分;它们是:

  • 使用 requests 库

  • 使用 pandas 读取网页上的表格

  • 使用 Selenium 读取动态内容

使用 Requests 库

当我们谈到编写 Python 程序来从网络读取数据时,不可避免地,我们需要使用requests库。您需要安装它(以及我们稍后将介绍的 BeautifulSoup 和 lxml):

pip install requests beautifulsoup4 lxml

它为您提供了一个界面,使您可以轻松地与网页进行交互。

一个非常简单的用例是从 URL 读取网页:

import requests

# Lat-Lon of New York
URL = "https://weather.com/weather/today/l/40.75,-73.98"
resp = requests.get(URL)
print(resp.status_code)
print(resp.text)
200
<!doctype html><html dir="ltr" lang="en-US"><head>
      <meta data-react-helmet="true" charset="utf-8"/><meta data-react-helmet="true"
name="viewport" content="width=device-width, initial-scale=1, viewport-fit=cover"/>
...

如果您对 HTTP 比较熟悉,您可能会记得状态码 200 表示请求成功完成。然后我们可以读取响应。在上面的例子中,我们读取了文本响应并获取了网页的 HTML。如果是 CSV 或其他文本数据,我们可以在响应对象的text属性中获取它们。例如,这就是如何从联邦储备经济数据中读取 CSV 文件:

import io
import pandas as pd
import requests

URL = "https://fred.stlouisfed.org/graph/fredgraph.csv?id=T10YIE&cosd=2017-04-14&coed=2022-04-14"
resp = requests.get(URL)
if resp.status_code == 200:
   csvtext = resp.text
   csvbuffer = io.StringIO(csvtext)
   df = pd.read_csv(csvbuffer)
   print(df)
            DATE T10YIE
0     2017-04-17   1.88
1     2017-04-18   1.85
2     2017-04-19   1.85
3     2017-04-20   1.85
4     2017-04-21   1.84
...          ...    ...
1299  2022-04-08   2.87
1300  2022-04-11   2.91
1301  2022-04-12   2.86
1302  2022-04-13    2.8
1303  2022-04-14   2.89

[1304 rows x 2 columns]

如果数据是 JSON 格式,我们可以将其作为文本读取,或者让requests为您解码。例如,以下是从 GitHub 拉取 JSON 格式的数据并将其转换为 Python 字典的操作:

import requests

URL = "https://api.github.com/users/jbrownlee"
resp = requests.get(URL)
if resp.status_code == 200:
    data = resp.json()
    print(data)
{'login': 'jbrownlee', 'id': 12891, 'node_id': 'MDQ6VXNlcjEyODkx',
'avatar_url': 'https://avatars.githubusercontent.com/u/12891?v=4',
'gravatar_id': '', 'url': 'https://api.github.com/users/jbrownlee',
'html_url': 'https://github.com/jbrownlee',
...
'company': 'Machine Learning Mastery', 'blog': 'https://machinelearningmastery.com',
'location': None, 'email': None, 'hireable': None,
'bio': 'Making developers awesome at machine learning.', 'twitter_username': None,
'public_repos': 5, 'public_gists': 0, 'followers': 1752, 'following': 0,
'created_at': '2008-06-07T02:20:58Z', 'updated_at': '2022-02-22T19:56:27Z'
}

但如果 URL 返回的是一些二进制数据,比如 ZIP 文件或 JPEG 图像,您需要从content属性中获取它们,因为这是二进制数据。例如,这就是如何下载一张图片(维基百科的标志):

import requests

URL = "https://en.wikipedia.org/static/images/project-logos/enwiki.png"
wikilogo = requests.get(URL)
if wikilogo.status_code == 200:
    with open("enwiki.png", "wb") as fp:
        fp.write(wikilogo.content)

既然我们已经获得了网页,应该如何提取数据?这超出了requests库的功能,但我们可以使用其他库来帮助完成。根据我们想要指定数据的方式,有两种方法可以实现。

第一种方法是将 HTML 视为一种 XML 文档,并使用 XPath 语言提取元素。在这种情况下,我们可以利用 lxml 库首先创建文档对象模型(DOM),然后通过 XPath 进行搜索:

...
from lxml import etree

# Create DOM from HTML text
dom = etree.HTML(resp.text)
# Search for the temperature element and get the content
elements = dom.xpath("//span[@data-testid='TemperatureValue' and contains(@class,'CurrentConditions')]")
print(elements[0].text)
61°

XPath 是一个字符串,用于指定如何查找一个元素。lxml 对象提供了一个 xpath() 函数,用于搜索匹配 XPath 字符串的 DOM 元素,这可能会有多个匹配项。上述 XPath 意味着查找任何具有 <span> 标签且属性 data-testid 匹配 “TemperatureValue” 和 class 以 “CurrentConditions” 开头的 HTML 元素。我们可以通过检查 HTML 源代码在浏览器的开发者工具中(例如,下面的 Chrome 截图)了解到这一点。

本示例旨在找到纽约市的温度,由我们从该网页获取的特定元素提供。我们知道 XPath 匹配的第一个元素就是我们需要的,我们可以读取 <span> 标签中的文本。

另一种方法是使用 HTML 文档上的 CSS 选择器,我们可以利用 BeautifulSoup 库:

...
from bs4 import BeautifulSoup

soup = BeautifulSoup(resp.text, "lxml")
elements = soup.select('span[data-testid="TemperatureValue"][class^="CurrentConditions"]')
print(elements[0].text)
61°

上述过程中,我们首先将 HTML 文本传递给 BeautifulSoup。BeautifulSoup 支持各种 HTML 解析器,每种解析器都有不同的能力。在上述过程中,我们使用 lxml 库作为解析器,正如 BeautifulSoup 推荐的(它也通常是最快的)。CSS 选择器是一种不同的迷你语言,与 XPath 相比有其优缺点。上面的选择器与我们在之前示例中使用的 XPath 是相同的。因此,我们可以从第一个匹配的元素中获取相同的温度。

下面是一个完整的代码示例,根据网页上的实时信息打印纽约市当前温度:

import requests
from lxml import etree

# Reading temperature of New York
URL = "https://weather.com/weather/today/l/40.75,-73.98"
resp = requests.get(URL)

if resp.status_code == 200:
    # Using lxml
    dom = etree.HTML(resp.text)
    elements = dom.xpath("//span[@data-testid='TemperatureValue' and contains(@class,'CurrentConditions')]")
    print(elements[0].text)

    # Using BeautifulSoup
    soup = BeautifulSoup(resp.text, "lxml")
    elements = soup.select('span[data-testid="TemperatureValue"][class^="CurrentConditions"]')
    print(elements[0].text)

正如你可以想象的那样,你可以通过定期运行这个脚本来收集温度的时间序列。同样,我们可以自动从各种网站收集数据。这就是我们如何为机器学习项目获取数据的方法。

使用 Pandas 读取网页上的表格

很多时候,网页会使用表格来承载数据。如果页面足够简单,我们甚至可以跳过检查它以找到 XPath 或 CSS 选择器,直接使用 pandas 一次性获取页面上的所有表格。这可以用一行代码简单实现:

import pandas as pd

tables = pd.read_html("https://www.federalreserve.gov/releases/h15/")
print(tables)
[                               Instruments 2022Apr7 2022Apr8 2022Apr11 2022Apr12 2022Apr13
0          Federal funds (effective) 1 2 3     0.33     0.33      0.33      0.33      0.33
1                 Commercial Paper 3 4 5 6      NaN      NaN       NaN       NaN       NaN
2                             Nonfinancial      NaN      NaN       NaN       NaN       NaN
3                                  1-month     0.30     0.34      0.36      0.39      0.39
4                                  2-month     n.a.     0.48      n.a.      n.a.      n.a.
5                                  3-month     n.a.     n.a.      n.a.      0.78      0.78
6                                Financial      NaN      NaN       NaN       NaN       NaN
7                                  1-month     0.49     0.45      0.46      0.39      0.46
8                                  2-month     n.a.     n.a.      0.60      0.71      n.a.
9                                  3-month     0.85     0.81      0.75      n.a.      0.86
10                   Bank prime loan 2 3 7     3.50     3.50      3.50      3.50      3.50
11      Discount window primary credit 2 8     0.50     0.50      0.50      0.50      0.50
12              U.S. government securities      NaN      NaN       NaN       NaN       NaN
13   Treasury bills (secondary market) 3 4      NaN      NaN       NaN       NaN       NaN
14                                  4-week     0.21     0.20      0.21      0.19      0.23
15                                 3-month     0.68     0.69      0.78      0.74      0.75
16                                 6-month     1.12     1.16      1.22      1.18      1.17
17                                  1-year     1.69     1.72      1.75      1.67      1.67
18            Treasury constant maturities      NaN      NaN       NaN       NaN       NaN
19                               Nominal 9      NaN      NaN       NaN       NaN       NaN
20                                 1-month     0.21     0.20      0.22      0.21      0.26
21                                 3-month     0.68     0.70      0.77      0.74      0.75
22                                 6-month     1.15     1.19      1.23      1.20      1.20
23                                  1-year     1.78     1.81      1.85      1.77      1.78
24                                  2-year     2.47     2.53      2.50      2.39      2.37
25                                  3-year     2.66     2.73      2.73      2.58      2.57
26                                  5-year     2.70     2.76      2.79      2.66      2.66
27                                  7-year     2.73     2.79      2.84      2.73      2.71
28                                 10-year     2.66     2.72      2.79      2.72      2.70
29                                 20-year     2.87     2.94      3.02      2.99      2.97
30                                 30-year     2.69     2.76      2.84      2.82      2.81
31                    Inflation indexed 10      NaN      NaN       NaN       NaN       NaN
32                                  5-year    -0.56    -0.57     -0.58     -0.65     -0.59
33                                  7-year    -0.34    -0.33     -0.32     -0.36     -0.31
34                                 10-year    -0.16    -0.15     -0.12     -0.14     -0.10
35                                 20-year     0.09     0.11      0.15      0.15      0.18
36                                 30-year     0.21     0.23      0.27      0.28      0.30
37  Inflation-indexed long-term average 11     0.23     0.26      0.30      0.30      0.33,       0               1
0  n.a.  Not available.]

pandas 中的 read_html() 函数读取一个 URL,并查找页面上的所有表格。每个表格都被转换为 pandas DataFrame,然后将所有表格以列表形式返回。在此示例中,我们正在读取来自联邦储备系统的各种利率,该页面上只有一个表格。表格列由 pandas 自动识别。

可能并非所有表格都是我们感兴趣的。有时,网页仅仅使用表格作为格式化页面的一种方式,但 pandas 可能无法聪明地识别这一点。因此,我们需要测试并挑选 read_html() 函数返回的结果。

想开始使用 Python 进行机器学习吗?

现在就参加我的免费 7 天邮件速成课程(附样例代码)。

点击注册,同时获取课程的免费 PDF 电子书版本。

使用 Selenium 读取动态内容

现代网页中有很大一部分充满了 JavaScript。这虽然提供了更炫的体验,但却成为了提取数据时的障碍。一个例子是 Yahoo 的主页,如果我们只是加载页面并查找所有新闻标题,那么看到的新闻数量远远少于在浏览器中看到的:

import requests

# Read Yahoo home page
URL = "https://www.yahoo.com/"
resp = requests.get(URL)
dom = etree.HTML(resp.text)

# Print news headlines
elements = dom.xpath("//h3/a[u[@class='StretchedBox']]")
for elem in elements:
    print(etree.tostring(elem, method="text", encoding="unicode"))

这是因为像这样的网站依赖 JavaScript 来填充内容。像 AngularJS 或 React 这样的著名 web 框架驱动了这一类别。Python 库,如 requests,无法理解 JavaScript。因此,你会看到不同的结果。如果你想从网页中获取的数据是其中之一,你可以研究 JavaScript 如何被调用,并在你的程序中模拟浏览器的行为。但这可能过于繁琐,难以实现。

另一种方法是让真实浏览器读取网页,而不是使用 requests。这正是 Selenium 可以做到的。在我们可以使用它之前,我们需要安装这个库:

pip install selenium

但 Selenium 只是一个控制浏览器的框架。你需要在你的计算机上安装浏览器以及将 Selenium 连接到浏览器的驱动程序。如果你打算使用 Chrome,你还需要下载并安装ChromeDriver。你需要将驱动程序放在可执行路径中,以便 Selenium 可以像正常命令一样调用它。例如,在 Linux 中,你只需从下载的 ZIP 文件中获取 chromedriver 可执行文件,并将其放在 /usr/local/bin 中。

类似地,如果你使用的是 Firefox,你需要GeckoDriver。有关设置 Selenium 的更多细节,你应该参考其文档

之后,你可以使用 Python 脚本来控制浏览器行为。例如:

import time
from selenium import webdriver
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.common.by import By

# Launch Chrome browser in headless mode
options = webdriver.ChromeOptions()
options.add_argument("headless")
browser = webdriver.Chrome(options=options)

# Load web page
browser.get("https://www.yahoo.com")
# Network transport takes time. Wait until the page is fully loaded
def is_ready(browser):
    return browser.execute_script(r"""
        return document.readyState === 'complete'
    """)
WebDriverWait(browser, 30).until(is_ready)

# Scroll to bottom of the page to trigger JavaScript action
browser.execute_script("window.scrollTo(0, document.body.scrollHeight);")
time.sleep(1)
WebDriverWait(browser, 30).until(is_ready)

# Search for news headlines and print
elements = browser.find_elements(By.XPATH, "//h3/a[u[@class='StretchedBox']]")
for elem in elements:
    print(elem.text)

# Close the browser once finish
browser.close()

上述代码的工作方式如下。我们首先以无头模式启动浏览器,这意味着我们要求 Chrome 启动但不在屏幕上显示。如果我们想远程运行脚本,这一点很重要,因为可能没有图形用户界面支持。请注意,每个浏览器的开发方式不同,因此我们使用的选项语法特定于 Chrome。如果我们使用 Firefox,代码将会是这样的:

options = webdriver.FirefoxOptions()
options.set_headless()
browser = webdriver.Firefox(firefox_options=options)

在我们启动浏览器后,我们给它一个 URL 进行加载。但由于网络传输页面需要时间,浏览器也需要时间来渲染,因此我们应该等到浏览器准备好后再进行下一步操作。我们通过使用 JavaScript 来检测浏览器是否完成渲染。我们让 Selenium 执行 JavaScript 代码,并使用execute_script()函数告诉我们结果。我们利用 Selenium 的WebDriverWait工具运行代码直到成功或直到 30 秒超时。随着页面加载,我们滚动到页面底部,以便触发 JavaScript 加载更多内容。然后我们无条件等待一秒钟,以确保浏览器触发了 JavaScript,再等到页面再次准备好。之后,我们可以使用 XPath(或使用 CSS 选择器)提取新闻标题元素。由于浏览器是一个外部程序,我们需要在脚本中负责关闭它。

使用 Selenium 与使用requests库在几个方面有所不同。首先,你的 Python 代码中不会直接包含网页内容。相反,每当你需要时,你都要引用浏览器的内容。因此,find_elements() 函数返回的网页元素指的是外部浏览器中的对象,因此我们在完成使用之前不能关闭浏览器。其次,所有操作都应基于浏览器交互,而不是网络请求。因此,你需要通过模拟键盘和鼠标操作来控制浏览器。但作为回报,你可以使用完整功能的浏览器并支持 JavaScript。例如,你可以使用 JavaScript 检查页面上元素的大小和位置,这只有在 HTML 元素渲染后才能知道。

Selenium 框架提供了很多功能,但我们在这里无法一一覆盖。它功能强大,但由于它与浏览器连接,使用起来比requests库更为复杂且速度较慢。通常,这是一种从网络获取信息的最后手段。

进一步阅读

另一个在 Python 中非常有名的网页爬取库是 Scrapy,它类似于将requests库和 BeautifulSoup 合并成一个库。网页协议复杂。有时我们需要管理网页 cookies 或使用 POST 方法提供额外数据。这些都可以通过 requests 库的不同函数或附加参数完成。以下是一些资源供你深入了解:

文章
API 文档
书籍

总结

在这个教程中,你了解了我们可以用来从网络获取内容的工具。

具体来说,你学到了:

  • 如何使用 requests 库发送 HTTP 请求并从响应中提取数据

  • 如何从 HTML 构建文档对象模型,以便在网页上找到一些特定的信息

  • 如何使用 pandas 快速轻松地读取网页上的表格

  • 如何使用 Selenium 控制浏览器处理网页上的动态内容

你的 Python 项目的 web 框架

原文:machinelearningmastery.com/web-frameworks-for-your-python-projects/

当我们完成一个 Python 项目并推出供其他人使用时,最简单的方法是将项目呈现为命令行程序。如果你想让它更友好,可能需要为你的程序开发一个 GUI,以便人们可以在运行时通过鼠标点击进行互动。开发 GUI 可能很困难,因为人机交互的模型复杂。因此,折衷方案是为你的程序创建一个网页界面。这相比于纯命令行程序需要额外的工作,但不像使用 Qt5 库那样繁重。在这篇文章中,我们将展示网页界面的细节以及如何轻松地为你的程序提供一个界面。

完成本教程后,你将学到:

  • 从一个简单的例子看 Flask 框架

  • 使用 Dash 完全用 Python 构建交互式网页

  • 一个 web 应用程序如何运行

通过我的新书 《Python 机器学习》启动你的项目,包括 逐步教程 和所有示例的 Python 源代码 文件。

让我们开始吧!

你的 Python 项目的 web 框架

图片由 Quang Nguyen Vinh 提供。保留了一些权利。

概述

本教程分为五个部分,它们是:

  • Python 和网络

  • Flask 用于 web API 应用程序

  • Dash 用于交互式小部件

  • Dash 中的轮询

  • 结合 Flask 和 Dash

Python 和网络

网络通过超文本传输协议(HTTP)进行服务。Python 的标准库支持与 HTTP 的交互。如果你只是想用 Python 运行一个 web 服务器,没有比进入一个文件目录并运行命令更简单的了。

python -m http.server

这通常会在 8000 端口启动一个 web 服务器。如果目录中存在 index.html,那将是我们在相同计算机上使用地址 http://localhost:8000/ 打开浏览器时提供的默认页面。

这个内置的 web 服务器非常适合快速设置 web 服务器(例如,让网络上的另一台计算机下载一个文件)。但如果我们想做更多的事情,比如拥有一些动态内容,它将不够用。

在深入细节之前,让我们回顾一下我们在谈到网页界面时希望实现的目标。首先,现代网页将是一个与用户互动的界面,用于传播信息。这不仅意味着从服务器发送信息,还包括接收用户的输入。浏览器能够以美观的方式呈现信息。

另外,我们还可以使用不带浏览器的网页。例如,使用 web 协议下载文件。在 Linux 中,我们有著名的 wget 工具来完成这个任务。另一个例子是查询信息或向服务器传递信息。例如,在 AWS EC2 实例中,你可以在地址 http://169.254.169.254/latest/meta-data/ 检查机器实例的 元数据(其中 169.254.169.254 是 EC2 机器上可用的特殊 IP 地址)。在 Linux 实例中,我们可以使用 curl 工具进行检查。其输出将不是 HTML,而是机器可读的纯文本格式。有时,我们将其称为 web API,因为我们像使用远程执行函数一样使用它。

这两种是网页应用中的不同范式。第一种需要编写用户与服务器之间交互的代码。第二种需要在 URL 上设置各种端点,以便用户可以使用不同的地址请求不同的内容。在 Python 中,有第三方库可以完成这两种任务。

想开始学习 Python 机器学习吗?

现在就参加我的免费 7 天电子邮件速成课程(附示例代码)。

点击注册并获取免费的 PDF 电子书版本课程。

Flask 用于 Web API 应用程序

允许我们用 Python 编写程序来构建基于网页的应用程序的工具被称为 web 框架。有很多这样的框架。Django 可能是最著名的一个。然而,不同的 web 框架的学习曲线可能差异很大。一些 web 框架假设你使用的是模型-视图设计,你需要理解其背后的原理才能明白如何使用它。

作为机器学习从业者,你可能希望做一些快速的、不太复杂的,但又足够强大以满足许多使用场景的事情。Flask 可能是这个类别中的一个不错选择。

Flask 是一个轻量级的 web 框架。你可以将它作为一个命令运行,并将其用作 Python 模块。假设我们想编写一个 web 服务器,报告任何用户指定时区的当前时间。可以通过 Flask 以简单的方式实现:

from datetime import datetime
import pytz
from flask import Flask

app = Flask("time now")

@app.route("/now/<path:timezone>")
def timenow(timezone):
    try:
        zone = pytz.timezone(timezone)
        now = datetime.now(zone)
        return now.strftime("%Y-%m-%d %H:%M:%S %z %Z\n")
    except pytz.exceptions.UnknownTimeZoneError:
        return f"Unknown time zone: {timezone}\n"

app.run()

将以上内容保存到 server.py 或任何你喜欢的文件名中,然后在终端中运行它。你将看到以下内容:

 * Serving Flask app 'time now' (lazy loading)
 * Environment: production
   WARNING: This is a development server. Do not use it in a production deployment.
   Use a production WSGI server instead.
 * Debug mode: off
 * Running on http://127.0.0.1:5000 (Press CTRL+C to quit)

这意味着你的脚本现在正在 http://127.0.0.1:5000 作为一个 web 服务器运行。它将永远服务于 web 请求,直到你用 Ctrl-C 中断它。

如果你打开另一个终端并查询 URL,例如,在 Linux 中使用 curl

$ curl http://127.0.0.1:5000/now/Asia/Tokyo
2022-04-20 13:29:42 +0900 JST

你会在屏幕上看到以你请求的时区(在这个例子中是 Asia/Tokyo)打印的时间,你可以在维基百科上查看所有支持的时区列表 在维基百科。函数返回的字符串将是 URL 返回的内容。如果时区无法识别,你会看到“未知时区”消息,如上面代码中的 except 块所返回的。

如果我们想稍微扩展一下,假设在未提供时区的情况下使用 UTC,我们只需向函数中添加另一个装饰器:

from datetime import datetime
import pytz
from flask import Flask

app = Flask("time now")

@app.route('/now', defaults={'timezone': ''})
@app.route("/now/<path:timezone>")
def timenow(timezone):
    try:
        if not timezone:
            zone = pytz.utc
        else:
            zone = pytz.timezone(timezone)
        now = datetime.now(zone)
        return now.strftime("%Y-%m-%d %H:%M:%S %z %Z\n")
    except pytz.exceptions.UnknownTimeZoneError:
        return f"Unknown timezone: {timezone}\n"

app.run()

重启服务器后,我们可以看到如下结果:

$ curl http://127.0.0.1:5000/now/Asia/Tokyo
2022-04-20 13:37:27 +0900 JST
$ curl http://127.0.0.1:5000/now/Asia/Tok
Unknown timezone: Asia/Tok
$ curl http://127.0.0.1:5000/now
2022-04-20 04:37:29 +0000 UTC

如今,许多这样的应用程序会返回一个 JSON 字符串以表示更复杂的数据,但从技术上讲,任何东西都可以被传递。如果你希望创建更多的 Web API,只需定义你的函数以返回数据,并像上面的例子一样用@app.route()进行装饰。

用于交互式小部件的 Dash

Flask 提供的 Web 端点非常强大。很多 Web 应用程序都是这样做的。例如,我们可以使用 HTML 编写网页用户界面,并用 Javascript 处理用户交互。一旦用户触发事件,我们可以让 Javascript 处理任何 UI 更改,并通过发送数据到一个端点创建一个 AJAX 调用并等待回复。AJAX 调用是异步的;因此,当接收到 Web 服务器的响应(通常在几分之一秒内)时,Javascript 会再次被触发,以进一步更新 UI,让用户了解情况。

然而,随着网页界面的复杂度越来越高,编写 Javascript 代码可能会变得繁琐。因此,有许多客户端库可以简化这一过程。有些库简化了 Javascript 编程,例如 jQuery。有些库改变了 HTML 和 Javascript 的交互方式,例如 ReactJS。但由于我们正在用 Python 开发机器学习项目,能够在不使用 Javascript 的情况下开发一个交互式网页应用将是非常棒的。Dash 就是为此而设的工具。

让我们考虑一个机器学习的例子:我们希望使用 MNIST 手写数字数据集来训练一个手写数字识别器。LeNet5 模型在这项任务中非常有名。但我们希望让用户微调 LeNet5 模型,重新训练它,然后用于识别。训练一个简单的 LeNet5 模型只需几行代码:

import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Dense, AveragePooling2D, Flatten
from tensorflow.keras.utils import to_categorical

# Load MNIST digits
(X_train, y_train), (X_test, y_test) = mnist.load_data()

# Reshape data to (n_samples, height, width, n_channel)
X_train = np.expand_dims(X_train, axis=3).astype("float32")
X_test = np.expand_dims(X_test, axis=3).astype("float32")

# One-hot encode the output
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

# LeNet5 model
model = Sequential([
    Conv2D(6, (5,5), activation="tanh",
           input_shape=(28,28,1), padding="same"),
    AveragePooling2D((2,2), strides=2),
    Conv2D(16, (5,5), activation="tanh"),
    AveragePooling2D((2,2), strides=2),
    Conv2D(120, (5,5), activation="tanh"),
    Flatten(),
    Dense(84, activation="tanh"),
    Dense(10, activation="softmax")
])

# Train the model
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=100, batch_size=32)

在这段代码中,我们可以更改几个超参数,例如激活函数、训练的优化器、训练轮次和批量大小。我们可以在 Dash 中创建一个界面,让用户更改这些参数并重新训练模型。这个界面将以 HTML 呈现,但用 Python 编码:

...
from flask import Flask
from dash import Dash, html, dcc

# default values
model_data = {
    "activation": "relu",
    "optimizer": "adam",
    "epochs": 100,
    "batchsize": 32,
}
...
server = Flask("mlm")
app = Dash(server=server)
app.layout = html.Div(
    id="parent",
    children=[
        html.H1(
            children="LeNet5 training",
            style={"textAlign": "center"}
        ),
        html.Div(
            className="flex-container",
            children=[
                html.Div(children=[
                    html.Div(id="activationdisplay", children="Activation:"),
                    dcc.Dropdown(
                        id="activation",
                        options=[
                            {"label": "Rectified linear unit", "value": "relu"},
                            {"label": "Hyperbolic tangent", "value": "tanh"},
                            {"label": "Sigmoidal", "value": "sigmoid"},
                        ],
                        value=model_data["activation"]
                    )
                ]),
                html.Div(children=[
                    html.Div(id="optimizerdisplay", children="Optimizer:"),
                    dcc.Dropdown(
                        id="optimizer",
                        options=[
                            {"label": "Adam", "value": "adam"},
                            {"label": "Adagrad", "value": "adagrad"},
                            {"label": "Nadam", "value": "nadam"},
                            {"label": "Adadelta", "value": "adadelta"},
                            {"label": "Adamax", "value": "adamax"},
                            {"label": "RMSprop", "value": "rmsprop"},
                            {"label": "SGD", "value": "sgd"},
                            {"label": "FTRL", "value": "ftrl"},
                        ],
                        value=model_data["optimizer"]
                    ),
                ]),
                html.Div(children=[
                    html.Div(id="epochdisplay", children="Epochs:"),
                    dcc.Slider(1, 200, 1, marks={1: "1", 100: "100", 200: "200"},
                               value=model_data["epochs"], id="epochs"),
                ]),
                html.Div(children=[
                    html.Div(id="batchdisplay", children="Batch size:"),
                    dcc.Slider(1, 128, 1, marks={1: "1", 128: "128"},
                               value=model_data["batchsize"], id="batchsize"),
                ]),
            ]
        ),
        html.Button(id="train", n_clicks=0, children="Train"),
    ]
)

在这里,我们设置了一个基于 Flask 服务器的 Dash 应用程序。上面的代码主要用于设置 Dash 应用程序的布局,该布局将在网页浏览器中显示。布局顶部有一个标题,底部有一个按钮(标签为“Train”),中间有一个包含多个选项小部件的大框。布局中有一个用于激活函数的下拉框,一个用于训练优化器的下拉框,以及两个滑块,一个用于轮次,一个用于批量大小。布局如下所示:

如果你熟悉 HTML 开发,你可能注意到我们上面使用了许多<div>元素。此外,我们还向一些元素提供了style参数,以改变它们在浏览器中的渲染方式。实际上,我们将这些 Python 代码保存到文件server.py中,并创建了一个文件assets/main.css,其内容如下:

CSS

.flex-container {
    display: flex;
    padding: 5px;
    flex-wrap: nowrap;
    background-color: #EEEEEE;
}

.flex-container > * {
    flex-grow: 1
}

当运行此代码时,我们可以使四个不同的用户选项水平对齐。

在创建了 HTML 前端之后,关键是让用户通过从下拉列表中选择或移动滑块来更改超参数。然后,在用户点击“训练”按钮后,我们启动模型训练。我们定义训练函数如下:

...
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Dense, AveragePooling2D, Flatten
from tensorflow.keras.callbacks import EarlyStopping

def train():
    activation = model_data["activation"]
    model = Sequential([
        Conv2D(6, (5, 5), activation=activation,
               input_shape=(28, 28, 1), padding="same"),
        AveragePooling2D((2, 2), strides=2),
        Conv2D(16, (5, 5), activation=activation),
        AveragePooling2D((2, 2), strides=2),
        Conv2D(120, (5, 5), activation=activation),
        Flatten(),
        Dense(84, activation=activation),
        Dense(10, activation="softmax")
    ])
    model.compile(loss="categorical_crossentropy",
                  optimizer=model_data["optimizer"],
                  metrics=["accuracy"])
    earlystop = EarlyStopping(monitor="val_loss", patience=3,
                              restore_best_weights=True)
    history = model.fit(
            X_train, y_train, validation_data=(X_test, y_test),
            epochs=model_data["epochs"],
            batch_size=model_data["batchsize"],
            verbose=0, callbacks=[earlystop])
    return model, history

这个函数依赖于一个外部字典model_data来获取参数和数据集,例如X_trainy_train,这些是在函数外部定义的。它将创建一个新模型,训练它,并返回带有训练历史的模型。我们只需在浏览器上的“训练”按钮被点击时运行此函数即可。我们在fit()函数中设置verbose=0,以要求训练过程不要向屏幕打印任何内容,因为它应该在服务器上运行,而用户则在浏览器中查看。用户无法看到服务器上的终端输出。我们还可以进一步显示训练周期中的损失和评估指标历史。这是我们需要做的:

...
import pandas as pd
import plotly.express as px
from dash.dependencies import Input, Output, State

...
app.layout = html.Div(
    id="parent",
    children=[
        ...
        html.Button(id="train", n_clicks=0, children="Train"),
        dcc.Graph(id="historyplot"),
    ]
)

...
@app.callback(Output("historyplot", "figure"),
              Input("train", "n_clicks"),
              State("activation", "value"),
              State("optimizer", "value"),
              State("epochs", "value"),
              State("batchsize", "value"),
              prevent_initial_call=True)
def train_action(n_clicks, activation, optimizer, epoch, batchsize):
    model_data.update({
        "activation": activation,
        "optimizer": optimizer,
        "epoch": epoch,
        "batchsize": batchsize,
    })
    model, history = train()
    model_data["model"] = model  # keep the trained model
    history = pd.DataFrame(history.history)
    fig = px.line(history, title="Model training metrics")
    fig.update_layout(xaxis_title="epochs",
                      yaxis_title="metric value", legend_title="metrics")
    return fig

我们首先在网页上添加一个Graph组件来显示我们的训练指标。Graph组件不是标准的 HTML 元素,而是 Dash 组件。Dash 提供了许多这样的组件,作为其主要特性。Dash 是 Plotly 的姊妹项目,Plotly 是一个类似于 Bokeh 的可视化库,将交互式图表渲染到 HTML 中。Graph组件用于显示 Plotly 图表。

然后我们定义了一个函数train_action(),并用我们 Dash 应用程序的回调函数装饰它。函数train_action()接受多个输入(模型超参数)并返回一个输出。在 Dash 中,输出通常是一个字符串,但我们在这里返回一个 Plotly 图形对象。回调装饰器要求我们指定输入和输出。这些是由其 ID 字段指定的网页组件,以及作为输入或输出的属性。在此示例中,除了输入和输出,我们还需要一些称为“状态”的额外数据。

在 Dash 中,输入是触发操作的因素。在这个示例中,Dash 中的一个按钮会记住它被按下的次数,这个次数存储在组件的属性n_clicks中。所以我们将这个属性的变化声明为触发该函数的因素。类似地,当这个函数返回时,图形对象将替换Graph组件。状态参数作为非触发参数提供给这个函数。指定输出、输入和状态的顺序非常重要,因为这是回调装饰器所期望的,以及我们定义的函数的参数顺序。

我们不会详细解释 Plotly 的语法。如果你了解了像 Bokeh 这样的可视化库的工作原理,查阅 Plotly 的文档后,应该不会很难将你的知识适应到 Plotly 上。

但是,我们需要提到 Dash 回调函数的一点:当网页首次加载时,所有回调函数会被调用一次,因为组件是新创建的。由于所有组件的属性从不存在到有了一些值,因此它们会触发事件。如果我们不希望它们在页面加载时被调用(例如,在这种情况下,我们不希望耗时的训练过程在用户确认超参数之前开始),我们需要在装饰器中指定prevent_initial_call=True

我们可以进一步通过使超参数选择变得交互化来迈出一步。这是礼貌的,因为你会对用户的操作提供反馈。由于我们已经为每个选择组件的标题有一个<div>元素,我们可以利用它来提供反馈,通过创建以下函数:

...

@app.callback(Output(component_id="epochdisplay", component_property="children"),
              Input(component_id="epochs", component_property="value"))
def update_epochs(value):
    return f"Epochs: {value}"

@app.callback(Output("batchdisplay", "children"),
              Input("batchsize", "value"))
def update_batchsize(value):
    return f"Batch size: {value}"

@app.callback(Output("activationdisplay", "children"),
              Input("activation", "value"))
def update_activation(value):
    return f"Activation: {value}"

@app.callback(Output("optimizerdisplay", "children"),
              Input("optimizer", "value"))
def update_optimizer(value):
    return f"Optimizer: {value}"

这些函数很简单,返回一个字符串,这个字符串会成为<div>元素的“子元素”。我们还展示了第一个函数装饰器中的命名参数,以防你希望更明确。

把所有内容整合在一起,以下是可以通过网页接口控制模型训练的完整代码:

import numpy as np
import pandas as pd
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Dense, AveragePooling2D, Flatten
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import EarlyStopping

import plotly.express as px
from dash import Dash, html, dcc
from dash.dependencies import Input, Output, State
from flask import Flask

server = Flask("mlm")
app = Dash(server=server)
# Load MNIST digits
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = np.expand_dims(X_train, axis=3).astype("float32")
X_test = np.expand_dims(X_test, axis=3).astype("float32")
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

model_data = {
    "activation": "relu",
    "optimizer": "adam",
    "epochs": 100,
    "batchsize": 32,
}

def train():
    activation = model_data["activation"]
    model = Sequential([
        Conv2D(6, (5, 5), activation=activation,
               input_shape=(28, 28, 1), padding="same"),
        AveragePooling2D((2, 2), strides=2),
        Conv2D(16, (5, 5), activation=activation),
        AveragePooling2D((2, 2), strides=2),
        Conv2D(120, (5, 5), activation=activation),
        Flatten(),
        Dense(84, activation=activation),
        Dense(10, activation="softmax")
    ])
    model.compile(loss="categorical_crossentropy",
                  optimizer=model_data["optimizer"],
                  metrics=["accuracy"])
    earlystop = EarlyStopping(monitor="val_loss", patience=3,
                              restore_best_weights=True)
    history = model.fit(
            X_train, y_train, validation_data=(X_test, y_test),
            epochs=model_data["epochs"],
            batch_size=model_data["batchsize"],
            verbose=0, callbacks=[earlystop])
    return model, history

app.layout = html.Div(
    id="parent",
    children=[
        html.H1(
            children="LeNet5 training",
            style={"textAlign": "center"}
        ),
        html.Div(
            className="flex-container",
            children=[
                html.Div(children=[
                    html.Div(id="activationdisplay"),
                    dcc.Dropdown(
                        id="activation",
                        options=[
                            {"label": "Rectified linear unit", "value": "relu"},
                            {"label": "Hyperbolic tangent", "value": "tanh"},
                            {"label": "Sigmoidal", "value": "sigmoid"},
                        ],
                        value=model_data["activation"]
                    )
                ]),
                html.Div(children=[
                    html.Div(id="optimizerdisplay"),
                    dcc.Dropdown(
                        id="optimizer",
                        options=[
                            {"label": "Adam", "value": "adam"},
                            {"label": "Adagrad", "value": "adagrad"},
                            {"label": "Nadam", "value": "nadam"},
                            {"label": "Adadelta", "value": "adadelta"},
                            {"label": "Adamax", "value": "adamax"},
                            {"label": "RMSprop", "value": "rmsprop"},
                            {"label": "SGD", "value": "sgd"},
                            {"label": "FTRL", "value": "ftrl"},
                        ],
                        value=model_data["optimizer"]
                    ),
                ]),
                html.Div(children=[
                    html.Div(id="epochdisplay"),
                    dcc.Slider(1, 200, 1, marks={1: "1", 100: "100", 200: "200"},
                               value=model_data["epochs"], id="epochs"),
                ]),
                html.Div(children=[
                    html.Div(id="batchdisplay"),
                    dcc.Slider(1, 128, 1, marks={1: "1", 128: "128"},
                               value=model_data["batchsize"], id="batchsize"),
                ]),
            ]
        ),
        html.Button(id="train", n_clicks=0, children="Train"),
        dcc.Graph(id="historyplot"),
    ]
)

@app.callback(Output(component_id="epochdisplay", component_property="children"),
              Input(component_id="epochs", component_property="value"))
def update_epochs(value):
    model_data["epochs"] = value
    return f"Epochs: {value}"

@app.callback(Output("batchdisplay", "children"),
              Input("batchsize", "value"))
def update_batchsize(value):
    model_data["batchsize"] = value
    return f"Batch size: {value}"

@app.callback(Output("activationdisplay", "children"),
              Input("activation", "value"))
def update_activation(value):
    model_data["activation"] = value
    return f"Activation: {value}"

@app.callback(Output("optimizerdisplay", "children"),
              Input("optimizer", "value"))
def update_optimizer(value):
    model_data["optimizer"] = value
    return f"Optimizer: {value}"

@app.callback(Output("historyplot", "figure"),
              Input("train", "n_clicks"),
              State("activation", "value"),
              State("optimizer", "value"),
              State("epochs", "value"),
              State("batchsize", "value"),
              prevent_initial_call=True)
def train_action(n_clicks, activation, optimizer, epoch, batchsize):
    model_data.update({
        "activation": activation,
        "optimizer": optimizer,
        "epcoh": epoch,
        "batchsize": batchsize,
    })
    model, history = train()
    model_data["model"] = model  # keep the trained model
    history = pd.DataFrame(history.history)
    fig = px.line(history, title="Model training metrics")
    fig.update_layout(xaxis_title="epochs",
                      yaxis_title="metric value", legend_title="metrics")
    return fig

# run server, with hot-reloading
app.run_server(debug=True, threaded=True)

上述代码的最后一行是运行 Dash 应用程序,就像我们在上一节中运行 Flask 应用程序一样。run_server() 函数的 debug=True 参数用于“热重载”,这意味着每当 Dash 检测到我们的脚本已更改时,它会重新加载所有内容。这在我们在另一个窗口编辑代码时非常方便,因为它不需要我们终止 Dash 服务器并重新运行。threaded=True 是要求 Dash 服务器在处理多个请求时以多线程运行。一般来说,不建议 Python 程序使用多线程,因为全局解释器锁的问题。但在 Web 服务器环境中,由于大多数时候服务器在等待 I/O,所以是可以接受的。如果不是多线程,选项将是多进程运行。我们不能在单线程和单进程中运行服务器,因为即使我们只为一个用户提供服务,浏览器也会同时启动多个 HTTP 查询(例如,请求我们上面创建的 CSS 文件时加载网页)。

在 Dash 中进行轮询

如果我们用中等数量的 epochs 运行上述 Dash 应用程序,它将花费相当长的时间来完成。我们希望看到它运行,而不仅仅在完成后更新图表。有一种方法可以要求 Dash 向我们的浏览器推送更新,但这需要一个插件(例如,dash_devices 包可以做到这一点)。但我们也可以要求浏览器拉取任何更新。这种设计称为轮询

在我们上面定义的 train() 函数中,我们设置 verbose=0 来跳过终端输出。但是我们仍然需要了解训练过程的进度。在 Keras 中,这可以通过自定义回调函数来完成。我们可以如下定义一个:

...
from tensorflow.keras.callbacks import Callback

train_status = {
    "running": False,
    "epoch": 0,
    "batch": 0,
    "batch metric": None,
    "last epoch": None,
}

class ProgressCallback(Callback):
    def on_train_begin(self, logs=None):
        train_status["running"] = True
        train_status["epoch"] = 0
    def on_train_end(self, logs=None):
        train_status["running"] = False
    def on_epoch_begin(self, epoch, logs=None):
        train_status["epoch"] = epoch
        train_status["batch"] = 0
    def on_epoch_end(self, epoch, logs=None):
        train_status["last epoch"] = logs
    def on_train_batch_begin(self, batch, logs=None):
        train_status["batch"] = batch
    def on_train_batch_end(self, batch, logs=None):
        train_status["batch metric"] = logs

def train():
    ...
    history = model.fit(
            X_train, y_train, validation_data=(X_test, y_test),
            epochs=model_data["epochs"],
            batch_size=model_data["batchsize"],
            verbose=0, callbacks=[earlystop, ProgressCallback()])
    return model, history

如果我们为 Keras 模型的 fit() 函数提供此类的实例,这个类的成员函数将在训练周期、epoch 或批次的开始或结束时被调用。在函数内部我们可以做很多事情。在 epoch 或批次结束时,函数的 logs 参数是损失和验证指标的字典。因此,我们定义了一个全局字典对象来记住这些指标。

现在,我们可以随时检查字典 train_status 来了解模型训练的进度,我们可以修改我们的网页来显示它:

...

app.layout = html.Div(
    id="parent",
    children=[
        ...
        html.Button(id="train", n_clicks=0, children="Train"),
        html.Pre(id="progressdisplay"),
        dcc.Interval(id="trainprogress", n_intervals=0, interval=1000),
        dcc.Graph(id="historyplot"),
    ]
)

import json

@app.callback(Output("progressdisplay", "children"),
              Input("trainprogress", "n_intervals"))
def update_progress(n):
    return json.dumps(train_status, indent=4)

我们创建一个不可见组件 dcc.Interval(),它每隔 1000 毫秒(= 1 秒)自动更改其属性 n_intervals。然后我们在我们的“Train”按钮下创建一个 <pre> 元素,并命名为 progressdisplay。每当 Interval 组件触发时,我们将 train_status 字典转换为 JSON 字符串并显示在那个 <pre> 元素中。如果你愿意,你可以创建一个小部件来显示这些信息。Dash 提供了几个小部件。

仅仅通过这些更改,当您的模型训练完成时,您的浏览器将看起来像这样:

以下是完整的代码。不要忘记你还需要 assets/main.css 文件以正确渲染网页:

import json

import numpy as np
import pandas as pd
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Dense, AveragePooling2D, Flatten
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import Callback, EarlyStopping

import plotly.express as px
from dash import Dash, html, dcc
from dash.dependencies import Input, Output, State
from flask import Flask

server = Flask("mlm")
app = Dash(server=server)

# Load MNIST digits
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = np.expand_dims(X_train, axis=3).astype("float32")
X_test = np.expand_dims(X_test, axis=3).astype("float32")
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

model_data = {
    "activation": "relu",
    "optimizer": "adam",
    "epochs": 100,
    "batchsize": 32,
}

train_status = {
    "running": False,
    "epoch": 0,
    "batch": 0,
    "batch metric": None,
    "last epoch": None,
}

class ProgressCallback(Callback):
    def on_train_begin(self, logs=None):
        train_status["running"] = True
        train_status["epoch"] = 0
    def on_train_end(self, logs=None):
        train_status["running"] = False
    def on_epoch_begin(self, epoch, logs=None):
        train_status["epoch"] = epoch
        train_status["batch"] = 0
    def on_epoch_end(self, epoch, logs=None):
        train_status["last epoch"] = logs
    def on_train_batch_begin(self, batch, logs=None):
        train_status["batch"] = batch
    def on_train_batch_end(self, batch, logs=None):
        train_status["batch metric"] = logs

def train():
    activation = model_data["activation"]
    model = Sequential([
        Conv2D(6, (5, 5), activation=activation,
               input_shape=(28, 28, 1), padding="same"),
        AveragePooling2D((2, 2), strides=2),
        Conv2D(16, (5, 5), activation=activation),
        AveragePooling2D((2, 2), strides=2),
        Conv2D(120, (5, 5), activation=activation),
        Flatten(),
        Dense(84, activation=activation),
        Dense(10, activation="softmax")
    ])
    model.compile(loss="categorical_crossentropy",
                  optimizer=model_data["optimizer"],
                  metrics=["accuracy"])
    earlystop = EarlyStopping(monitor="val_loss", patience=3,
                              restore_best_weights=True)
    history = model.fit(
            X_train, y_train, validation_data=(X_test, y_test),
            epochs=model_data["epochs"],
            batch_size=model_data["batchsize"],
            verbose=0, callbacks=[earlystop, ProgressCallback()])
    return model, history

app.layout = html.Div(
    id="parent",
    children=[
        html.H1(
            children="LeNet5 training",
            style={"textAlign": "center"}
        ),
        html.Div(
            className="flex-container",
            children=[
                html.Div(children=[
                    html.Div(id="activationdisplay"),
                    dcc.Dropdown(
                        id="activation",
                        options=[
                            {"label": "Rectified linear unit", "value": "relu"},
                            {"label": "Hyperbolic tangent", "value": "tanh"},
                            {"label": "Sigmoidal", "value": "sigmoid"},
                        ],
                        value=model_data["activation"]
                    )
                ]),
                html.Div(children=[
                    html.Div(id="optimizerdisplay"),
                    dcc.Dropdown(
                        id="optimizer",
                        options=[
                            {"label": "Adam", "value": "adam"},
                            {"label": "Adagrad", "value": "adagrad"},
                            {"label": "Nadam", "value": "nadam"},
                            {"label": "Adadelta", "value": "adadelta"},
                            {"label": "Adamax", "value": "adamax"},
                            {"label": "RMSprop", "value": "rmsprop"},
                            {"label": "SGD", "value": "sgd"},
                            {"label": "FTRL", "value": "ftrl"},
                        ],
                        value=model_data["optimizer"]
                    ),
                ]),
                html.Div(children=[
                    html.Div(id="epochdisplay"),
                    dcc.Slider(1, 200, 1, marks={1: "1", 100: "100", 200: "200"},
                               value=model_data["epochs"], id="epochs"),
                ]),
                html.Div(children=[
                    html.Div(id="batchdisplay"),
                    dcc.Slider(1, 128, 1, marks={1: "1", 128: "128"},
                               value=model_data["batchsize"], id="batchsize"),
                ]),
            ]
        ),
        html.Button(id="train", n_clicks=0, children="Train"),
        html.Pre(id="progressdisplay"),
        dcc.Interval(id="trainprogress", n_intervals=0, interval=1000),
        dcc.Graph(id="historyplot"),
    ]
)

@app.callback(Output(component_id="epochdisplay", component_property="children"),
              Input(component_id="epochs", component_property="value"))
def update_epochs(value):
    return f"Epochs: {value}"

@app.callback(Output("batchdisplay", "children"),
              Input("batchsize", "value"))
def update_batchsize(value):
    return f"Batch size: {value}"

@app.callback(Output("activationdisplay", "children"),
              Input("activation", "value"))
def update_activation(value):
    return f"Activation: {value}"

@app.callback(Output("optimizerdisplay", "children"),
              Input("optimizer", "value"))
def update_optimizer(value):
    return f"Optimizer: {value}"

@app.callback(Output("historyplot", "figure"),
              Input("train", "n_clicks"),
              State("activation", "value"),
              State("optimizer", "value"),
              State("epochs", "value"),
              State("batchsize", "value"),
              prevent_initial_call=True)
def train_action(n_clicks, activation, optimizer, epoch, batchsize):
    model_data.update({
        "activation": activation,
        "optimizer": optimizer,
        "epoch": epoch,
        "batchsize": batchsize,
    })
    model, history = train()
    model_data["model"] = model  # keep the trained model
    history = pd.DataFrame(history.history)
    fig = px.line(history, title="Model training metrics")
    fig.update_layout(xaxis_title="epochs",
                      yaxis_title="metric value", legend_title="metrics")
    return fig

@app.callback(Output("progressdisplay", "children"),
              Input("trainprogress", "n_intervals"))
def update_progress(n):
    return json.dumps(train_status, indent=4)

# run server, with hot-reloading
app.run_server(debug=True, threaded=True)

结合 Flask 和 Dash

你也可以提供一个网页界面来 使用 训练好的模型吗?当然可以。如果模型接受一些数字输入,这会更容易,因为我们只需在页面上提供一个输入框元素。在这种情况下,由于这是一个手写数字识别模型,我们需要一种方法在浏览器中提供图像,并将其传递给服务器上的模型。只有这样,我们才能获得结果并显示出来。我们可以选择两种方式来实现这一点:我们可以让用户上传一个数字图像供模型识别,或者让用户直接在浏览器中绘制图像。

在 HTML5 中,我们有一个 <canvas> 元素,允许我们在网页上绘制或显示像素。我们可以利用这个元素让用户在上面绘制,然后将其转换为 28×28 的数字矩阵,并将其发送到服务器端,让模型进行预测并显示预测结果。

这样做不是 Dash 的工作,因为我们想要读取 <canvas> 元素并将其转换为正确格式的矩阵。我们将在 Javascript 中完成这项工作。但之后,我们会在一个网页 URL 中调用模型,就像我们在文章开头所描述的那样。一个带有参数的查询会被发送,服务器的响应将是我们的模型识别出的数字。

在后台,Dash 使用 Flask,根 URL 指向 Dash 应用程序。我们可以创建一个使用模型的 Flask 端点,如下所示:

...
@server.route("/recognize", methods=["POST"])
def recognize():
    if not model_data.get("model"):
        return "Please train your model."
    matrix = json.loads(request.form["matrix"])
    matrix = np.asarray(matrix).reshape(1, 28, 28)
    proba = model_data["model"].predict(matrix).reshape(-1)
    result = np.argmax(proba)
    return "Digit "+str(result)

正如我们所回忆的,变量 server 是我们构建 Dash 应用程序的 Flask 服务器。我们使用其装饰器创建一个端点。由于我们要传递一个 28×28 矩阵作为参数,因此我们使用 HTTP POST 方法,这对于大块数据更为适合。POST 方法提供的数据不会成为 URL 的一部分。因此,我们没有在 @server.route() 装饰器中设置路径参数。相反,我们通过 request.form["matrix"] 读取数据,其中 "matrix" 是我们传递的参数名称。然后我们假设字符串为 JSON 格式,将其转换为数字列表,并进一步转换为 NumPy 数组,然后传递给模型以预测数字。我们将训练好的模型保存在 model_data["model"] 中,但我们可以通过检查该训练模型是否存在并在不存在时返回错误消息,使上述代码更健壮。

要修改网页,我们只需添加一些组件:

app.layout = html.Div(
    id="parent",
    children=[
        ...
        dcc.Graph(id="historyplot"),
        html.Div(
            className="flex-container",
            id="predict",
            children=[
                html.Div(
                    children=html.Canvas(id="writing"),
                    style={"textAlign": "center"}
                ),
                html.Div(id="predictresult", children="?"),
                html.Pre(
                    id="lastinput",
                ),
            ]
        ),
        html.Div(id="dummy", style={"display": "none"}),
    ]
)

底部是一个隐藏的 <div> 元素,我们稍后将使用它。主要部分是另一个 <div> 元素,其中包含三个项目,即一个 <canvas> 元素(ID 为 "writing"),一个 <div> 元素(ID 为 "predictresult")用于显示结果,以及一个 <pre> 元素(ID 为 "lastinput")用于显示我们传递给服务器的矩阵。

由于这些元素不是由 Dash 处理的,我们不需要在 Python 中创建更多的函数。相反,我们需要创建一个 JavaScript 文件 assets/main.js 以便与这些组件进行交互。Dash 应用程序会自动加载 assets 目录下的所有内容,并在网页加载时将其发送给用户。我们可以用纯 JavaScript 编写这些内容,但为了使代码更简洁,我们将使用 jQuery。因此,我们需要告诉 Dash 我们将在这个 Web 应用程序中使用 jQuery:

...
app = Dash(server=server,
           external_scripts=[
               "https://code.jquery.com/jquery-3.6.0.min.js"
           ])

external_scripts 参数是一个 URL 列表,这些 URL 指向将在网页加载之前作为附加脚本加载的资源。因此,我们通常会在这里提供库,但将我们自己的代码保持在外部。

我们自己的 JavaScript 代码将是一个单独的函数,因为它在网页完全加载后被调用:

JavaScript

function pageinit() {
	// Set up canvas object
	var canvas = document.getElementById("writing");
	canvas.width = parseInt($("#writing").css("width"));
	canvas.height = parseInt($("#writing").css("height"));
	var context = canvas.getContext("2d");  // to remember drawing
	context.strokeStyle = "#FF0000";        // draw in bright red
	context.lineWidth = canvas.width / 15;  // thickness adaptive to canvas size

	...
};

我们首先在 JavaScript 中设置 <canvas> 元素。这些设置是特定于我们需求的。首先,我们将以下内容添加到 assets/main.css 中:

CSS

canvas#writing {
    width: 300px;
    height: 300px;
    margin: auto;
    padding: 10px;
    border: 3px solid #7f7f7f;
    background-color: #FFFFFF;
}

这将宽度和高度固定为 300 像素,以使我们的画布成为正方形,同时进行其他美观上的微调。由于最终我们会将手写的内容转换为 28×28 像素的图像,以适应模型的期望,所以我们在画布上写的每一笔都不能过于细。因此,我们将笔画宽度设置为与画布大小相关。

仅有这些还不足以使我们的画布可用。假设我们从未在移动设备上使用它,而只在桌面浏览器上使用,绘图是通过鼠标点击和移动完成的。我们需要定义鼠标点击在画布上执行的操作。因此,我们将以下功能添加到 JavaScript 代码中:

JavaScript

function pageinit() {
	...

	// Canvas reset by timeout
	var timeout = null; // holding the timeout event
	var reset = function() {
		// clear the canvas
		context.clearRect(0, 0, canvas.width, canvas.height);
	}

	// Set up drawing with mouse
	var mouse = {x:0, y:0}; // to remember the coordinate w.r.t. canvas
	var onPaint = function() {
		clearTimeout(timeout);
		// event handler for mouse move in canvas
		context.lineTo(mouse.x, mouse.y);
		context.stroke();
	};

	// HTML5 Canvas mouse event - in case of desktop browser
	canvas.addEventListener("mousedown", function(e) {
		clearTimeout(timeout);
		// mouse down, begin path at current mouse position
		context.moveTo(mouse.x, mouse.y);
		context.beginPath();
		// all mouse move from now on should be painted
		canvas.addEventListener("mousemove", onPaint, false);
	}, false);
	canvas.addEventListener("mousemove", function(e) {
		// mouse move remember position w.r.t. canvas
		mouse.x = e.pageX - this.offsetLeft;
		mouse.y = e.pageY - this.offsetTop;
	}, false);
	canvas.addEventListener("mouseup", function(e) {
		clearTimeout(timeout);
		// all mouse move from now on should NOT be painted
		canvas.removeEventListener("mousemove", onPaint, false);
		// read drawing into image
		var img = new Image(); // on load, this will be the canvas in same WxH
		img.onload = function() {
			// Draw the 28x28 to top left corner of canvas
			context.drawImage(img, 0, 0, 28, 28);
			// Extract data: Each pixel becomes a RGBA value, hence 4 bytes each
			var data = context.getImageData(0, 0, 28, 28).data;
			var input = [];
			for (var i=0; i<data.length; i += 4) {
				// scan each pixel, extract first byte (R component)
				input.push(data[i]);
			};

			// TODO: use "input" for prediction
		};
		img.src = canvas.toDataURL("image/png");
		timeout = setTimeout(reset, 5000); // clear canvas after 5 sec
	}, false);
};

这有点啰嗦,但基本上我们要求监听画布上的三个鼠标事件,即按下鼠标按钮、移动鼠标和释放鼠标按钮。这三个事件组合在一起就是我们在画布上绘制一笔的方式。

首先,我们添加到 <canvas> 元素上的 mousemove 事件处理器仅仅是为了记住 JavaScript 对象 mouse 中当前的鼠标位置。

然后在 mousedown 事件处理器中,我们从最新的鼠标位置开始绘图上下文。由于绘图已经开始,所有后续的鼠标移动都应该在画布上绘制。我们定义了 onPaint 函数,以将线段扩展到画布上当前的鼠标位置。现在这个函数被注册为 mousemove 事件的附加事件处理器。

最后,mouseup 事件处理程序用于处理用户完成一次绘制并释放鼠标按钮的情况。所有后续的鼠标移动不应在画布上绘制,因此我们需要移除 onPaint 函数的事件处理程序。然后,当我们完成一次绘制时,这 可能是 一个完成的数字,因此我们想将其提取为 28×28 像素版本。这可以很容易完成。我们只需在 Javascript 中创建一个新的 Image 对象,并将整个画布加载到其中。当完成后,Javascript 会自动调用与之关联的 onload 函数。在其中,我们将这个 Image 对象转化为 28×28 像素,并绘制到我们 context 对象的左上角。然后我们逐像素读取它(每个像素将是 0 到 255 的 RGB 值,但由于我们使用红色绘制,我们只关心红色通道)到 Javascript 数组 input 中。我们只需将这个 input 数组传递给我们的模型,然后可以进行预测。

我们不想创建任何额外的按钮来清除我们的画布或提交我们的数字进行识别。因此,我们希望如果用户在 5 秒内没有绘制任何新内容,画布会自动清除。这是通过 Javascript 函数 setTimeout()clearTimeout() 实现的。我们创建一个 reset 函数来清除画布,该函数将在 mouseup 事件后 5 秒触发。而这个计划调用的 reset 函数会在超时之前发生绘制事件时被取消。同样,每当发生 mouseup 事件时,识别也会自动进行。

给定我们有一个 28×28 像素的输入数据被转化为一个 Javascript 数组,我们可以直接使用我们用 Flask 创建的 recognize 端点。如果我们能看到我们传递给 recognize 的内容以及它返回的结果会很有帮助。所以我们在 ID 为 lastinput<pre> 元素中显示输入数据,并在 ID 为 predictresult<div> 元素中显示 recognize 端点返回的结果。这可以通过稍微扩展 mouseup 事件处理程序轻松完成。

JavaScript

function pageinit() {
	canvas.addEventListener("mouseup", function(e) {
		...
		img.onload = function() {
            ...
			var input = [];
			for (var i=0; i<data.length; i += 4) {
				// scan each pixel, extract first byte (R component)
				input.push(data[i]);
			};
			var matrix = [];
			for (var i=0; i<input.length; i+=28) {
				matrix.push(input.slice(i, i+28).toString());
			};
			$("#lastinput").html("[[" + matrix.join("],<br/>[") + "]]");
			// call predict function with the matrix
			predict(input);
		};
		img.src = canvas.toDataURL("image/png");
		setTimeout(reset, 5000); // clear canvas after 5 sec
	}, false);

	function predict(input) {
		$.ajax({
			type: "POST",
			url: "/recognize",
			data: {"matrix": JSON.stringify(input)},
			success: function(result) {
				$("#predictresult").html(result);
			}
		});
	};
};

我们定义了一个新的 Javascript 函数 predict(),它会发起一个 AJAX 调用到我们用 Flask 设置的 recognize 端点。它使用 POST 方法,数据 matrix 赋值为 Javascript 数组的 JSON 版本。我们不能直接在 HTTP 请求中传递数组,因为一切必须被序列化。当 AJAX 调用返回时,我们更新 <div> 元素以显示结果。

这个 predict() 函数是由 mouseup 事件处理程序调用的,当我们完成将 28×28 像素图像转化为数字数组时。同时,我们将一个版本写入 <pre> 元素,仅用于显示目的。

到这里,我们的应用程序已经完成。但我们仍然需要在 Dash 应用程序加载时调用pageinit()函数。实际上,Dash 应用程序使用 React 来进行延迟渲染,因此我们不应该将pageinit()函数挂钩到document.onload事件处理程序上,否则我们会发现我们要找的组件不存在。正确的方法是在 Dash 应用程序完全加载时调用 JavaScript 函数是设置一个客户端回调,即由浏览器端 JavaScript 处理的回调,而不是服务器端的 Python。我们在 Python 程序server.py中添加以下函数调用:

...
app.clientside_callback(
    "pageinit",
    Output("dummy", "children"),
    Input("dummy", "children")
)

clientside_callback()函数不是作为装饰器使用,而是作为完整的函数调用。它将 JavaScript 函数作为第一个参数,将OutputInput对象作为第二和第三个参数,类似于回调装饰器的情况。由于这个原因,我们在网页布局中创建了一个隐藏的虚拟组件,以帮助在页面加载时触发 JavaScript 函数,所有 Dash 回调会被调用一次,除非prevent_initial_call=True作为回调的一个参数。

现在我们一切就绪。我们可以运行server.py脚本来启动我们的 Web 服务器,它将加载assets/目录下的两个文件。打开浏览器访问 Dash 应用程序报告的 URL,我们可以更改超参数并训练模型,然后使用模型进行预测。

综合起来,以下是我们 JavaScript 部分的完整代码,保存为assets/main.js

JavaScript

function pageinit() {
	// Set up canvas object
	var canvas = document.getElementById("writing");
	canvas.width = parseInt($("#writing").css("width"));
	canvas.height = parseInt($("#writing").css("height"));
	var context = canvas.getContext("2d");  // to remember drawing
	context.strokeStyle = "#FF0000";        // draw in bright red
	context.lineWidth = canvas.width / 15;  // thickness adaptive to canvas size

	// Canvas reset by timeout
	var timeout = null; // holding the timeout event
	var reset = function() {
		// clear the canvas
		context.clearRect(0, 0, canvas.width, canvas.height);
	}

	// Set up drawing with mouse
	var mouse = {x:0, y:0}; // to remember the coordinate w.r.t. canvas
	var onPaint = function() {
		clearTimeout(timeout);
		// event handler for mousemove in canvas
		context.lineTo(mouse.x, mouse.y);
		context.stroke();
	};

	// HTML5 Canvas mouse event - in case of desktop browser
	canvas.addEventListener("mousedown", function(e) {
		clearTimeout(timeout);
		// mousedown, begin path at mouse position
		context.moveTo(mouse.x, mouse.y);
		context.beginPath();
		// all mousemove from now on should be painted
		canvas.addEventListener("mousemove", onPaint, false);
	}, false);
	canvas.addEventListener("mousemove", function(e) {
		// mousemove remember position w.r.t. canvas
		mouse.x = e.pageX - this.offsetLeft;
		mouse.y = e.pageY - this.offsetTop;
	}, false);
	canvas.addEventListener("mouseup", function(e) {
		clearTimeout(timeout);
		// all mousemove from now on should NOT be painted
		canvas.removeEventListener("mousemove", onPaint, false);
		// read drawing into image
		var img = new Image(); // on load, this will be the canvas in same WxH
		img.onload = function() {
			// Draw the 28x28 to top left corner of canvas
			context.drawImage(img, 0, 0, 28, 28);
			// Extract data: Each pixel becomes a RGBA value, hence 4 bytes each
			var data = context.getImageData(0, 0, 28, 28).data;
			var input = [];
			for (var i=0; i<data.length; i += 4) {
				// scan each pixel, extract first byte (R component)
				input.push(data[i]);
			};
			var matrix = [];
			for (var i=0; i<input.length; i+=28) {
				matrix.push(input.slice(i, i+28).toString());
			};
			$("#lastinput").html("[[" + matrix.join("],\n[") + "]]");
			// call predict function with the matrix
			predict(input);
		};
		img.src = canvas.toDataURL("image/png");
		timeout = setTimeout(reset, 5000); // clear canvas after 5 sec
	}, false);

	function predict(input) {
		$.ajax({
			type: "POST",
			url: "/recognize",
			data: {"matrix": JSON.stringify(input)},
			success: function(result) {
				$("#predictresult").html(result);
			}
		});
	};
};

以下是 CSS 的完整代码,assets/main.csspre#lastinput部分是使用较小的字体显示我们的输入矩阵):

CSS

.flex-container {
    display: flex;
    padding: 5px;
    flex-wrap: nowrap;
    background-color: #EEEEEE;
}

.flex-container > * {
    flex-grow: 1
}

canvas#writing {
    width: 300px;
    height: 300px;
    margin: auto;
    padding: 10px;
    border: 3px solid #7f7f7f;
    background-color: #FFFFFF;
}

pre#lastinput {
    font-size: 50%;
}

以下是主要的 Python 程序,server.py

import json

import numpy as np
import pandas as pd
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import Conv2D, Dense, AveragePooling2D, Flatten
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import Callback, EarlyStopping

import plotly.express as px
from dash import Dash, html, dcc
from dash.dependencies import Input, Output, State
from flask import Flask, request

server = Flask("mlm")
app = Dash(server=server,
           external_scripts=[
               "https://code.jquery.com/jquery-3.6.0.min.js"
           ])

# Load MNIST digits
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = np.expand_dims(X_train, axis=3).astype("float32")
X_test = np.expand_dims(X_test, axis=3).astype("float32")
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

model_data = {
    "activation": "relu",
    "optimizer": "adam",
    "epochs": 100,
    "batchsize": 32,
    "model": load_model("lenet5.h5"),
}
train_status = {
    "running": False,
    "epoch": 0,
    "batch": 0,
    "batch metric": None,
    "last epoch": None,
}

class ProgressCallback(Callback):
    def on_train_begin(self, logs=None):
        train_status["running"] = True
        train_status["epoch"] = 0
    def on_train_end(self, logs=None):
        train_status["running"] = False
    def on_epoch_begin(self, epoch, logs=None):
        train_status["epoch"] = epoch
        train_status["batch"] = 0
    def on_epoch_end(self, epoch, logs=None):
        train_status["last epoch"] = logs
    def on_train_batch_begin(self, batch, logs=None):
        train_status["batch"] = batch
    def on_train_batch_end(self, batch, logs=None):
        train_status["batch metric"] = logs

def train():
    activation = model_data["activation"]
    model = Sequential([
        Conv2D(6, (5, 5), activation=activation,
               input_shape=(28, 28, 1), padding="same"),
        AveragePooling2D((2, 2), strides=2),
        Conv2D(16, (5, 5), activation=activation),
        AveragePooling2D((2, 2), strides=2),
        Conv2D(120, (5, 5), activation=activation),
        Flatten(),
        Dense(84, activation=activation),
        Dense(10, activation="softmax")
    ])
    model.compile(loss="categorical_crossentropy",
                  optimizer=model_data["optimizer"],
                  metrics=["accuracy"])
    earlystop = EarlyStopping(monitor="val_loss", patience=3,
                              restore_best_weights=True)
    history = model.fit(
            X_train, y_train, validation_data=(X_test, y_test),
            epochs=model_data["epochs"],
            batch_size=model_data["batchsize"],
            verbose=0, callbacks=[earlystop, ProgressCallback()])
    return model, history

app.layout = html.Div(
    id="parent",
    children=[
        html.H1(
            children="LeNet5 training",
            style={"textAlign": "center"}
        ),
        html.Div(
            className="flex-container",
            children=[
                html.Div(children=[
                    html.Div(id="activationdisplay"),
                    dcc.Dropdown(
                        id="activation",
                        options=[
                            {"label": "Rectified linear unit", "value": "relu"},
                            {"label": "Hyperbolic tangent", "value": "tanh"},
                            {"label": "Sigmoidal", "value": "sigmoid"},
                        ],
                        value=model_data["activation"]
                    )
                ]),
                html.Div(children=[
                    html.Div(id="optimizerdisplay"),
                    dcc.Dropdown(
                        id="optimizer",
                        options=[
                            {"label": "Adam", "value": "adam"},
                            {"label": "Adagrad", "value": "adagrad"},
                            {"label": "Nadam", "value": "nadam"},
                            {"label": "Adadelta", "value": "adadelta"},
                            {"label": "Adamax", "value": "adamax"},
                            {"label": "RMSprop", "value": "rmsprop"},
                            {"label": "SGD", "value": "sgd"},
                            {"label": "FTRL", "value": "ftrl"},
                        ],
                        value=model_data["optimizer"]
                    ),
                ]),
                html.Div(children=[
                    html.Div(id="epochdisplay"),
                    dcc.Slider(1, 200, 1, marks={1: "1", 100: "100", 200: "200"},
                               value=model_data["epochs"], id="epochs"),
                ]),
                html.Div(children=[
                    html.Div(id="batchdisplay"),
                    dcc.Slider(1, 128, 1, marks={1: "1", 128: "128"},
                               value=model_data["batchsize"], id="batchsize"),
                ]),
            ]
        ),
        html.Button(id="train", n_clicks=0, children="Train"),
        html.Pre(id="progressdisplay"),
        dcc.Interval(id="trainprogress", n_intervals=0, interval=1000),
        dcc.Graph(id="historyplot"),
        html.Div(
            className="flex-container",
            id="predict",
            children=[
                html.Div(
                    children=html.Canvas(id="writing"),
                    style={"textAlign": "center"}
                ),
                html.Div(id="predictresult", children="?"),
                html.Pre(
                    id="lastinput",
                ),
            ]
        ),
        html.Div(id="dummy", style={"display": "none"}),
    ]
)

@app.callback(Output(component_id="epochdisplay", component_property="children"),
              Input(component_id="epochs", component_property="value"))
def update_epochs(value):
    model_data["epochs"] = value
    return f"Epochs: {value}"

@app.callback(Output("batchdisplay", "children"),
              Input("batchsize", "value"))
def update_batchsize(value):
    model_data["batchsize"] = value
    return f"Batch size: {value}"

@app.callback(Output("activationdisplay", "children"),
              Input("activation", "value"))
def update_activation(value):
    model_data["activation"] = value
    return f"Activation: {value}"

@app.callback(Output("optimizerdisplay", "children"),
              Input("optimizer", "value"))
def update_optimizer(value):
    model_data["optimizer"] = value
    return f"Optimizer: {value}"

@app.callback(Output("historyplot", "figure"),
              Input("train", "n_clicks"),
              State("activation", "value"),
              State("optimizer", "value"),
              State("epochs", "value"),
              State("batchsize", "value"),
              prevent_initial_call=True)
def train_action(n_clicks, activation, optimizer, epoch, batchsize):
    model_data.update({
        "activation": activation,
        "optimizer": optimizer,
        "epoch": epoch,
        "batchsize": batchsize,
    })
    model, history = train()
    model_data["model"] = model  # keep the trained model
    history = pd.DataFrame(history.history)
    fig = px.line(history, title="Model training metrics")
    fig.update_layout(xaxis_title="epochs",
                      yaxis_title="metric value", legend_title="metrics")
    return fig

@app.callback(Output("progressdisplay", "children"),
              Input("trainprogress", "n_intervals"))
def update_progress(n):
    return json.dumps(train_status, indent=4)

app.clientside_callback(
    "function() { pageinit(); };",
    Output("dummy", "children"),
    Input("dummy", "children")
)

@server.route("/recognize", methods=["POST"])
def recognize():
    if not model_data.get("model"):
        return "Please train your model."
    matrix = json.loads(request.form["matrix"])
    matrix = np.asarray(matrix).reshape(1, 28, 28)
    proba = model_data["model"].predict(matrix).reshape(-1)
    result = np.argmax(proba)
    return "Digit "+str(result)

# run server, with hot-reloading
app.run_server(debug=True, threaded=True)

如果我们运行所有这些,我们应该看到如下屏幕:

深入阅读

目前有大量的 Web 框架可用,Flask 只是其中之一。另一个流行的框架是 CherryPy。如果你想深入了解,以下是相关资源。

书籍
文章
APIs 和软件

总结

在本教程中,你学习了如何使用 Dash 库在 Python 中轻松构建网页应用。你还学会了如何使用 Flask 创建一些网页 API。具体来说,你学习了

  • 网页应用的机制

  • 我们如何使用 Dash 来构建一个由网页组件触发的简单网页应用

  • 我们如何使用 Flask 创建网页 API

  • 如何在 Javascript 中构建网页应用,并在使用我们用 Flask 构建的网页 API 的浏览器上运行