Python类型提示与静态分析:超越文档的代码契约

42 阅读6分钟

一、基础类型提示:从简单到复杂

基本类型与容器

from typing import List, Dict, Tuple, Set, Optional, Union, Any
from collections.abc import Sequence, Mapping, Iterable
from dataclasses import dataclass
import datetime

# 基本类型提示
def process_numbers(numbers: List[int]) -> float:
    """处理整数列表"""
    return sum(numbers) / len(numbers) if numbers else 0.0

# 使用泛型(Python 3.9+)
def process_numbers_modern(numbers: list[int]) -> float:
    """现代写法:直接使用内置类型"""
    return sum(numbers) / len(numbers) if numbers else 0.0

# 可选类型和默认值
def find_item(items: List[str], target: str) -> Optional[int]:
    """返回索引或None"""
    try:
        return items.index(target)
    except ValueError:
        return None

# 联合类型
def parse_value(value: Union[str, int, float]) -> float:
    """处理多种类型的输入"""
    if isinstance(value, str):
        return float(value)
    return float(value)

# 字面量类型
from typing import Literal

HttpMethod = Literal["GET", "POST", "PUT", "DELETE", "PATCH"]

def make_request(
    method: HttpMethod,
    url: str,
    data: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
    """只能使用预定义的方法"""
    print(f"Making {method} request to {url}")
    return {"status": "success"}

# 类型别名
Vector = List[float]
Matrix = List[Vector]

def matrix_multiply(a: Matrix, b: Matrix) -> Matrix:
    """矩阵乘法"""
    return [
        [
            sum(a[i][k] * b[k][j] for k in range(len(b)))
            for j in range(len(b[0]))
        ]
        for i in range(len(a))
    ]

# 使用示例
nums: List[int] = [1, 2, 3, 4, 5]
result: float = process_numbers(nums)
print(f"平均值: {result}")

# 类型检查器会发现错误
# process_numbers(["a", "b"])  # 类型不匹配

二、泛型与类型变量:创建可重用的类型

自定义泛型类

from typing import TypeVar, Generic, Iterator
from abc import ABC, abstractmethod

# 类型变量
T = TypeVar('T')  # 任意类型
U = TypeVar('U')  # 另一个任意类型
Num = TypeVar('Num', int, float, complex)  # 约束类型

class Stack(Generic[T]):
    """泛型栈"""
    def __init__(self) -> None:
        self._items: List[T] = []
    
    def push(self, item: T) -> None:
        self._items.append(item)
    
    def pop(self) -> T:
        if not self._items:
            raise IndexError("Pop from empty stack")
        return self._items.pop()
    
    def peek(self) -> T:
        if not self._items:
            raise IndexError("Peek from empty stack")
        return self._items[-1]
    
    def __len__(self) -> int:
        return len(self._items)
    
    def __iter__(self) -> Iterator[T]:
        return iter(self._items)

# 使用示例
int_stack: Stack[int] = Stack()
int_stack.push(1)
int_stack.push(2)
print(f"栈顶: {int_stack.peek()}")

str_stack: Stack[str] = Stack()
str_stack.push("hello")
str_stack.push("world")

# 泛型函数
def first_item(items: Sequence[T]) -> T:
    """返回序列的第一个元素"""
    return items[0]

# 多个类型变量
K = TypeVar('K')  # 键类型
V = TypeVar('V')  # 值类型

def reverse_dict(d: Dict[K, V]) -> Dict[V, K]:
    """反转字典(假设值唯一)"""
    return {v: k for k, v in d.items()}

# 约束泛型
def add_numbers(a: Num, b: Num) -> Num:
    """只能用于数值类型"""
    return a + b

print(add_numbers(1, 2))     # 3
print(add_numbers(3.14, 2.5))  # 5.64
# add_numbers("a", "b")  # 类型错误

三、Protocol:结构性子类型

定义和使用协议

from typing import Protocol, runtime_checkable
from dataclasses import dataclass

# 定义协议(接口)
class Drawable(Protocol):
    """可绘制对象的协议"""
    def draw(self) -> None:
        ...
    
    @property
    def area(self) -> float:
        ...
    
    def scale(self, factor: float) -> None:
        ...

# 实现协议的类
@dataclass
class Circle:
    radius: float
    
    def draw(self) -> None:
        print(f"绘制圆形,半径: {self.radius}")
    
    @property
    def area(self) -> float:
        return 3.14159 * self.radius ** 2
    
    def scale(self, factor: float) -> None:
        self.radius *= factor

@dataclass
class Rectangle:
    width: float
    height: float
    
    def draw(self) -> None:
        print(f"绘制矩形,{self.width}×{self.height}")
    
    @property
    def area(self) -> float:
        return self.width * self.height
    
    def scale(self, factor: float) -> None:
        self.width *= factor
        self.height *= factor

# 使用协议的函数
def render_all(shapes: List[Drawable]) -> None:
    """渲染所有可绘制对象"""
    for shape in shapes:
        shape.draw()
        print(f"面积: {shape.area}")

# 运行时可检查的协议
@runtime_checkable
class Serializable(Protocol):
    def to_json(self) -> str:
        ...
    
    @classmethod
    def from_json(cls, json_str: str) -> 'Serializable':
        ...

def save_objects(objects: List[Serializable]) -> str:
    """保存可序列化对象"""
    import json
    data = [obj.to_json() for obj in objects]
    return json.dumps(data)

# 回调协议
class EventHandler(Protocol):
    def __call__(self, event: Dict[str, Any]) -> None:
        ...

class EventDispatcher:
    def __init__(self) -> None:
        self._handlers: List[EventHandler] = []
    
    def add_handler(self, handler: EventHandler) -> None:
        self._handlers.append(handler)
    
    def dispatch(self, event: Dict[str, Any]) -> None:
        for handler in self._handlers:
            handler(event)

# 使用示例
shapes: List[Drawable] = [
    Circle(5.0),
    Rectangle(3.0, 4.0)
]
render_all(shapes)

四、TypedDict:类型化的字典

定义结构化字典

from typing import TypedDict, NotRequired, Required
from datetime import datetime

# 基本TypedDict
class UserDict(TypedDict):
    id: int
    username: str
    email: str
    created_at: datetime
    is_active: bool

# 可选字段
class ProductDict(TypedDict, total=False):
    id: NotRequired[int]  # 可选
    name: Required[str]   # 必需
    price: float
    description: NotRequired[str]
    tags: NotRequired[List[str]]

# 继承和混合
class BaseConfig(TypedDict):
    host: str
    port: int

class DatabaseConfig(BaseConfig, total=False):
    database: str
    username: str
    password: str
    pool_size: NotRequired[int]

# 使用示例
def create_user(data: UserDict) -> None:
    """创建用户"""
    print(f"创建用户: {data['username']}")
    # 类型检查器知道data有这些键
    print(f"邮箱: {data['email']}")
    print(f"活跃: {data['is_active']}")

# 正确使用
user_data: UserDict = {
    'id': 1,
    'username': 'alice',
    'email': 'alice@example.com',
    'created_at': datetime.now(),
    'is_active': True
}
create_user(user_data)

# 类型检查器会捕获错误
# user_data['invalid_key']  # 错误:没有这个键
# user_data['id'] = "not a number"  # 错误:类型不匹配

# 动态创建TypedDict
def create_response_typed(
    success: bool,
    data: Optional[Dict[str, Any]] = None,
    message: Optional[str] = None
) -> Dict[str, Any]:
    """创建API响应"""
    ResponseDict = TypedDict('ResponseDict', {
        'success': bool,
        'data': NotRequired[Dict[str, Any]],
        'message': NotRequired[str],
        'timestamp': datetime
    })
    
    response: ResponseDict = {
        'success': success,
        'timestamp': datetime.now()
    }
    
    if data:
        response['data'] = data
    if message:
        response['message'] = message
    
    return response

五、高级类型:重载、泛型和自引用

类型重载

from typing import overload, TypeVar, Type, Any
from decimal import Decimal

# 函数重载
@overload
def parse_number(value: str) -> Decimal: ...

@overload
def parse_number(value: int) -> int: ...

@overload
def parse_number(value: float) -> Decimal: ...

def parse_number(value: Any) -> Any:
    """根据输入类型解析数字"""
    if isinstance(value, str):
        return Decimal(value)
    elif isinstance(value, int):
        return value
    elif isinstance(value, float):
        # 避免浮点精度问题
        return Decimal(str(value))
    else:
        raise TypeError(f"不支持的类型: {type(value)}")

# 使用示例
decimal_val = parse_number("3.14")  # 类型: Decimal
int_val = parse_number(42)          # 类型: int

# 自引用类型
from typing import ForwardRef

class TreeNode:
    """树节点,引用自身类型"""
    def __init__(self, value: Any, children: Optional[List['TreeNode']] = None):
        self.value = value
        self.children = children or []
    
    def add_child(self, node: 'TreeNode') -> None:
        self.children.append(node)
    
    def find(self, value: Any) -> Optional['TreeNode']:
        """查找节点"""
        if self.value == value:
            return self
        for child in self.children:
            result = child.find(value)
            if result:
                return result
        return None

# 使用ForwardRef处理复杂引用
class GraphNode:
    def __init__(self, id: str):
        self.id = id
        self.neighbors: List['GraphNode'] = []

# 泛型自引用
A = TypeVar('A', bound='LinkedNode')

class LinkedNode(Generic[A]):
    def __init__(self, value: Any, next: Optional[A] = None):
        self.value = value
        self.next = next
    
    def traverse(self) -> Iterator[Any]:
        """遍历链表"""
        current: Optional[LinkedNode] = self
        while current:
            yield current.value
            current = current.next

# 使用示例
head = LinkedNode[int](1, LinkedNode[int](2, LinkedNode[int](3)))
print(list(head.traverse()))  # [1, 2, 3]

六、运行时类型检查与验证

Pydantic集成

from pydantic import BaseModel, Field, validator, root_validator
from typing import Optional, List
from datetime import datetime
import re

# Pydantic模型
class UserModel(BaseModel):
    id: int = Field(gt=0, description="用户ID必须为正数")
    username: str = Field(min_length=3, max_length=50)
    email: str = Field(regex=r'^[\w.-]+@[\w.-]+.\w+$')
    age: Optional[int] = Field(None, ge=0, le=150)
    tags: List[str] = Field(default_factory=list)
    created_at: datetime = Field(default_factory=datetime.now)
    
    @validator('username')
    def username_alphanumeric(cls, v):
        """验证用户名只包含字母数字"""
        if not v.isalnum():
            raise ValueError('用户名必须只包含字母和数字')
        return v
    
    @validator('age')
    def validate_age(cls, v, values):
        """年龄验证(可以访问其他字段)"""
        if v is not None and 'username' in values:
            # 示例:特定用户的年龄限制
            if values['username'] == 'admin' and v < 18:
                raise ValueError('管理员必须年满18岁')
        return v
    
    @root_validator
    def validate_model(cls, values):
        """整体验证"""
        if 'email' in values and 'username' in values:
            # 检查用户名是否在邮箱中
            username = values['username']
            email = values['email']
            if username not in email:
                raise ValueError('用户名必须在邮箱地址中')
        return values
    
    class Config:
        # 配置选项
        anystr_strip_whitespace = True  # 自动去除空格
        use_enum_values = True          # 使用枚举值
        extra = 'forbid'                # 禁止额外字段
        json_encoders = {
            datetime: lambda v: v.isoformat()
        }

# 使用示例
try:
    user = UserModel(
        id=1,
        username="alice123",
        email="alice123@example.com",
        age=25,
        tags=["premium", "active"]
    )
    print(user.json(indent=2))
    
    # 转换为字典
    user_dict = user.dict()
    print(f"用户字典: {user_dict}")
    
    # 从JSON创建
    json_str = '{"id": 2, "username": "bob", "email": "bob@example.com"}'
    user2 = UserModel.parse_raw(json_str)
    print(f"从JSON创建: {user2}")
    
except Exception as e:
    print(f"验证错误: {e}")

# 嵌套模型
class PostModel(BaseModel):
    id: int
    title: str = Field(min_length=1, max_length=200)
    content: str
    author: UserModel
    tags: List[str] = Field(default_factory=list)
    created_at: datetime = Field(default_factory=datetime.now)
    updated_at: Optional[datetime] = None
    
    @validator('tags')
    def validate_tags(cls, v):
        """验证标签"""
        if len(v) > 10:
            raise ValueError('最多10个标签')
        return [tag.lower().strip() for tag in v]

# 复杂验证示例
class BusinessRules(BaseModel):
    """业务规则验证"""
    discount_code: Optional[str] = None
    total_amount: float = Field(gt=0)
    items: List[Dict[str, Any]]
    
    @root_validator
    def validate_discount(cls, values):
        """验证折扣码逻辑"""
        discount_code = values.get('discount_code')
        total_amount = values.get('total_amount')
        
        if discount_code:
            if not discount_code.startswith('DISCOUNT_'):
                raise ValueError('无效的折扣码格式')
            
            # 折扣码有最低消费要求
            if total_amount < 100:
                raise ValueError('折扣码需要最低消费100元')
        
        return values

七、mypy高级配置与插件

mypy配置与自定义规则

# mypy.ini 配置文件示例
"""
[mypy]
python_version = 3.9
warn_return_any = True
warn_unused_configs = True
disallow_untyped_defs = True
disallow_incomplete_defs = True
check_untyped_defs = True
disallow_untyped_decorators = True
no_implicit_optional = True
warn_redundant_casts = True
warn_unused_ignores = True
warn_no_return = True
warn_unreachable = True
strict_equality = True

# 插件
plugins = pydantic.mypy

# 模块特定配置
[mypy-pydantic.*]
ignore_missing_imports = True

[mypy-tests.*]
ignore_errors = True
disallow_untyped_defs = False
"""

# 自定义mypy插件示例
from typing import Optional, Callable, Type
from mypy.plugin import Plugin, MethodContext
from mypy.types import Type as MypyType
from mypy.nodes import Expression

class CustomPlugin(Plugin):
    """自定义mypy插件"""
    
    def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], MypyType]]:
        """拦截方法调用进行类型检查"""
        if fullname == "app.models.User.validate":
            return self.validate_user_hook
        return None
    
    def validate_user_hook(self, ctx: MethodContext) -> MypyType:
        """用户验证方法的类型钩子"""
        # 检查参数类型
        args = ctx.args
        if len(args) != 1:
            ctx.api.fail("validate方法需要一个参数", ctx.context)
        
        # 可以在这里进行自定义类型检查
        return ctx.default_return_type

# 使用mypy API进行程序化检查
import subprocess
import json
from pathlib import Path

def run_mypy_analysis(file_path: Path) -> Dict[str, Any]:
    """运行mypy分析并获取结果"""
    result = subprocess.run(
        ["mypy", str(file_path), "--no-error-summary", "--show-column-numbers"],
        capture_output=True,
        text=True
    )
    
    errors = []
    for line in result.stdout.split('\n'):
        if line.strip():
            errors.append(line)
    
    return {
        "exit_code": result.returncode,
        "errors": errors,
        "stderr": result.stderr
    }

def generate_type_coverage_report(directory: Path) -> Dict[str, Any]:
    """生成类型覆盖率报告"""
    result = subprocess.run(
        ["mypy", str(directory), "--no-error-summary", "--json"],
        capture_output=True,
        text=True
    )
    
    if result.stdout:
        data = json.loads(result.stdout)
        
        # 分析结果
        total_lines = 0
        typed_lines = 0
        files = {}
        
        for item in data:
            if 'path' in item:
                rel_path = Path(item['path']).relative_to(directory)
                files[str(rel_path)] = {
                    'typed_lines': item.get('typed_lines', 0),
                    'total_lines': item.get('total_lines', 0),
                    'coverage': item.get('typed_lines', 0) / max(item.get('total_lines', 1), 1)
                }
                typed_lines += item.get('typed_lines', 0)
                total_lines += item.get('total_lines', 0)
        
        overall_coverage = typed_lines / max(total_lines, 1) if total_lines > 0 else 0
        
        return {
            "overall_coverage": overall_coverage,
            "typed_lines": typed_lines,
            "total_lines": total_lines,
            "files": files
        }
    
    return {}

# 使用示例
if __name__ == "__main__":
    # 检查单个文件
    file_path = Path(__file__)
    result = run_mypy_analysis(file_path)
    print(f"mypy退出码: {result['exit_code']}")
    
    if result['errors']:
        print("发现类型错误:")
        for error in result['errors']:
            print(f"  {error}")
    
    # 生成覆盖率报告
    project_dir = Path.cwd()
    report = generate_type_coverage_report(project_dir)
    print(f"\n类型覆盖率: {report['overall_coverage']:.1%}")

八、实际应用:类型安全的API设计

from typing import Generic, TypeVar, Optional, List
from fastapi import FastAPI, HTTPException, Depends
from pydantic import BaseModel, Field
import asyncpg

# 泛型响应模型
T = TypeVar('T')

class ApiResponse(BaseModel, Generic[T]):
    success: bool
    data: Optional[T] = None
    error: Optional[str] = None
    message: Optional[str] = None

# 类型安全的CRUD操作
class Repository(Generic[T]):
    """泛型仓库模式"""
    def __init__(self, model: Type[T], db_pool: asyncpg.Pool):
        self.model = model
        self.db_pool = db_pool
        self.table_name = model.__name__.lower()
    
    async def create(self, data: Dict[str, Any]) -> T:
        """创建记录"""
        columns = ', '.join(data.keys())
        placeholders = ', '.join(f'${i+1}' for i in range(len(data)))
        
        async with self.db_pool.acquire() as conn:
            query = f"""
                INSERT INTO {self.table_name} ({columns})
                VALUES ({placeholders})
                RETURNING *
            """
            row = await conn.fetchrow(query, *data.values())
            return self.model(**dict(row))
    
    async def find_by_id(self, id: int) -> Optional[T]:
        """根据ID查找"""
        async with self.db_pool.acquire() as conn:
            row = await conn.fetchrow(
                f"SELECT * FROM {self.table_name} WHERE id = $1",
                id
            )
            return self.model(**dict(row)) if row else None

# FastAPI应用
app = FastAPI(title="类型安全API示例")

# 数据库模型
class User(BaseModel):
    id: int
    username: str = Field(min_length=3)
    email: str = Field(regex=r'^[\w.-]+@[\w.-]+.\w+$')
    is_active: bool = True

class CreateUserRequest(BaseModel):
    username: str = Field(min_length=3)
    email: str = Field(regex=r'^[\w.-]+@[\w.-]+.\w+$')
    password: str = Field(min_length=8)

# API端点
@app.post("/users", response_model=ApiResponse[User])
async def create_user(
    request: CreateUserRequest,
    repo: Repository[User] = Depends(get_user_repository)
) -> ApiResponse[User]:
    """创建用户"""
    try:
        # 密码哈希处理(示例)
        hashed_password = hash_password(request.password)
        
        user_data = request.dict(exclude={'password'})
        user_data['password_hash'] = hashed_password
        
        user = await repo.create(user_data)
        return ApiResponse[User](success=True, data=user)
    
    except asyncpg.exceptions.UniqueViolationError:
        raise HTTPException(400, "用户名或邮箱已存在")
    except Exception as e:
        return ApiResponse[User](success=False, error=str(e))

@app.get("/users/{user_id}", response_model=ApiResponse[User])
async def get_user(
    user_id: int,
    repo: Repository[User] = Depends(get_user_repository)
) -> ApiResponse[User]:
    """获取用户"""
    user = await repo.find_by_id(user_id)
    if not user:
        return ApiResponse[User](success=False, error="用户不存在")
    
    return ApiResponse[User](success=True, data=user)

核心优势:

  1. 代码文档化:类型即文档
  2. 错误预防:在编码时捕获类型错误
  3. IDE支持:更好的自动完成和重构
  4. 维护性:使大型代码库更易维护

关键特性:

  1. 基础类型intstrListDict
  2. 高级类型UnionOptionalLiteralTypedDict
  3. 泛型TypeVarGeneric, 自定义泛型类
  4. 协议:结构性子类型,鸭子类型
  5. 运行时验证:Pydantic集成