系统设计实战 185:185. 设计列式存储

0 阅读15分钟

🚀 系统设计实战 185:185. 设计列式存储

摘要:本文深入剖析系统的核心架构关键算法工程实践,提供完整的设计方案和面试要点。

你是否想过,设计列式存储背后的技术挑战有多复杂?

1. 需求分析

功能需求

  • 列式存储格式: 按列存储数据,优化分析查询
  • 压缩算法: 高效的列数据压缩
  • 查询优化: 支持列剪枝和谓词下推
  • OLAP支持: 优化聚合和分析查询
  • 分区策略: 支持数据分区和并行处理
  • 向量化执行: 批量数据处理优化
  • 元数据管理: 列统计信息和索引

非功能需求

  • 查询性能: 分析查询延迟小于秒级
  • 压缩比: 数据压缩比达到5:1以上
  • 扫描吞吐: 支持GB/s级数据扫描
  • 并发性: 支持多用户并发查询
  • 可扩展性: 支持PB级数据存储

2. 系统架构

整体架构

┌─────────────────────────────────────────────────────────┐
│                    Query Layer                          │
│  ┌─────────────┐ ┌─────────────┐ ┌─────────────┐       │
│  │    SQL      │ │  Analytics  │ │    BI       │       │
│  └─────────────┘ └─────────────┘ └─────────────┘       │
└─────────────────────────────────────────────────────────┘
                            │
┌─────────────────────────────────────────────────────────┐
│                 Execution Engine                        │
│  ┌─────────────┐ ┌─────────────┐ ┌─────────────┐       │
│  │Query Parser │ │ Optimizer   │ │ Vectorized  │       │
│  │             │ │             │ │ Executor    │       │
│  └─────────────┘ └─────────────┘ └─────────────┘       │
└─────────────────────────────────────────────────────────┘
                            │
┌─────────────────────────────────────────────────────────┐
│                  Storage Engine                         │
│  ┌─────────────┐ ┌─────────────┐ ┌─────────────┐       │
│  │Column Store │ │ Compression │ │ Index Mgr   │       │
│  └─────────────┘ └─────────────┘ └─────────────┘       │
└─────────────────────────────────────────────────────────┘
                            │
┌─────────────────────────────────────────────────────────┐
│                   File System                          │
│  ┌─────────────┐ ┌─────────────┐ ┌─────────────┐       │
│  │ Parquet     │ │    ORC      │ │   Delta     │       │
│  │   Files     │ │   Files     │ │   Files     │       │
│  └─────────────┘ └─────────────┘ └─────────────┘       │
└─────────────────────────────────────────────────────────┘

3. 核心组件设计

3.1 列式存储格式

// 时间复杂度:O(N),空间复杂度:O(1)

from dataclasses import dataclass
from typing import List, Dict, Any, Optional
import struct
import numpy as np

@dataclass
class ColumnChunk:
    column_name: str
    data_type: str
    values: List[Any]
    null_bitmap: Optional[bytes] = None
    statistics: Optional['ColumnStatistics'] = None
    encoding: str = 'PLAIN'
    compression: str = 'SNAPPY'
    
    def __post_init__(self):
        if self.statistics is None:
            self.statistics = self._compute_statistics()
    
    def _compute_statistics(self) -> 'ColumnStatistics':
        non_null_values = [v for v in self.values if v is not None]
        
        if not non_null_values:
            return ColumnStatistics(
                null_count=len(self.values),
                min_value=None,
                max_value=None,
                distinct_count=0
            )
        
        return ColumnStatistics(
            null_count=len(self.values) - len(non_null_values),
            min_value=min(non_null_values),
            max_value=max(non_null_values),
            distinct_count=len(set(non_null_values))
        )

@dataclass
class ColumnStatistics:
    null_count: int
    min_value: Any
    max_value: Any
    distinct_count: int

class ColumnarTable:
    def __init__(self, table_name: str):
        self.table_name = table_name
        self.columns = {}  # column_name -> ColumnChunk
        self.row_count = 0
        self.schema = {}  # column_name -> data_type
    
    def add_column(self, column_chunk: ColumnChunk):
        self.columns[column_chunk.column_name] = column_chunk
        self.schema[column_chunk.column_name] = column_chunk.data_type
        
        # 更新行数
        if self.row_count == 0:
            self.row_count = len(column_chunk.values)
        elif self.row_count != len(column_chunk.values):
            raise ValueError("Column length mismatch")
    
    def get_column(self, column_name: str) -> Optional[ColumnChunk]:
        return self.columns.get(column_name)
    
    def project(self, column_names: List[str]) -> 'ColumnarTable':
        """列投影操作"""
        projected_table = ColumnarTable(f"{self.table_name}_projected")
        
        for column_name in column_names:
            if column_name in self.columns:
                projected_table.add_column(self.columns[column_name])
        
        return projected_table
    
    def filter(self, predicate_func) -> 'ColumnarTable':
        """行过滤操作"""
        # 计算满足条件的行索引
        matching_indices = []
        for i in range(self.row_count):
            row_data = {col_name: chunk.values[i] 
                       for col_name, chunk in self.columns.items()}
            if predicate_func(row_data):
                matching_indices.append(i)
        
        # 创建过滤后的表
        filtered_table = ColumnarTable(f"{self.table_name}_filtered")
        
        for column_name, chunk in self.columns.items():
            filtered_values = [chunk.values[i] for i in matching_indices]
            filtered_chunk = ColumnChunk(
                column_name=column_name,
                data_type=chunk.data_type,
                values=filtered_values,
                encoding=chunk.encoding,
                compression=chunk.compression
            )
            filtered_table.add_column(filtered_chunk)
        
        return filtered_table

class ColumnarStorageEngine:
    def __init__(self, storage_path: str):
        self.storage_path = storage_path
        self.compressor = ColumnCompressor()
        self.encoder = ColumnEncoder()
        self.metadata_manager = MetadataManager(storage_path)
    
    def write_table(self, table: ColumnarTable) -> str:
        """写入列式表"""
        table_path = os.path.join(self.storage_path, f"{table.table_name}.col")
        
        with open(table_path, 'wb') as f:
            # 写入表头
            self._write_table_header(f, table)
            
            # 写入每一列
            column_offsets = {}
            for column_name, chunk in table.columns.items():
                offset = f.tell()
                column_offsets[column_name] = offset
                self._write_column_chunk(f, chunk)
            
            # 更新元数据
            self.metadata_manager.update_table_metadata(
                table.table_name, table.schema, column_offsets
            )
        
        return table_path
    
    def read_table(self, table_name: str, 
                   column_names: List[str] = None) -> ColumnarTable:
        """读取列式表"""
        table_path = os.path.join(self.storage_path, f"{table_name}.col")
        metadata = self.metadata_manager.get_table_metadata(table_name)
        
        if not metadata:
            raise ValueError(f"Table {table_name} not found")
        
        # 如果没有指定列,读取所有列
        if column_names is None:
            column_names = list(metadata['schema'].keys())
        
        table = ColumnarTable(table_name)
        
        with open(table_path, 'rb') as f:
            # 跳过表头
            self._skip_table_header(f)
            
            # 读取指定列
            for column_name in column_names:
                if column_name in metadata['column_offsets']:
                    offset = metadata['column_offsets'][column_name]
                    f.seek(offset)
                    chunk = self._read_column_chunk(f, column_name)
                    table.add_column(chunk)
        
        return table
    
    def _write_column_chunk(self, file, chunk: ColumnChunk):
        # 编码数据
        encoded_data = self.encoder.encode(chunk.values, chunk.encoding)
        
        # 压缩数据
        compressed_data = self.compressor.compress(encoded_data, chunk.compression)
        
        # 写入列头
        column_header = {
            'column_name': chunk.column_name,
            'data_type': chunk.data_type,
            'encoding': chunk.encoding,
            'compression': chunk.compression,
            'row_count': len(chunk.values),
            'compressed_size': len(compressed_data),
            'statistics': chunk.statistics
        }
        
        header_data = pickle.dumps(column_header)
        file.write(len(header_data).to_bytes(4, 'big'))
        file.write(header_data)
        
        # 写入压缩数据
        file.write(compressed_data)
    
    def _read_column_chunk(self, file, column_name: str) -> ColumnChunk:
        # 读取列头
        header_size = int.from_bytes(file.read(4), 'big')
        header_data = file.read(header_size)
        column_header = pickle.loads(header_data)
        
        # 读取压缩数据
        compressed_data = file.read(column_header['compressed_size'])
        
        # 解压缩
        encoded_data = self.compressor.decompress(
            compressed_data, column_header['compression']
        )
        
        # 解码
        values = self.encoder.decode(
            encoded_data, column_header['encoding'], column_header['data_type']
        )
        
        return ColumnChunk(
            column_name=column_name,
            data_type=column_header['data_type'],
            values=values,
            encoding=column_header['encoding'],
            compression=column_header['compression'],
            statistics=column_header['statistics']
        )

3.2 列数据编码和压缩

class ColumnEncoder:
    def encode(self, values: List[Any], encoding: str) -> bytes:
        if encoding == 'PLAIN':
            return self._plain_encoding(values)
        elif encoding == 'DICTIONARY':
            return self._dictionary_encoding(values)
        elif encoding == 'RLE':
            return self._run_length_encoding(values)
        elif encoding == 'DELTA':
            return self._delta_encoding(values)
        elif encoding == 'BIT_PACKED':
            return self._bit_packed_encoding(values)
        else:
            raise ValueError(f"Unsupported encoding: {encoding}")
    
    def decode(self, data: bytes, encoding: str, data_type: str) -> List[Any]:
        if encoding == 'PLAIN':
            return self._plain_decoding(data, data_type)
        elif encoding == 'DICTIONARY':
            return self._dictionary_decoding(data)
        elif encoding == 'RLE':
            return self._run_length_decoding(data)
        elif encoding == 'DELTA':
            return self._delta_decoding(data)
        elif encoding == 'BIT_PACKED':
            return self._bit_packed_decoding(data)
        else:
            raise ValueError(f"Unsupported encoding: {encoding}")
    
    def _dictionary_encoding(self, values: List[Any]) -> bytes:
        """字典编码:适用于重复值较多的列"""
        # 构建字典
        unique_values = list(set(values))
        value_to_id = {value: i for i, value in enumerate(unique_values)}
        
        # 编码值
        encoded_ids = [value_to_id[value] for value in values]
        
        # 序列化字典和编码后的ID
        result = bytearray()
        
        # 写入字典大小
        result.extend(len(unique_values).to_bytes(4, 'big'))
        
        # 写入字典
        for value in unique_values:
            value_bytes = pickle.dumps(value)
            result.extend(len(value_bytes).to_bytes(4, 'big'))
            result.extend(value_bytes)
        
        # 写入编码后的ID
        result.extend(len(encoded_ids).to_bytes(4, 'big'))
        for id_val in encoded_ids:
            result.extend(id_val.to_bytes(4, 'big'))
        
        return bytes(result)
    
    def _run_length_encoding(self, values: List[Any]) -> bytes:
        """行程编码:适用于连续重复值"""
        if not values:
            return b''
        
        result = bytearray()
        current_value = values[0]
        count = 1
        
        for i in range(1, len(values)):
            if values[i] == current_value:
                count += 1
            else:
                # 写入当前值和计数
                value_bytes = pickle.dumps(current_value)
                result.extend(len(value_bytes).to_bytes(4, 'big'))
                result.extend(value_bytes)
                result.extend(count.to_bytes(4, 'big'))
                
                current_value = values[i]
                count = 1
        
        # 写入最后一个值
        value_bytes = pickle.dumps(current_value)
        result.extend(len(value_bytes).to_bytes(4, 'big'))
        result.extend(value_bytes)
        result.extend(count.to_bytes(4, 'big'))
        
        return bytes(result)
    
    def _delta_encoding(self, values: List[Any]) -> bytes:
        """增量编码:适用于有序数值列"""
        if not values or not all(isinstance(v, (int, float)) for v in values):
            return self._plain_encoding(values)
        
        result = bytearray()
        
        # 写入第一个值
        first_value = values[0]
        result.extend(struct.pack('d', float(first_value)))
        
        # 写入增量
        for i in range(1, len(values)):
            delta = values[i] - values[i-1]
            result.extend(struct.pack('d', float(delta)))
        
        return bytes(result)
    
    def _bit_packed_encoding(self, values: List[Any]) -> bytes:
        """位打包编码:适用于小整数"""
        if not values or not all(isinstance(v, int) and v >= 0 for v in values):
            return self._plain_encoding(values)
        
        # 计算所需位数
        max_value = max(values)
        bits_per_value = max_value.bit_length()
        
        if bits_per_value > 32:
            return self._plain_encoding(values)
        
        result = bytearray()
        result.extend(bits_per_value.to_bytes(1, 'big'))
        result.extend(len(values).to_bytes(4, 'big'))
        
        # 位打包
        bit_buffer = 0
        bits_in_buffer = 0
        
        for value in values:
            bit_buffer = (bit_buffer << bits_per_value) | value
            bits_in_buffer += bits_per_value
            
            while bits_in_buffer >= 8:
                byte_value = (bit_buffer >> (bits_in_buffer - 8)) & 0xFF
                result.append(byte_value)
                bits_in_buffer -= 8
        
        # 处理剩余位
        if bits_in_buffer > 0:
            byte_value = (bit_buffer << (8 - bits_in_buffer)) & 0xFF
            result.append(byte_value)
        
        return bytes(result)

class ColumnCompressor:
    def __init__(self):
        self.compressors = {
            'SNAPPY': self._snappy_compress,
            'GZIP': self._gzip_compress,
            'LZ4': self._lz4_compress,
            'ZSTD': self._zstd_compress
        }
        
        self.decompressors = {
            'SNAPPY': self._snappy_decompress,
            'GZIP': self._gzip_decompress,
            'LZ4': self._lz4_decompress,
            'ZSTD': self._zstd_decompress
        }
    
    def compress(self, data: bytes, algorithm: str) -> bytes:
        if algorithm in self.compressors:
            return self.compressors[algorithm](data)
        else:
            return data  # 无压缩
    
    def decompress(self, data: bytes, algorithm: str) -> bytes:
        if algorithm in self.decompressors:
            return self.decompressors[algorithm](data)
        else:
            return data  # 无压缩
    
    def _snappy_compress(self, data: bytes) -> bytes:
        try:
            import snappy
            return snappy.compress(data)
        except ImportError:
            return data
    
    def _gzip_compress(self, data: bytes) -> bytes:
        import gzip
        return gzip.compress(data)
    
    def _lz4_compress(self, data: bytes) -> bytes:
        try:
            import lz4.frame
            return lz4.frame.compress(data)
        except ImportError:
            return data

3.3 向量化查询执行

class VectorizedExecutor:
    def __init__(self):
        self.batch_size = 1024
    
    def execute_query(self, query: AnalyticalQuery, table: ColumnarTable) -> QueryResult:
        # 1. 列剪枝
        projected_table = self._apply_projection(table, query.select_columns)
        
        # 2. 谓词下推
        filtered_table = self._apply_filters(projected_table, query.where_conditions)
        
        # 3. 聚合操作
        if query.group_by_columns or query.aggregations:
            result = self._apply_aggregation(filtered_table, query)
        else:
            result = filtered_table
        
        # 4. 排序
        if query.order_by:
            result = self._apply_sorting(result, query.order_by)
        
        # 5. 限制结果
        if query.limit:
            result = self._apply_limit(result, query.limit)
        
        return QueryResult(result)
    
    def _apply_projection(self, table: ColumnarTable, 
                         select_columns: List[str]) -> ColumnarTable:
        """列投影优化"""
        if not select_columns:
            return table
        
        return table.project(select_columns)
    
    def _apply_filters(self, table: ColumnarTable, 
                      conditions: List[FilterCondition]) -> ColumnarTable:
        """向量化过滤"""
        if not conditions:
            return table
        
        # 构建过滤位图
        filter_bitmap = np.ones(table.row_count, dtype=bool)
        
        for condition in conditions:
            column_bitmap = self._evaluate_condition_vectorized(table, condition)
            filter_bitmap = filter_bitmap & column_bitmap
        
        # 应用过滤
        return self._apply_bitmap_filter(table, filter_bitmap)
    
    def _evaluate_condition_vectorized(self, table: ColumnarTable, 
                                     condition: FilterCondition) -> np.ndarray:
        """向量化条件评估"""
        column = table.get_column(condition.column_name)
        if not column:
            return np.zeros(table.row_count, dtype=bool)
        
        # 转换为numpy数组进行向量化操作
        values = np.array(column.values)
        
        if condition.operator == '=':
            return values == condition.value
        elif condition.operator == '!=':
            return values != condition.value
        elif condition.operator == '<':
            return values < condition.value
        elif condition.operator == '<=':
            return values <= condition.value
        elif condition.operator == '>':
            return values > condition.value
        elif condition.operator == '>=':
            return values >= condition.value
        elif condition.operator == 'IN':
            return np.isin(values, condition.value)
        elif condition.operator == 'LIKE':
            # 字符串模式匹配
            if isinstance(condition.value, str):
                pattern = condition.value.replace('%', '.*')
                return np.array([bool(re.match(pattern, str(v))) for v in values])
        
        return np.zeros(table.row_count, dtype=bool)
    
    def _apply_aggregation(self, table: ColumnarTable, 
                          query: AnalyticalQuery) -> ColumnarTable:
        """向量化聚合"""
        if not query.group_by_columns:
            # 全表聚合
            return self._global_aggregation(table, query.aggregations)
        else:
            # 分组聚合
            return self._group_by_aggregation(table, query.group_by_columns, 
                                            query.aggregations)
    
    def _global_aggregation(self, table: ColumnarTable, 
                           aggregations: List[AggregationSpec]) -> ColumnarTable:
        """全表聚合"""
        result_table = ColumnarTable(f"{table.table_name}_aggregated")
        
        for agg in aggregations:
            column = table.get_column(agg.column_name)
            if not column:
                continue
            
            values = np.array([v for v in column.values if v is not None])
            
            if agg.function == 'COUNT':
                result_value = len(column.values) - column.statistics.null_count
            elif agg.function == 'SUM':
                result_value = np.sum(values)
            elif agg.function == 'AVG':
                result_value = np.mean(values)
            elif agg.function == 'MIN':
                result_value = np.min(values)
            elif agg.function == 'MAX':
                result_value = np.max(values)
            elif agg.function == 'STDDEV':
                result_value = np.std(values)
            else:
                continue
            
            result_column = ColumnChunk(
                column_name=f"{agg.function}_{agg.column_name}",
                data_type='DOUBLE',
                values=[result_value]
            )
            result_table.add_column(result_column)
        
        return result_table
    
    def _group_by_aggregation(self, table: ColumnarTable, 
                             group_by_columns: List[str],
                             aggregations: List[AggregationSpec]) -> ColumnarTable:
        """分组聚合"""
        # 构建分组键
        group_keys = []
        for i in range(table.row_count):
            key = tuple(table.get_column(col).values[i] for col in group_by_columns)
            group_keys.append(key)
        
        # 按组分组数据
        groups = {}
        for i, key in enumerate(group_keys):
            if key not in groups:
                groups[key] = []
            groups[key].append(i)
        
        # 对每个组执行聚合
        result_data = {}
        
        # 初始化结果列
        for col in group_by_columns:
            result_data[col] = []
        
        for agg in aggregations:
            result_data[f"{agg.function}_{agg.column_name}"] = []
        
        # 计算每个组的聚合值
        for group_key, row_indices in groups.items():
            # 添加分组键值
            for i, col in enumerate(group_by_columns):
                result_data[col].append(group_key[i])
            
            # 计算聚合值
            for agg in aggregations:
                column = table.get_column(agg.column_name)
                if column:
                    group_values = [column.values[idx] for idx in row_indices 
                                  if column.values[idx] is not None]
                    
                    if agg.function == 'COUNT':
                        agg_value = len(group_values)
                    elif agg.function == 'SUM':
                        agg_value = sum(group_values)
                    elif agg.function == 'AVG':
                        agg_value = sum(group_values) / len(group_values) if group_values else 0
                    elif agg.function == 'MIN':
                        agg_value = min(group_values) if group_values else None
                    elif agg.function == 'MAX':
                        agg_value = max(group_values) if group_values else None
                    else:
                        agg_value = None
                    
                    result_data[f"{agg.function}_{agg.column_name}"].append(agg_value)
        
        # 构建结果表
        result_table = ColumnarTable(f"{table.table_name}_grouped")
        for col_name, values in result_data.items():
            chunk = ColumnChunk(
                column_name=col_name,
                data_type='DOUBLE' if col_name not in group_by_columns else 'STRING',
                values=values
            )
            result_table.add_column(chunk)
        
        return result_table

@dataclass
class AnalyticalQuery:
    select_columns: List[str]
    where_conditions: List['FilterCondition']
    group_by_columns: List[str] = None
    aggregations: List['AggregationSpec'] = None
    order_by: List['OrderBySpec'] = None
    limit: int = None

@dataclass
class FilterCondition:
    column_name: str
    operator: str  # =, !=, <, <=, >, >=, IN, LIKE
    value: Any

@dataclass
class AggregationSpec:
    function: str  # COUNT, SUM, AVG, MIN, MAX, STDDEV
    column_name: str

@dataclass
class OrderBySpec:
    column_name: str
    ascending: bool = True

3.4 查询优化器

class ColumnarQueryOptimizer:
    def __init__(self):
        self.statistics_manager = StatisticsManager()
    
    def optimize_query(self, query: AnalyticalQuery, 
                      table_metadata: Dict) -> AnalyticalQuery:
        # 1. 谓词下推优化
        query = self._optimize_predicate_pushdown(query, table_metadata)
        
        # 2. 列剪枝优化
        query = self._optimize_column_pruning(query)
        
        # 3. 分区剪枝
        query = self._optimize_partition_pruning(query, table_metadata)
        
        # 4. 聚合优化
        query = self._optimize_aggregation(query)
        
        return query
    
    def _optimize_predicate_pushdown(self, query: AnalyticalQuery, 
                                   table_metadata: Dict) -> AnalyticalQuery:
        """谓词下推优化"""
        optimized_conditions = []
        
        for condition in query.where_conditions:
            column_stats = table_metadata.get('column_statistics', {}).get(condition.column_name)
            
            if column_stats:
                # 检查条件是否可以通过统计信息快速判断
                if self._can_skip_condition(condition, column_stats):
                    continue  # 跳过总是为真的条件
                
                if self._is_always_false(condition, column_stats):
                    # 如果条件总是为假,返回空结果
                    return self._create_empty_result_query()
            
            optimized_conditions.append(condition)
        
        query.where_conditions = optimized_conditions
        return query
    
    def _optimize_column_pruning(self, query: AnalyticalQuery) -> AnalyticalQuery:
        """列剪枝优化"""
        required_columns = set(query.select_columns)
        
        # 添加WHERE条件中使用的列
        for condition in query.where_conditions:
            required_columns.add(condition.column_name)
        
        # 添加GROUP BY列
        if query.group_by_columns:
            required_columns.update(query.group_by_columns)
        
        # 添加聚合列
        if query.aggregations:
            for agg in query.aggregations:
                required_columns.add(agg.column_name)
        
        # 添加ORDER BY列
        if query.order_by:
            for order_spec in query.order_by:
                required_columns.add(order_spec.column_name)
        
        # 更新查询只包含必需的列
        query.select_columns = list(required_columns)
        return query
    
    def _optimize_partition_pruning(self, query: AnalyticalQuery, 
                                  table_metadata: Dict) -> AnalyticalQuery:
        """分区剪枝优化"""
        partition_columns = table_metadata.get('partition_columns', [])
        
        if not partition_columns:
            return query
        
        # 分析WHERE条件中的分区列过滤
        partition_filters = {}
        for condition in query.where_conditions:
            if condition.column_name in partition_columns:
                partition_filters[condition.column_name] = condition
        
        # 基于分区过滤条件,确定需要扫描的分区
        if partition_filters:
            query.partition_filters = partition_filters
        
        return query
    
    def _can_skip_condition(self, condition: FilterCondition, 
                           column_stats: ColumnStatistics) -> bool:
        """检查条件是否可以跳过"""
        if condition.operator == '>':
            return condition.value < column_stats.min_value
        elif condition.operator == '<':
            return condition.value > column_stats.max_value
        elif condition.operator == '=':
            return (condition.value < column_stats.min_value or 
                   condition.value > column_stats.max_value)
        
        return False
    
    def _is_always_false(self, condition: FilterCondition, 
                        column_stats: ColumnStatistics) -> bool:
        """检查条件是否总是为假"""
        if condition.operator == '<':
            return condition.value <= column_stats.min_value
        elif condition.operator == '>':
            return condition.value >= column_stats.max_value
        
        return False

class StatisticsManager:
    def __init__(self):
        self.column_statistics = {}
    
    def collect_statistics(self, table: ColumnarTable):
        """收集表统计信息"""
        for column_name, chunk in table.columns.items():
            self.column_statistics[column_name] = chunk.statistics
    
    def get_column_statistics(self, column_name: str) -> Optional[ColumnStatistics]:
        return self.column_statistics.get(column_name)
    
    def estimate_selectivity(self, condition: FilterCondition) -> float:
        """估算条件的选择性"""
        column_stats = self.get_column_statistics(condition.column_name)
        
        if not column_stats:
            return 0.5  # 默认选择性
        
        if condition.operator == '=':
            return 1.0 / column_stats.distinct_count
        elif condition.operator in ['<', '<=', '>', '>=']:
            # 基于最小值和最大值估算范围选择性
            if isinstance(condition.value, (int, float)):
                range_size = column_stats.max_value - column_stats.min_value
                if range_size > 0:
                    if condition.operator in ['<', '<=']:
                        return (condition.value - column_stats.min_value) / range_size
                    else:
                        return (column_stats.max_value - condition.value) / range_size
        
        return 0.5  # 默认选择性

3.5 分区管理

class PartitionManager:
    def __init__(self, storage_path: str):
        self.storage_path = storage_path
        self.partition_metadata = {}
    
    def create_partitioned_table(self, table: ColumnarTable, 
                                partition_columns: List[str]) -> Dict[str, str]:
        """创建分区表"""
        partitions = {}
        
        # 按分区列分组数据
        partition_groups = self._group_by_partition_keys(table, partition_columns)
        
        # 为每个分区创建文件
        for partition_key, partition_data in partition_groups.items():
            partition_name = self._generate_partition_name(table.table_name, partition_key)
            partition_table = self._create_partition_table(partition_name, partition_data)
            
            # 存储分区
            storage_engine = ColumnarStorageEngine(self.storage_path)
            partition_path = storage_engine.write_table(partition_table)
            partitions[partition_key] = partition_path
        
        # 更新分区元数据
        self.partition_metadata[table.table_name] = {
            'partition_columns': partition_columns,
            'partitions': partitions
        }
        
        return partitions
    
    def _group_by_partition_keys(self, table: ColumnarTable, 
                               partition_columns: List[str]) -> Dict[tuple, Dict]:
        """按分区键分组数据"""
        partition_groups = {}
        
        for i in range(table.row_count):
            # 构建分区键
            partition_key = tuple(
                table.get_column(col).values[i] for col in partition_columns
            )
            
            if partition_key not in partition_groups:
                partition_groups[partition_key] = {col: [] for col in table.columns.keys()}
            
            # 添加行数据到对应分区
            for column_name, chunk in table.columns.items():
                partition_groups[partition_key][column_name].append(chunk.values[i])
        
        return partition_groups
    
    def query_partitions(self, table_name: str, 
                        partition_filters: Dict[str, FilterCondition]) -> List[str]:
        """根据分区过滤条件查询相关分区"""
        if table_name not in self.partition_metadata:
            return []
        
        metadata = self.partition_metadata[table_name]
        partition_columns = metadata['partition_columns']
        
        matching_partitions = []
        
        for partition_key, partition_path in metadata['partitions'].items():
            # 检查分区是否匹配过滤条件
            matches = True
            for i, column in enumerate(partition_columns):
                if column in partition_filters:
                    condition = partition_filters[column]
                    partition_value = partition_key[i]
                    
                    if not self._evaluate_partition_condition(partition_value, condition):
                        matches = False
                        break
            
            if matches:
                matching_partitions.append(partition_path)
        
        return matching_partitions
    
    def _evaluate_partition_condition(self, partition_value: Any, 
                                    condition: FilterCondition) -> bool:
        """评估分区条件"""
        if condition.operator == '=':
            return partition_value == condition.value
        elif condition.operator == '!=':
            return partition_value != condition.value
        elif condition.operator == '<':
            return partition_value < condition.value
        elif condition.operator == '<=':
            return partition_value <= condition.value
        elif condition.operator == '>':
            return partition_value > condition.value
        elif condition.operator == '>=':
            return partition_value >= condition.value
        elif condition.operator == 'IN':
            return partition_value in condition.value
        
        return True

4. 性能优化

4.1 缓存策略

class ColumnarCache:
    def __init__(self, max_memory_mb: int = 1024):
        self.max_memory = max_memory_mb * 1024 * 1024
        self.current_memory = 0
        self.cache = {}  # (table_name, column_name) -> ColumnChunk
        self.access_times = {}
        self.column_sizes = {}
    
    def get_column(self, table_name: str, column_name: str) -> Optional[ColumnChunk]:
        cache_key = (table_name, column_name)
        
        if cache_key in self.cache:
            self.access_times[cache_key] = time.time()
            return self.cache[cache_key]
        
        return None
    
    def put_column(self, table_name: str, column_chunk: ColumnChunk):
        cache_key = (table_name, column_chunk.column_name)
        
        # 估算列大小
        column_size = self._estimate_column_size(column_chunk)
        
        # 检查是否需要淘汰
        while self.current_memory + column_size > self.max_memory and self.cache:
            self._evict_lru_column()
        
        # 添加到缓存
        self.cache[cache_key] = column_chunk
        self.column_sizes[cache_key] = column_size
        self.access_times[cache_key] = time.time()
        self.current_memory += column_size
    
    def _evict_lru_column(self):
        # 找到最久未使用的列
        lru_key = min(self.access_times.keys(), key=lambda k: self.access_times[k])
        
        # 从缓存中移除
        del self.cache[lru_key]
        self.current_memory -= self.column_sizes[lru_key]
        del self.column_sizes[lru_key]
        del self.access_times[lru_key]
    
    def _estimate_column_size(self, column_chunk: ColumnChunk) -> int:
        # 简单估算列的内存大小
        return len(column_chunk.values) * 8  # 假设每个值8字节

5. 总结

列式存储的设计重点在于优化分析查询性能、提高数据压缩比和支持向量化执行。通过列式存储格式、专门的编码压缩算法和查询优化技术,可以构建出高性能的分析型数据库系统。

关键设计要点:

  • 列式存储格式优化分析查询
  • 多种编码和压缩算法提高存储效率
  • 向量化执行引擎提升查询性能
  • 智能查询优化器减少数据扫描
  • 分区策略支持大规模数据处理
  • 缓存机制优化热点数据访问
  • 统计信息驱动的查询优化

🎯 场景引入

你打开App,

你打开手机准备使用设计列式存储服务。看似简单的操作背后,系统面临三大核心挑战:

  • 挑战一:高并发——如何在百万级 QPS 下保持低延迟?
  • 挑战二:高可用——如何在节点故障时保证服务不中断?
  • 挑战三:数据一致性——如何在分布式环境下保证数据正确?

📈 容量估算

假设 DAU 1000 万,人均日请求 50 次

指标数值
数据总量10 TB+
日写入量~100 GB
写入 TPS~5 万/秒
读取 QPS~20 万/秒
P99 读延迟< 10ms
节点数10-50
副本因子3

❓ 高频面试问题

Q1:列式存储的核心设计原则是什么?

参考正文中的架构设计部分,核心原则包括:高可用(故障自动恢复)、高性能(低延迟高吞吐)、可扩展(水平扩展能力)、一致性(数据正确性保证)。面试时需结合具体场景展开。

Q2:列式存储在大规模场景下的主要挑战是什么?

  1. 性能瓶颈:随着数据量和请求量增长,单节点无法承载;2) 一致性:分布式环境下的数据一致性保证;3) 故障恢复:节点故障时的自动切换和数据恢复;4) 运维复杂度:集群管理、监控、升级。

Q3:如何保证列式存储的高可用?

  1. 多副本冗余(至少 3 副本);2) 自动故障检测和切换(心跳 + 选主);3) 数据持久化和备份;4) 限流降级(防止雪崩);5) 多机房/多活部署。

Q4:列式存储的性能优化有哪些关键手段?

  1. 缓存(减少重复计算和 IO);2) 异步处理(非关键路径异步化);3) 批量操作(减少网络往返);4) 数据分片(并行处理);5) 连接池复用。

Q5:列式存储与同类方案相比有什么优劣势?

参考方案对比表格。选型时需考虑:团队技术栈、数据规模、延迟要求、一致性需求、运维成本。没有银弹,需根据业务场景权衡取舍。



| 方案一 | 简单实现 | 低 | 适合小规模 | | 方案二 | 中等复杂度 | 中 | 适合中等规模 | | 方案三 | 高复杂度 ⭐推荐 | 高 | 适合大规模生产环境 |

🚀 架构演进路径

阶段一:单机版 MVP(用户量 < 10 万)

  • 单体应用 + 单机数据库
  • 功能验证优先,快速迭代
  • 适用场景:产品早期验证

阶段二:基础版分布式(用户量 10 万 - 100 万)

  • 应用层水平扩展(无状态服务 + 负载均衡)
  • 数据库主从分离(读写分离)
  • 引入 Redis 缓存热点数据
  • 适用场景:业务增长期

阶段三:生产级高可用(用户量 > 100 万)

  • 微服务拆分,独立部署和扩缩容
  • 数据库分库分表(按业务维度分片)
  • 引入消息队列解耦异步流程
  • 多机房部署,异地容灾
  • 全链路监控 + 自动化运维

✅ 架构设计检查清单

检查项状态说明
高可用多副本部署,自动故障转移,99.9% SLA
可扩展无状态服务水平扩展,数据层分片
数据一致性核心路径强一致,非核心最终一致
安全防护认证授权 + 加密 + 审计日志
监控告警Metrics + Logging + Tracing 三支柱
容灾备份多机房部署,定期备份,RPO < 1 分钟
性能优化多级缓存 + 异步处理 + 连接池
灰度发布支持按用户/地域灰度,快速回滚

⚖️ 关键 Trade-off 分析

🔴 Trade-off 1:一致性 vs 可用性

  • 强一致(CP):适用于金融交易等不能出错的场景
  • 高可用(AP):适用于社交动态等允许短暂不一致的场景
  • 本系统选择:核心路径强一致,非核心路径最终一致

🔴 Trade-off 2:同步 vs 异步

  • 同步处理:延迟低但吞吐受限,适用于核心交互路径
  • 异步处理:吞吐高但增加延迟,适用于后台计算
  • 本系统选择:核心路径同步,非核心路径异步