FastAPI 数据科学应用构建指南(二)
原文:
annas-archive.org/md5/a1f4ad3f5a4649378151351d58ad6e73译者:飞龙
第四章:在 FastAPI 中管理 Pydantic 数据模型
本章将详细讲解如何使用 Pydantic 定义数据模型,这是 FastAPI 使用的底层数据验证库。我们将解释如何在不重复代码的情况下实现相同模型的变种,得益于类的继承。最后,我们将展示如何将自定义数据验证逻辑实现到 Pydantic 模型中。
本章我们将涵盖以下主要内容:
-
使用 Pydantic 定义模型及其字段类型
-
使用类继承创建模型变种
-
使用 Pydantic 添加自定义数据验证
-
使用 Pydantic 对象
技术要求
要运行代码示例,你需要一个 Python 虚拟环境,我们在 第一章,Python 开发环境设置 中进行了设置。
你可以在专门的 GitHub 仓库中找到本章的所有代码示例,链接为:github.com/PacktPublishing/Building-Data-Science-Applications-with-FastAPI-Second-Edition/tree/main/chapter04。
使用 Pydantic 定义模型及其字段类型
Pydantic 是一个强大的库,用于通过 Python 类和类型提示定义数据模型。这种方法使得这些类与静态类型检查完全兼容。此外,由于它是常规的 Python 类,我们可以使用继承,并且还可以定义我们自己的方法来添加自定义逻辑。
在 第三章,使用 FastAPI 开发 RESTful API 中,你学习了如何使用 Pydantic 定义数据模型的基础:你需要定义一个继承自 BaseModel 的类,并将所有字段列为类的属性,每个字段都有一个类型提示来强制其类型。
在本节中,我们将重点关注模型定义,并查看我们在定义字段时可以使用的所有可能性。
标准字段类型
我们将从定义标准类型字段开始,这只涉及简单的类型提示。让我们回顾一下一个表示个人信息的简单模型。你可以在以下代码片段中看到它:
chapter04_standard_field_types_01.py
from pydantic import BaseModelclass Person(BaseModel):
first_name: str
last_name: str
age: int
正如我们所说,你只需要写出字段的名称,并使用预期的类型对其进行类型提示。当然,我们不仅限于标量类型:我们还可以使用复合类型,如列表和元组,或像 datetime 和 enum 这样的类。在下面的示例中,你可以看到一个使用这些更复杂类型的模型:
chapter04_standard_field_types_02.py
from datetime import datefrom enum import Enum
from pydantic import BaseModel, ValidationError
class Gender(str, Enum):
MALE = "MALE"
FEMALE = "FEMALE"
NON_BINARY = "NON_BINARY"
class Person(BaseModel):
first_name: str
last_name: str
gender: Gender
birthdate: date
interests: list[str]
在这个示例中有三点需要注意。
首先,我们使用标准 Python Enum 类作为 gender 字段的类型。这允许我们指定一组有效值。如果输入的值不在该枚举中,Pydantic 会引发错误,如以下示例所示:
chapter04_standard_field_types_02.py
# Invalid gendertry:
Person(
first_name="John",
last_name="Doe",
gender="INVALID_VALUE",
birthdate="1991-01-01",
interests=["travel", "sports"],
)
except ValidationError as e:
print(str(e))
如果你运行前面的示例,你将得到如下输出:
1 validation error for Persongender
value is not a valid enumeration member; permitted: 'MALE', 'FEMALE', 'NON_BINARY' (type=type_error.enum; enum_values=[<Gender.MALE: 'MALE'>, <Gender.FEMALE: 'FEMALE'>, <Gender.NON_BINARY: 'NON_BINARY'>])
实际上,这正是我们在第三章《使用 FastAPI 开发 RESTful API》中所做的,用以限制 path 参数的允许值。
然后,我们将 date Python 类作为 birthdate 字段的类型。Pydantic 能够自动解析以 ISO 格式字符串或时间戳整数给出的日期和时间,并实例化一个合适的 date 或 datetime 对象。当然,如果解析失败,你也会得到一个错误。你可以在以下示例中进行实验:
chapter04_standard_field_types_02.py
# Invalid birthdatetry:
Person(
first_name="John",
last_name="Doe",
gender=Gender.MALE,
birthdate="1991-13-42",
interests=["travel", "sports"],
)
except ValidationError as e:
print(str(e))
这是输出结果:
1 validation error for Personbirthdate
invalid date format (type=value_error.date)
最后,我们将 interests 定义为一个字符串列表。同样,Pydantic 会检查该字段是否是有效的字符串列表。
显然,如果一切正常,我们将得到一个 Person 实例,并能够访问正确解析的字段。这就是我们在以下代码片段中展示的内容:
chapter04_standard_field_types_02.py
# Validperson = Person(
first_name="John",
last_name="Doe",
gender=Gender.MALE,
birthdate="1991-01-01",
interests=["travel", "sports"],
)
# first_name='John' last_name='Doe' gender=<Gender.MALE: 'MALE'> birthdate=datetime.date(1991, 1, 1) interests=['travel', 'sports']
print(person)
如你所见,这非常强大,我们可以拥有相当复杂的字段类型。但这还不是全部:字段本身可以是 Pydantic 模型,允许你拥有子对象!在以下代码示例中,我们扩展了前面的代码片段,添加了一个 address 字段:
chapter04_standard_field_types_03.py
class Address(BaseModel): street_address: str
postal_code: str
city: str
country: str
class Person(BaseModel):
first_name: str
last_name: str
gender: Gender
birthdate: date
interests: list[str]
address: Address
我们只需定义另一个 Pydantic 模型,并将其作为类型提示使用。现在,你可以使用已经有效的Address实例来实例化Person,或者更好的是,使用字典。在这种情况下,Pydantic 会自动解析它并根据地址模型进行验证。
在下面的代码片段中,我们尝试输入一个无效的地址:
chapter04_standard_field_types_03.py
# Invalid addresstry:
Person(
first_name="John",
last_name="Doe",
gender=Gender.MALE,
birthdate="1991-01-01",
interests=["travel", "sports"],
address={
"street_address": "12 Squirell Street",
"postal_code": "424242",
"city": "Woodtown",
# Missing country
},
)
except ValidationError as e:
print(str(e))
这将生成以下验证错误:
1 validation error for Personaddress -> country
field required (type=value_error.missing)
Pydantic 清晰地显示了子对象中缺失的字段。再次强调,如果一切顺利,我们将获得一个Person实例及其关联的Address,如下面的代码片段所示:
chapter04_standard_field_types_03.py
# Validperson = Person(
first_name="John",
last_name="Doe",
gender=Gender.MALE,
birthdate="1991-01-01",
interests=["travel", "sports"],
address={
"street_address": "12 Squirell Street",
"postal_code": "424242",
"city": "Woodtown",
"country": "US",
},
)
print(person)
可选字段和默认值
到目前为止,我们假设在实例化模型时,每个字段都必须提供。然而,通常情况下,有些值我们希望是可选的,因为它们可能对每个对象实例并不相关。有时,我们还希望为未指定的字段设置默认值。
正如你可能猜到的,这可以通过| None类型注解非常简单地完成,如以下代码片段所示:
chapter04_optional_fields_default_values_01.py
from pydantic import BaseModelclass UserProfile(BaseModel):
nickname: str
location: str | None = None
subscribed_newsletter: bool = True
当定义一个字段时,使用| None类型提示,它接受None值。如上面的代码所示,默认值可以通过将值放在等号后面简单地赋值。
但要小心:不要为动态类型(如日期时间)赋予默认值。如果这样做,日期时间实例化只会在模型导入时评估一次。这样一来,你实例化的所有对象都会共享相同的值,而不是每次都生成一个新的值。你可以在以下示例中观察到这种行为:
chapter04_optional_fields_default_values_02.py
class Model(BaseModel): # Don't do this.
# This example shows you why it doesn't work.
d: datetime = datetime.now()
o1 = Model()
print(o1.d)
time.sleep(1) # Wait for a second
o2 = Model()
print(o2.d)
print(o1.d < o2.d) # False
即使我们在实例化o1和o2之间等待了 1 秒钟,d日期时间仍然是相同的!这意味着日期时间只在类被导入时评估一次。
如果你想要一个默认的列表,比如l: list[str] = ["a", "b", "c"],你也会遇到同样的问题。注意,这不仅仅适用于 Pydantic 模型,所有的 Python 对象都会存在这个问题,所以你应该牢记这一点。
那么,我们该如何赋予动态默认值呢?幸运的是,Pydantic 提供了一个Field函数,允许我们为字段设置一些高级选项,其中包括为创建动态值设置工厂。在展示这个之前,我们首先会介绍一下Field函数。
在第三章《使用 FastAPI 开发 RESTful API》中,我们展示了如何对请求参数应用一些验证,检查一个数字是否在某个范围内,或一个字符串是否匹配正则表达式。实际上,这些选项直接来自 Pydantic!我们可以使用相同的技术对模型的字段进行验证。
为此,我们将使用 Pydantic 的Field函数,并将其结果作为字段的默认值。在下面的示例中,我们定义了一个Person模型,其中first_name和last_name是必填字段,必须至少包含三个字符,age是一个可选字段,必须是介于0和120之间的整数。我们在下面的代码片段中展示了该模型的实现:
chapter04_fields_validation_01.py
from pydantic import BaseModel, Field, ValidationErrorclass Person(BaseModel):
first_name: str = Field(..., min_length=3)
last_name: str = Field(..., min_length=3)
age: int | None = Field(None, ge=0, le=120)
如你所见,语法与我们之前看到的Path、Query和Body非常相似。第一个位置参数定义了字段的默认值。如果字段是必填的,我们使用省略号...。然后,关键字参数用于设置字段的选项,包括一些基本的验证。
你可以在官方 Pydantic 文档中查看Field接受的所有参数的完整列表,网址为pydantic-docs.helpmanual.io/usage/schema/#field-customization。
动态默认值
在上一节中,我们曾提醒你不要将动态值设置为默认值。幸运的是,Pydantic 在Field函数中提供了default_factory参数来处理这种用例。这个参数要求你传递一个函数,这个函数将在模型实例化时被调用。因此,每次你创建一个新对象时,生成的对象将在运行时进行评估。你可以在以下示例中看到如何使用它:
chapter04_fields_validation_02.py
from datetime import datetimefrom pydantic import BaseModel, Field
def list_factory():
return ["a", "b", "c"]
class Model(BaseModel):
l: list[str] = Field(default_factory=list_factory)
d: datetime = Field(default_factory=datetime.now)
l2: list[str] = Field(default_factory=list)
你只需将一个函数传递给这个参数。不要在其上放置参数:当你实例化新对象时,Pydantic 会自动调用这个函数。如果你需要使用特定的参数调用一个函数,你需要将它包装成自己的函数,正如我们为list_factory所做的那样。
还请注意,默认值所使用的第一个位置参数(如None或...)在这里完全省略了。这是有道理的:同时使用默认值和工厂是不一致的。如果你将这两个参数一起设置,Pydantic 会抛出错误。
使用 Pydantic 类型验证电子邮件地址和 URL
为了方便,Pydantic 提供了一些类,可以作为字段类型来验证一些常见模式,例如电子邮件地址或 URL。
在以下示例中,我们将使用EmailStr和HttpUrl来验证电子邮件地址和 HTTP URL。
要使EmailStr工作,你需要一个可选的依赖项email-validator,你可以使用以下命令安装:
(venv)$ pip install email-validator
这些类的工作方式与其他类型或类相同:只需将它们作为字段的类型提示使用。你可以在以下代码片段中看到这一点:
chapter04_pydantic_types_01.py
from pydantic import BaseModel, EmailStr, HttpUrl, ValidationErrorclass User(BaseModel):
email: EmailStr
website: HttpUrl
在以下示例中,我们检查电子邮件地址是否被正确验证:
chapter04_pydantic_types_01.py
# Invalid emailtry:
User(email="jdoe", website="https://www.example.com")
except ValidationError as e:
print(str(e))
你将看到以下输出:
1 validation error for Useremail
value is not a valid email address (type=value_error.email)
我们还检查了 URL 是否被正确解析,如下所示:
chapter04_pydantic_types_01.py
# Invalid URLtry:
User(email="jdoe@example.com", website="jdoe")
except ValidationError as e:
print(str(e))
你将看到以下输出:
1 validation error for Userwebsite
invalid or missing URL scheme (type=value_error.url.scheme)
如果你查看下面的有效示例,你会发现 URL 被解析为一个对象,这样你就可以访问它的不同部分,比如协议或主机名:
chapter04_pydantic_types_01.py
# Validuser = User(email="jdoe@example.com", website="https://www.example.com")
# email='jdoe@example.com' website=HttpUrl('https://www.example.com', scheme='https', host='www.example.com', tld='com', host_type='domain')
print(user)
Pydantic 提供了一套非常丰富的类型,可以帮助你处理各种情况。我们邀请你查阅官方文档中的完整列表:pydantic-docs.helpmanual.io/usage/types/#pydantic-types。
现在你对如何通过使用更高级的类型或利用验证功能来细化定义 Pydantic 模型有了更清晰的了解。正如我们所说,这些模型是 FastAPI 的核心,你可能需要为同一个实体定义多个变体,以应对不同的情况。在接下来的部分中,我们将展示如何做到这一点,同时最小化重复。
使用类继承创建模型变体
在第三章,使用 FastAPI 开发 RESTful API中,我们看到一个例子,在这个例子中我们需要定义 Pydantic 模型的两个变体,以便将我们想要存储在后端的数据和我们想要展示给用户的数据分开。这是 FastAPI 中的一个常见模式:你定义一个用于创建的模型,一个用于响应的模型,以及一个用于存储在数据库中的数据模型。
我们在以下示例中展示了这种基本方法:
chapter04_model_inheritance_01.py
from pydantic import BaseModelclass PostCreate(BaseModel):
title: str
content: str
class PostRead(BaseModel):
id: int
title: str
content: str
class Post(BaseModel):
id: int
title: str
content: str
nb_views: int = 0
这里我们有三个模型,涵盖了三种情况:
-
PostCreate将用于POST端点来创建新帖子。我们期望用户提供标题和内容;然而,标识符(ID)将由数据库自动确定。 -
PostRead将用于我们检索帖子数据时。我们当然希望获取它的标题和内容,还希望知道它在数据库中的关联 ID。 -
Post将包含我们希望存储在数据库中的所有数据。在这里,我们还想存储查看次数,但希望将其保密,以便内部进行统计。
你可以看到这里我们重复了很多,特别是 title 和 content 字段。在包含许多字段和验证选项的大型示例中,这可能会迅速变得难以管理。
避免这种情况的方法是利用模型继承。方法很简单:找出每个变种中共有的字段,并将它们放入一个模型中,作为所有其他模型的基类。然后,你只需从这个模型继承来创建变体,并添加特定的字段。在以下示例中,我们可以看到使用这种方法后的结果:
chapter04_model_inheritance_02.py
from pydantic import BaseModelclass PostBase(BaseModel):
title: str
content: str
class PostCreate(PostBase):
pass
class PostRead(PostBase):
id: int
class Post(PostBase):
id: int
nb_views: int = 0
现在,每当你需要为整个实体添加一个字段时,所需要做的就是将其添加到 PostBase 模型中,如下所示的代码片段所示。
如果你希望在模型中定义方法,这也是非常方便的。记住,Pydantic 模型是普通的 Python 类,因此你可以根据需要实现尽可能多的方法!
chapter04_model_inheritance_03.py
class PostBase(BaseModel): title: str
content: str
def excerpt(self) -> str:
return f"{self.content[:140]}..."
在 PostBase 中定义 excerpt 方法意味着它将会在每个模型变种中都可用。
虽然这种继承方法不是强制要求的,但它大大有助于防止代码重复,并最终减少错误。我们将在下一节看到,使用自定义验证方法时,它将显得更加有意义。
使用 Pydantic 添加自定义数据验证
到目前为止,我们已经看到了如何通过 Field 参数或 Pydantic 提供的自定义类型为模型应用基本验证。然而,在一个实际项目中,你可能需要为特定情况添加自定义验证逻辑。Pydantic 允许通过定义 validators 来实现这一点,验证方法可以应用于字段级别或对象级别。
在字段级别应用验证
这是最常见的情况:为单个字段定义验证规则。要在 Pydantic 中定义验证规则,我们只需要在模型中编写一个静态方法,并用 validator 装饰器装饰它。作为提醒,装饰器是一种语法糖,它允许用通用逻辑包装函数或类,而不会影响可读性。
以下示例检查出生日期,确保这个人不超过 120 岁:
chapter04_custom_validation_01.py
from datetime import datefrom pydantic import BaseModel, ValidationError, validator
class Person(BaseModel):
first_name: str
last_name: str
birthdate: date
@validator("birthdate")
def valid_birthdate(cls, v: date):
delta = date.today() - v
age = delta.days / 365
if age > 120:
raise ValueError("You seem a bit too old!")
return v
如你所见,validator 是一个静态类方法(第一个参数,cls,是类本身),v 参数是要验证的值。它由 validator 装饰器装饰,要求第一个参数是需要验证的参数的名称。
Pydantic 对此方法有两个要求,如下所示:
-
如果值根据你的逻辑不合法,你应该抛出一个
ValueError错误并提供明确的错误信息。 -
否则,你应该返回将被赋值给模型的值。请注意,它不需要与输入值相同:你可以根据需要轻松地更改它。这实际上是我们将在接下来的章节中做的,在 Pydantic 解析之前应用验证。
在对象级别应用验证
很多时候,一个字段的验证依赖于另一个字段——例如,检查密码确认是否与密码匹配,或在某些情况下强制要求某个字段为必填项。为了允许这种验证,我们需要访问整个对象的数据。为此,Pydantic 提供了 root_validator 装饰器,如下面的代码示例所示:
chapter04_custom_validation_02.py
from pydantic import BaseModel, EmailStr, ValidationError, root_validatorclass UserRegistration(BaseModel):
email: EmailStr
password: str
password_confirmation: str
@root_validator()
def passwords_match(cls, values):
password = values.get("password")
password_confirmation = values.get("password_confirmation")
if password != password_confirmation:
raise ValueError("Passwords don't match")
return values
使用此装饰器的方法类似于 validator 装饰器。静态类方法与 values 参数一起调用,values 是一个 字典,包含所有字段。这样,你可以获取每个字段并实现你的逻辑。
再次强调,Pydantic 对此方法有两个要求,如下所示:
-
如果根据你的逻辑,值不合法,你应该抛出一个
ValueError错误并提供明确的错误信息。 -
否则,你应该返回一个
values字典,这个字典将被赋值给模型。请注意,你可以根据需要在这个字典中修改某些值。
在 Pydantic 解析之前应用验证
默认情况下,验证器在 Pydantic 完成解析工作之后运行。这意味着你得到的值已经符合你指定的字段类型。如果类型不正确,Pydantic 会抛出错误,而不会调用你的验证器。
然而,有时你可能希望提供一些自定义解析逻辑,以允许你转换那些对于所设置类型来说原本不正确的输入值。在这种情况下,你需要在 Pydantic 解析器之前运行你的验证器:这就是 validator 中 pre 参数的作用。
在下面的示例中,我们展示了如何将一个由逗号分隔的字符串转换为列表:
chapter04_custom_validation_03.py
from pydantic import BaseModel, validatorclass Model(BaseModel):
values: list[int]
@validator("values", pre=True)
def split_string_values(cls, v):
if isinstance(v, str):
return v.split(",")
return v
m = Model(values="1,2,3")
print(m.values) # [1, 2, 3]
你可以看到,在这里我们的验证器首先检查我们是否有一个字符串。如果有,我们将逗号分隔的字符串进行拆分,并返回结果列表;否则,我们直接返回该值。Pydantic 随后会运行它的解析逻辑,因此你仍然可以确保如果 v 是无效值,会抛出错误。
使用 Pydantic 对象
在使用 FastAPI 开发 API 接口时,你可能会处理大量的 Pydantic 模型实例。接下来,你需要实现逻辑,将这些对象与服务进行连接,比如数据库或机器学习模型。幸运的是,Pydantic 提供了一些方法,使得这个过程变得非常简单。我们将回顾一些开发过程中常用的使用场景。
将对象转换为字典
这可能是你在 Pydantic 对象上执行最多的操作:将其转换为一个原始字典,这样你就可以轻松地将其发送到另一个 API,或者例如用在数据库中。你只需在对象实例上调用 dict 方法。
以下示例重用了我们在本章的标准字段类型部分看到的 Person 和 Address 模型:
chapter04_working_pydantic_objects_01.py
person = Person( first_name="John",
last_name="Doe",
gender=Gender.MALE,
birthdate="1991-01-01",
interests=["travel", "sports"],
address={
"street_address": "12 Squirell Street",
"postal_code": "424242",
"city": "Woodtown",
"country": "US",
},
)
person_dict = person.dict()
print(person_dict["first_name"]) # "John"
print(person_dict["address"]["street_address"]) # "12 Squirell Street"
如你所见,调用 dict 就足以将所有数据转换为字典。子对象也会递归地被转换:address 键指向一个包含地址属性的字典。
有趣的是,dict 方法支持一些参数,允许你选择要转换的属性子集。你可以指定你希望包括的属性,或者希望排除的属性,正如下面的代码片段所示:
chapter04_working_pydantic_objects_02.py
person_include = person.dict(include={"first_name", "last_name"})print(person_include) # {"first_name": "John", "last_name": "Doe"}
person_exclude = person.dict(exclude={"birthdate", "interests"})
print(person_exclude)
include 和 exclude 参数期望一个集合,集合中包含你希望包含或排除的字段的键。
对于像 address 这样的嵌套结构,你也可以使用字典来指定要包含或排除的子字段,以下示例演示了这一点:
chapter04_working_pydantic_objects_02.py
person_nested_include = person.dict( include={
"first_name": ...,
"last_name": ...,
"address": {"city", "country"},
}
)
# {"first_name": "John", "last_name": "Doe", "address": {"city": "Woodtown", "country": "US"}}
print(person_nested_include)
结果的 address 字典仅包含城市和国家。请注意,当使用这种语法时,像 first_name 和 last_name 这样的标量字段必须与省略号 ... 一起使用。
如果你经常进行某种转换,将其放入一个方法中以便于随时重用是很有用的,以下示例演示了这一点:
chapter04_working_pydantic_objects_03.py
class Person(BaseModel): first_name: str
last_name: str
gender: Gender
birthdate: date
interests: list[str]
address: Address
def name_dict(self):
return self.dict(include={"first_name", "last_name"})
从子类对象创建实例
在 通过类继承创建模型变体 这一节中,我们研究了根据具体情况创建特定模型类的常见模式。特别地,你会有一个专门用于创建端点的模型,其中只有创建所需的字段,以及一个包含我们想要存储的所有字段的数据库模型。
让我们再看一下 Post 示例:
chapter04_working_pydantic_objects_04.py
class PostBase(BaseModel): title: str
content: str
class PostCreate(PostBase):
pass
class PostRead(PostBase):
id: int
class Post(PostBase):
id: int
nb_views: int = 0
假设我们有一个创建端点的 API。在这种情况下,我们会得到一个只有 title 和 content 的 PostCreate 实例。然而,在将其存储到数据库之前,我们需要构建一个适当的 Post 实例。
一种方便的做法是同时使用 dict 方法和解包语法。在以下示例中,我们使用这种方法实现了一个创建端点:
chapter04_working_pydantic_objects_04.py
@app.post("/posts", status_code=status.HTTP_201_CREATED, response_model=PostRead)async def create(post_create: PostCreate):
new_id = max(db.posts.keys() or (0,)) + 1
post = Post(id=new_id, **post_create.dict())
db.posts[new_id] = post
return post
如你所见,路径操作函数为我们提供了一个有效的PostCreate对象。然后,我们想将其转换为Post对象。
我们首先确定缺失的id属性,这是由数据库提供的。在这里,我们使用基于字典的虚拟数据库,因此我们只需取数据库中已存在的最大键并将其递增。在实际情况下,这个值会由数据库自动确定。
这里最有趣的一行是Post实例化。你可以看到,我们首先使用关键字参数分配缺失的字段,然后解包post_create的字典表示。提醒一下,**在函数调用中的作用是将像{"title": "Foo", "content": "Bar"}这样的字典转换为像title="Foo", content="Bar"这样的关键字参数。这是一种非常方便和动态的方式,将我们已有的所有字段设置到新的模型中。
请注意,我们还在路径操作装饰器中设置了response_model参数。我们在第三章,使用 FastAPI 开发 RESTful API中解释了这一点,但基本上,它提示 FastAPI 构建一个只包含PostRead字段的 JSON 响应,即使我们最终返回的是一个Post实例。
部分更新实例
在某些情况下,你可能需要允许部分更新。换句话说,你允许最终用户仅向你的 API 发送他们想要更改的字段,并省略不需要更改的字段。这是实现PATCH端点的常见方式。
为此,你首先需要一个特殊的 Pydantic 模型,所有字段都标记为可选,这样在缺少某个字段时不会引发错误。让我们看看在我们的Post示例中这是什么样的:
chapter04_working_pydantic_objects_05.py
class PostBase(BaseModel): title: str
content: str
class PostPartialUpdate(BaseModel):
title: str | None = None
content: str | None = None
我们现在能够实现一个端点,接受Post字段的子集。由于这是一个更新操作,我们将通过其 ID 从数据库中检索现有的帖子。然后,我们需要找到一种方法,只更新负载中的字段,保持其他字段不变。幸运的是,Pydantic 再次提供了便捷的方法和选项来解决这个问题。
让我们看看如何在以下示例中实现这样的端点:
chapter04_working_pydantic_objects_05.py
@app.patch("/posts/{id}", response_model=PostRead)async def partial_update(id: int, post_update: PostPartialUpdate):
try:
post_db = db.posts[id]
updated_fields = post_update.dict(exclude_unset=True)
updated_post = post_db.copy(update=updated_fields)
db.posts[id] = updated_post
return updated_post
except KeyError:
raise HTTPException(status.HTTP_404_NOT_FOUND)
我们的路径操作函数接受两个参数:id属性(来自路径)和PostPartialUpdate实例(来自请求体)。
首先要做的是检查这个id属性是否存在于数据库中。由于我们使用字典作为虚拟数据库,访问一个不存在的键会引发KeyError。如果发生这种情况,我们只需抛出一个HTTPException并返回404状态码。
现在是有趣的部分:更新现有对象。你可以看到,首先要做的是使用dict方法将PostPartialUpdate转换为字典。然而,这次我们将exclude_unset参数设置为True。这样做的效果是,Pydantic 不会在结果字典中输出未提供的字段:我们只会得到用户在有效负载中发送的字段。
然后,在我们现有的post_db数据库实例上,调用copy方法。这个方法是克隆 Pydantic 对象到另一个实例的一个有用方法。这个方法的好处在于它甚至接受一个update参数。这个参数期望一个字典,包含所有在复制过程中应该更新的字段:这正是我们想用updated_fields字典来做的!
就这样!我们现在有了一个更新过的post实例,只有在有效负载中需要的更改。你在使用 FastAPI 开发时,可能会经常使用exclude_unset参数和copy方法,所以一定要记住它们——它们会让你的工作更轻松!
总结
恭喜你!你已经学习了 FastAPI 的另一个重要方面:使用 Pydantic 设计和管理数据模型。现在,你应该对创建模型、应用字段级验证、使用内建选项和类型,以及实现你自己的验证方法有信心。你还了解了如何在对象级别应用验证,检查多个字段之间的一致性。你还学会了如何利用模型继承来避免在定义模型变体时出现代码重复。最后,你学会了如何正确处理 Pydantic 模型实例,从而以高效且可读的方式进行转换和更新。
到现在为止,你几乎已经掌握了 FastAPI 的所有功能。现在有一个最后非常强大的功能等着你去学习:依赖注入。这允许你定义自己的逻辑和数值,并将它们直接注入到路径操作函数中,就像你对路径参数和有效负载对象所做的那样,你可以在项目的任何地方重用它们。这是下一章的内容。
第五章:FastAPI 中的依赖注入
在本章中,我们将重点讨论 FastAPI 最有趣的部分之一:依赖注入。你将会看到,它是一种强大且易于阅读的方式,用于在项目中重用逻辑。事实上,它将允许你为项目创建复杂的构建模块,这些模块可以在整个逻辑中重复使用。认证系统、查询参数验证器或速率限制器是依赖项的典型用例。在 FastAPI 中,依赖注入甚至可以递归调用另一个依赖项,从而允许你从基础功能构建高级模块。到本章结束时,你将能够为 FastAPI 创建自己的依赖项,并在项目的多个层次上使用它们。
在本章中,我们将涵盖以下主要主题:
-
什么是依赖注入?
-
创建和使用函数依赖项
-
使用类创建和使用带参数的依赖项
-
在路径、路由器和全局级别使用依赖项
技术要求
要运行代码示例,你需要一个 Python 虚拟环境,我们在第一章中设置了该环境,Python 开发 环境设置。
你可以在专门的 GitHub 仓库中找到本章的所有代码示例:github.com/PacktPublishing/Building-Data-Science-Applications-with-FastAPI-Second-Edition/tree/main/chapter05。
什么是依赖注入?
一般来说,依赖注入是一种自动实例化对象及其依赖项的系统。开发者的责任是仅提供对象创建的声明,让系统在运行时解析所有的依赖链并创建实际的对象。
FastAPI 允许你通过在路径操作函数的参数中声明它们,仅声明你希望使用的对象和变量。事实上,我们在前几章中已经使用了依赖注入。在以下示例中,我们使用 Header 函数来检索 user-agent 头信息:
chapter05_what_is_dependency_injection_01.py
from fastapi import FastAPI, Headerapp = FastAPI()
@app.get("/")
async def header(user_agent: str = Header(...)):
return {"user_agent": user_agent}
内部来说,Header 函数具有一些逻辑,可以自动获取 request 对象,检查是否存在所需的头信息,返回其值,或者在不存在时抛出错误。然而,从开发者的角度来看,我们并不知道它是如何处理所需的对象的:我们只需要获取我们所需的值。这就是 依赖注入。
诚然,你可以通过在 request 对象的 headers 字典中选取 user-agent 属性来在函数体中轻松地重现这个示例。然而,依赖注入方法相比之下有许多优势:
-
意图明确:你可以在不阅读函数代码的情况下,知道端点在请求数据中期望什么。
-
你有一个明确的关注点分离:端点的逻辑和更通用的逻辑之间的头部检索及其关联的错误处理不会污染其他逻辑;它在依赖函数中自包含。此外,它可以轻松地在其他端点中重用。
-
在 FastAPI 中,它被用来生成 OpenAPI 架构,以便自动生成的文档可以清楚地显示此端点所需的参数。
换句话说,每当你需要一些工具逻辑来检索或验证数据、进行安全检查,或调用你在应用中多次需要的外部逻辑时,依赖是一个理想的选择。
FastAPI 很大程度上依赖于这个依赖注入系统,并鼓励开发者使用它来实现他们的构建模块。如果你来自其他 Web 框架,比如 Flask 或 Express,可能会有些困惑,但你肯定会很快被它的强大和相关性所说服。
为了说服你,我们现在将看到如何创建和使用你自己的依赖,首先从函数形式开始。
创建并使用一个函数依赖
在 FastAPI 中,依赖可以被定义为一个函数或一个可调用的类。在本节中,我们将重点关注函数,因为它们是你最可能经常使用的。
正如我们所说,依赖是将一些逻辑封装起来的方式,这些逻辑会获取一些子值或子对象,处理它们,并最终返回一个将被注入到调用端点中的值。
让我们来看第一个示例,我们定义一个函数依赖来获取分页查询参数 skip 和 limit:
chapter05_function_dependency_01.py
async def pagination(skip: int = 0, limit: int = 10) -> tuple[int, int]: return (skip, limit)
@app.get("/items")
async def list_items(p: tuple[int, int] = Depends(pagination)):
skip, limit = p
return {"skip": skip, "limit": limit}
这个示例有两个部分:
- 首先,我们有依赖定义,带有
pagination函数。你会看到我们定义了两个参数,skip和limit,它们是具有默认值的整数。这些将是我们端点的查询参数。我们定义它们的方式与在路径操作函数中定义的方式完全相同。这就是这种方法的美妙之处:FastAPI 会递归地处理依赖中的参数,并根据需要与请求数据(如查询参数或头部)进行匹配。
我们只需将这些值作为一个元组返回。
- 第二,我们有路径操作函数
list_items,它使用了pagination依赖。你可以看到,使用方法与我们为头部或正文值所做的非常相似:我们定义了结果参数的名称,并使用函数结果作为默认值。对于依赖,我们使用Depends函数。它的作用是将函数作为参数传递,并在调用端点时执行它。子依赖会被自动发现并执行。
在该端点中,我们将分页直接作为一个元组返回。
让我们使用以下命令运行这个示例:
$ uvicorn chapter05_function_dependency_01:app
现在,我们将尝试调用 /items 端点,看看它是否能够获取查询参数。你可以使用以下 HTTPie 命令来尝试:
$ http "http://localhost:8000/items?limit=5&skip=10"HTTP/1.1 200 OK
content-length: 21
content-type: application/json
date: Tue, 15 Nov 2022 08:33:46 GMT
server: uvicorn
{
"limit": 5,
"skip": 10
}
limit 和 skip 查询参数已经通过我们的函数依赖正确地获取。你也可以尝试不带查询参数调用该端点,并注意它会返回默认值。
依赖返回值的类型提示
你可能已经注意到,我们在路径操作的参数中必须对依赖的结果进行类型提示,即使我们已经为依赖函数本身进行了类型提示。不幸的是,这是 FastAPI 及其 Depends 函数的一个限制,Depends 函数无法传递依赖函数的类型。因此,我们必须手动对结果进行类型提示,就像我们在这里所做的那样。
就这样!如你所见,在 FastAPI 中创建和使用依赖非常简单直接。当然,你现在可以在多个端点中随意重用它,正如你在其余示例中所看到的那样。
chapter05_function_dependency_01.py
@app.get("/things")async def list_things(p: tuple[int, int] = Depends(pagination)):
skip, limit = p
return {"skip": skip, "limit": limit}
在这些依赖中,我们可以做更复杂的事情,就像在常规路径操作函数中一样。在以下示例中,我们为这些分页参数添加了一些验证,并将 limit 限制为 100:
chapter05_function_dependency_02.py
async def pagination( skip: int = Query(0, ge=0),
limit: int = Query(10, ge=0),
) -> tuple[int, int]:
capped_limit = min(100, limit)
return (skip, capped_limit)
如你所见,我们的依赖开始变得更加复杂:
-
我们在参数中添加了
Query函数来增加验证约束;现在,如果skip或limit是负整数,系统将抛出422错误。 -
我们确保
limit最多为100。
我们的路径操作函数中的代码不需要修改;我们清楚地将端点的逻辑与分页参数的更通用逻辑分开。
让我们来看另一个典型的依赖项使用场景:获取一个对象或引发404错误。
获取对象或引发 404 错误
在 REST API 中,你通常会有端点用于根据路径中的标识符获取、更新和删除单个对象。在每个端点中,你很可能会有相同的逻辑:尝试从数据库中检索这个对象,或者如果它不存在,就引发一个404错误。这是一个非常适合使用依赖项的场景!在以下示例中,你将看到如何实现它:
chapter05_function_dependency_03.py
async def get_post_or_404(id: int) -> Post: try:
return db.posts[id]
except KeyError:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
依赖项的定义很简单:它接受一个参数,即我们想要获取的帖子的 ID。它将从相应的路径参数中提取。然后,我们检查它是否存在于我们的虚拟字典数据库中:如果存在,我们返回它;否则,我们会引发一个404状态码的 HTTP 异常。
这是这个示例的关键要点:你可以在依赖项中引发错误。在执行端点逻辑之前,检查某些前置条件是非常有用的。另一个典型的例子是身份验证:如果端点需要用户认证,我们可以通过检查令牌或 cookie,在依赖项中引发401错误。
现在,我们可以在每个 API 端点中使用这个依赖项,如下例所示:
chapter05_function_dependency_03.py
@app.get("/posts/{id}")async def get(post: Post = Depends(get_post_or_404)):
return post
@app.patch("/posts/{id}")
async def update(post_update: PostUpdate, post: Post = Depends(get_post_or_404)):
updated_post = post.copy(update=post_update.dict())
db.posts[post.id] = updated_post
return updated_post
@app.delete("/posts/{id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete(post: Post = Depends(get_post_or_404)):
db.posts.pop(post.id)
如你所见,我们只需要定义post参数并在get_post_or_404依赖项上使用Depends函数。然后,在路径操作逻辑中,我们可以确保手头有post对象,我们可以集中处理核心逻辑,现在变得非常简洁。例如,get端点只需要返回该对象。
在这种情况下,唯一需要注意的是不要忘记这些端点路径中的 ID 参数。根据 FastAPI 的规则,如果你在路径中没有设置这个参数,它会自动被视为查询参数,这不是我们想要的。你可以在第三章,使用 FastAPI 开发 RESTful API的路径参数部分找到更多详细信息。
这就是函数依赖项的全部内容。正如我们所说,它们是 FastAPI 项目的主要构建块。然而,在某些情况下,你可能需要在这些依赖项中设置一些参数——例如,来自环境变量的值。为此,我们可以定义类依赖项。
使用类创建和使用带参数的依赖项
在前一部分中,我们将依赖项定义为常规函数,这在大多数情况下效果良好。然而,你可能需要为依赖项设置一些参数,以便精细调整其行为。由于函数的参数是由依赖注入系统设置的,我们无法向函数添加额外的参数。
在分页示例中,我们添加了一些逻辑将限制值设定为 100。如果我们想要动态设置这个最大限制值,该如何操作呢?
解决方案是创建一个作为依赖项使用的类。这样,我们可以通过 __init__ 方法等设置类属性,并在依赖项的逻辑中使用它们。这些逻辑将会在类的 __call__ 方法中定义。如果你还记得我们在第二章的可调用对象部分中学到的内容,你会知道它使对象可调用,也就是说,它可以像常规函数一样被调用。事实上,这就是 Depends 对依赖项的所有要求:可调用。我们将利用这一特性,通过类来创建一个带参数的依赖项。
在下面的示例中,我们使用类重新实现了分页示例,这使得我们可以动态设置最大限制:
chapter05_class_dependency_01.py
class Pagination: def __init__(self, maximum_limit: int = 100):
self.maximum_limit = maximum_limit
async def __call__(
self,
skip: int = Query(0, ge=0),
limit: int = Query(10, ge=0),
) -> tuple[int, int]:
capped_limit = min(self.maximum_limit, limit)
return (skip, capped_limit)
正如你所看到的,__call__ 方法中的逻辑与我们在前一个示例中定义的函数相同。唯一的区别是,我们可以从类的属性中获取最大限制值,这些属性可以在对象初始化时设置。
然后,你可以简单地创建该类的实例,并在路径操作函数中使用 Depends 作为依赖项,就像你在以下代码块中看到的那样:
chapter05_class_dependency_01.py
pagination = Pagination(maximum_limit=50)@app.get("/items")
async def list_items(p: tuple[int, int] = Depends(pagination)):
skip, limit = p
return {"skip": skip, "limit": limit}
在这里,我们硬编码了 50 的值,但我们完全可以从配置文件或环境变量中获取这个值。
类依赖的另一个优点是它可以在内存中保持局部值。如果我们需要进行一些繁重的初始化逻辑,例如加载一个机器学习模型,我们希望在启动时只做一次。然后,可调用的部分只需调用已加载的模型来进行预测,这应该是非常快速的。
使用类方法作为依赖项
即使__call__方法是实现类依赖的最直接方式,你也可以直接将方法传递给Depends。实际上,正如我们所说,它只需要一个可调用对象作为参数,而类方法是一个完全有效的可调用对象!
如果你有一些公共参数或逻辑需要在稍微不同的情况下重用,这种方法非常有用。例如,你可以有一个使用 scikit-learn 训练的预训练机器学习模型。在应用决策函数之前,你可能想根据输入数据应用不同的预处理步骤。
要做到这一点,只需将你的逻辑写入一个类方法,并通过点符号将其传递给Depends函数。
你可以在以下示例中看到这一点,我们为分页依赖项实现了另一种样式,使用page和size参数,而不是skip和limit:
chapter05_class_dependency_02.py
class Pagination: def __init__(self, maximum_limit: int = 100):
self.maximum_limit = maximum_limit
async def skip_limit(
self,
skip: int = Query(0, ge=0),
limit: int = Query(10, ge=0),
) -> tuple[int, int]:
capped_limit = min(self.maximum_limit, limit)
return (skip, capped_limit)
async def page_size(
self,
page: int = Query(1, ge=1),
size: int = Query(10, ge=0),
) -> tuple[int, int]:
capped_size = min(self.maximum_limit, size)
return (page, capped_size)
这两种方法的逻辑非常相似。我们只是在查看不同的查询参数。然后,在我们的路径操作函数中,我们将/items端点设置为使用skip/limit样式,而/things端点将使用page/size样式:
chapter05_class_dependency_02.py
pagination = Pagination(maximum_limit=50)@app.get("/items")
async def list_items(p: tuple[int, int] = Depends(pagination.skip_limit)):
skip, limit = p
return {"skip": skip, "limit": limit}
@app.get("/things")
async def list_things(p: tuple[int, int] = Depends(pagination.page_size)):
page, size = p
return {"page": page, "size": size}
正如你所看到的,我们只需通过点符号将所需的方法传递给pagination对象。
总结来说,类依赖方法比函数依赖方法更为先进,但在需要动态设置参数、执行繁重的初始化逻辑或在多个依赖项之间重用公共逻辑时非常有用。
到目前为止,我们假设我们关心依赖项的返回值。虽然大多数情况下确实如此,但你可能偶尔需要调用依赖项以检查某些条件,但并不需要返回值。FastAPI 允许这种用例,接下来我们将看到这个功能。
在路径、路由器和全局级别使用依赖项
如我们所说,依赖项是创建 FastAPI 项目构建模块的推荐方式,它允许你在多个端点间重用逻辑,同时保持代码的最大可读性。到目前为止,我们已将依赖项应用于单个端点,但我们能否将这种方法扩展到整个路由器?甚至是整个 FastAPI 应用程序?事实上,我们可以!
这样做的主要动机是能够在多个路由上应用一些全局请求验证或执行副作用逻辑,而无需在每个端点上都添加依赖项。通常,身份验证方法或速率限制器可能是这个用例的很好的候选者。
为了向你展示它是如何工作的,我们将实现一个简单的依赖项,并在以下所有示例中使用它。你可以在以下示例中看到它:
chapter05_path_dependency_01.py
def secret_header(secret_header: str | None = Header(None)) -> None: if not secret_header or secret_header != "SECRET_VALUE":
raise HTTPException(status.HTTP_403_FORBIDDEN)
这个依赖项将简单地查找请求中名为 Secret-Header 的头部。如果它缺失或不等于 SECRET_VALUE,它将引发 403 错误。请注意,这种方法仅用于示例;有更好的方式来保护你的 API,我们将在第七章中讨论,在 FastAPI 中管理身份验证和安全性。
在路径装饰器上使用依赖项
直到现在,我们一直假设我们总是对依赖项的返回值感兴趣。正如我们的 secret_header 依赖项在这里清楚地显示的那样,这并非总是如此。这就是为什么你可以将依赖项添加到路径操作装饰器上,而不是传递参数。你可以在以下示例中看到如何操作:
chapter05_path_dependency_01.py
@app.get("/protected-route", dependencies=[Depends(secret_header)])async def protected_route():
return {"hello": "world"}
路径操作装饰器接受一个参数 dependencies,该参数期望一个依赖项列表。你会发现,就像为依赖项传递参数一样,你需要用 Depends 函数包装你的函数(或可调用对象)。
现在,每当调用 /protected-route 路由时,依赖项将被调用并检查所需的头部信息。
如你所料,由于 dependencies 是一个列表,你可以根据需要添加任意数量的依赖项。
这很有趣,但如果我们想保护一整组端点呢?手动为每个端点添加可能会有点繁琐且容易出错。幸运的是,FastAPI 提供了一种方法来实现这一点。
在整个路由器上使用依赖项
如果你记得在 第三章中的 使用多个路由器结构化一个更大的项目 部分,使用 FastAPI 开发 RESTful API,你就知道你可以在项目中创建多个路由器,以清晰地拆分 API 的不同部分,并将它们“连接”到你的主 FastAPI 应用程序。这是通过 APIRouter 类和 FastAPI 类的 include_router 方法来完成的。
使用这种方法,将一个依赖项注入整个路由器是很有趣的,这样它会在该路由器的每个路由上被调用。你有两种方法可以做到这一点:
- 在
APIRouter类上设置dependencies参数,正如以下示例所示:
chapter05_router_dependency_01.py
router = APIRouter(dependencies=[Depends(secret_header)])@router.get("/route1")
async def router_route1():
return {"route": "route1"}
@router.get("/route2")
async def router_route2():
return {"route": "route2"}
app = FastAPI()
app.include_router(router, prefix="/router")
- 在
include_router方法上设置dependencies参数,正如以下示例所示:
chapter05_router_dependency_02.py
router = APIRouter()@router.get("/route1")
async def router_route1():
return {"route": "route1"}
@router.get("/route2")
async def router_route2():
return {"route": "route2"}
app = FastAPI()
app.include_router(router, prefix="/router", dependencies=[Depends(secret_header)])
在这两种情况下,dependencies 参数都期望一个依赖项的列表。你可以看到,就像传递依赖项作为参数一样,你需要用 Depends 函数将你的函数(或可调用对象)包装起来。当然,由于它是一个列表,如果需要,你可以添加多个依赖项。
现在,如何选择这两种方法呢?在这两种情况下,效果完全相同,所以我们可以说其实并不重要。从哲学角度来看,我们可以说,如果依赖项在这个路由器的上下文中是必要的,我们应该在 APIRouter 类上声明依赖项。换句话说,我们可以问自己这个问题,如果我们独立运行这个路由器,是否没有这个依赖项就无法工作?如果这个问题的答案是否,那么你可能应该在 APIRouter 类上设置依赖项。否则,在 include_router 方法中声明它可能更有意义。但再说一次,这只是一个思想选择,它不会改变你 API 的功能,因此你可以选择你更舒适的方式。
现在,我们能够为整个路由器设置依赖项。在某些情况下,为整个应用程序声明依赖项也可能很有趣!
在整个应用程序中使用依赖项
如果你有一个实现了某些日志记录或限流功能的依赖项,例如,将其应用到你 API 的每个端点可能会很有意义。幸运的是,FastAPI 允许这样做,正如以下示例所示:
chapter05_global_dependency_01.py
app = FastAPI(dependencies=[Depends(secret_header)])@app.get("/route1")
async def route1():
return {"route": "route1"}
@app.get("/route2")
async def route2():
return {"route": "route2"}
再次强调,你只需直接在主 FastAPI 类上设置 dependencies 参数。现在,依赖项应用于你 API 中的每个端点!
在 图 5*.1* 中,我们提出了一个简单的决策树,用于确定你应该在哪个级别注入依赖项:
图 5.1 – 我应该在哪个级别注入我的依赖项?
摘要
恭喜!现在你应该已经熟悉了 FastAPI 最具标志性的特性之一:依赖注入。通过实现自己的依赖项,你可以将希望在整个 API 中重用的常见逻辑与端点的逻辑分开。这样可以使你的项目清晰可维护,同时保持最大的可读性;只需将依赖项声明为路径操作函数的参数即可,这将帮助你理解意图,而无需阅读函数体。
这些依赖项可以是简单的包装器,用于检索和验证请求参数,也可以是执行机器学习任务的复杂服务。多亏了基于类的方法,你确实可以设置动态参数或为最复杂的任务保持局部状态。
最后,这些依赖项还可以在路由器或全局级别上使用,允许你对一组路由或整个应用程序执行常见逻辑或检查。
这就是本书第一部分的结束!你现在已经熟悉了 FastAPI 的主要特性,并且应该能够使用这个框架编写干净且高性能的 REST API。
在下一部分中,我们将带你的知识提升到新的高度,并展示如何实现和部署一个强大、安全且经过测试的 Web 后端。第一章将专注于数据库,大多数 API 都需要能够读取和写入数据。
第二部分:使用 FastAPI 构建和部署完整的 Web 后端
本节的目标是向你展示如何使用 FastAPI 构建一个真实世界的后端,该后端能够读取和写入数据,进行用户认证,并且经过充分测试,且为生产环境正确配置。
本节包括以下章节:
-
第六章,数据库和异步 ORM
-
第七章,在 FastAPI 中管理认证和安全性
-
第八章,在 FastAPI 中定义 WebSockets 实现双向互动通信
-
第九章,使用 pytest 和 HTTPX 异步测试 API
-
第十章,部署 FastAPI 项目
第六章:数据库和异步 ORM
REST API 的主要目标当然是读写数据。到目前为止,我们只使用了 Python 和 FastAPI 提供的工具,允许我们构建可靠的端点来处理和响应请求。然而,我们尚未能够有效地检索和持久化这些信息:我们还没有 数据库。
本章的目标是展示你如何在 FastAPI 中与不同类型的数据库及相关库进行交互。值得注意的是,FastAPI 对数据库是完全无关的:你可以使用任何你想要的系统,并且集成工作由你负责。这就是为什么我们将回顾两种不同的数据库集成方式:使用 对象关系映射(ORM)系统连接 SQL 数据库,以及使用 NoSQL 数据库。
本章我们将讨论以下主要主题:
-
关系型数据库和 NoSQL 数据库概述
-
使用 SQLAlchemy ORM 与 SQL 数据库进行通信
-
使用 Motor 与 MongoDB 数据库进行通信
技术要求
对于本章,你将需要一个 Python 虚拟环境,正如我们在 第一章 中设置的,Python 开发 环境设置。
对于 使用 Motor 与 MongoDB 数据库进行通信 部分,你需要在本地计算机上运行 MongoDB 服务器。最简单的方法是将其作为 Docker 容器运行。如果你以前从未使用过 Docker,我们建议你参考官方文档中的 入门教程,链接为 docs.docker.com/get-started/。完成这些步骤后,你将能够使用以下简单命令运行 MongoDB 服务器:
$ docker run -d --name fastapi-mongo -p 27017:27017 mongo:6.0
MongoDB 服务器实例将通过端口 27017 在你的本地计算机上提供。
你可以在本书专门的 GitHub 仓库中找到本章的所有代码示例,地址为 github.com/PacktPublishing/Building-Data-Science-Applications-with-FastAPI-Second-Edition/tree/main/chapter06。
关系型数据库和 NoSQL 数据库概述
数据库的作用是以结构化的方式存储数据,保持数据的完整性,并提供查询语言,使你在应用程序需要时能够检索这些数据。
如今,选择适合你网站项目的数据库时,你有两个主要选择:关系型数据库,及其相关的 SQL 查询语言,和 NoSQL 数据库,它们与第一类数据库相对立。
选择适合你项目的技术由你来决定,因为这在很大程度上取决于你的需求和要求。在本节中,我们将概述这两类数据库的主要特点和功能,并尝试为你提供一些选择适合项目的数据库的见解。
关系型数据库
关系型数据库自 1970 年代以来就存在,并且随着时间的推移证明了它们的高效性和可靠性。它们几乎与 SQL 不可分离,SQL 已成为查询此类数据库的事实标准。即使不同数据库引擎之间有一些差异,大多数语法是通用的,简单易懂,足够灵活,可以表达复杂的查询。
关系型数据库实现了关系模型:应用的每个实体或对象都存储在表中。例如,如果我们考虑一个博客应用,我们可以有表示用户、帖子和评论的表。
每个表都会有多个列,表示实体的属性。如果我们考虑帖子,可能会有一个标题、发布日期和内容。在这些表中,会有多行,每行表示这种类型的一个实体;每篇帖子将有自己的行。
关系型数据库的一个关键点是,如其名称所示,关系。每个表可以与其他表建立关系,表中的行可以引用其他表中的行。在我们的示例中,一篇帖子可以与写它的用户相关联。类似地,一条评论可以与其相关的帖子关联。
这样做的主要动机是避免重复。事实上,如果我们在每篇帖子上都重复用户的姓名或邮箱,这并不是很高效。如果需要修改某个信息,我们就得通过每篇帖子修改,这容易出错并危及数据一致性。因此,我们更倾向于在帖子中引用用户。那么,我们该如何实现这一点呢?
通常,关系型数据库中的每一行都有一个标识符,称为主键。这个键在表中是唯一的,允许你唯一标识这一行。因此,可以在另一个表中使用这个键来引用它。我们称之为外键:外键之所以叫做外,是因为它引用了另一个表。
图 6.1展示了使用实体-关系图表示这种数据库模式的方式。请注意,每个表都有自己的主键,名为id。Post表通过user_id外键引用一个用户。类似地,Comment表通过user_id和post_id外键分别引用一个用户和一篇帖子:
图 6.1 – 博客应用的关系型数据库模式示例
在一个应用中,你可能希望检索一篇帖子,以及与之相关的评论和用户。为了实现这一点,我们可以执行一个连接查询,根据外键返回所有相关记录。关系型数据库旨在高效地执行此类任务;然而,如果模式更加复杂,这些操作可能会变得昂贵。这就是为什么在设计关系型模式及其查询时需要小心谨慎的原因。
NoSQL 数据库
所有非关系型的数据库引擎都属于 NoSQL 范畴。这是一个相当模糊的术语,涵盖了不同类型的数据库:键值存储,例如 Redis;图数据库,例如 Neo4j;以及面向文档的数据库,例如 MongoDB。也就是说,当我们谈论“NoSQL 数据库”时,通常是指面向文档的数据库。它们是我们关注的对象。
面向文档的数据库摒弃了关系型架构,试图将给定对象的所有信息存储在一个文档中。因此,执行联接查询的情况非常少见,通常也更为困难。
这些文档存储在集合中。与关系型数据库不同,集合中的文档可能没有相同的属性:关系型数据库中的表有定义好的模式,而集合可以接受任何类型的文档。
图 6.2 显示了我们之前博客示例的表示,已经调整为面向文档的数据库结构。在这种配置中,我们选择了一个集合用于用户,另一个集合用于帖子。然而,请注意,评论现在是帖子的组成部分,直接作为一个列表包含在内:
图 6.2 — 博客应用的面向文档的架构示例
要检索一篇帖子及其所有评论,你不需要执行联接查询:所有数据只需一个查询即可获取。这是开发面向文档数据库的主要动机:通过减少查看多个集合的需求来提高查询性能。特别是,它们在处理具有巨大数据规模和较少结构化数据的应用(如社交网络)时表现出了极大的价值。
你应该选择哪一个?
正如我们在本节引言中提到的,你选择数据库引擎很大程度上取决于你的应用和需求。关系型数据库和面向文档的数据库之间的详细比较超出了本书的范围,但我们可以看一下你需要考虑的一些要素。
关系型数据库非常适合存储结构化数据,且实体之间存在大量关系。此外,它们在任何情况下都会维护数据的一致性,即使在发生错误或硬件故障时也不例外。然而,你必须精确定义模式,并考虑迁移系统,以便在需求变化时更新你的模式。
另一方面,面向文档的数据库不需要你定义模式:它们接受任何文档结构,因此如果你的数据高度可变或你的项目尚未成熟,它们会很方便。其缺点是,它们在数据一致性方面要求较低,可能导致数据丢失或不一致。
对于小型和中型应用程序,选择并不太重要:关系型数据库和面向文档的数据库都经过了高度优化,在这些规模下都会提供出色的性能。
接下来,我们将展示如何使用 FastAPI 处理这些不同类型的数据库。当我们在第二章中介绍异步 I/O 时,Python 编程特性,我们提到过选择你用来执行 I/O 操作的库是很重要的。当然,在这种情况下,数据库尤为重要!
尽管在 FastAPI 中使用经典的非异步库是完全可行的,但你可能会错过框架的一个关键方面,无法达到它所能提供的最佳性能。因此,在本章中,我们将只专注于异步库。
使用 SQLAlchemy ORM 与 SQL 数据库进行通信
首先,我们将讨论如何使用 SQLAlchemy 库处理关系型数据库。SQLAlchemy 已经存在多年,并且是 Python 中处理 SQL 数据库时最受欢迎的库。从版本 1.4 开始,它也原生支持异步。
理解这个库的关键点是,它由两个部分组成:
-
SQLAlchemy Core,提供了读取和写入 SQL 数据库数据的所有基本功能
-
SQLAlchemy ORM,提供对 SQL 概念的强大抽象
虽然你可以选择只使用 SQLAlchemy Core,但通常使用 ORM 更为方便。ORM 的目标是抽象出表和列的 SQL 概念,这样你只需要处理 Python 对象。ORM 的作用是将这些对象映射到它们所属的表和列,并自动生成相应的 SQL 查询。
第一步是安装这个库:
(venv) $ pip install "sqlalchemy[asyncio,mypy]"
请注意,我们添加了两个可选依赖项:asyncio和mypy。第一个确保安装了异步支持所需的工具。
第二个是一个为 mypy 提供特殊支持的插件,专门用于 SQLAlchemy。ORM 在后台做了很多“魔法”事情,这些对于类型检查器来说很难理解。有了这个插件,mypy 能够学会识别这些构造。
正如我们在介绍中所说,存在许多 SQL 引擎。你可能听说过 PostgreSQL 和 MySQL,它们是最受欢迎的引擎之一。另一个有趣的选择是 SQLite,它是一个小型引擎,所有数据都存储在你电脑上的单个文件中,不需要复杂的服务器软件。它非常适合用于测试和实验。为了让 SQLAlchemy 能够与这些引擎进行通信,你需要安装相应的驱动程序。根据你的引擎,这里是你需要安装的异步驱动程序:
-
PostgreSQL:
(venv) $ pip install asyncpg -
MySQL:
(venv) $ pip install aiomysql -
SQLite:
(venv) $ pip install aiosqlite
在本节的其余部分,我们将使用 SQLite 数据库。我们将一步步展示如何设置完整的数据库交互。图 6.4展示了项目的结构:
图 6.3 – FastAPI 和 SQLAlchemy 项目结构
创建 ORM 模型
首先,您需要定义您的 ORM 模型。每个模型是一个 Python 类,其属性代表表中的列。数据库中的实际实体将是该类的实例,您可以像访问任何其他对象一样访问其数据。在幕后,SQLAlchemy ORM 的作用是将 Python 对象与数据库中的行链接起来。让我们来看一下我们博客文章模型的定义:
models.py
from datetime import datetimefrom sqlalchemy import DateTime, Integer, String, Text
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(DeclarativeBase):
pass
class Post(Base):
__tablename__ = "posts"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
publication_date: Mapped[datetime] = mapped_column(
DateTime, nullable=False, default=datetime.now
)
title: Mapped[str] = mapped_column(String(255), nullable=False)
content: Mapped[str] = mapped_column(Text, nullable=False)
第一步是创建一个继承自 DeclarativeBase 的 Base 类。我们所有的模型都将继承自这个类。在内部,SQLAlchemy 使用它来将所有有关数据库模式的信息集中在一起。这就是为什么在整个项目中只需要创建一次,并始终使用相同的 Base 类。
接下来,我们必须定义我们的 Post 类。再次注意,它是如何从 Base 类继承的。在这个类中,我们可以以类属性的形式定义每一列。它们是通过 mapped_column 函数来赋值的,这个函数帮助我们定义列的类型及其相关属性。例如,我们将 id 列定义为一个自增的整数主键,这在 SQL 数据库中非常常见。
请注意,我们不会详细介绍 SQLAlchemy 提供的所有类型和选项。只需知道它们与 SQL 数据库通常提供的类型非常相似。您可以在官方文档中查看完整的列表,如下所示:
-
您可以在
docs.sqlalchemy.org/en/20/core/type_basics.html#generic-camelcase-types找到类型的列表。 -
您可以在
docs.sqlalchemy.org/en/20/orm/mapping_api.html#sqlalchemy.orm.mapped_column找到mapped_column参数的列表。
这里另一个值得注意的有趣点是,我们为每个属性添加了类型提示,这些类型与我们列的 Python 类型对应。这将极大地帮助我们在开发过程中:例如,如果我们尝试获取帖子对象的 title 属性,类型检查器会知道它是一个字符串。为了使这一点生效,请注意,我们将每个类型都包裹在 Mapped 类中。这是 SQLAlchemy 提供的一个特殊类,类型检查器可以通过它了解数据的底层类型,当我们将一个 MappedColumn 对象分配给它时。
这是在 SQLAlchemy 2.0 中声明模型的方式
我们将在本节中展示的声明模型的方式是 SQLAlchemy 2.0 中引入的最新方式。
如果你查看网上较老的教程或文档,你可能会看到一种略有不同的方法,其中我们将属性分配给Column对象。虽然这种旧风格在 SQLAlchemy 2.0 中仍然有效,但它应该被视为过时的。
现在我们有了一个帮助我们读写数据库中帖子数据的模型。然而,正如你现在所知道的,使用 FastAPI 时,我们还需要 Pydantic 模型,以便验证输入数据并在 API 中输出正确的表示。如果你需要复习这部分内容,可以查看第三章,使用 FastAPI 开发 RESTful API。
定义 Pydantic 模型
正如我们所说的,如果我们想正确验证进出 FastAPI 应用的数据,我们需要使用 Pydantic 模型。在 ORM 上下文中,它们将帮助我们在 ORM 模型之间来回转换。这一节的关键要点是:我们将使用 Pydantic 模型来验证和序列化数据,但数据库通信将通过 ORM 模型完成。
为了避免混淆,我们现在将 Pydantic 模型称为模式。当我们谈论模型时,我们指的是 ORM 模型。
这就是为什么那些模式的定义被放置在schemas.py模块中的原因,如下所示:
schemas.py
from datetime import datetimefrom pydantic import BaseModel, Field
class PostBase(BaseModel):
title: str
content: str
publication_date: datetime = Field(default_factory=datetime.now)
class Config:
orm_mode = True
class PostPartialUpdate(BaseModel):
title: str | None = None
content: str | None = None
class PostCreate(PostBase):
pass
class PostRead(PostBase):
id: int
上面的代码对应我们在第四章中解释的模式,在 FastAPI 中管理 Pydantic 数据模型。
但有一个新内容:你可能已经注意到Config子类,它是在PostBase中定义的。这是为 Pydantic 模式添加一些配置选项的一种方式。在这里,我们将orm_mode选项设置为True。顾名思义,这是一个使 Pydantic 与 ORM 更好配合的选项。在标准设置下,Pydantic 被设计用来解析字典中的数据:如果它想解析title属性,它会使用d["title"]。然而,在 ORM 中,我们通过点号表示法(o.title)来像访问对象一样访问属性。启用 ORM 模式后,Pydantic 就能使用这种风格。
连接到数据库
现在我们的模型和模式已经准备好了,我们必须设置 FastAPI 应用和数据库引擎之间的连接。为此,我们将创建一个database.py模块,并在其中放置我们需要的对象:
database.py
from collections.abc import AsyncGeneratorfrom sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from chapter06.sqlalchemy.models import Base
DATABASE_URL = "sqlite+aiosqlite:///chapter06_sqlalchemy.db"
engine = create_async_engine(DATABASE_URL)
async_session_maker = async_sessionmaker(engine, expire_on_commit=False)
在这里,你可以看到我们已经将连接字符串设置在DATABASE_URL变量中。通常,它由以下几个部分组成:
-
数据库引擎。在这里,我们使用
sqlite。 -
可选的驱动程序,后面带有加号。这里,我们设置为
aiosqlite。在异步环境中,必须指定我们想要使用的异步驱动程序。否则,SQLAlchemy 会回退到标准的同步驱动程序。 -
可选的身份验证信息。
-
数据库服务器的主机名。在 SQLite 的情况下,我们只需指定将存储所有数据的文件路径。
你可以在官方 SQLAlchemy 文档中找到该格式的概述:docs.sqlalchemy.org/en/20/core/engines.html#database-urls。
然后,我们使用create_async_engine函数和这个 URL 创建引擎。引擎是一个对象,SQLAlchemy 将在其中管理与数据库的连接。此时,重要的是要理解,尚未建立任何连接:我们只是声明了相关内容。
然后,我们有一个更为复杂的代码行来定义async_session_maker变量。我们不会深入讨论async_sessionmaker函数的细节。只需知道它返回一个函数,允许我们生成与数据库引擎绑定的会话。
什么是会话?它是由 ORM 定义的概念。会话将与数据库建立实际连接,并代表一个区域,在该区域中它将存储你从数据库中读取的所有对象以及你定义的所有将在数据库中写入的对象。它是 ORM 概念和基础 SQL 查询之间的代理。
在构建 HTTP 服务器时,我们通常在请求开始时打开一个新的会话,并在响应请求时关闭它。因此,每个 HTTP 请求代表与数据库的一个工作单元。这就是为什么我们必须定义一个 FastAPI 依赖项,其作用是提供一个新的会话给我们:
database.py
async def get_async_session() -> AsyncGenerator[AsyncSession, None]: async with async_session_maker() as session:
yield session
将它作为依赖项将大大帮助我们在实现路径操作函数时。
到目前为止,我们还没有机会讨论with语法。在 Python 中,这被称为with块,对象会自动执行设置逻辑。当你退出该块时,它会执行拆解逻辑。你可以在 Python 文档中阅读更多关于上下文管理器的信息:docs.python.org/3/reference/datamodel.html#with-statement-context-managers。
在我们的案例中,async_session_maker作为上下文管理器工作。它负责打开与数据库的连接等操作。
注意,我们在这里通过使用yield定义了一个生成器。这一点很重要,因为它确保了会话在请求结束前保持打开状态。如果我们使用一个简单的return语句,上下文管理器会立即关闭。使用yield时,我们确保只有在请求和端点逻辑被 FastAPI 完全处理后,才会退出上下文管理器。
使用依赖注入来获取数据库实例
你可能会想,为什么我们不直接在路径操作函数中调用async_session_maker,而是使用依赖注入。这是可行的,但当我们尝试实现单元测试时会非常困难。实际上,将这个实例替换为模拟对象或测试数据库将变得非常困难。通过使用依赖注入,FastAPI 使得我们可以轻松地将其替换为另一个函数。我们将在第九章,使用 pytest 和 HTTPX 异步测试 API中详细了解这一点。
在这个模块中我们必须定义的最后一个函数是create_all_tables。它的目标是创建数据库中的表模式。如果我们不这么做,数据库将是空的,无法保存或检索数据。像这样创建模式是一种简单的做法,只适用于简单的示例和实验。在实际应用中,你应该有一个合适的迁移系统,确保你的数据库模式保持同步。我们将在本章稍后学习如何为 SQLAlchemy 设置迁移系统。
为了确保在应用启动时创建我们的模式,我们必须在app.py模块中调用这个函数:
app.py
@contextlib.asynccontextmanagerasync def lifespan(app: FastAPI):
await create_all_tables()
yield
创建对象
让我们从向数据库中插入新对象开始。主要的挑战是接受 Pydantic 模式作为输入,将其转换为 SQLAlchemy 模型,并将其保存到数据库中。让我们回顾一下这个过程,如下例所示:
app.py
@app.post( "/posts", response_model=schemas.PostRead, status_code=status.HTTP_201_CREATED
)
async def create_post(
post_create: schemas.PostCreate, session: AsyncSession = Depends(get_async_session)
) -> Post:
post = Post(**post_create.dict())
session.add(post)
await session.commit()
return post
在这里,我们有一个POST端点,接受我们的PostCreate模式。注意,我们通过get_async_session依赖注入了一个新的 SQLAlchemy 会话。核心逻辑包括两个操作。
首先,我们将post_create转换为完整的Post模型对象。为此,我们可以简单地调用 Pydantic 的dict方法,并用**解包它,直接赋值给属性。此时,文章还没有保存到数据库中:我们需要告诉会话有关它的信息。
第一步是通过add方法将其添加到会话中。现在,post 已经进入会话内存,但尚未存储在数据库中。通过调用commit方法,我们告诉会话生成适当的 SQL 查询并在数据库上执行它们。正如我们所预料的那样,我们发现需要await此方法:我们对数据库进行了 I/O 操作,因此它是异步操作。
最后,我们可以直接返回post对象。你可能会惊讶于我们直接返回了一个 SQLAlchemy ORM 对象,而不是 Pydantic 模式。FastAPI 如何正确地序列化它并保留我们指定的属性呢?如果你留心一下,你会看到我们在路径操作装饰器中设置了response_model属性。正如你可能从 第三章的响应模型部分回想起来的那样,使用 FastAPI 开发 RESTful API,你就能理解发生了什么:FastAPI 会自动处理将 ORM 对象转化为指定模式的过程。正因为如此,我们需要启用 Pydantic 的orm_mode,正如前面一节所示!
从这里,你可以看到实现过程非常直接。现在,让我们来检索这些数据吧!
获取和筛选对象
通常,REST API 提供两种类型的端点来读取数据:一种用于列出对象,另一种用于获取特定对象。这正是我们接下来要回顾的内容!
在下面的示例中,你可以看到我们如何实现列出对象的端点:
app.py
@app.get("/posts", response_model=list[schemas.PostRead])async def list_posts(
pagination: tuple[int, int] = Depends(pagination),
session: AsyncSession = Depends(get_async_session),
) -> Sequence[Post]:
skip, limit = pagination
select_query = select(Post).offset(skip).limit(limit)
result = await session.execute(select_query)
return result.scalars().all()
该操作分为两步执行。首先,我们构建一个查询。SQLAlchemy 的select函数允许我们开始定义查询。方便的是,我们可以直接将model类传递给它:它会自动理解我们所谈论的表格。接下来,我们可以应用各种方法和筛选条件,这些与纯 SQL 中的操作是相似的。在这里,我们能够通过offset和limit应用我们的分页参数。
然后,我们使用一个新的会话对象的execute方法执行此查询(该会话对象再次通过我们的依赖注入)。由于我们是从数据库中读取数据,这是一项异步操作。
由此,我们得到一个result对象。这个对象是 SQLAlchemy 的Result类的实例。它不是我们直接的帖子列表,而是表示 SQL 查询结果的一个集合。这就是为什么我们需要调用scalars和all。第一个方法会确保我们获得实际的Post对象,而第二个方法会将它们作为一个序列返回。
再次说明,我们可以直接返回这些 SQLAlchemy ORM 对象:感谢response_model设置,FastAPI 会将它们转化为正确的模式。
现在,让我们看看如何通过 ID 获取单个 post:
app.py
@app.get("/posts/{id}", response_model=schemas.PostRead)async def get_post(post: Post = Depends(get_post_or_404)) -> Post:
return post
这是一个简单的GET端点,期望在路径参数中提供帖子的 ID。实现非常简单:我们只是返回帖子。大部分逻辑在get_post_or_404依赖项中,我们将在应用程序中经常重复使用它。以下是它的实现:
app.py
async def get_post_or_404( id: int, session: AsyncSession = Depends(get_async_session)
) -> Post:
select_query = select(Post).where(Post.id == id)
result = await session.execute(select_query)
post = result.scalar_one_or_none()
if post is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return post
如你所见,这与我们在列表端点看到的内容非常相似。我们同样从构建一个选择查询开始,但这次,我们添加了一个where子句,以便只检索与所需 ID 匹配的帖子。这个子句本身可能看起来有些奇怪。
首先,我们必须设置我们想要比较的实际列。事实上,当你直接访问model类的属性时,比如Post.id,SQLAlchemy 会自动理解你在引用列。
然后,我们使用等号运算符来比较列与我们实际的id变量。它看起来像是一个标准的比较,会产生一个布尔值,而不是一个 SQL 语句!在一般的 Python 环境中,确实是这样。然而,SQLAlchemy 的开发者在这里做了一些聪明的事情:他们重载了标准运算符,使其产生 SQL 表达式而不是比较对象。这正是我们在第二章的Python 编程特性中看到的内容。
现在,我们可以简单地执行查询并在结果集上调用scalar_one_or_none。这是一个方便的快捷方式,告诉 SQLAlchemy 如果存在单个对象则返回它,否则返回None。
如果结果是None,我们可以抛出一个404错误:没有帖子匹配这个 ID。否则,我们可以简单地返回帖子。
更新和删除对象
最后,我们将展示如何更新和删除现有对象。你会发现这只是操作 ORM 对象并在session上调用正确方法的事情。
检查以下代码,并审查update端点的实现:
app.py
@app.patch("/posts/{id}", response_model=schemas.PostRead)async def update_post(
post_update: schemas.PostPartialUpdate,
post: Post = Depends(get_post_or_404),
session: AsyncSession = Depends(get_async_session),
) -> Post:
post_update_dict = post_update.dict(exclude_unset=True)
for key, value in post_update_dict.items():
setattr(post, key, value)
session.add(post)
await session.commit()
return post
这里,主要需要注意的是,我们将直接操作我们想要修改的帖子。这是使用 ORM 时的一个关键点:实体是可以按需修改的对象。当你对数据满意时,可以将其持久化到数据库中。这正是我们在这里所做的:我们通过get_post_or_404获取帖子的最新表示。然后,我们将post_update架构转换为字典,并遍历这些属性,将它们设置到我们的 ORM 对象上。最后,我们可以将其保存在会话中并提交到数据库,就像我们在创建时所做的那样。
当你想删除一个对象时,同样的概念也适用:当你拥有一个实例时,可以将其传递给session的delete方法,从而安排它的删除。你可以通过以下示例查看这一过程:
app.py
@app.delete("/posts/{id}", status_code=status.HTTP_204_NO_CONTENT)async def delete_post(
post: Post = Depends(get_post_or_404),
session: AsyncSession = Depends(get_async_session),
):
await session.delete(post)
await session.commit()
在这些示例中,你看到我们总是在写操作后调用commit:你的更改必须被写入数据库,否则它们将仅停留在会话内存中并丢失。
添加关系
正如我们在本章开头提到的,关系型数据库关心的是数据及其关系。你经常需要创建与其他实体相关联的实体。例如,在一个博客应用中,评论是与其相关的帖子关联的。在这一部分,我们将讨论如何使用 SQLAlchemy ORM 设置这种关系。
首先,我们需要为评论定义一个新模型。这个新模型必须放在Post模型之上。稍后我们会解释为什么这很重要。你可以在以下示例中查看它的定义:
models.py
class Comment(Base): __tablename__ = "comments"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
post_id: Mapped[int] = mapped_column(ForeignKey("posts.id"), nullable=False)
publication_date: Mapped[datetime] = mapped_column(
DateTime, nullable=False, default=datetime.now
)
content: Mapped[str] = mapped_column(Text, nullable=False)
post: Mapped["Post"] = relationship("Post", back_populates="comments")
这里的重要点是post_id列,它是ForeignKey类型。这是一个特殊类型,告诉 SQLAlchemy 自动处理该列的类型和相关约束。我们只需要提供它所指向的表和列名。
但这只是定义中的 SQL 部分。现在我们需要告诉 ORM 我们的Comment对象与Post对象之间存在关系。这就是post属性的目的,它被分配给relationship函数。它是 SQLAlchemy ORM 暴露的一个特殊函数,用来定义模型之间的关系。它不会在 SQL 定义中创建一个新列——这是ForeignKey列的作用——但它允许我们通过comment.post直接获取与评论相关联的Post对象。你还可以看到我们定义了back_populates参数。它允许我们执行相反的操作——也就是说,从一个post获取评论列表。这个选项的名称决定了我们用来访问评论的属性名。这里,它是post.comments。
前向引用类型提示
如果你查看post属性的类型提示,你会看到我们正确地将其设置为Post类。然而,我们将其放在了引号中:post: "Post" = …。
这就是所谓的Post在Comment之后定义。如果我们忘记了引号,Python 会抱怨,因为我们试图访问一个尚未存在的东西。为了解决这个问题,我们可以将其放在引号中。类型检查器足够智能,可以理解你指的是什么。
现在,如果你查看以下Post模型,你会看到我们添加了一个内容:
models.py
class Post(Base): __tablename__ = "posts"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
publication_date: Mapped[datetime] = mapped_column(
DateTime, nullable=False, default=datetime.now
)
title: Mapped[str] = mapped_column(String(255), nullable=False)
content: Mapped[str] = mapped_column(Text, nullable=False)
comments: Mapped[list[Comment]] = relationship("Comment", cascade="all, delete")
我们还定义了镜像关系,并注意以我们为back_populates选择的相同名称命名。这次,我们还设置了cascade参数,它允许我们定义 ORM 在删除帖子时的行为:我们是应该隐式删除评论,还是将它们保留为孤立的?在这个例子中,我们选择了删除它们。请注意,这与 SQL 的CASCADE DELETE构造不完全相同:它具有相同的效果,但将由 ORM 在 Python 代码中处理,而不是由 SQL 数据库处理。
关于关系有很多选项,所有这些选项都可以在官方文档中找到:docs.sqlalchemy.org/en/20/orm/relationship_api.html#sqlalchemy.orm.relationship.
再次强调,添加这个comments属性并不会改变 SQL 定义:它只是为 ORM 在 Python 端做的连接。
现在,我们可以为评论实体定义 Pydantic 模式。它们非常直接,因此我们不会深入讨论细节。但请注意我们是如何将comments属性添加到PostRead模式中的:
schemas.py
class PostRead(PostBase): id: int
comments: list[CommentRead]
确实,在 REST API 中,有些情况下自动检索实体的相关对象是有意义的。在这里,能够在一次请求中获取帖子的评论会很方便。这个架构将允许我们序列化评论以及帖子数据*。
现在,我们将实现一个端点来创建新的评论。以下示例展示了这一点:
app.py
@app.post( "/posts/{id}/comments",
response_model=schemas.CommentRead,
status_code=status.HTTP_201_CREATED,
)
async def create_comment(
comment_create: schemas.CommentCreate,
post: Post = Depends(get_post_or_404),
session: AsyncSession = Depends(get_async_session),
) -> Comment:
comment = Comment(**comment_create.dict(), post=post)
session.add(comment)
await session.commit()
return comment
这个端点已定义,因此我们需要直接在路径中设置帖子 ID。它允许我们重用get_post_or_404依赖项,并且如果尝试向不存在的帖子添加评论时,会自动触发404错误。
除此之外,它与本章中创建对象部分的内容非常相似。这里唯一需要注意的是,我们手动设置了这个新comment对象的post属性。由于关系定义的存在,我们可以直接分配post对象,ORM 将自动在post_id列中设置正确的值。
之前我们提到过,我们希望同时检索帖子及其评论。为了实现这一点,我们在获取帖子时需要稍微调整一下查询。以下示例展示了我们为get_post_or_404函数所做的调整,但对列表端点也是一样的:
app.py
async def get_post_or_404( id: int, session: AsyncSession = Depends(get_async_session)
) -> Post:
select_query = (
select(Post).options(selectinload(Post.comments)).where(Post.id == id)
)
result = await session.execute(select_query)
post = result.scalar_one_or_none()
if post is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return post
如你所见,我们添加了对options的调用,并使用了selectinload构造。这是告诉 ORM 在执行查询时自动检索帖子的相关评论的一种方式。如果我们不这么做,就会出错。为什么?因为我们的查询是异步的。但我们从头开始讲。
在经典的同步 ORM 上下文中,你可以这样做:
comments = post.comments
如果comments在第一次请求时没有被加载,同步 ORM 将隐式地对 SQL 数据库执行一个新查询。这对用户是不可见的,但实际上会进行 I/O 操作。这被称为懒加载,它是 SQLAlchemy 中关系的默认行为。
然而,在异步上下文中,I/O 操作不能隐式执行:我们必须显式地*等待(await)*它们。这就是为什么如果你忘记在第一次查询时显式加载关系,系统会报错的原因。当 Pydantic 尝试序列化PostRead模式时,它将尝试访问post.comments,但是 SQLAlchemy 无法执行这个隐式查询。
因此,在使用异步(async)时,你需要在关系上执行预加载(eager loading),以便直接从 ORM 对象访问。诚然,这比同步版本不太方便。然而,它有一个巨大的优势:你可以精确控制执行的查询。事实上,使用同步 ORM 时,某些端点可能因为代码执行了数十个隐式查询而导致性能不佳。而使用异步 ORM 时,你可以确保所有内容都在单个或少数几个查询中加载。这是一种权衡,但从长远来看,它可能会带来好处。
可以在关系中配置预加载(eager loading)
如果你确定无论上下文如何,你始终需要加载实体的相关对象,你可以直接在relationship函数中定义预加载策略。这样,你就无需在每个查询中设置它。你可以在官方文档中阅读更多关于此的信息:docs.sqlalchemy.org/en/20/orm/relationship_api.html#sqlalchemy.orm.relationship.params.lazy。
本质上,处理 SQLAlchemy ORM 中的关系就是这些。你已经看到,关键在于正确定义关系,以便 ORM 可以理解对象之间是如何关联的。
使用 Alembic 设置数据库迁移系统
在开发应用程序时,你很可能会对数据库模式进行更改,添加新表、新列或修改现有的列。当然,如果你的应用程序已经投入生产,你不希望删除所有数据并重新创建数据库模式:你希望它能迁移到新的模式。为此任务开发了相关工具,本节将学习如何设置Alembic,它是 SQLAlchemy 的创作者所开发的库。让我们来安装这个库:
(venv) $ pip install alembic
完成此操作后,你将能够使用alembic命令来管理此迁移系统。在开始一个新项目时,首先需要初始化迁移环境,该环境包括一组文件和目录,Alembic 将在其中存储其配置和迁移文件。在项目的根目录下,运行以下命令:
(venv) $ alembic init alembic
这将会在项目的根目录下创建一个名为alembic的目录。你可以在图 6.4所示的示例仓库中查看该命令的执行结果:
图 6.4 – Alembic 迁移环境结构
这个文件夹将包含所有迁移的配置以及迁移脚本本身。它应该与代码一起提交,以便你有这些文件版本的记录。
另外,请注意,它创建了一个alembic.ini文件,其中包含了 Alembic 的所有配置选项。我们将在此文件中查看一个重要的设置:sqlalchemy.url。你可以在以下代码中看到:
alembic.ini
sqlalchemy.url = sqlite:///chapter06_sqlalchemy_relationship.db
可以预见的是,这是你数据库的连接字符串,它将接收迁移查询。它遵循我们之前看到的相同约定。在这里,我们设置了 SQLite 数据库。但是,请注意,我们没有设置aiosqlite驱动程序:Alembic 只能与同步驱动程序一起使用。这并不是什么大问题,因为它仅在执行迁移的专用脚本中运行。
接下来,我们将重点关注env.py文件。它是一个包含 Alembic 初始化迁移引擎和执行迁移的所有逻辑的 Python 脚本。作为 Python 脚本,它允许我们精细定制 Alembic 的执行。暂时我们保持默认配置,除了一个小改动:我们会导入我们的Base对象。你可以通过以下示例查看:
env.py
from chapter06.sqlalchemy_relationship.models import Base# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = Base.metadata
默认情况下,该文件定义了一个名为target_metadata的变量,初始值为None。在这里,我们将其修改为指向从models模块导入的Base.metadata对象。但为什么要这么做呢?回想一下,Base是一个 SQLAlchemy 对象,包含了数据库架构的所有信息。通过将它提供给 Alembic,迁移系统将能够自动生成迁移脚本,只需查看你的架构!这样,你就不必从头编写迁移脚本了。
一旦你对数据库架构进行了更改,可以运行以下命令生成新的迁移脚本:
(venv) $ alembic revision --autogenerate -m "Initial migration"
这将根据你的架构更改创建一个新的脚本文件,并将命令反映到versions目录中。该文件定义了两个函数:upgrade和downgrade。你可以在以下代码片段中查看upgrade:
eabd3f9c5b64_initial_migration.py
def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"posts",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("publication_date", sa.DateTime(), nullable=False),
sa.Column("title", sa.String(length=255), nullable=False),
sa.Column("content", sa.Text(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"comments",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("post_id", sa.Integer(), nullable=False),
sa.Column("publication_date", sa.DateTime(), nullable=False),
sa.Column("content", sa.Text(), nullable=False),
sa.ForeignKeyConstraint(
["post_id"],
["posts.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# ### end Alembic commands ###
这个函数在我们应用迁移时执行。它描述了创建 posts 和 comments 表所需的操作,包括所有列和约束。
现在,让我们来看一下这个文件中的另一个函数,downgrade:
eabd3f9c5b64_initial_migration.py
def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ###
op.drop_table("comments")
op.drop_table("posts")
# ### end Alembic commands ###
这个函数描述了回滚迁移的操作,以便数据库恢复到之前的状态。这一点非常重要,因为如果迁移过程中出现问题,或者你需要恢复到应用程序的旧版本,你可以在不破坏数据的情况下做到这一点。
自动生成并不能检测到所有问题
请记住,尽管自动生成非常有帮助,但它并不总是准确的,有时它无法检测到模糊的变化。例如,如果你重命名了一个列,它会删除旧列并创建一个新的列。因此,该列中的数据将丢失!这就是为什么你应该始终仔细审查迁移脚本,并为类似这种极端情况做出必要的修改。
最后,你可以使用以下命令将迁移应用到数据库:
(venv) $ alembic upgrade head
这将运行所有尚未应用到数据库中的迁移,直到最新的版本。值得注意的是,在这个过程中,Alembic 会在数据库中创建一个表,以便它可以记住所有已应用的迁移:这就是它如何检测需要运行的脚本。
一般来说,当你在数据库上运行此类命令时,应该极其小心,特别是在生产环境中。如果犯了错误,可能会发生非常糟糕的事情,甚至丢失宝贵的数据。在在生产数据库上运行迁移之前,你应该始终在测试环境中进行测试,并确保有最新且有效的备份。
这只是对 Alembic 及其强大迁移系统的简短介绍。我们强烈建议你阅读它的文档,以了解所有机制,特别是关于迁移脚本的操作。请参考 alembic.sqlalchemy.org/en/latest/index.html。
这就是本章 SQLAlchemy 部分的内容!它是一个复杂但强大的库,用于处理 SQL 数据库。接下来,我们将离开关系型数据库的世界,探索如何与文档导向型数据库 MongoDB 进行交互。
使用 Motor 与 MongoDB 数据库进行通信
正如我们在本章开头提到的,使用文档导向型数据库(例如 MongoDB)与使用关系型数据库有很大的不同。首先,你不需要提前配置模式:它遵循你插入到其中的数据结构。在 FastAPI 中,这使得我们的工作稍微轻松一些,因为我们只需处理 Pydantic 模型。然而,关于文档标识符,还有一些细节需要我们注意。接下来我们将讨论这一点。
首先,我们将安装 Motor,这是一个用于与 MongoDB 异步通信的库,并且是 MongoDB 官方支持的。运行以下命令:
(venv) $ pip install motor
完成这部分工作后,我们可以开始实际操作了!
创建与 MongoDB ID 兼容的模型
正如我们在本节介绍中提到的,MongoDB 用于存储文档的标识符存在一些困难。事实上,默认情况下,MongoDB 会为每个文档分配一个 _id 属性,作为集合中的唯一标识符。这导致了两个问题:
-
在 Pydantic 模型中,如果一个属性以下划线开头,则被认为是私有的,因此不会作为数据字段使用。
-
_id被编码为一个二进制对象,称为ObjectId,而不是简单的整数或字符串。它通常以类似608d1ee317c3f035100873dc的字符串形式表示。Pydantic 或 FastAPI 默认不支持这种类型的对象。
这就是为什么我们需要一些样板代码来确保这些标识符能够与 Pydantic 和 FastAPI 一起工作。首先,在下面的示例中,我们创建了一个 MongoBaseModel 基类,用来处理定义 id 字段:
models.py
class MongoBaseModel(BaseModel): id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
class Config:
json_encoders = {ObjectId: str}
首先,我们需要定义一个 id 字段,类型为 PyObjectId。这是一个在前面代码中定义的自定义类型。我们不会深入讨论它的实现细节,但只需知道它是一个类,使得 ObjectId 成为 Pydantic 兼容的类型。我们将此类定义为该字段的默认工厂。有趣的是,这种标识符允许我们在客户端生成它们,这与传统的关系型数据库中自动递增的整数不同,在某些情况下可能非常有用。
最有趣的参数是 alias。这是 Pydantic 的一个选项,允许我们在序列化过程中更改字段的名称。在这个例子中,当我们在 MongoBaseModel 的实例上调用 dict 方法时,标识符将被设置为 _id 键,这也是 MongoDB 所期望的名称。这解决了第一个问题。
接着,我们添加了 Config 子类并设置了 json_encoders 选项。默认情况下,Pydantic 完全不了解我们的 PyObjectId 类型,因此无法正确地将其序列化为 JSON。这个选项允许我们使用一个函数映射自定义类型,以便在序列化时调用它们。在这里,我们只是将其转换为字符串(因为 ObjectId 实现了 __str__ 魔法方法)。这解决了 Pydantic 的第二个问题。
我们的 Pydantic 基础模型已经完成!现在,我们可以将其作为 base 类,而不是 BaseModel,来创建我们实际的数据模型。然而请注意,PostPartialUpdate 并没有继承它。实际上,我们不希望这个模型中有 id 字段;否则,PATCH 请求可能会替换文档的 ID,进而导致奇怪的问题。
连接到数据库
现在我们的模型已经准备好,我们可以设置与 MongoDB 服务器的连接。这非常简单,仅涉及类的实例化,代码示例如下所示,在 database.py 模块中:
database.py
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase# Connection to the whole server
motor_client = AsyncIOMotorClient("mongodb://localhost:27017")
# Single database instance
database = motor_client["chapter06_mongo"]
def get_database() -> AsyncIOMotorDatabase:
return database
在这里,你可以看到 AsyncIOMotorClient 仅仅需要一个连接字符串来连接到你的数据库。通常,它由协议、后面的身份验证信息和数据库服务器的主机名组成。你可以在官方的 MongoDB 文档中查看这个格式的概述:docs.mongodb.com/manual/reference/connection-string/。
然而,要小心。与我们迄今讨论的库不同,这里实例化的客户端并没有绑定到任何数据库——也就是说,它只是一个与整个服务器的连接。因此,我们需要第二行代码:通过访问 chapter06_mongo 键,我们得到了一个数据库实例。值得注意的是,MongoDB 并不要求你提前创建数据库:如果数据库不存在,它会自动创建。
接着,我们创建一个简单的函数来返回这个数据库实例。我们将把这个函数作为依赖项,在路径操作函数中获取这个实例。我们在使用 SQLAlchemy 与 SQL 数据库通信 ORM 部分中解释了这种模式的好处。
就这样!我们现在可以对数据库执行查询了!
插入文档
我们将首先演示如何实现一个端点来创建帖子。本质上,我们只需要将转换后的 Pydantic 模型插入字典中:
app.py
@app.post("/posts", response_model=Post, status_code=status.HTTP_201_CREATED)async def create_post(
post_create: PostCreate, database: AsyncIOMotorDatabase = Depends(get_database)
) -> Post:
post = Post(**post_create.dict())
await database["posts"].insert_one(post.dict(by_alias=True))
post = await get_post_or_404(post.id, database)
return post
传统上,这是一个接受 PostCreate 模型格式负载的 POST 端点。此外,我们通过之前编写的依赖项注入数据库实例。
在路径操作本身中,你可以看到我们从 PostCreate 数据实例化了一个 Post。如果你有字段只在 Post 中出现并且需要初始化,这通常是一个好做法。
然后,我们有了查询。为了从我们的 MongoDB 数据库中检索一个集合,我们只需要通过名称像访问字典一样获取它。再次强调,如果该集合不存在,MongoDB 会自动创建它。如你所见,面向文档的数据库在架构方面比关系型数据库更加轻量!在这个集合中,我们可以调用 insert_one 方法插入单个文档。它期望一个字典来将字段映射到它们的值。因此,Pydantic 对象的 dict 方法再次成为我们的好朋友。然而,在这里,我们看到了一些新东西:我们用 by_alias 参数设置为 True 来调用它。默认情况下,Pydantic 会使用真实的字段名序列化对象,而不是别名。但是,我们确实需要 MongoDB 数据库中的 _id 标识符。使用这个选项,Pydantic 将使用别名作为字典中的键。
为了确保我们在字典中有一个真实且最新的文档表示,我们可以通过我们的 get_post_or_404 函数从数据库中检索一个。我们将在下一部分中查看这一点是如何工作的。
依赖关系就像函数
在这一部分中,我们使用 get_post_or_404 作为常规函数来检索我们新创建的博客帖子。这完全没问题:依赖项内部没有隐藏或魔法逻辑,因此你可以随意重用它们。唯一需要记住的是,由于你不在依赖注入的上下文中,因此必须手动提供每个参数。
获取文档
当然,从数据库中检索数据是 REST API 工作的重要部分。在这一部分中,我们将演示如何实现两个经典的端点——即列出帖子和获取单个帖子。让我们从第一个开始,看看它的实现:
app.py
@app.get("/posts", response_model=list[Post])async def list_posts(
pagination: tuple[int, int] = Depends(pagination),
database: AsyncIOMotorDatabase = Depends(get_database),
) -> list[Post]:
skip, limit = pagination
query = database["posts"].find({}, skip=skip, limit=limit)
results = [Post(**raw_post) async for raw_post in query]
return results
最有趣的部分是第二行,我们在这里定义了查询。在获取posts集合后,我们调用了find方法。第一个参数应该是过滤查询,遵循 MongoDB 语法。由于我们想要获取所有文档,所以将其留空。然后,我们有一些关键字参数,允许我们应用分页参数。
MongoDB 返回的是一个字典列表形式的结果,将字段映射到它们的值。这就是为什么我们添加了一个列表推导式结构来将它们转回Post实例——以便 FastAPI 能够正确序列化它们。
你可能注意到这里有些令人惊讶的地方:与我们通常做法不同,我们并没有直接等待查询。相反,我们在列表推导式中加入了async关键字。确实,在这种情况下,Motor 返回了async关键字,我们在遍历时必须加上它。
现在,让我们看一下获取单个帖子的端点。下面的示例展示了它的实现:
app.py
@app.get("/posts/{id}", response_model=Post)async def get_post(post: Post = Depends(get_post_or_404)) -> Post:
return post
如你所见,这是一个简单的GET端点,它接受id作为路径参数。大部分逻辑的实现都在可复用的get_post_or_404依赖中。你可以在这里查看它的实现:
app.py
async def get_post_or_404( id: ObjectId = Depends(get_object_id),
database: AsyncIOMotorDatabase = Depends(get_database),
) -> Post:
raw_post = await database["posts"].find_one({"_id": id})
if raw_post is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return Post(**raw_post)
逻辑与我们在列表端点中看到的非常相似。然而,这次我们调用了find_one方法,并使用查询来匹配帖子标识符:键是我们要过滤的文档属性的名称,值是我们正在寻找的值。
这个方法返回一个字典形式的文档,若不存在则返回None。在这种情况下,我们抛出一个适当的404错误。
最后,我们在返回之前将其转换回Post模型。
你可能已经注意到,我们是通过依赖get_object_id获取id的。实际上,FastAPI 会从路径参数中返回一个字符串。如果我们尝试用字符串形式的id进行查询,MongoDB 将无法与实际的二进制 ID 匹配。这就是为什么我们使用另一个依赖来将作为字符串表示的标识符(例如608d1ee317c3f035100873dc)转换为合适的ObjectId。
顺便提一下,这里有一个非常好的嵌套依赖的例子:端点使用get_post_or_404依赖,它本身从get_object_id获取一个值。你可以在下面的示例中查看这个依赖的实现:
app.py
async def get_object_id(id: str) -> ObjectId: try:
return ObjectId(id)
except (errors.InvalidId, TypeError):
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
在这里,我们只是从路径参数中提取id字符串,并尝试将其重新实例化为ObjectId。如果它不是一个有效值,我们会捕获相应的错误,并将其视为404错误。
这样,我们就解决了 MongoDB 标识符格式带来的所有挑战。现在,让我们讨论如何更新和删除文档。
更新和删除文档
现在我们将回顾更新和删除文档的端点。逻辑还是一样,只需要从请求负载构建适当的查询。
让我们从PATCH端点开始,您可以在以下示例中查看:
app.py
@app.patch("/posts/{id}", response_model=Post)async def update_post(
post_update: PostPartialUpdate,
post: Post = Depends(get_post_or_404),
database: AsyncIOMotorDatabase = Depends(get_database),
) -> Post:
await database["posts"].update_one(
{"_id": post.id}, {"$set": post_update.dict(exclude_unset=True)}
)
post = await get_post_or_404(post.id, database)
return post
在这里,我们使用update_one方法来更新一条文档。第一个参数是过滤查询,第二个参数是要应用于文档的实际操作。同样,它遵循 MongoDB 的语法:$set操作允许我们通过传递update字典,仅修改我们希望更改的字段。
DELETE端点更简单;它只是一个查询,您可以在以下示例中看到:
app.py
@app.delete("/posts/{id}", status_code=status.HTTP_204_NO_CONTENT)async def delete_post(
post: Post = Depends(get_post_or_404),
database: AsyncIOMotorDatabase = Depends(get_database),
):
await database["posts"].delete_one({"_id": post.id})
delete_one方法期望过滤查询作为第一个参数。
就是这样!当然,这里我们只是演示了最简单的查询类型,但 MongoDB 有一个非常强大的查询语言,允许你执行更复杂的操作。如果你不熟悉这个,我们建议你阅读官方文档中的精彩介绍:docs.mongodb.com/manual/crud。
嵌套文档
在本章开始时,我们提到过,与关系型数据库不同,基于文档的数据库旨在将与实体相关的所有数据存储在一个文档中。在我们当前的示例中,如果我们希望将评论与帖子一起存储,我们只需要添加一个列表,每个项目就是评论数据。
在本节中,我们将实现这一行为。你会看到 MongoDB 的工作方式使得这变得非常简单。
我们将从向Post模型添加一个新的comments属性开始。您可以在以下示例中查看:
models.py
class Post(PostBase): comments: list[Comment] = Field(default_factory=list)
这个字段只是一个Comment的列表。我们不会深入讨论评论模型,因为它们非常简单。请注意,我们使用list函数作为此属性的默认工厂。当我们创建一个没有设置评论的Post时,默认会实例化一个空列表。
现在我们已经有了模型,我们可以实现一个端点来创建新的评论。你可以在下面的示例中查看:
app.py
@app.post( "/posts/{id}/comments", response_model=Post, status_code=status.HTTP_201_CREATED
)
async def create_comment(
comment: CommentCreate,
post: Post = Depends(get_post_or_404),
database: AsyncIOMotorDatabase = Depends(get_database),
) -> Post:
await database["posts"].update_one(
{"_id": post.id}, {"$push": {"comments": comment.dict()}}
)
post = await get_post_or_404(post.id, database)
return post
正如我们之前所做的,我们将端点嵌套在单个帖子的路径下。因此,如果该帖子存在,我们可以重新使用get_post_or_404来检索我们要添加评论的帖子。
然后,我们触发一个update_one查询:这次,使用$push操作符。这个操作符对于向列表属性添加元素非常有用。也有可用的操作符用于从列表中移除元素。你可以在官方文档中找到每个update操作符的描述:docs.mongodb.com/manual/reference/operator/update/。
就这样!我们甚至不需要修改其余的代码。因为评论已经包含在整个文档中,当我们在数据库中查询帖子时,我们总是能够检索到它们。此外,我们的Post模型现在期待一个comments属性,所以 Pydantic 会自动处理它们的序列化。
这部分关于 MongoDB 的内容到此结束。你已经看到,它可以非常快速地集成到 FastAPI 应用中,特别是由于其非常灵活的架构。
总结
恭喜!你已经达到了掌握如何使用 FastAPI 构建 REST API 的另一个重要里程碑。正如你所知道的,数据库是每个系统中不可或缺的一部分;它们允许你以结构化的方式保存数据,并通过强大的查询语言精确而可靠地检索数据。现在,无论是关系型数据库还是文档导向型数据库,你都能在 FastAPI 中充分利用它们的强大功能。
现在可以进行更严肃的操作了;用户可以向你的系统发送和检索数据。然而,这也带来了一个新的挑战:这些数据需要受到保护,以确保它们能够保持私密和安全。这正是我们在下一章将讨论的内容:如何认证用户并为 FastAPI 配置最大安全性。
第七章:在 FastAPI 中管理身份验证和安全性
大多数时候,你不希望互联网上的每个人都能访问你的 API,而不对他们能创建或读取的数据设置任何限制。这就是为什么你至少需要用私有令牌保护你的应用程序,或者拥有一个合适的身份验证系统来管理授予每个用户的权限。在本章中,我们将看到 FastAPI 如何提供安全依赖,帮助我们通过遵循不同的标准来检索凭证,这些标准直接集成到自动文档中。我们还将构建一个基本的用户注册和身份验证系统来保护我们的 API 端点。
最后,我们将讨论当你想从浏览器中的 Web 应用程序调用 API 时需要解决的安全挑战——特别是 CORS 和 CSRF 攻击的风险。
在本章中,我们将讨论以下主要内容:
-
FastAPI 中的安全依赖
-
获取用户并生成访问令牌
-
为经过身份验证的用户保护 API 端点
-
使用访问令牌保护端点
-
配置 CORS 并防止 CSRF 攻击
技术要求
对于本章内容,你将需要一个 Python 虚拟环境,正如我们在第一章中设置的那样,Python 开发 环境设置。
你可以在专门的 GitHub 仓库中找到本章的所有代码示例,地址是github.com/PacktPublishing/Building-Data-Science-Applications-with-FastAPI-Second-Edition/tree/main/chapter07。
FastAPI 中的安全依赖
为了保护 REST API,以及更广泛的 HTTP 端点,已经提出了许多标准。以下是最常见的一些标准的非详尽列表:
-
Authorization。该值由Basic关键字组成,后跟以Base64编码的用户凭证。这是一种非常简单的方案,但并不太安全,因为密码会出现在每个请求中。 -
Cookies:Cookies 是一个在客户端(通常是在 Web 浏览器上)存储静态数据的有用方式,这些数据会在每次请求时发送到服务器。通常,一个 cookie 包含一个会话令牌,服务器可以验证并将其与特定用户关联。
-
Authorization头:在 REST API 上下文中,可能是最常用的头部,它仅仅是通过 HTTPAuthorization头发送一个令牌。该令牌通常以方法关键字(如Bearer)为前缀。在服务器端,可以验证这个令牌并将其与特定用户关联。
每个标准都有其优缺点,并适用于特定的使用场景。
如你所知,FastAPI 主要是关于依赖注入和可调用项,它们会在运行时被自动检测并调用。身份验证方法也不例外:FastAPI 默认提供了大部分的安全依赖。
首先,让我们学习如何从任意头部检索访问令牌。为此,我们可以使用ApiKeyHeader依赖项,如下例所示:
chapter07_api_key_header.py
from fastapi import Depends, FastAPI, HTTPException, statusfrom fastapi.security import APIKeyHeader
API_TOKEN = "SECRET_API_TOKEN"
app = FastAPI()
api_key_header = APIKeyHeader(name="Token")
@app.get("/protected-route")
async def protected_route(token: str = Depends(api_key_header)):
if token != API_TOKEN:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return {"hello": "world"}
在这个简单的示例中,我们硬编码了一个令牌API_TOKEN,并检查头部传递的令牌是否等于这个令牌,之后才授权调用端点。为了做到这一点,我们使用了APIKeyHeader安全依赖项,它专门用于从头部检索值。它是一个类依赖项,可以通过参数实例化。特别地,它接受name参数,该参数保存它将要查找的头部名称。
然后,在我们的端点中,我们注入了这个依赖项来获取令牌的值。如果它等于我们的令牌常量,我们就继续执行端点逻辑。否则,我们抛出403错误。
我们在《第五章》中的路径、路由器和全局依赖项部分的示例,FastAPI 中的依赖注入,与这个示例并没有太大不同。我们只是从一个任意的头部中检索值并进行等式检查。那么,为什么要使用专门的依赖项呢?有两个原因:
-
首先,检查头部是否存在并检索其值的逻辑包含在
APIKeyHeader中。当你到达端点时,可以确定已检索到令牌值;否则,将抛出403错误。 -
第二个,可能也是最重要的,事情是它被 OpenAPI 架构检测到,并包含在其交互式文档中。这意味着使用此依赖项的端点将显示一个锁定图标,表示这是一个受保护的端点。此外,你将能够访问一个界面来输入你的令牌,如下图所示。令牌将自动包含在你从文档发出的请求中:
图 7.1 – 在交互式文档中的令牌授权
当然,你可以将检查令牌值的逻辑封装在自己的依赖项中,以便在各个端点之间重用,如下例所示:
chapter07_api_key_header_dependency.py
async def api_token(token: str = Depends(APIKeyHeader(name="Token"))): if token != API_TOKEN:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
@app.get("/protected-route", dependencies=[Depends(api_token)])
async def protected_route():
return {"hello": "world"}
这种依赖关系非常适合用作路由器或全局依赖项,以保护一组路由,正如我们在第五章《FastAPI 中的依赖注入》中看到的那样。
这是为你的 API 添加授权的一个非常基本的示例。在这个示例中,我们没有用户管理;我们只是检查令牌是否与一个常量值对应。虽然它对于不打算由最终用户调用的私有微服务来说可能有用,但这种方法不应被认为非常安全。
首先,确保你的 API 始终通过 HTTPS 提供服务,以确保令牌不会在头部暴露。然后,如果这是一个私有微服务,你还应该考虑不要公开暴露它到互联网上,并确保只有受信任的服务器才能调用它。由于你不需要用户向这个服务发起请求,因此它比一个简单的令牌密钥要安全得多,因为后者可能会被盗取。
当然,大多数情况下,你会希望通过用户自己的个人访问令牌来验证真实用户,从而让他们访问自己的数据。你可能已经使用过实现这种典型模式的服务:
-
首先,你必须在该服务上注册一个账户,通常是通过提供你的电子邮件地址和密码。
-
接下来,你可以使用相同的电子邮件地址和密码登录该服务。该服务会检查电子邮件地址是否存在以及密码是否有效。
-
作为交换,服务会为你提供一个会话令牌,可以在后续请求中使用它来验证身份。这样,你就不需要在每次请求时都提供电子邮件地址和密码,这样既麻烦又危险。通常,这种会话令牌有一个有限的生命周期,这意味着一段时间后你需要重新登录。这可以减少会话令牌被盗时的安全风险。
在下一部分,你将学习如何实现这样的系统。
将用户及其密码安全地存储在数据库中
将用户实体存储在数据库中与存储任何其他实体并没有区别,你可以像在第六章《数据库和异步 ORM》中一样实现它。你必须特别小心的唯一事项就是密码存储。你绝不能将密码以明文形式存储在数据库中。为什么?如果不幸地,某个恶意的人成功进入了你的数据库,他们将能够获取所有用户的密码。由于许多人会在多个地方使用相同的密码,他们在其他应用程序和网站上的账户安全将受到严重威胁。
为了避免这种灾难,我们可以对密码应用加密哈希函数。这些函数的目标是将密码字符串转换为哈希值。设计这个的目的是让从哈希值中恢复原始数据几乎不可能。因此,即使你的数据库被入侵,密码依然安全。
当用户尝试登录时,我们只需计算他们输入的密码的哈希值,并将其与我们数据库中的哈希值进行比较。如果匹配,则意味着密码正确。
现在,让我们学习如何使用 FastAPI 和 SQLAlchemy ORM 来实现这样的系统。
创建模型
我们从为用户创建 SQLAlchemy ORM 模型开始,如下所示:
models.py
class User(Base): __tablename__ = "users"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
email: Mapped[str] = mapped_column(
String(1024), index=True, unique=True, nullable=False
)
hashed_password: Mapped[str] = mapped_column(String(1024), nullable=False)
为了简化这个示例,我们在模型中仅考虑了 ID、电子邮件地址和密码。请注意,我们对 email 列添加了唯一约束,以确保数据库中不会有重复的电子邮件。
接下来,我们可以实现相应的 Pydantic 模式:
schemas.py
class UserBase(BaseModel): email: EmailStr
class Config:
orm_mode = True
class UserCreate(UserBase):
password: str
class User(UserBase):
id: int
hashed_password: str
class UserRead(UserBase):
id: int
如你所见,UserCreate 和 User 之间有一个主要区别:前者接受我们在注册时会进行哈希处理的明文密码,而后者仅会在数据库中保留哈希后的密码。我们还会确保在 UserRead 中不包含 hashed_password,因此哈希值不会出现在 API 响应中。尽管哈希数据应当是不可解读的,但一般不建议泄露这些数据。
哈希密码
在我们查看注册端点之前,让我们先实现一些用于哈希密码的重要工具函数。幸运的是,已有一些库提供了最安全、最高效的算法来完成这项任务。在这里,我们将使用 passlib。你可以安装它以及 argon2_cffi,这是写作时最安全的哈希函数之一:
(venv) $ pip install passlib argon2_cffi
现在,我们只需要实例化 passlib 类,并封装它们的一些函数,以简化我们的工作:
password.py
from passlib.context import CryptContextpwd_context = CryptContext(schemes=["argon2"], deprecated="auto")
def get_password_hash(password: str) -> str:
return pwd_context.hash(password)
CryptContext 是一个非常有用的类,因为它允许我们使用不同的哈希算法。如果有一天,出现比 argon2 更好的算法,我们只需将其添加到我们的允许的模式中。新密码将使用新算法进行哈希,但现有密码仍然可以识别(并可选择升级为新算法)。
实现注册路由
现在,我们具备了创建合适注册路由的所有要素。再次强调,它将与我们之前看到的非常相似。唯一需要记住的是,在将密码插入到数据库之前,我们必须先对其进行哈希处理。
让我们看一下实现:
app.py
@app.post( "/register", status_code=status.HTTP_201_CREATED, response_model=schemas.UserRead
)
async def register(
user_create: schemas.UserCreate, session: AsyncSession = Depends(get_async_session)
) -> User:
hashed_password = get_password_hash(user_create.password)
user = User(
*user_create.dict(exclude={"password"}), hashed_password=hashed_password
)
try:
session.add(user)
await session.commit()
except exc.IntegrityError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Email already exists"
)
return user
如你所见,我们在将用户插入数据库之前,对输入的密码调用了get_password_hash。请注意,我们捕获了可能出现的exc.IntegrityError异常,这意味着我们正在尝试插入一个已存在的电子邮件。
此外,请注意我们设置了response_model为UserRead。通过这样做,我们确保hashed_password不会出现在输出中。
太棒了!我们现在有了一个合适的用户模型,用户可以通过我们的 API 创建新账户。下一步是允许用户登录并为其提供访问令牌。
获取用户并生成访问令牌
在成功注册后,下一步是能够登录:用户将发送其凭证并接收一个身份验证令牌,以访问 API。在这一部分,我们将实现允许此操作的端点。基本上,我们将从请求有效载荷中获取凭证,使用给定的电子邮件检索用户并验证其密码。如果用户存在且密码有效,我们将生成一个访问令牌并将其返回在响应中。
实现数据库访问令牌
首先,让我们思考一下这个访问令牌的性质。它应该是一个数据字符串,能够唯一标识一个用户,并且无法被恶意第三方伪造。在这个示例中,我们将采用一种简单但可靠的方法:我们将生成一个随机字符串,并将其存储在数据库中的专用表中,同时设置外键引用到用户。
这样,当一个经过身份验证的请求到达时,我们只需检查令牌是否存在于数据库中,并寻找相应的用户。这个方法的优势是令牌是集中管理的,如果它们被泄露,可以轻松作废;我们只需要从数据库中删除它们。
第一步是为这个新实体实现 SQLAlchemy ORM 模型:
models.py
class AccessToken(Base): __tablename__ = "access_tokens"
access_token: Mapped[str] = mapped_column(
String(1024), primary_key=True, default=generate_token
)
user_id: Mapped[int] = mapped_column(ForeignKey("users.id"), nullable=False)
expiration_date: Mapped[datetime] = mapped_column(
DateTime, nullable=False, default=get_expiration_date
)
user: Mapped[User] = relationship("User", lazy="joined")
我们定义了三个列:
-
access_token:这是将在请求中传递以进行身份验证的字符串。请注意,我们将generate_token函数定义为默认工厂;它是一个简单的先前定义的函数,用于生成随机安全密码。在底层,它依赖于标准的secrets模块。 -
user_id:指向users表的外键,用于标识与此令牌对应的用户。 -
expiration_date:访问令牌将到期并且不再有效的日期和时间。为访问令牌设置到期日期总是一个好主意,以减轻其被盗的风险。在这里,get_expiration_date工厂设置了默认的有效期为 24 小时。
我们还不要忘记定义关系,这样我们可以直接从访问令牌对象访问用户实体。请注意,我们默认设置了一种急加载策略,因此在查询访问令牌时始终检索用户。如果需要其背后的原理,请参阅第六章**,数据库和异步 ORM中的添加关系部分。
在这里我们不需要 Pydantic 模式,因为访问令牌将通过特定方法创建和序列化。
实现登录端点
现在,让我们考虑登录端点。其目标是接收请求有效载荷中的凭据,检索相应的用户,检查密码并生成新的访问令牌。除了一个事项外,它的实现非常直接:用于处理请求的模型。通过下面的示例你将明白为什么:
app.py
@app.post("/token")async def create_token(
form_data: OAuth2PasswordRequestForm = Depends(OAuth2PasswordRequestForm),
session: AsyncSession = Depends(get_async_session),
):
email = form_data.username
password = form_data.password
user = await authenticate(email, password, session)
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
token = await create_access_token(user, session)
return {"access_token": token.access_token, "token_type": "bearer"}
正如您所见,我们通过 FastAPI 的安全模块中提供的OAuth2PasswordRequestForm模块检索请求数据。它期望在表单编码中有几个字段,特别是username和password,而不是 JSON。
为什么我们要使用这个类?使用这个类的主要好处是它完全集成到 OpenAPI 模式中。这意味着交互式文档能够自动检测到它,并在授权按钮后显示适当的身份验证表单,如下面的截图所示:
图 7.2 – 交互式文档中的 OAuth2 授权
但这还不是全部:它还能自动获取返回的访问令牌,并在后续请求中设置正确的授权头。身份验证过程由交互式文档透明处理。
这个类遵循 OAuth2 协议,这意味着你还需要包含客户端 ID 和密钥字段。我们不会在这里学习如何实现完整的 OAuth2 协议,但请注意,FastAPI 提供了所有正确实现它所需的工具。对于我们的项目,我们将只使用用户名和密码。请注意,根据协议,字段被命名为用户名,无论我们是使用电子邮件地址来识别用户与否。这不是大问题,我们只需在获取它时记住这一点。
剩下的路径操作函数相当简单:首先,我们尝试根据电子邮件和密码获取用户。如果没有找到相应的用户,我们将抛出401错误。否则,我们会在返回之前生成一个新的访问令牌。请注意,响应结构中还包括token_type属性。这使得交互式文档能够自动设置授权头。
在下面的示例中,我们将查看authenticate和create_access_token函数的实现。我们不会深入细节,因为它们非常简单:
authentication.py
async def authenticate(email: str, password: str, session: AsyncSession) -> User | None: query = select(User).where(User.email == email)
result = await session.execute(query)
user: User | None = result.scalar_one_or_none()
if user is None:
return None
if not verify_password(password, user.hashed_password):
return None
return user
async def create_access_token(user: User, session: AsyncSession) -> AccessToken:
access_token = AccessToken(user=user)
session.add(access_token)
await session.commit()
return access_token
请注意,我们定义了一个名为verify_password的函数来检查密码的有效性。再一次,它在后台使用passlib,该库负责比较密码的哈希值。
密码哈希升级
为了简化示例,我们实现了一个简单的密码比较。通常,最好在这个阶段实现一个机制来升级密码哈希。假设引入了一个新的、更强大的哈希算法。我们可以借此机会使用这个新算法对密码进行哈希处理并将其存储在数据库中。passlib包含一个函数,可以在一次操作中验证和升级哈希。你可以通过以下文档了解更多内容:passlib.readthedocs.io/en/stable/narr/context-tutorial.html#integrating-hash-migration。
我们几乎达成了目标!用户现在可以登录并获取新的访问令牌。接下来,我们只需要实现一个依赖项来检索Authorization头并验证这个令牌!
使用访问令牌保护端点
之前,我们学习了如何实现一个简单的依赖项来保护带有头部的端点。在这里,我们也会从请求头中获取令牌,但接下来,我们需要检查数据库,看看它是否有效。如果有效,我们将返回相应的用户。
让我们看看我们的依赖项是什么样子的:
app.py
async def get_current_user( token: str = Depends(OAuth2PasswordBearer(tokenUrl="/token")),
session: AsyncSession = Depends(get_async_session),
) -> User:
query = select(AccessToken).where(
AccessToken.access_token == token,
AccessToken.expiration_date >= datetime.now(tz=timezone.utc),
)
result = await session.execute(query)
access_token: AccessToken | None = result.scalar_one_or_none()
if access_token is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
return access_token.user
首先需要注意的是,我们使用了来自 FastAPI 的 OAuth2PasswordBearer 依赖项。它与我们在前一节中看到的 OAuth2PasswordRequestForm 配合使用。它不仅检查 Authorization 头中的访问令牌,还告知 OpenAPI 架构获取新令牌的端点是 /token。这就是 tokenUrl 参数的目的。通过这一点,自动化文档可以自动调用我们之前看到的登录表单中的访问令牌端点。
然后我们使用 SQLAlchemy 执行了数据库查询。我们应用了两个条件:一个用于匹配我们获得的令牌,另一个确保过期时间是在未来。如果在数据库中找不到相应的记录,我们会抛出一个 401 错误。否则,我们会返回与访问令牌相关的用户。
就这样!我们的整个身份验证系统完成了。现在,我们可以通过简单地注入这个依赖项来保护我们的端点。我们甚至可以访问用户数据,从而根据当前用户量身定制响应。你可以在以下示例中看到这一点:
app.py
@app.get("/protected-route", response_model=schemas.UserRead)async def protected_route(user: User = Depends(get_current_user)):
return user
至此,你已经学会了如何从头开始实现完整的注册和身份验证系统。我们故意保持其简单,以便专注于最重要的点,但这为你扩展提供了一个良好的基础。
我们在这里展示的模式是适合 REST API 的良好范例,这些 API 是由其他客户端程序外部调用的。然而,你可能希望通过一个非常常见的软件来调用你的 API:浏览器。在这种情况下,有一些额外的安全考虑需要处理。
配置 CORS 并防止 CSRF 攻击
如今,许多软件都被设计为通过使用 HTML、CSS 和 JavaScript 构建的界面在浏览器中使用。传统上,Web 服务器负责处理浏览器请求并返回 HTML 响应,以供用户查看。这是 Django 等框架的常见用例。
近年来,随着 JavaScript 框架如 Angular、React 和 Vue 的出现,这一模式正在发生变化。我们现在往往会看到前端和后端的明确分离,前端是一个由 JavaScript 驱动的高度互动的用户界面,后端则负责数据存储、检索以及执行业务逻辑。这是 REST API 擅长的任务!从 JavaScript 代码中,用户界面可以向你的 API 发送请求并处理结果进行展示。
然而,我们仍然需要处理身份验证:我们希望用户能够登录前端应用,并向 API 发送经过身份验证的请求。虽然如我们到目前为止看到的Authorization头可以工作,但在浏览器中处理身份验证有一个更好的方法:Cookies!
Cookies(浏览器 Cookie)旨在将用户信息存储在浏览器内存中,并在每次请求发送到你的服务器时自动发送。多年来它们得到了支持,浏览器也集成了很多机制来确保它们的安全和可靠。
然而,这也带来了一些安全挑战。网站是黑客的常见攻击目标,多年来已经出现了很多攻击方式。
最典型的攻击方式之一是跨站请求伪造(CSRF)。在这种情况下,攻击者会在另一个网站上尝试欺骗当前已在你的应用程序中认证的用户,向你的服务器发起请求。由于浏览器通常会在每次请求时发送 Cookie,你的服务器无法识别出请求实际上是伪造的。由于这些恶意请求是用户自己无意中发起的,因此这类攻击并不旨在窃取数据,而是执行改变应用状态的操作,比如更改电子邮件地址或进行转账。
显然,我们应该为这些风险做好准备,并采取措施来减轻它们。
理解 CORS 并在 FastAPI 中进行配置
当你有一个明确分离的前端应用和 REST API 后端时,它们通常不会来自同一个子域。例如,前端可能来自www.myapplication.com,而 REST API 来自api.myapplication.com。正如我们在介绍中提到的,我们希望从前端应用程序通过 JavaScript 向该 API 发起请求。
然而,浏览器不允许跨域****资源共享(CORS)的HTTP 请求,即域 A 无法向域 B 发起请求。这遵循了所谓的同源策略。一般来说,这是一个好事,因为它是防止 CSRF 攻击的第一道屏障。
为了体验这种行为,我们将运行一个简单的例子。在我们的示例仓库中,chapter07/cors 文件夹包含一个名为 app_without_cors.py 的 FastAPI 应用程序和一个简单的 HTML 文件 index.html,该文件包含一些用于执行 HTTP 请求的 JavaScript。
首先,让我们使用通常的 uvicorn 命令运行 FastAPI 应用程序:
(venv) $ uvicorn chapter07.cors.app_without_cors:app
这将默认启动 FastAPI 应用程序,端口为 8000。在另一个终端中,我们将使用内置的 Python HTTP 服务器提供 HTML 文件。它是一个简单的服务器,但非常适合快速提供静态文件。我们可以通过以下命令在端口 9000 启动它:
(venv) $ python -m http.server --directory chapter07/cors 9000
启动多个终端
在 Linux 和 macOS 上,你可以通过创建一个新的窗口或标签页来启动一个新的终端。在 Windows 和 WSL 上,如果你使用 Windows 终端应用程序,也可以有多个标签页:apps.microsoft.com/store/detail/windows-terminal/9N0DX20HK701。
否则,你可以简单地点击 开始 菜单中的 Ubuntu 快捷方式来启动另一个终端。
现在我们有两个正在运行的服务器——一个在 localhost:8000,另一个在 localhost:9000。严格来说,由于它们在不同的端口上,它们属于不同的源;因此,这是一个很好的设置,可以尝试跨源 HTTP 请求。
在你的浏览器中,访问 http://localhost:9000。你会看到在 index.html 中实现的简单应用程序,如下图所示:
图 7.3 – 尝试 CORS 策略的简单应用
有两个按钮,可以向我们的 FastAPI 应用程序发起 GET 和 POST 请求,端口为 8000。如果点击其中任意一个按钮,你会在错误区域看到一条消息,显示 获取失败。如果你查看开发者工具中的浏览器控制台,你会发现请求失败的原因是没有 CORS 策略,正如下图所示。这正是我们想要的——默认情况下,浏览器会阻止跨源 HTTP 请求:
图 7.4 – 浏览器控制台中的 CORS 错误
但是,如果你查看正在运行 FastAPI 应用程序的终端,你会看到类似下面的输出:
图 7.5 – 执行简单请求时的 Uvicorn 输出
显然,GET 和 POST 请求已经接收并处理:我们甚至返回了 200 状态。那么,这意味着什么呢?在这种情况下,浏览器确实会将请求发送到服务器。缺乏 CORS 策略只会禁止它读取响应;请求仍然会执行。
这是浏览器认为是 GET、POST 或 HEAD 方法的请求,它们不设置自定义头部或不使用不常见的内容类型的情况。你可以通过访问以下 MDN 页面了解更多关于简单请求及其条件的信息:developer.mozilla.org/en-US/docs/Web/HTTP/CORS#simple_requests。
这意味着,对于简单请求来说,相同来源策略不足以保护我们免受 CSRF 攻击。
你可能已经注意到,我们的简单 Web 应用再次提供了 GET 和 POST 请求的切换功能。在你的 FastAPI 终端中,应该会看到类似于以下的输出:
图 7.6 – Uvicorn 在接收到预检请求时的输出
如你所见,我们的服务器接收到了两个奇怪的 OPTIONS 请求。这是我们所说的具有 application/json 值的 Content-Type 头部,这违反了简单请求的条件。
通过执行这个预检请求,浏览器期望服务器提供有关它可以和不可以执行的跨源 HTTP 请求的信息。由于我们这里没有实现任何内容,我们的服务器无法对这个预检请求作出响应。因此,浏览器在这里停止,并且不会继续执行实际请求。
这基本上就是 CORS:服务器用一组 HTTP 头部响应预检请求,提供关于浏览器是否可以发起请求的信息。从这个意义上说,CORS 并不会让你的应用更安全,恰恰相反:它放宽了一些规则,使得前端应用可以向另一个域上的后端发起请求。因此,正确配置 CORS 至关重要,以免暴露于危险的攻击之下。
幸运的是,使用 FastAPI 这非常简单。我们需要做的就是导入并添加由 Starlette 提供的 CORSMiddleware 类。你可以在以下示例中看到它的实现:
app_with_cors.py
app.add_middleware( CORSMiddleware,
allow_origins=["http://localhost:9000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
max_age=-1, # Only for the sake of the example. Remove this in your own project.
)
中间件是一种特殊的类,它将全局逻辑添加到 add_middleware 方法中,将这种中间件集成到你的应用中。
在这里,CORSMiddleware 会捕获浏览器发送的预检请求,并返回带有与你配置相对应的 CORS 头部的适当响应。你可以看到,有选项可以精细调整 CORS 策略以满足你的需求。
最重要的可能是 allow_origins,它是允许向你的 API 发起请求的源列表。由于我们的 HTML 应用是从 http://localhost:9000 提供的,因此我们在此参数中填写该地址。如果浏览器尝试从任何其他源发出请求,它将被阻止,因为 CORS 头部不允许。
另一个有趣的参数是 allow_credentials。默认情况下,浏览器不会为跨域 HTTP 请求发送 cookies。如果我们希望向 API 发出认证请求,需要通过此选项来允许此操作。
我们还可以精细调节请求中允许的 HTTP 方法和头部。你可以在官方 Starlette 文档中找到此中间件的完整参数列表:www.starlette.io/middleware/#corsmiddleware。
让我们简要讨论一下 max_age 参数。此参数允许你控制 CORS 响应的缓存时长。在实际请求之前执行预检请求是一个昂贵的操作。为了提高性能,浏览器可以缓存响应,以避免每次都执行此操作。在此,我们将缓存禁用,设置值为 -1,以确保你在这个示例中看到浏览器的行为。在你的项目中,可以删除此参数,以便设置适当的缓存值。
现在,让我们看看启用了 CORS 的应用程序如何在我们的 Web 应用中表现。停止之前的 FastAPI 应用,并使用常规命令运行此应用:
(venv) $ uvicorn chapter07.cors.app_with_cors:app
现在,如果你尝试从 HTML 应用执行请求,你应该会在每种情况下看到有效的响应,无论是否使用 JSON 内容类型。如果你查看 FastAPI 的终端,你应该会看到类似于以下内容的输出:
图 7.7 – 启用 CORS 头的 Uvicorn 输出
前两个请求是“简单请求”,根据浏览器规则,这些请求无需预检请求。接着,我们可以看到启用了 JSON 内容类型的请求。在 GET 和 POST 请求之前,执行了一个 OPTIONS 请求:即预检请求!
多亏了这个配置,你现在可以在前端应用和位于另一个源的后端之间进行跨域 HTTP 请求。再次强调,这并不能提升应用的安全性,但它允许你在确保应用其余部分安全的同时,使这个特定场景得以正常运行。
即便这些策略可以作为抵御 CSRF 的第一道防线,但并不能完全消除风险。事实上,“简单请求”仍然是一个问题:POST 请求是允许的,尽管响应不能被读取,但实际上它是在服务器上执行的。
现在,让我们学习如何实现一种模式,以确保我们完全避免此类攻击:双重提交 Cookie。
实现双重提交 Cookie 以防止 CSRF 攻击
如前所述,当依赖 Cookies 存储用户凭据时,我们容易遭受 CSRF 攻击,因为浏览器会自动将 Cookie 发送到你的服务器。这对于浏览器认为的“简单请求”尤其如此,因为在请求执行之前不会强制执行 CORS 策略。还有其他攻击向量,例如传统的 HTML 表单提交,甚至是图片标签的 src 属性。
由于这些原因,我们需要额外的安全层来缓解这种风险。再次强调,这仅在你计划通过浏览器应用使用 API 并使用 Cookies 进行身份验证时才是必要的。
为了帮助你理解这一点,我们构建了一个新的示例应用程序,使用 Cookie 存储用户访问令牌。它与我们在本章开头看到的应用非常相似;我们只是修改了它,使其从 Cookie 中获取访问令牌,而不是从请求头中获取。
为了使这个示例生效,你需要安装 starlette-csrf 库。我们稍后会解释它的作用。现在,只需运行以下命令:
(venv) $ pip install starlette-csrf
在以下示例中,你可以看到设置了包含访问令牌值的 Cookie 的登录端点:
app.py
@app.post("/login")async def login(
response: Response,
email: str = Form(...),
password: str = Form(...),
session: AsyncSession = Depends(get_async_session),
):
user = await authenticate(email, password, session)
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
token = await create_access_token(user, session)
response.set_cookie(
TOKEN_COOKIE_NAME,
token.access_token,
max_age=token.max_age(),
secure=True,
httponly=True,
samesite="lax",
)
请注意,我们为生成的 Cookie 使用了 Secure 和 HttpOnly 标志。这确保了该 Cookie 仅通过 HTTPS 发送,并且其值不能通过 JavaScript 读取。虽然这不足以防止所有类型的攻击,但对于这种敏感信息来说至关重要。
除此之外,我们还将 SameSite 标志设置为 lax。这是一个相对较新的标志,允许我们控制 Cookie 在跨源上下文中如何发送。lax 是大多数浏览器中的默认值,它允许将 Cookie 发送到 Cookie 域的子域名,但不允许发送到其他站点。从某种意义上讲,它是为防范 CSRF 攻击设计的标准内置保护。然而,目前仍然需要其他 CSRF 缓解技术,比如我们将在此实现的技术。实际上,仍有一些旧版浏览器不兼容 SameSite 标志,依然存在漏洞。
现在,当检查已认证用户时,我们只需从请求中发送的 Cookie 中提取令牌。再次强调,FastAPI 提供了一个安全依赖项,帮助实现这一功能,名为 APIKeyCookie。你可以在以下示例中看到它:
app.py
async def get_current_user( token: str = Depends(APIKeyCookie(name=TOKEN_COOKIE_NAME)),
session: AsyncSession = Depends(get_async_session),
) -> User:
query = select(AccessToken).where(
AccessToken.access_token == token,
AccessToken.expiration_date >= datetime.now(tz=timezone.utc),
)
result = await session.execute(query)
access_token: AccessToken | None = result.scalar_one_or_none()
if access_token is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)
return access_token.user
基本上就是这样!其余的代码保持不变。现在,让我们实现一个端点,允许我们更新经过身份验证的用户的电子邮件地址。您可以在以下示例中看到:
app.py
@app.post("/me", response_model=schemas.UserRead)async def update_me(
user_update: schemas.UserUpdate,
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_async_session),
):
user_update_dict = user_update.dict(exclude_unset=True)
for key, value in user_update_dict.items():
setattr(user, key, value)
session.add(user)
await session.commit()
return user
这个实现并不令人惊讶,遵循了我们到目前为止所见的方式。然而,它使我们暴露于 CSRF 威胁中。如您所见,它使用了POST方法。如果我们在浏览器中向该端点发出没有任何特殊头部的请求,它会将其视为普通请求并执行。因此,攻击者可能会更改当前已验证用户的电子邮件地址,这是一个重大威胁。
这正是我们在这里需要 CSRF 保护的原因。在 REST API 的上下文中,最直接的技术是双重提交 cookie 模式。其工作原理如下:
-
用户首先发出一个被认为是安全的方法的请求,通常是一个
GET请求。 -
在响应中,它接收一个包含随机秘密值的 cookie——即 CSRF 令牌。
-
当发出不安全请求时,例如
POST,用户会从 cookie 中读取 CSRF 令牌,并将相同的值放入请求头中。由于浏览器还会发送内存中存储的 cookie,请求将同时在 cookie 和请求头中包含该令牌。这就是为什么称之为双重提交。 -
在处理请求之前,服务器将比较请求头中提供的 CSRF 令牌与 cookie 中存在的令牌。如果匹配,它将继续处理请求。否则,它将抛出一个错误。
这是安全的,原因有二:
-
针对第三方网站的攻击者无法读取他们没有所有权的域名的 cookie。因此,他们无法检索到 CSRF 令牌的值。
-
添加自定义头部违反了“简单请求”的条件。因此,浏览器在发送请求之前必须进行预检请求,从而强制执行 CORS 策略。
这是一个广泛使用的模式,在防止此类风险方面效果良好。这也是为什么我们在本节开始时安装了starlette-csrf:它提供了一个中间件来实现这一点。
我们可以像使用其他中间件一样使用它,以下示例演示了这一点:
app.py
app.add_middleware( CSRFMiddleware,
secret=CSRF_TOKEN_SECRET,
sensitive_cookies={TOKEN_COOKIE_NAME},
cookie_domain="localhost",
)
我们在这里设置了几个重要的参数。首先,我们有一个密钥,它应该是一个强密码,用于签名 CSRF 令牌。然后,我们有sensitive_cookies,这是一个包含应该触发 CSRF 保护的 cookie 名称的集合。如果没有 cookie,或者提供的 cookie 不是关键性的,我们可以绕过 CSRF 检查。如果你有其他的认证方法(如不依赖于 cookie 的授权头),这也很有用,因为这些方法不容易受到 CSRF 攻击。最后,设置 cookie 域名将允许你在不同的子域上获取包含 CSRF 令牌的 cookie;这在跨源情况下是必要的。
这就是你需要准备的必要保护。为了简化获取新 CSRF 令牌的过程,我们实现了一个最小的 GET 端点,叫做/csrf。它的唯一目的是提供一个简单的方式来设置 CSRF 令牌的 cookie。我们可以在加载前端应用时直接调用它。
现在,让我们在我们的环境中试试。正如我们在上一节中所做的那样,我们将会在两个不同的端口上运行 FastAPI 应用程序和简单的 HTML 应用程序。为此,只需运行以下命令:
(venv) $ uvicorn chapter07.csrf.app:app
这将会在8000端口上运行 FastAPI 应用程序。现在,运行以下命令:
(venv) $ python -m http.server --directory chapter07/csrf 9000
前端应用程序现在可以在http://localhost:9000访问。打开它在浏览器中,你应该看到一个类似于以下界面的界面:
图 7.8 – 尝试 CSRF 保护 API 的简单应用
在这里,我们添加了表单来与 API 端点交互:注册、登录获取认证用户,以及更新端点。如果你尝试这些,它们应该没有问题。如果你查看发送的请求,可以看到x-csrftoken中包含了 CSRF 令牌。
在顶部,有一个开关可以防止应用程序在头部发送 CSRF 令牌。如果你禁用它,你会看到所有的POST操作都会导致错误。
太好了!我们现在已经防止了 CSRF 攻击!这里的大部分工作是由中间件完成的,但理解它是如何在后台工作的,以及它如何保护你的应用程序,是很有意思的。然而,请记住,它有一个缺点:它会破坏交互式文档。实际上,它并没有设计成从 cookie 中检索 CSRF 令牌并将其放入每个请求的头部。除非你计划以其他方式进行认证(例如通过头部中的令牌),否则你将无法在文档中直接调用你的端点。
总结
本章内容就到这里,主要介绍了 FastAPI 中的认证和安全性。我们看到,借助 FastAPI 提供的工具,实现一个基本的认证系统是相当简单的。我们展示了一种实现方法,但还有许多其他不错的模式可以用来解决这个问题。然而,在处理这些问题时,始终要牢记安全性,并确保不会将应用程序和用户数据暴露于危险的威胁之中。特别地,我们已经看到,在设计将在浏览器应用中使用的 REST API 时,必须考虑防止 CSRF 攻击。理解 Web 应用程序中所有安全风险的一个好资源是 OWASP Cheat Sheet 系列:cheatsheetseries.owasp.org。
至此,我们已经涵盖了关于 FastAPI 应用开发的大部分重要主题。在下一章中,我们将学习如何使用与 FastAPI 集成的最新技术——WebSockets,它允许客户端和服务器之间进行实时、双向通信。