智能测试数据生成:用GAN合成更真实的业务数据

12 阅读14分钟

关注 霍格沃兹测试学院公众号,回复「资料」, 领取人工智能测试开发技术合集

一、当测试数据遇到天花板

还记得上个月我们遇到的那个棘手问题吗?金融系统迁移测试需要10万条客户交易数据,但合规部门明确禁止使用生产数据,哪怕脱敏也不行。开发团队用规则引擎造的数据又太“完美”——所有交易金额都是整齐的倍数,时间戳均匀分布,用户行为像是同一个模子刻出来的。结果呢?测试倒是通过了,上线后第一周就冒出三个边界条件bug。

这就是传统测试数据生成的困境:要么风险太大,要么真实性不足。今天,我想分享一个我们团队摸索半年的解决方案——用生成对抗网络(GAN)合成既安全又真实的业务数据。

二、为什么是GAN?不仅仅是技术时髦

你可能在想:数据生成方法那么多,为什么偏偏选GAN?

我们试过传统方法:规则模板、概率分布、甚至基于真实数据的变异。但总绕不开两个核心问题:

  1. 业务规则保持困难:造出来的数据单个字段看挺合理,组合起来却违反业务逻辑(比如18岁用户有30年信用卡历史)
  2. 数据关联性丢失:用户属性、行为、时间序列之间的复杂关联被简化

GAN的优势在于它能学习数据背后的真实分布,包括那些我们都没意识到的隐藏模式。举个例子,我们发现在真实电商数据中,凌晨购物的用户更可能选择货到付款,这个模式连产品经理都没总结过,但GAN生成的数据却保留了这一特性。

三、从零搭建你的第一个数据GAN

3.1 环境准备:少走弯路的配置

# requirements.txt
torch==1.9.0
pandas==1.3.0
scikit-learn==0.24.2
numpy==1.21.0
matplotlib==3.4.2
sdv==0.13.0  # 合成数据评估工具

# 硬件建议
"""
- 至少8GB RAM(处理百万级记录时需要16GB+)
- 支持CUDA的GPU不是必须,但能让训练快5-10倍
- 磁盘空间:原始数据的3-5倍(用于缓存和中间结果)
"""

3.2 数据预处理:比模型选择更重要的一步

import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler, OneHotEncoder

class DataPreprocessor:
    def __init__(self, categorical_threshold=10):
        """
        categorical_threshold: 唯一值少于这个数的视为分类变量
        """
        self.categorical_columns = []
        self.numerical_columns = []
        self.encoders = {}
        
    def analyze_columns(self, df):
        """智能识别列类型"""
        for col in df.columns:
            # 处理日期时间列
            if df[col].dtype == 'datetime64[ns]':
                # 提取时间特征,保留周期模式
                df[f'{col}_year'] = df[col].dt.year
                df[f'{col}_month'] = df[col].dt.month
                df[f'{col}_day'] = df[col].dt.day
                df[f'{col}_weekday'] = df[col].dt.weekday
                df[f'{col}_hour'] = df[col].dt.hour
                self.numerical_columns.extend([
                    f'{col}_year'f'{col}_month'f'{col}_day',
                    f'{col}_weekday'f'{col}_hour'
                ])
                
            # 分类变量识别
            elif df[col].nunique() < self.categorical_threshold:
                self.categorical_columns.append(col)
                
            # 数值变量
            elif pd.api.types.is_numeric_dtype(df[col]):
                # 处理负值和零值(特别是金额类字段)
                if df[col].min() <= 0:
                    # 对数变换前处理非正值
                    df[col] = df[col] - df[col].min() + 1
                self.numerical_columns.append(col)
                
        return df
    
    def fit_transform(self, df):
        """训练预处理管道"""
        df = self.analyze_columns(df.copy())
        
        # 处理分类变量
        transformed_data = []
        for col in self.categorical_columns:
            encoder = OneHotEncoder(sparse=False, handle_unknown='ignore')
            encoded = encoder.fit_transform(df[[col]])
            self.encoders[col] = encoder
            
            # 创建编码后的列名
            for i, category in enumerate(encoder.categories_[0]):
                df[f'{col}_{category}'] = encoded[:, i]
                
        # 处理数值变量(保留原始值供后续反变换)
        self.scaler = StandardScaler()
        df[self.numerical_columns] = self.scaler.fit_transform(
            df[self.numerical_columns]
        )
        
        return df

3.3 核心模型:Conditional Tabular GAN (CTGAN)

我们选择CTGAN而不是原始GAN,因为它能更好地处理混合数据类型(数值+分类):

import torch
import torch.nn as nn
import torch.optim as optim

class Generator(nn.Module):
    """生成器:从噪声生成合成数据"""
    def __init__(self, input_dim, output_dim, hidden_dim=128):
        super().__init__()
        
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Dropout(0.1),
            
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim * 2),
            nn.Dropout(0.1),
            
            nn.Linear(hidden_dim * 2, hidden_dim * 4),
            nn.ReLU(),
            nn.BatchNorm1d(hidden_dim * 4),
            
            nn.Linear(hidden_dim * 4, output_dim),
            nn.Tanh()  # 输出归一化到[-1, 1]
        )
        
    def forward(self, z, conditions=None):
        if conditions isnotNone:
            z = torch.cat([z, conditions], dim=1)
        return self.net(z)

class Discriminator(nn.Module):
    """判别器:区分真实和生成数据"""
    def __init__(self, input_dim, hidden_dim=128):
        super().__init__()
        
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim * 4),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(hidden_dim * 4, hidden_dim * 2),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LeakyReLU(0.2),
            
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        return self.net(x)

class CTGAN:
    def __init__(self, generator, discriminator, device='cuda'):
        self.generator = generator.to(device)
        self.discriminator = discriminator.to(device)
        self.device = device
        
        # 使用不同的学习率(通常判别器学得更快)
        self.g_optimizer = optim.Adam(
            generator.parameters(), lr=2e-4, betas=(0.50.999)
        )
        self.d_optimizer = optim.Adam(
            discriminator.parameters(), lr=2e-4, betas=(0.50.999)
        )
        
        self.criterion = nn.BCELoss()
        
    def train_step(self, real_data, conditions=None):
        batch_size = real_data.size(0)
        
        # 真实和假标签
        real_labels = torch.ones(batch_size, 1).to(self.device)
        fake_labels = torch.zeros(batch_size, 1).to(self.device)
        
        # ---------------------
        # 训练判别器
        # ---------------------
        self.d_optimizer.zero_grad()
        
        # 真实数据的损失
        real_output = self.discriminator(real_data)
        d_real_loss = self.criterion(real_output, real_labels)
        
        # 生成假数据
        z = torch.randn(batch_size, self.generator.input_dim).to(self.device)
        fake_data = self.generator(z, conditions)
        
        # 假数据的损失
        fake_output = self.discriminator(fake_data.detach())
        d_fake_loss = self.criterion(fake_output, fake_labels)
        
        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        self.d_optimizer.step()
        
        # ---------------------
        # 训练生成器
        # ---------------------
        self.g_optimizer.zero_grad()
        
        # 让生成的数据尽可能骗过判别器
        fake_output = self.discriminator(fake_data)
        g_loss = self.criterion(fake_output, real_labels)
        
        # 添加模式正则化(防止模式坍塌)
        g_loss += self._mode_regularization(fake_data)
        
        g_loss.backward()
        self.g_optimizer.step()
        
        return {
            'd_loss': d_loss.item(),
            'g_loss': g_loss.item(),
            'real_score': real_output.mean().item(),
            'fake_score': fake_output.mean().item()
        }
    
    def _mode_regularization(self, fake_data, lambda_mr=0.1):
        """模式正则化:鼓励生成样本的多样性"""
        # 计算批次内样本的相似度
        batch_size = fake_data.size(0)
        if batch_size < 2:
            return0
        
        # 随机选择两个样本计算相似度
        idx1 = torch.randint(0, batch_size, (1,)).item()
        idx2 = torch.randint(0, batch_size, (1,)).item()
        
        similarity = torch.cosine_similarity(
            fake_data[idx1].unsqueeze(0),
            fake_data[idx2].unsqueeze(0)
        )
        
        # 惩罚过高的相似度
        return lambda_mr * torch.relu(similarity - 0.5)

3.4 训练技巧:我们踩过的那些坑

class GANTrainer:
    def __init__(self, model, epochs=1000, batch_size=64):
        self.model = model
        self.epochs = epochs
        self.batch_size = batch_size
        self.losses = []
        
    def train(self, data_loader, conditions_loader=None):
        for epoch in range(self.epochs):
            epoch_d_loss = 0
            epoch_g_loss = 0
            
            for i, real_data in enumerate(data_loader):
                conditions = None
                if conditions_loader:
                    conditions = next(conditions_loader)
                
                # 渐进式训练:前10轮只训练判别器
                if epoch < 10:
                    self.model.d_optimizer.zero_grad()
                    # ... 仅判别器训练代码
                else:
                    metrics = self.model.train_step(real_data, conditions)
                    epoch_d_loss += metrics['d_loss']
                    epoch_g_loss += metrics['g_loss']
                
                # 每100轮降低学习率
                if epoch % 100 == 0and epoch > 0:
                    self._adjust_lr(epoch)
                
                # 防止判别器过强
                if metrics.get('real_score'0) > 0.9:
                    # 跳过一次生成器训练
                    pass
            
            # 保存检查点
            if epoch % 50 == 0:
                self._save_checkpoint(epoch)
                
            # 早停判断
            if self._early_stop(epoch):
                print(f"早停于第{epoch}轮")
                break
    
    def _adjust_lr(self, epoch):
        """学习率调整策略"""
        for param_group in self.model.g_optimizer.param_groups:
            param_group['lr'] *= 0.95
        for param_group in self.model.d_optimizer.param_groups:
            param_group['lr'] *= 0.95
    
    def _early_stop(self, epoch, patience=50):
        """验证集性能不再提升时停止"""
        if epoch < 100:  # 前100轮不早停
            returnFalse
        
        # 检查最近patience轮的生成质量
        recent_losses = self.losses[-patience:]
        if len(recent_losses) < patience:
            returnFalse
            
        # 如果损失不再下降
        if np.std(recent_losses) < 1e-5:
            returnTrue
            
        returnFalse

四、数据评估:如何判断生成数据的好坏?

生成数据不能只看损失函数,我们建立了三层评估体系:

4.1 统计相似度评估

from scipy import stats
from sdv.metrics import CSTest, KSTest

class DataEvaluator:
    def evaluate_statistical_similarity(self, real_df, synthetic_df):
        """统计属性相似度"""
        results = {}
        
        # 1. 分布相似性(KS检验)
        for col in real_df.columns:
            if pd.api.types.is_numeric_dtype(real_df[col]):
                stat, p_value = stats.ks_2samp(
                    real_df[col].dropna(),
                    synthetic_df[col].dropna()
                )
                results[f'ks_{col}'] = {'statistic': stat, 'p_value': p_value}
        
        # 2. 相关性保持度
        real_corr = real_df.corr().abs().mean().mean()
        synth_corr = synthetic_df.corr().abs().mean().mean()
        results['correlation_preservation'] = 1 - abs(real_corr - synth_corr)
        
        # 3. 类别比例保持
        for col in real_df.select_dtypes(include=['object']).columns:
            real_props = real_df[col].value_counts(normalize=True)
            synth_props = synthetic_df[col].value_counts(normalize=True)
            
            # 对齐类别(可能生成数据有新的类别)
            all_cats = set(real_props.index) | set(synth_props.index)
            for cat in all_cats:
                real_val = real_props.get(cat, 0)
                synth_val = synth_props.get(cat, 0)
                results[f'prop_{col}_{cat}'] = abs(real_val - synth_val)
        
        return results

4.2 业务规则保持评估

class BusinessRuleValidator:
    def __init__(self, rules_config):
        """
        rules_config示例:
        {
            'age_income_rule': {
                'condition': 'age < 18',
                'constraint': 'annual_income < 10000'
            },
            'transaction_limit': {
                'condition': 'account_type == "basic"',
                'constraint': 'transaction_amount <= 5000'
            }
        }
        """
        self.rules = rules_config
    
    def validate(self, df):
        violations = {}
        
        for rule_name, rule in self.rules.items():
            condition_mask = df.eval(rule['condition'])
            constrained_data = df[condition_mask]
            
            violation_mask = constrained_data.eval(rule['constraint'])
            violation_count = (~violation_mask).sum()
            
            violations[rule_name] = {
                'total_affected': len(constrained_data),
                'violations': int(violation_count),
                'violation_rate': violation_count / max(len(constrained_data), 1)
            }
        
        return violations

4.3 机器学习效用评估

def evaluate_ml_utility(real_df, synthetic_df, target_column):
    """
    用生成数据训练的模型,在真实数据上测试效果
    如果性能相近,说明生成数据保留了预测模式
    """
    from sklearn.model_selection import train_test_split
    from sklearn.ensemble import RandomForestClassifier
    from sklearn.metrics import f1_score
    
    # 用真实数据划分训练测试集
    X_real = real_df.drop(columns=[target_column])
    y_real = real_df[target_column]
    X_train_real, X_test_real, y_train_real, y_test_real = train_test_split(
        X_real, y_real, test_size=0.3, random_state=42
    )
    
    # 用生成数据训练
    X_synth = synthetic_df.drop(columns=[target_column])
    y_synth = synthetic_df[target_column]
    
    # 训练两个模型
    model_real = RandomForestClassifier(n_estimators=100)
    model_real.fit(X_train_real, y_train_real)
    
    model_synth = RandomForestClassifier(n_estimators=100)
    model_synth.fit(X_synth, y_synth)
    
    # 在真实测试集上评估
    real_on_real = f1_score(y_test_real, model_real.predict(X_test_real))
    synth_on_real = f1_score(y_test_real, model_synth.predict(X_test_real))
    
    return {
        'real_data_performance': real_on_real,
        'synthetic_data_performance': synth_on_real,
        'performance_gap': abs(real_on_real - synth_on_real)
    }

4.4 隐私泄露检测

class PrivacyChecker:
    def detect_memorization(self, real_df, synthetic_df, threshold=0.01):
        """
        检测生成数据是否记忆了真实数据
        返回可能泄露的记录
        """
        suspicious_records = []
        
        for _, synth_row in synthetic_df.iterrows():
            # 计算与每个真实记录的相似度
            similarities = []
            for _, real_row in real_df.iterrows():
                sim = self._record_similarity(synth_row, real_row)
                similarities.append(sim)
            
            # 如果与某个真实记录过于相似
            if max(similarities) > threshold:
                idx = np.argmax(similarities)
                suspicious_records.append({
                    'synthetic_index': _,
                    'real_index': idx,
                    'similarity': max(similarities)
                })
        
        return suspicious_records
    
    def _record_similarity(self, row1, row2):
        """计算两条记录的相似度"""
        # 数值字段使用相对误差
        num_cols = row1.select_dtypes(include=[np.number]).index
        num_sim = 0
        for col in num_cols:
            if row1[col] == 0and row2[col] == 0:
                num_sim += 1
            else:
                num_sim += 1 - abs(row1[col] - row2[col]) / (abs(row1[col]) + abs(row2[col]) + 1e-10)
        num_sim /= len(num_cols) if num_cols else1
        
        # 分类字段使用精确匹配
        cat_cols = row1.select_dtypes(include=['object']).index
        cat_sim = sum(row1[col] == row2[col] for col in cat_cols)
        cat_sim /= len(cat_cols) if cat_cols else1
        
        return (num_sim + cat_sim) / 2

人工智能技术学习交流群

伙伴们,对AI测试、大模型评测、质量保障感兴趣吗?我们建了一个 「人工智能测试开发交流群」,专门用来探讨相关技术、分享资料、互通有无。无论你是正在实践还是好奇探索,都欢迎扫码加入,一起抱团成长!期待与你交流!👇

image.png

五、实战:生成电商测试数据

让我们看一个完整例子:

# config.yaml
data_config:
  source: "data/real_transactions.csv"
  row_limit: 50000# 使用5万条真实数据训练
  test_size: 10000# 生成1万条测试数据

model_config:
  epochs: 2000
  batch_size: 256
  latent_dim: 100
  hidden_dim: 512

business_rules:
  - name: "会员等级购买力"
    condition: "会员等级 == '普通'"
    constraint: "订单金额 <= 1000"
  - name: "退货时间限制"
    condition: "退货标记 == True"
    constraint: "下单时间 - 发货时间 <= 30天"

# main.py
def generate_ecommerce_data():
    # 1. 加载和预处理
    preprocessor = DataPreprocessor()
    real_data = pd.read_csv('data/real_transactions.csv')
    processed_data = preprocessor.fit_transform(real_data)
    
    # 2. 训练GAN
    input_dim = processed_data.shape[1]
    generator = Generator(input_dim=100, output_dim=input_dim)
    discriminator = Discriminator(input_dim=input_dim)
    
    ctgan = CTGAN(generator, discriminator)
    trainer = GANTrainer(ctgan, epochs=2000)
    
    # 创建数据加载器
    data_tensor = torch.tensor(processed_data.values, dtype=torch.float32)
    data_loader = torch.utils.data.DataLoader(
        data_tensor, batch_size=256, shuffle=True
    )
    
    trainer.train(data_loader)
    
    # 3. 生成数据
    synthetic_tensors = []
    for _ in range(10):  # 生成10批次
        z = torch.randn(1000100).to(device)
        synthetic = ctgan.generator(z)
        synthetic_tensors.append(synthetic.cpu())
    
    synthetic_data = torch.cat(synthetic_tensors, dim=0)
    
    # 4. 后处理和反标准化
    synthetic_df = pd.DataFrame(
        synthetic_data.detach().numpy(),
        columns=processed_data.columns
    )
    
    # 反变换数值列
    synthetic_df[preprocessor.numerical_columns] = preprocessor.scaler.inverse_transform(
        synthetic_df[preprocessor.numerical_columns]
    )
    
    # 反变换分类列
    for col in preprocessor.categorical_columns:
        # 从one-hot解码
        encoded_cols = [c for c in synthetic_df.columns if c.startswith(f'{col}_')]
        encoded_values = synthetic_df[encoded_cols].values
        original_values = preprocessor.encoders[col].inverse_transform(encoded_values)
        synthetic_df[col] = original_values.flatten()
        synthetic_df.drop(columns=encoded_cols, inplace=True)
    
    # 5. 数据质量检查
    evaluator = DataEvaluator()
    stats_result = evaluator.evaluate_statistical_similarity(
        real_data.sample(10000),
        synthetic_df.sample(10000)
    )
    
    validator = BusinessRuleValidator(business_rules)
    violations = validator.validate(synthetic_df)
    
    print(f"数据生成完成")
    print(f"统计相似度: {stats_result['correlation_preservation']:.3f}")
    print(f"业务规则违反率: {max(v['violation_rate'] for v in violations.values()):.3f}")
    
    return synthetic_df

六、最佳实践和坑点指南

6.1 我们总结的经验

  1. 数据量不是越多越好

    • 10万条高质量数据 > 100万条脏数据
    • 先做好数据清洗,否则GAN会学习噪声
  2. 渐进式训练策略

    • 先用简单架构,确认能收敛
    • 逐步增加网络深度和复杂度
    • 监控生成数据的多样性(警惕模式坍塌)
  3. 领域知识注入

    def inject_domain_knowledge(synthetic_df, knowledge_rules):
        """
        用业务规则修正生成数据
        比如:确保VIP用户平均消费>普通用户
        """
        for rule in knowledge_rules:
            synthetic_df = synthetic_df.eval(rule)
        return synthetic_df
    

6.2 常见问题及解决

问题1:生成数据缺少极端值

  • 现象:所有订单金额都在平均值附近
  • 解决:在潜在空间采样时,增加边缘区域的采样概率

问题2:类别不平衡被放大

  • 现象:罕见类别在生成数据中更罕见或完全消失
  • 解决:使用条件GAN,或对罕见类别过采样

问题3:训练不稳定

  • 现象:损失函数剧烈震荡
  • 解决:使用WGAN-GP、谱归一化等技术

七、在测试体系中的集成方案

最后,如何把这项技术落地到你们的测试体系中?

class SyntheticDataPipeline:
    def __init__(self, config_path):
        self.config = self._load_config(config_path)
        self.models = {}
        
    def generate_for_test_scenario(self, scenario_name, sample_count):
        """为不同测试场景生成定制数据"""
        scenario_config = self.config['scenarios'][scenario_name]
        
        # 加载对应模型
        if scenario_name notin self.models:
            self.models[scenario_name] = self._load_model(scenario_config['model_path'])
        
        # 条件生成
        conditions = self._create_conditions(scenario_config)
        synthetic_data = self.models[scenario_name].generate(
            sample_count, conditions=conditions
        )
        
        # 场景特定后处理
        if'post_process'in scenario_config:
            synthetic_data = self._apply_post_process(
                synthetic_data, scenario_config['post_process']
            )
        
        # 验证并输出
        self._validate_and_export(synthetic_data, scenario_name)
        
        return synthetic_data
    
    def _create_conditions(self, scenario_config):
        """创建测试场景条件"""
        conditions = {}
        if scenario_config.get('stress_test'):
            # 压力测试:生成边界值数据
            conditions['amount_range'] = ('min''max')
            conditions['concurrency'] = 'high'
        elif scenario_config.get('regression_test'):
            # 回归测试:保持分布一致性
            conditions['distribution_preservation'] = True
        return conditions

写在最后

GAN生成测试数据不是银弹,但它确实解决了一些传统方法难以解决的问题。我们在三个项目中落地了这项技术,最明显的改进是:

  1. 发现边界bug的能力提升:生成了更多真实场景中罕见的组合
  2. 测试数据准备时间减少:从平均3人天降到2小时
  3. 数据安全性保障:合规部门认可了这种生成方式

当然,技术债还是要还的。维护GAN模型需要持续投入,特别是当业务规则变化时。我们的经验是,对于核心业务场景,投资是值得的;对于简单场景,传统方法可能更经济。

这项技术还在快速发展中,最近我们在试验扩散模型(Diffusion Models)用于数据生成,初步效果显示在数据多样性上更有优势。但无论如何,记住测试数据的核心原则:不是为了数据而数据,是为了更好的测试而数据

希望这篇分享能帮你少走些弯路。在实际落地中遇到问题,欢迎随时交流。毕竟,好的测试数据,应该像好的测试用例一样——既能验证功能,也能发现未知。

推荐学习

AI智能体实战指南课程,带你从理论跃入实战前线。课程浓缩5大核心场景:从Playwright、Appium实现自动化测试,到Cursor、Codex辅助高效编码;从定制ClawdBot助理,到Dify、Coze搭建智能工作流,乃至用FFmpeg打造短视频。内容直击当下开发与运营的关键需求,助你快速掌握AI智能体落地能力,全面提升工作效率。

image.png

关于我们

霍格沃兹测试开发学社,隶属于 测吧(北京)科技有限公司,是一个面向软件测试爱好者的技术交流社区。

学社围绕现代软件测试工程体系展开,内容涵盖软件测试入门、自动化测试、性能测试、接口测试、测试开发、全栈测试,以及人工智能测试与 AI 在测试工程中的应用实践

我们关注测试工程能力的系统化建设,包括 Python 自动化测试、Java 自动化测试、Web 与 App 自动化、持续集成与质量体系建设,同时探索 AI 驱动的测试设计、用例生成、自动化执行与质量分析方法,沉淀可复用、可落地的测试开发工程经验。

在技术社区与工程实践之外,学社还参与测试工程人才培养体系建设,面向高校提供测试实训平台与实践支持,组织开展  “火焰杯” 软件测试相关技术赛事,并探索以能力为导向的人才培养模式,包括高校学员先学习、就业后付款的实践路径。

同时,学社结合真实行业需求,为在职测试工程师与高潜学员提供名企大厂 1v1 私教服务,用于个性化能力提升与工程实践指导。