一、基础类型提示:从简单到复杂
基本类型与容器
from typing import List, Dict, Tuple, Set, Optional, Union, Any
from collections.abc import Sequence, Mapping, Iterable
from dataclasses import dataclass
import datetime
def process_numbers(numbers: List[int]) -> float:
"""处理整数列表"""
return sum(numbers) / len(numbers) if numbers else 0.0
def process_numbers_modern(numbers: list[int]) -> float:
"""现代写法:直接使用内置类型"""
return sum(numbers) / len(numbers) if numbers else 0.0
def find_item(items: List[str], target: str) -> Optional[int]:
"""返回索引或None"""
try:
return items.index(target)
except ValueError:
return None
def parse_value(value: Union[str, int, float]) -> float:
"""处理多种类型的输入"""
if isinstance(value, str):
return float(value)
return float(value)
from typing import Literal
HttpMethod = Literal["GET", "POST", "PUT", "DELETE", "PATCH"]
def make_request(
method: HttpMethod,
url: str,
data: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""只能使用预定义的方法"""
print(f"Making {method} request to {url}")
return {"status": "success"}
Vector = List[float]
Matrix = List[Vector]
def matrix_multiply(a: Matrix, b: Matrix) -> Matrix:
"""矩阵乘法"""
return [
[
sum(a[i][k] * b[k][j] for k in range(len(b)))
for j in range(len(b[0]))
]
for i in range(len(a))
]
nums: List[int] = [1, 2, 3, 4, 5]
result: float = process_numbers(nums)
print(f"平均值: {result}")
二、泛型与类型变量:创建可重用的类型
自定义泛型类
from typing import TypeVar, Generic, Iterator
from abc import ABC, abstractmethod
T = TypeVar('T')
U = TypeVar('U')
Num = TypeVar('Num', int, float, complex)
class Stack(Generic[T]):
"""泛型栈"""
def __init__(self) -> None:
self._items: List[T] = []
def push(self, item: T) -> None:
self._items.append(item)
def pop(self) -> T:
if not self._items:
raise IndexError("Pop from empty stack")
return self._items.pop()
def peek(self) -> T:
if not self._items:
raise IndexError("Peek from empty stack")
return self._items[-1]
def __len__(self) -> int:
return len(self._items)
def __iter__(self) -> Iterator[T]:
return iter(self._items)
int_stack: Stack[int] = Stack()
int_stack.push(1)
int_stack.push(2)
print(f"栈顶: {int_stack.peek()}")
str_stack: Stack[str] = Stack()
str_stack.push("hello")
str_stack.push("world")
def first_item(items: Sequence[T]) -> T:
"""返回序列的第一个元素"""
return items[0]
K = TypeVar('K')
V = TypeVar('V')
def reverse_dict(d: Dict[K, V]) -> Dict[V, K]:
"""反转字典(假设值唯一)"""
return {v: k for k, v in d.items()}
def add_numbers(a: Num, b: Num) -> Num:
"""只能用于数值类型"""
return a + b
print(add_numbers(1, 2))
print(add_numbers(3.14, 2.5))
三、Protocol:结构性子类型
定义和使用协议
from typing import Protocol, runtime_checkable
from dataclasses import dataclass
class Drawable(Protocol):
"""可绘制对象的协议"""
def draw(self) -> None:
...
@property
def area(self) -> float:
...
def scale(self, factor: float) -> None:
...
@dataclass
class Circle:
radius: float
def draw(self) -> None:
print(f"绘制圆形,半径: {self.radius}")
@property
def area(self) -> float:
return 3.14159 * self.radius ** 2
def scale(self, factor: float) -> None:
self.radius *= factor
@dataclass
class Rectangle:
width: float
height: float
def draw(self) -> None:
print(f"绘制矩形,{self.width}×{self.height}")
@property
def area(self) -> float:
return self.width * self.height
def scale(self, factor: float) -> None:
self.width *= factor
self.height *= factor
def render_all(shapes: List[Drawable]) -> None:
"""渲染所有可绘制对象"""
for shape in shapes:
shape.draw()
print(f"面积: {shape.area}")
@runtime_checkable
class Serializable(Protocol):
def to_json(self) -> str:
...
@classmethod
def from_json(cls, json_str: str) -> 'Serializable':
...
def save_objects(objects: List[Serializable]) -> str:
"""保存可序列化对象"""
import json
data = [obj.to_json() for obj in objects]
return json.dumps(data)
class EventHandler(Protocol):
def __call__(self, event: Dict[str, Any]) -> None:
...
class EventDispatcher:
def __init__(self) -> None:
self._handlers: List[EventHandler] = []
def add_handler(self, handler: EventHandler) -> None:
self._handlers.append(handler)
def dispatch(self, event: Dict[str, Any]) -> None:
for handler in self._handlers:
handler(event)
shapes: List[Drawable] = [
Circle(5.0),
Rectangle(3.0, 4.0)
]
render_all(shapes)
四、TypedDict:类型化的字典
定义结构化字典
from typing import TypedDict, NotRequired, Required
from datetime import datetime
class UserDict(TypedDict):
id: int
username: str
email: str
created_at: datetime
is_active: bool
class ProductDict(TypedDict, total=False):
id: NotRequired[int]
name: Required[str]
price: float
description: NotRequired[str]
tags: NotRequired[List[str]]
class BaseConfig(TypedDict):
host: str
port: int
class DatabaseConfig(BaseConfig, total=False):
database: str
username: str
password: str
pool_size: NotRequired[int]
def create_user(data: UserDict) -> None:
"""创建用户"""
print(f"创建用户: {data['username']}")
print(f"邮箱: {data['email']}")
print(f"活跃: {data['is_active']}")
user_data: UserDict = {
'id': 1,
'username': 'alice',
'email': 'alice@example.com',
'created_at': datetime.now(),
'is_active': True
}
create_user(user_data)
def create_response_typed(
success: bool,
data: Optional[Dict[str, Any]] = None,
message: Optional[str] = None
) -> Dict[str, Any]:
"""创建API响应"""
ResponseDict = TypedDict('ResponseDict', {
'success': bool,
'data': NotRequired[Dict[str, Any]],
'message': NotRequired[str],
'timestamp': datetime
})
response: ResponseDict = {
'success': success,
'timestamp': datetime.now()
}
if data:
response['data'] = data
if message:
response['message'] = message
return response
五、高级类型:重载、泛型和自引用
类型重载
from typing import overload, TypeVar, Type, Any
from decimal import Decimal
@overload
def parse_number(value: str) -> Decimal: ...
@overload
def parse_number(value: int) -> int: ...
@overload
def parse_number(value: float) -> Decimal: ...
def parse_number(value: Any) -> Any:
"""根据输入类型解析数字"""
if isinstance(value, str):
return Decimal(value)
elif isinstance(value, int):
return value
elif isinstance(value, float):
return Decimal(str(value))
else:
raise TypeError(f"不支持的类型: {type(value)}")
decimal_val = parse_number("3.14")
int_val = parse_number(42)
from typing import ForwardRef
class TreeNode:
"""树节点,引用自身类型"""
def __init__(self, value: Any, children: Optional[List['TreeNode']] = None):
self.value = value
self.children = children or []
def add_child(self, node: 'TreeNode') -> None:
self.children.append(node)
def find(self, value: Any) -> Optional['TreeNode']:
"""查找节点"""
if self.value == value:
return self
for child in self.children:
result = child.find(value)
if result:
return result
return None
class GraphNode:
def __init__(self, id: str):
self.id = id
self.neighbors: List['GraphNode'] = []
A = TypeVar('A', bound='LinkedNode')
class LinkedNode(Generic[A]):
def __init__(self, value: Any, next: Optional[A] = None):
self.value = value
self.next = next
def traverse(self) -> Iterator[Any]:
"""遍历链表"""
current: Optional[LinkedNode] = self
while current:
yield current.value
current = current.next
head = LinkedNode[int](1, LinkedNode[int](2, LinkedNode[int](3)))
print(list(head.traverse()))
六、运行时类型检查与验证
Pydantic集成
from pydantic import BaseModel, Field, validator, root_validator
from typing import Optional, List
from datetime import datetime
import re
class UserModel(BaseModel):
id: int = Field(gt=0, description="用户ID必须为正数")
username: str = Field(min_length=3, max_length=50)
email: str = Field(regex=r'^[\w.-]+@[\w.-]+.\w+$')
age: Optional[int] = Field(None, ge=0, le=150)
tags: List[str] = Field(default_factory=list)
created_at: datetime = Field(default_factory=datetime.now)
@validator('username')
def username_alphanumeric(cls, v):
"""验证用户名只包含字母数字"""
if not v.isalnum():
raise ValueError('用户名必须只包含字母和数字')
return v
@validator('age')
def validate_age(cls, v, values):
"""年龄验证(可以访问其他字段)"""
if v is not None and 'username' in values:
if values['username'] == 'admin' and v < 18:
raise ValueError('管理员必须年满18岁')
return v
@root_validator
def validate_model(cls, values):
"""整体验证"""
if 'email' in values and 'username' in values:
username = values['username']
email = values['email']
if username not in email:
raise ValueError('用户名必须在邮箱地址中')
return values
class Config:
anystr_strip_whitespace = True
use_enum_values = True
extra = 'forbid'
json_encoders = {
datetime: lambda v: v.isoformat()
}
try:
user = UserModel(
id=1,
username="alice123",
email="alice123@example.com",
age=25,
tags=["premium", "active"]
)
print(user.json(indent=2))
user_dict = user.dict()
print(f"用户字典: {user_dict}")
json_str = '{"id": 2, "username": "bob", "email": "bob@example.com"}'
user2 = UserModel.parse_raw(json_str)
print(f"从JSON创建: {user2}")
except Exception as e:
print(f"验证错误: {e}")
class PostModel(BaseModel):
id: int
title: str = Field(min_length=1, max_length=200)
content: str
author: UserModel
tags: List[str] = Field(default_factory=list)
created_at: datetime = Field(default_factory=datetime.now)
updated_at: Optional[datetime] = None
@validator('tags')
def validate_tags(cls, v):
"""验证标签"""
if len(v) > 10:
raise ValueError('最多10个标签')
return [tag.lower().strip() for tag in v]
class BusinessRules(BaseModel):
"""业务规则验证"""
discount_code: Optional[str] = None
total_amount: float = Field(gt=0)
items: List[Dict[str, Any]]
@root_validator
def validate_discount(cls, values):
"""验证折扣码逻辑"""
discount_code = values.get('discount_code')
total_amount = values.get('total_amount')
if discount_code:
if not discount_code.startswith('DISCOUNT_'):
raise ValueError('无效的折扣码格式')
if total_amount < 100:
raise ValueError('折扣码需要最低消费100元')
return values
七、mypy高级配置与插件
mypy配置与自定义规则
"""
[mypy]
python_version = 3.9
warn_return_any = True
warn_unused_configs = True
disallow_untyped_defs = True
disallow_incomplete_defs = True
check_untyped_defs = True
disallow_untyped_decorators = True
no_implicit_optional = True
warn_redundant_casts = True
warn_unused_ignores = True
warn_no_return = True
warn_unreachable = True
strict_equality = True
# 插件
plugins = pydantic.mypy
# 模块特定配置
[mypy-pydantic.*]
ignore_missing_imports = True
[mypy-tests.*]
ignore_errors = True
disallow_untyped_defs = False
"""
from typing import Optional, Callable, Type
from mypy.plugin import Plugin, MethodContext
from mypy.types import Type as MypyType
from mypy.nodes import Expression
class CustomPlugin(Plugin):
"""自定义mypy插件"""
def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], MypyType]]:
"""拦截方法调用进行类型检查"""
if fullname == "app.models.User.validate":
return self.validate_user_hook
return None
def validate_user_hook(self, ctx: MethodContext) -> MypyType:
"""用户验证方法的类型钩子"""
args = ctx.args
if len(args) != 1:
ctx.api.fail("validate方法需要一个参数", ctx.context)
return ctx.default_return_type
import subprocess
import json
from pathlib import Path
def run_mypy_analysis(file_path: Path) -> Dict[str, Any]:
"""运行mypy分析并获取结果"""
result = subprocess.run(
["mypy", str(file_path), "--no-error-summary", "--show-column-numbers"],
capture_output=True,
text=True
)
errors = []
for line in result.stdout.split('\n'):
if line.strip():
errors.append(line)
return {
"exit_code": result.returncode,
"errors": errors,
"stderr": result.stderr
}
def generate_type_coverage_report(directory: Path) -> Dict[str, Any]:
"""生成类型覆盖率报告"""
result = subprocess.run(
["mypy", str(directory), "--no-error-summary", "--json"],
capture_output=True,
text=True
)
if result.stdout:
data = json.loads(result.stdout)
total_lines = 0
typed_lines = 0
files = {}
for item in data:
if 'path' in item:
rel_path = Path(item['path']).relative_to(directory)
files[str(rel_path)] = {
'typed_lines': item.get('typed_lines', 0),
'total_lines': item.get('total_lines', 0),
'coverage': item.get('typed_lines', 0) / max(item.get('total_lines', 1), 1)
}
typed_lines += item.get('typed_lines', 0)
total_lines += item.get('total_lines', 0)
overall_coverage = typed_lines / max(total_lines, 1) if total_lines > 0 else 0
return {
"overall_coverage": overall_coverage,
"typed_lines": typed_lines,
"total_lines": total_lines,
"files": files
}
return {}
if __name__ == "__main__":
file_path = Path(__file__)
result = run_mypy_analysis(file_path)
print(f"mypy退出码: {result['exit_code']}")
if result['errors']:
print("发现类型错误:")
for error in result['errors']:
print(f" {error}")
project_dir = Path.cwd()
report = generate_type_coverage_report(project_dir)
print(f"\n类型覆盖率: {report['overall_coverage']:.1%}")
八、实际应用:类型安全的API设计
from typing import Generic, TypeVar, Optional, List
from fastapi import FastAPI, HTTPException, Depends
from pydantic import BaseModel, Field
import asyncpg
T = TypeVar('T')
class ApiResponse(BaseModel, Generic[T]):
success: bool
data: Optional[T] = None
error: Optional[str] = None
message: Optional[str] = None
class Repository(Generic[T]):
"""泛型仓库模式"""
def __init__(self, model: Type[T], db_pool: asyncpg.Pool):
self.model = model
self.db_pool = db_pool
self.table_name = model.__name__.lower()
async def create(self, data: Dict[str, Any]) -> T:
"""创建记录"""
columns = ', '.join(data.keys())
placeholders = ', '.join(f'${i+1}' for i in range(len(data)))
async with self.db_pool.acquire() as conn:
query = f"""
INSERT INTO {self.table_name} ({columns})
VALUES ({placeholders})
RETURNING *
"""
row = await conn.fetchrow(query, *data.values())
return self.model(**dict(row))
async def find_by_id(self, id: int) -> Optional[T]:
"""根据ID查找"""
async with self.db_pool.acquire() as conn:
row = await conn.fetchrow(
f"SELECT * FROM {self.table_name} WHERE id = $1",
id
)
return self.model(**dict(row)) if row else None
app = FastAPI(title="类型安全API示例")
class User(BaseModel):
id: int
username: str = Field(min_length=3)
email: str = Field(regex=r'^[\w.-]+@[\w.-]+.\w+$')
is_active: bool = True
class CreateUserRequest(BaseModel):
username: str = Field(min_length=3)
email: str = Field(regex=r'^[\w.-]+@[\w.-]+.\w+$')
password: str = Field(min_length=8)
@app.post("/users", response_model=ApiResponse[User])
async def create_user(
request: CreateUserRequest,
repo: Repository[User] = Depends(get_user_repository)
) -> ApiResponse[User]:
"""创建用户"""
try:
hashed_password = hash_password(request.password)
user_data = request.dict(exclude={'password'})
user_data['password_hash'] = hashed_password
user = await repo.create(user_data)
return ApiResponse[User](success=True, data=user)
except asyncpg.exceptions.UniqueViolationError:
raise HTTPException(400, "用户名或邮箱已存在")
except Exception as e:
return ApiResponse[User](success=False, error=str(e))
@app.get("/users/{user_id}", response_model=ApiResponse[User])
async def get_user(
user_id: int,
repo: Repository[User] = Depends(get_user_repository)
) -> ApiResponse[User]:
"""获取用户"""
user = await repo.find_by_id(user_id)
if not user:
return ApiResponse[User](success=False, error="用户不存在")
return ApiResponse[User](success=True, data=user)
核心优势:
- 代码文档化:类型即文档
- 错误预防:在编码时捕获类型错误
- IDE支持:更好的自动完成和重构
- 维护性:使大型代码库更易维护
关键特性:
- 基础类型:
int, str, List, Dict等
- 高级类型:
Union, Optional, Literal, TypedDict
- 泛型:
TypeVar, Generic, 自定义泛型类
- 协议:结构性子类型,鸭子类型
- 运行时验证:Pydantic集成