Lucidrains-系列项目源码解析-二十九-

75 阅读24分钟

Lucidrains 系列项目源码解析(二十九)

.\lucidrains\electra-pytorch\examples\glue\run.py

# 设置文件编码为 UTF-8
# 版权声明,版权归 Google AI Language Team Authors 和 HuggingFace Inc. 团队所有,以及 NVIDIA 公司所有
# 根据 Apache 许可证 2.0 版本使用此文件,详细信息请访问 http://www.apache.org/licenses/LICENSE-2.0
# 除非符合许可证规定或书面同意,否则不得使用此文件
# 根据许可证规定,软件按"原样"分发,不提供任何明示或暗示的担保或条件
# 请查看许可证以获取有关特定语言的权限和限制

""" 在 GLUE 上对库模型进行序列分类微调(Bert、XLM、XLNet、RoBERTa、Albert、XLM-RoBERTa)。"""

# 导入所需的库
import argparse
import glob
import json
import logging
import os
import random

import numpy as np
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange

# 导入自定义的计算指标函数
from metrics import glue_compute_metrics as compute_metrics
# 导入数据处理函数
from processors import glue_convert_examples_to_features as convert_examples_to_features
# 导入输出模式
from processors import glue_output_modes as output_modes
# 导入处理器
from processors import glue_processors as processors
# 导入任务标签数量
from processors import glue_tasks_num_labels as task_num_labels

# 设置日志记录器
logger = logging.getLogger(__name__)

##################################################
# 适配 Google 风格的 GLUE 代码

# Tokenizer 适配器类
class TokenizerAdapter:
    def __init__(self, tokenizer, pad_token, cls_token="[CLS]", sep_token="[SEP]"):
        self.tokenizer = tokenizer
        self.pad_token = pad_token
        self.cls_token = cls_token
        self.sep_token = sep_token

    # 将 tokens 转换为 ids
    def convert_tokens_to_ids(self, tokens):
        return self.tokenizer.convert_tokens_to_ids(tokens)

    # 截断序列
    def truncate_sequences(
        self,
        ids,
        pair_ids,
        num_tokens_to_remove,
        truncation_strategy,
        stride,
    ):
        # 确保 ids 的长度大于要移除的 tokens 数量
        assert len(ids) > num_tokens_to_remove
        # 计算窗口长度
        window_len = min(len(ids), stride + num_tokens_to_remove)
        # 获取溢出的 tokens
        overflowing_tokens = ids[-window_len:]
        # 截断 ids
        ids = ids[:-num_tokens_to_remove]

        return (ids, pair_ids, overflowing_tokens)
    # 对输入文本进行编码,生成输入的 token ids 和 token type ids
    def encode_plus(self, text, text_pair, add_special_tokens, max_length, return_token_type_ids):

        # 对第一个文本进行 tokenization,转换成 token ids
        token_ids_0 = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
        len_ids = len(token_ids_0)
        # 如果有第二个文本,则对其进行 tokenization,转换成 token ids
        if text_pair:
            token_ids_1 = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text_pair))
            len_pair_ids = len(token_ids_1)
        else:
            token_ids_1 = None
            len_pair_ids = 0

 
        # 截断文本
        assert add_special_tokens
        num_special_tokens_to_add = (2 if not text_pair else 3)
        total_len = len_ids + len_pair_ids + num_special_tokens_to_add
        # 如果总长度超过最大长度,则进行截断
        if max_length and total_len > max_length:
            token_ids_0, token_ids_1, overflowing_tokens = self.truncate_sequences(
                token_ids_0,
                pair_ids=token_ids_1,
                num_tokens_to_remove=total_len - max_length,
                truncation_strategy='only_first', # TODO(nijkamp): is this the correct truncation strategy for all GLUE tasks?
                stride=0,
            )


        # 添加特殊 token
        cls = [self.tokenizer.vocab[self.cls_token]]
        sep = [self.tokenizer.vocab[self.sep_token]]

        if not text_pair:

            input_ids = cls + token_ids_0 + sep
            token_type_ids = len(cls + token_ids_0 + sep) * [0]

        else:

            input_ids = cls + token_ids_0 + sep + token_ids_1 + sep
            token_type_ids = len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]

        assert len(input_ids) <= max_length

        # 返回编码结果
        return {"input_ids": input_ids, "token_type_ids": token_type_ids}

    # 返回 tokenizer 的词汇表长度
    def __len__(self):
        return len(self.tokenizer.vocab)

    # 保存预训练模型到指定目录
    def save_pretrained(self, outputdir):
        pass
# 将给定的 tokenizer 和 pad_token 封装成 TokenizerAdapter 对象并返回
def wrap_tokenizer(tokenizer, pad_token):
    return TokenizerAdapter(tokenizer, pad_token)


##################################################
# distilled Google-like/HF glue code

# 设置随机种子,确保实验的可重复性
def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

# 创建一个学习率调度器,包括线性增加和线性减少学习率
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
    """ Create a schedule with a learning rate that decreases linearly after
    linearly increasing during a warmup period.
    """

    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(
            0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
        )

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)

# 训练模型
def train(args, train_dataset, model, tokenizer):
    """ Train the model """

    # 设置训练批次大小
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    # 创建训练数据采样器
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    # 创建训练数据加载器
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    # 计算总的训练步数
    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # 准备优化器和调度器(线性增加和减少学习率)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )

    # 检查是否存在保存的优化器或调度器状态
    if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
        os.path.join(args.model_name_or_path, "scheduler.pt")
    ):
        # 加载优化器和调度器状态
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))

    # 如果启用混合精度训练
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    # 多 GPU 训练(应在 apex fp16 初始化之后)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # 分布式训练(应在 apex fp16 初始化之后)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True,
        )

    # 开始训练
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    # 打印训练批次的总大小(包括并行、分布式和累积),根据参数计算得出
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    # 打印梯度累积步数
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    # 打印总优化步数
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # 检查是否从检查点继续训练
    if os.path.exists(args.model_name_or_path):
        # 将 global_step 设置为模型路径中最后一个保存检查点的 global_step
        try:
            global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
        except ValueError:
            global_step = 0
        epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)

        logger.info("  Continuing training from checkpoint, will skip to saved global_step")
        logger.info("  Continuing training from epoch %d", epochs_trained)
        logger.info("  Continuing training from global step %d", global_step)
        logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)

    tr_loss, logging_loss = 0.0, 0.0
    # 将模型梯度置零
    model.zero_grad()
    # 创建训练迭代器
    train_iterator = trange(
        epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0],
    )
    set_seed(args)  # 为了可重现性而添加在这里
    # 返回 global_step 和 tr_loss/global_step
    return global_step, tr_loss / global_step
def evaluate(args, model, tokenizer, prefix=""):
    # 循环处理 MNLI 双重评估(匹配,不匹配)
    eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
    eval_outputs_dirs = (args.output_dir, args.output_dir + "-MM") if args.task_name == "mnli" else (args.output_dir,)

    results = {}
    for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
        eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)

        if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(eval_output_dir)

        args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
        # 注意 DistributedSampler 会随机采样
        eval_sampler = SequentialSampler(eval_dataset)
        eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

        # 多 GPU 评估
        if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
            model = torch.nn.DataParallel(model)

        # 评估!
        logger.info("***** Running evaluation {} *****".format(prefix))
        logger.info("  Num examples = %d", len(eval_dataset))
        logger.info("  Batch size = %d", args.eval_batch_size)
        eval_loss = 0.0
        nb_eval_steps = 0
        preds = None
        out_label_ids = None
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            model.eval()
            batch = tuple(t.to(args.device) for t in batch)

            with torch.no_grad():
                inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
                if args.model_type != "distilbert":
                    inputs["token_type_ids"] = (
                        batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None
                    )  # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
                outputs = model(**inputs)
                tmp_eval_loss, logits = outputs[:2]

                eval_loss += tmp_eval_loss.mean().item()
            nb_eval_steps += 1
            if preds is None:
                preds = logits.detach().cpu().numpy()
                out_label_ids = inputs["labels"].detach().cpu().numpy()
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
                out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)

        eval_loss = eval_loss / nb_eval_steps
        if args.output_mode == "classification":
            preds = np.argmax(preds, axis=1)
            print(preds)
        elif args.output_mode == "regression":
            preds = np.squeeze(preds)
        result = compute_metrics(eval_task, preds, out_label_ids)
        results.update(result)

        output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results {} *****".format(prefix))
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key]))

    return results


def load_and_cache_examples(args, task, tokenizer, evaluate=False):
    if args.local_rank not in [-1, 0] and not evaluate:
        torch.distributed.barrier()  # 确保在分布式训练中只有第一个进程处理数据集,其他进程将使用缓存

    processor = processors[task]()
    output_mode = output_modes[task]
    # 从缓存或数据集文件加载数据特征
    cached_features_file = os.path.join(
        args.data_dir,
        "cached_{}_{}_{}_{}".format(
            "dev" if evaluate else "train",
            list(filter(None, args.model_name_or_path.split("/"))).pop(),
            str(args.max_seq_length),
            str(task),
        ),
    )
    # 检查缓存文件是否存在且不覆盖缓存时
    if os.path.exists(cached_features_file) and not args.overwrite_cache:
        # 输出日志信息,加载缓存文件中的特征
        logger.info("Loading features from cached file %s", cached_features_file)
        # 从缓存文件中加载特征数据
        features = torch.load(cached_features_file)
    else:
        # 输出日志信息,从数据集文件中创建特征
        logger.info("Creating features from dataset file at %s", args.data_dir)
        # 获取标签列表
        label_list = processor.get_labels()
        # 如果任务是 mnli 或 mnli-mm 且模型类型是 roberta 或 xlmroberta
        if task in ["mnli", "mnli-mm"] and args.model_type in ["roberta", "xlmroberta"]:
            # HACK(在 RoBERTa 预训练模型中交换标签索引)
            label_list[1], label_list[2] = label_list[2], label_list[1]
        # 获取示例数据
        examples = (
            processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
        )
        # 将示例转换为特征
        features = convert_examples_to_features(
            examples,
            tokenizer,
            label_list=label_list,
            max_length=args.max_seq_length,
            output_mode=output_mode,
            pad_on_left=False,  # 在 xlnet 中左侧填充
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
            pad_token_segment_id=0,
        )
        # 如果本地进程的索引为 -1 或 0
        if args.local_rank in [-1, 0]:
            # 输出日志信息,将特征保存到缓存文件中
            logger.info("Saving features into cached file %s", cached_features_file)
            # 将特征保存到缓存文件中
            torch.save(features, cached_features_file)

    # 如果本地进程的索引为 0 且不是评估模式
    if args.local_rank == 0 and not evaluate:
        # 确保只有分布式训练中的第一个进程处理数据集,其他进程将使用缓存
        torch.distributed.barrier()

    # 转换为张量并构建数据集
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
    # 如果输出模式是分类
    if output_mode == "classification":
        all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
    # 如果输出模��是回归
    elif output_mode == "regression":
        all_labels = torch.tensor([f.label for f in features], dtype=torch.float)

    # 构建张量数据集
    dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
    return dataset
# 定义主函数,设置默认参数task='MRPC', seed=42, ckpt='output/pretrain/2020-08-28-02-41-37/ckpt/60000'
def main(task='MRPC', seed=42, ckpt='output/pretrain/2020-08-28-02-41-37/ckpt/60000'):
    # 创建参数解析器
    parser = argparse.ArgumentParser()

    # 必需参数
    # 指定输入数据目录,应包含任务的.tsv文件(或其他数据文件)
    parser.add_argument(
        "--data_dir",
        default=f'data/glue_data/{task}',
        type=str,
        help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
    )
    # 模型类型,默认为"bert"
    parser.add_argument(
        "--model_type",
        default="bert",
        type=str,
    )
    # 模型名称或路径,默认为ckpt
    parser.add_argument(
        "--model_name_or_path",
        default=ckpt,
        type=str,
    )
    # 词汇表路径,默认为'data/vocab.txt'
    parser.add_argument(
        "--vocab_path",
        default='data/vocab.txt',
        type=str,
    )
    # 任务名称,默认为task
    parser.add_argument(
        "--task_name",
        default=task,
        type=str,
        help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
    )
    # 输出目录,默认为'output/glue'
    parser.add_argument(
        "--output_dir",
        default='output/glue',
        type=str,
        help="The output directory where the model predictions and checkpoints will be written.",
    )

    # 其他参数
    # 缓存目录,默认为空字符串
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help="Where do you want to store the pre-trained models downloaded from s3",
    )
    # 最大序列长度,默认为128
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help="The maximum total input sequence length after tokenization. Sequences longer "
        "than this will be truncated, sequences shorter will be padded.",
    )
    # 是否进行训练,默认为True
    parser.add_argument("--do_train", default=True, help="Whether to run training.")
    # 是否在开发集上进行评估,默认为True
    parser.add_argument("--do_eval", default=True, help="Whether to run eval on the dev set.")
    # 训练期间是否进行评估,默认为True
    parser.add_argument(
        "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step.",
    )
    # 是否使用小写模型,默认为True
    parser.add_argument(
        "--do_lower_case", default=True, help="Set this flag if you are using an uncased model.",
    )

    # 每个GPU/CPU的训练批次大小,默认为32
    parser.add_argument(
        "--per_gpu_train_batch_size", default=32, type=int, help="Batch size per GPU/CPU for training.",
    )
    # 每个GPU/CPU的评估批次大小,默认为8
    parser.add_argument(
        "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation.",
    )
    # 累积梯度更新的步数,默认为1
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    # Adam优化器的初始学习率,默认为2e-5
    parser.add_argument("--learning_rate", default=2e-5, type=float, help="The initial learning rate for Adam.")
    # 权重衰减,默认为0.0
    parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
    # Adam优化器��epsilon值,默认为1e-8
    parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
    # 最大梯度范数,默认为1.0
    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
    # 总训练周期数,默认为3.0
    parser.add_argument(
        "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.",
    )
    # 最大步数,默认为-1
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
    )
    # 线性预热步数,默认为0
    parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")

    # 每X次更新步骤记录一次日志,默认为500
    parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
    # 每X次更新步骤保存一次检查点,默认为500
    parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
    # 添加一个参数,用于评估所有具有与 model_name 相同前缀和以步数结尾的检查点
    parser.add_argument(
        "--eval_all_checkpoints",
        action="store_true",
        help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
    )
    # 添加一个参数,用于避免在可用时使用 CUDA
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
    # 添加一个参数,用于覆盖输出目录的内容
    parser.add_argument(
        "--overwrite_output_dir", default=True, help="Overwrite the content of the output directory",
    )
    # 添加一个参数,用于覆盖缓存的训练和评估集
    parser.add_argument(
        "--overwrite_cache", default=True, help="Overwrite the cached training and evaluation sets",
    )
    # 添加一个参数,用于初始化随机种子
    parser.add_argument("--seed", type=int, default=seed, help="random seed for initialization")

    # 添加一个参数,用于指定是否使用 16 位(混合)精度
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )
    # 添加一个参数,用于指定 fp16 的优化级别
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html",
    )
    # 添加一个参数,用于分布式训练中的本地排名
    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
    # 添加一个参数,用于远程调试的服务器 IP 地址
    parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
    # 添加一个参数,用于远程调试的服务器端口
    parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
    # 解析参数
    args = parser.parse_args()

    # 如果输出目录已存在且不为空,并且需要训练且不覆盖输出目录,则引发 ValueError
    if (
        os.path.exists(args.output_dir)
        and os.listdir(args.output_dir)
        and args.do_train
        and not args.overwrite_output_dir
    ):
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
                args.output_dir
            )
        )

    # 如果需要远程调试,则设置远程调试
    if args.server_ip and args.server_port:
        # 远程调试 - 参考 https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd

        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    # 设置 CUDA、GPU 和分布式训练
    device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    args.n_gpu = 1
    args.device = device

    # 设置日志记录
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank,
        device,
        args.n_gpu,
        bool(args.local_rank != -1),
        args.fp16,
    )

    # 设置随机种子
    set_seed(args)

    # 准备 GLUE 任务
    args.task_name = args.task_name.lower()
    if args.task_name not in processors:
        raise ValueError("Task not found: %s" % (args.task_name))
    processor = processors[args.task_name]()
    args.output_mode = output_modes[args.task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    # 加载预训练模型和分词器
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()  # 确保只有分布式训练中的第一个进程会下载模型和词汇表

    from transformers import AutoConfig, AutoModelForSequenceClassification
    args.model_type = args.model_type.lower()
    config = AutoConfig.from_pretrained(
        args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=args.task_name,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    # 从预训练模型中加载自动序列分类模型
    model = AutoModelForSequenceClassification.from_pretrained(
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )

    # 导入自定义的新标记器
    from pretraining.openwebtext.dataset import new_tokenizer
    # 使用新标记器包装标记器,并设置填充标记
    tokenizer = wrap_tokenizer(new_tokenizer(args.vocab_path), pad_token='[PAD]')

    # 如果本地进程的排名为0,则执行分布式训练中的同步操作
    if args.local_rank == 0:
        torch.distributed.barrier()  # 确保只有分布式训练中的第一个进程会下载模型和词汇表

    # 将模型移动到指定设备
    model.to(args.device)

    logger.info("Training/evaluation parameters %s", args)

    # 训练
    if args.do_train:
        # 加载并缓存训练数据集示例
        train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
        # 训练模型并获取全局步数和训练损失
        global_step, tr_loss = train(args, train_dataset, model, tokenizer)
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

    # 保存最佳实践:如果使用默认名称为模型,则可以使用from_pretrained()重新加载它
    if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        # 如果需要,创建输出目录
        if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(args.output_dir)

        logger.info("Saving model checkpoint to %s", args.output_dir)
        # 保存训练后的模型、配置和标记器使用`save_pretrained()`方法
        # 可以使用`from_pretrained()`重新加载它们
        model_to_save = (
            model.module if hasattr(model, "module") else model
        )  # 处理分布式/并行训练
        model_to_save.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)

        # 良好的实践:将训练参数与训练后的模型一起保存
        torch.save(args, os.path.join(args.output_dir, "training_args.bin"))

        # 加载已经微调的训练模型和词汇表
        model = model_to_save
        # TODO(nijkamp): 我们忽略模型序列化
        # model = AutoModelForSequenceClassification.from_pretrained(args.output_dir)
        # tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
        model.to(args.device)

    # 评估
    results = {}
    if args.do_eval and args.local_rank in [-1, 0]:
        # TODO(nijkamp): 我们忽略模型序列化
        # tokenizer = AutoTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints:
            checkpoints = list(
                os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
            )
            logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # 减少日志记录
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
            prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""

            # TODO(nijkamp): 我们忽略模型序列化
            # model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
            model.to(args.device)
            result = evaluate(args, model, tokenizer, prefix=prefix)
            result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
            results.update(result)

    return results
# 如果当前脚本被直接执行,则调用主函数
if __name__ == "__main__":
    main()

.\lucidrains\electra-pytorch\examples\glue\utils.py

# 设置文件编码为 utf-8
# 版权声明,包括作者和团队信息
# 版权声明,版权所有,保留所有权利
# 根据 Apache 许可证 2.0 版本,只有在遵守许可证的情况下才能使用此文件
# 可以在以下网址获取许可证的副本
# http://www.apache.org/licenses/LICENSE-2.0
# 除非适用法律要求或书面同意,否则按原样分发软件
# 分发的软件基于“原样”基础,没有任何明示或暗示的保证或条件
# 请查看许可证以获取特定语言的权限和限制

# 导入必要的库
import copy
import csv
import dataclasses
import json
import logging
from dataclasses import dataclass
from typing import Optional

# 定义函数 is_torch_available 和 is_tf_available
is_torch_available = lambda: True
is_tf_available = lambda: False

# 获取 logger 对象
logger = logging.getLogger(__name__)

# 定义一个数据类 InputExample,用于表示单个训练/测试示例
@dataclass(frozen=True)
class InputExample:
    """
    A single training/test example for simple sequence classification.

    Args:
        guid: Unique id for the example.
        text_a: string. The untokenized text of the first sequence. For single
            sequence tasks, only this sequence must be specified.
        text_b: (Optional) string. The untokenized text of the second sequence.
            Only must be specified for sequence pair tasks.
        label: (Optional) string. The label of the example. This should be
            specified for train and dev examples, but not for test examples.
    """

    guid: str
    text_a: str
    text_b: Optional[str] = None
    label: Optional[str] = None

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"

# 定义一个类 InputFeatures,表示单个数据特征集
class InputFeatures(object):
    """
    A single set of features of data.

    Args:
        input_ids: Indices of input sequence tokens in the vocabulary.
        attention_mask: Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            Usually  ``1`` for tokens that are NOT MASKED, ``0`` for MASKED (padded) tokens.
        token_type_ids: Segment token indices to indicate first and second portions of the inputs.
        label: Label corresponding to the input
    """

    def __init__(self, input_ids, attention_mask=None, token_type_ids=None, label=None):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.token_type_ids = token_type_ids
        self.label = label

    def __repr__(self):
        return str(self.to_json_string())

    def to_dict(self):
        """Serializes this instance to a Python dictionary."""
        output = copy.deepcopy(self.__dict__)
        return output

    def to_json_string(self):
        """Serializes this instance to a JSON string."""
        return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"

# 定义一个数据处理类 DataProcessor,用于序列分类数据集的数据转换器
class DataProcessor(object):
    """Base class for data converters for sequence classification data sets."""

    def get_example_from_tensor_dict(self, tensor_dict):
        """Gets an example from a dict with tensorflow tensors
        Args:
            tensor_dict: Keys and values should match the corresponding Glue
                tensorflow_dataset examples.
        """
        raise NotImplementedError()

    def get_train_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()
    # 将给定的示例转换为正确的格式,以适应 GLUE 数据集的要求
    def tfds_map(self, example):
        """Some tensorflow_datasets datasets are not formatted the same way the GLUE datasets are.
        This method converts examples to the correct format."""
        # 如果标签数量大于1,则将示例的标签转换为对应的标签值
        if len(self.get_labels()) > 1:
            example.label = self.get_labels()[int(example.label)]
        # 返回转换后的示例
        return example

    # 读取一个以制表符分隔的值文件
    @classmethod
    def _read_tsv(cls, input_file, quotechar=None):
        """Reads a tab separated value file."""
        # 打开文件并以 UTF-8 编码读取
        with open(input_file, "r", encoding="utf-8-sig") as f:
            # 使用 csv 模块读取文件内容,以制表符为分隔符
            return list(csv.reader(f, delimiter="\t", quotechar=quotechar))
class SingleSentenceClassificationProcessor(DataProcessor):
    """ Generic processor for a single sentence classification data set."""

    def __init__(self, labels=None, examples=None, mode="classification", verbose=False):
        # 初始化函数,设置标签、示例、模式和详细信息
        self.labels = [] if labels is None else labels
        self.examples = [] if examples is None else examples
        self.mode = mode
        self.verbose = verbose

    def __len__(self):
        # 返回示例的数量
        return len(self.examples)

    def __getitem__(self, idx):
        # 获取指定索引的示例
        if isinstance(idx, slice):
            return SingleSentenceClassificationProcessor(labels=self.labels, examples=self.examples[idx])
        return self.examples[idx]

    @classmethod
    def create_from_csv(
        cls, file_name, split_name="", column_label=0, column_text=1, column_id=None, skip_first_row=False, **kwargs
    ):
        # 从CSV文件创建处理器
        processor = cls(**kwargs)
        processor.add_examples_from_csv(
            file_name,
            split_name=split_name,
            column_label=column_label,
            column_text=column_text,
            column_id=column_id,
            skip_first_row=skip_first_row,
            overwrite_labels=True,
            overwrite_examples=True,
        )
        return processor

    @classmethod
    def create_from_examples(cls, texts_or_text_and_labels, labels=None, **kwargs):
        # 从示例创建处理器
        processor = cls(**kwargs)
        processor.add_examples(texts_or_text_and_labels, labels=labels)
        return processor

    def add_examples_from_csv(
        self,
        file_name,
        split_name="",
        column_label=0,
        column_text=1,
        column_id=None,
        skip_first_row=False,
        overwrite_labels=False,
        overwrite_examples=False,
    ):
        # 从CSV文件中添加示例
        lines = self._read_tsv(file_name)
        if skip_first_row:
            lines = lines[1:]
        texts = []
        labels = []
        ids = []
        for (i, line) in enumerate(lines):
            texts.append(line[column_text])
            labels.append(line[column_label])
            if column_id is not None:
                ids.append(line[column_id])
            else:
                guid = "%s-%s" % (split_name, i) if split_name else "%s" % i
                ids.append(guid)

        return self.add_examples(
            texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples
        )

    def add_examples(
        self, texts_or_text_and_labels, labels=None, ids=None, overwrite_labels=False, overwrite_examples=False
    ):
        # 添加示例
        assert labels is None or len(texts_or_text_and_labels) == len(labels)
        assert ids is None or len(texts_or_text_and_labels) == len(ids)
        if ids is None:
            ids = [None] * len(texts_or_text_and_labels)
        if labels is None:
            labels = [None] * len(texts_or_text_and_labels)
        examples = []
        added_labels = set()
        for (text_or_text_and_label, label, guid) in zip(texts_or_text_and_labels, labels, ids):
            if isinstance(text_or_text_and_label, (tuple, list)) and label is None:
                text, label = text_or_text_and_label
            else:
                text = text_or_text_and_label
            added_labels.add(label)
            examples.append(InputExample(guid=guid, text_a=text, text_b=None, label=label))

        # Update examples
        if overwrite_examples:
            self.examples = examples
        else:
            self.examples.extend(examples)

        # Update labels
        if overwrite_labels:
            self.labels = list(added_labels)
        else:
            self.labels = list(set(self.labels).union(added_labels))

        return self.examples

    def get_features(
        self,
        tokenizer,
        max_length=None,
        pad_on_left=False,
        pad_token=0,
        mask_padding_with_zero=True,
        return_tensors=None,

.\lucidrains\electra-pytorch\pretraining\openwebtext\arg.py

# 导入必要的模块
import argparse
import dataclasses

# 定义公开的类
__all__ = ('Arg', 'Int', 'Float', 'Bool', 'Str', 'Choice', 'parse_to')

# 定义参数类
class Arg:
    def __init__(self, **kwargs):
        super().__init__()
        self.kwargs = kwargs

# 定义整数参数类
class Int(Arg):
    def __init__(self, **kwargs):
        super().__init__(type=int, **kwargs)

# 定义浮点数参数类
class Float(Arg):
    def __init__(self, **kwargs):
        super().__init__(type=float, **kwargs)

# 定义布尔参数类
class Bool(Arg):
    def __init__(self, **kwargs):
        super().__init__(type=bool, **kwargs)

# 定义字符串参数类
class Str(Arg):
    def __init__(self, **kwargs):
        super().__init__(type=str, **kwargs)

# 定义选择参数类
class _MetaChoice(type):
    def __getitem__(self, item):
        return self(choices=list(item), type=item)

# 定义选择参数类
class Choice(Arg, metaclass=_MetaChoice):
    def __init__(self, choices, **kwargs):
        super().__init__(choices=choices, **kwargs)

# 解析参数并填充到指定的容器类中
def parse_to(container_class, **kwargs):
    # 将字段名转换为命令行参数格式
    def mangle_name(name):
        return '--' + name.replace('_', '-')

    # 创建参数解析器
    parser = argparse.ArgumentParser(description=container_class.__doc__)
    # 遍历容器类的字段
    for field in dataclasses.fields(container_class):
        name = field.name
        default = field.default
        value_or_class = field.type
        # 如果字段类型是类,则使用默认值创建实例
        if isinstance(value_or_class, type):
            value = value_or_class(default=default)
        else:
            value = value_or_class
            value.kwargs['default'] = default
        # 添加参数到参数解析器
        parser.add_argument(
            mangle_name(name), **value.kwargs)

    # 解析参数并返回填充后的容器类实例
    arg_dict = parser.parse_args(**kwargs)
    return container_class(**vars(arg_dict))

.\lucidrains\electra-pytorch\pretraining\openwebtext\dataset.py

import math
import os
import random
from dataclasses import dataclass
from itertools import chain
from functools import partial
from pathlib import Path

import numpy as np

import torch
import torch.utils.data

from openwebtext import tokenization


class ExampleBuilder:
    """Given a stream of input text, creates pretraining examples."""

    def __init__(self, vocab, max_length):
        # 初始化 ExampleBuilder 类,传入词汇表和最大长度参数
        self._vocab = vocab
        self._current_sentences = []  # 当前正在构建的例子的句子列表
        self._current_length = 0  # 当前正在构建的例子的长度
        self._max_length = max_length  # 最大长度
        self._target_length = max_length  # 目标长度

    def add_line(self, bert_tokids):
        """Adds a line of text to the current example being built."""
        # 将一行文本添加到当前正在构建的例子中
        self._current_sentences.append(bert_tokids)  # 将 BERT token ids 添加到当前句子列表中
        self._current_length += len(bert_tokids)  # 更新当前例子的长度
        if self._current_length >= self._target_length:
            return self._create_example()  # 如果当前长度达到目标长度,则创建一个例子
        return None

    def _create_example(self):
        """Creates a pre-training example from the current list of sentences."""
        # 有很小的概率只有一个段落,类似分类任务
        if random.random() < 0.1:
            first_segment_target_length = 100000
        else:
            # -3 是因为输入文本中尚未有 [CLS]/[SEP] 标记
            first_segment_target_length = (self._target_length - 3) // 2

        first_segment = []  # 第一个段落
        second_segment = []  # 第二个段落
        for sentence in self._current_sentences:
            # 如果第一个段落为空,或者加入当前句子不会超过目标长度,或者50%的概率加入当前句子会超过目标长度但第二个段落为空
            if (len(first_segment) == 0 or
                len(first_segment) + len(sentence) < first_segment_target_length or
                (len(second_segment) == 0 and
                len(first_segment) < first_segment_target_length and
                random.random() < 0.5)):
                first_segment += sentence  # 将当前句子加入第一个段落
            else:
                second_segment += sentence  # 将当前句子加入第二个段落

        # 裁剪到最大长度,考虑尚未添加的 [CLS]/[SEP] 标记
        first_segment = first_segment[:self._max_length - 2]
        second_segment = second_segment[:max(0, self._max_length - len(first_segment) - 3)]

        # 准备开始构建下一个例子
        self._current_sentences = []  # 清空当前句子列表
        self._current_length = 0  # 重置当前长度
        # 有很小的概率选择随机长度而不是最大长度
        if random.random() < 0.05:
            self._target_length = random.randint(5, self._max_length)
        else:
            self._target_length = self._max_length

        return self._make_tf_example(first_segment, second_segment)  # 创建 TF 格式的例子
    def _make_tf_example(self, first_segment, second_segment):
        """将两个文本“段”转换为tf.train.Example。"""
        # 获取词汇表
        vocab = self._vocab
        # 构建输入文本的token id序列,包括[CLS]和[SEP]标记
        input_ids = [vocab["[CLS]"]] + first_segment + [vocab["[SEP]"]
        # 初始化段落标识符,全部为0
        segment_ids = [0] * len(input_ids)
        # 如果存在第二个文本段
        if second_segment:
            # 添加第二个文本段的token id序列和段落标识符
            input_ids += second_segment + [vocab["[SEP]"]]
            segment_ids += [1] * (len(second_segment) + 1)
        # 初始化输入掩码,全部为1
        input_mask = [1] * len(input_ids)
        # 将输入文本的token id序列、输入掩码和段落标识符填充至最大长度
        input_ids += [0] * (self._max_length - len(input_ids))
        input_mask += [0] * (self._max_length - len(input_mask))
        segment_ids += [0] * (self._max_length - len(segment_ids)

        # 定义创建整数特征的函数
        def create_int_feature(tensors):
            return torch.tensor(tensors)

        # 构建tf.train.Example对象
        tf_example = {
            "input_ids": create_int_feature(input_ids),
            "input_mask": create_int_feature(input_mask),
            "segment_ids": create_int_feature(segment_ids)
        }
        return tf_example
# 定义一个继承自torch.utils.data.IterableDataset的OpenWebTextDataset类
class OpenWebTextDataset(torch.utils.data.IterableDataset):
    # 初始化方法,接收feature_set_paths和n_tensors_per_file两个参数
    def __init__(self, feature_set_paths, n_tensors_per_file):
        # 将feature_set_paths赋值给实例变量feature_set_paths
        self.feature_set_paths = feature_set_paths
        # 将n_tensors_per_file赋值给实例变量n_tensors_per_file

    # 静态方法,用于解析文件,接收file_index作为参数
    @staticmethod
    def parse_file(file_index):
        # 尝试加载文件内容为features
        try:
            features = torch.load(str(file_index))
            # 生成器,逐个返回features中的元素
            yield from features
        # 捕获RuntimeError异常
        except RuntimeError:
            # 抛出带有文件索引信息的RuntimeError异常
            raise RuntimeError(f'Corrupted file {file_index}')

    # 返回数据集的长度
    def __len__(self):
        return len(self.feature_set_paths) * self.n_tensors_per_file

    # 迭代器方法,返回一个可迭代对象
    def __iter__(self):
        # 使用map函数将parse_file应用于feature_set_paths中的每个元素,然后使用chain.from_iterable将结果展平
        return chain.from_iterable(map(self.parse_file, self.feature_set_paths))


# 定义一个继承自torch.utils.data.IterableDataset的ExampleBuilderDataset类
class ExampleBuilderDataset(torch.utils.data.IterableDataset):
    # 初始化方法,接收dataset和builder两个参数
    def __init__(self, dataset, builder):
        # 将dataset赋值给实例变量dataset
        self.dataset = dataset
        # 将builder赋值给实例变量builder

    # 返回数据集的长度
    def __len__(self):
        return len(self.dataset)

    # 迭代器方法,返回一个可迭代对象
    def __iter__(self):
        # 定义一个内部函数create_example
        def create_example():
            # 无限循环
            while True:
                # 获取下一个dataset元素,转换为CPU上的numpy数组,然后转换为列表
                token_ids = list(next(self.dataset).cpu().numpy())
                # 使用builder的add_line方法添加token_ids,如果返回了example,则返回该example
                example = self.builder.add_line(token_ids)
                if example:
                    return example

        # 无限循环
        while True:
            # 生成器,逐个返回create_example函数的结果
            yield create_example()


# 定义一个循环生成器函数cycle
def cycle(iterable):
    # 无限循环
    while True:
        # 遍历可迭代对象iterable,逐个返回元素
        for x in iterable:
            yield x


# 定义一个函数new_tokenizer,接收vocab_file和do_lower_case两个参数
def new_tokenizer(vocab_file, do_lower_case=True):
    # 返回一个FullTokenizer对象,传入vocab_file和do_lower_case参数
    return tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case)


# 定义一个函数parse_tokenizer,接收tokenizer和text两个参数
def parse_tokenizer(tokenizer, text):
    # 将text转换为token ids并返回
    return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))


# 定义一个函数create_tokenizer,接收vocab_file和do_lower_case两个参数
def create_tokenizer(vocab_file, do_lower_case=True):
    # 创建一个FullTokenizer对象
    tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case)
    # 返回一个partial对象,传入parse_tokenizer函数和tokenizer参数
    return partial(parse_tokenizer, tokenizer)


# 定义一个函数load_owt,接收owt_dir和n_tensors_per_file两个参数
def load_owt(owt_dir, n_tensors_per_file):
    # 将owt_dir转换为Path对象
    owt_dir_path = Path(owt_dir)
    # 获取owt_dir_path目录下的所有文件路径,随机打乱顺序
    feature_set_paths = [owt_dir_path / feature_set_path for feature_set_path in os.listdir(owt_dir_path)]
    np.random.shuffle(feature_set_paths)
    # 断言feature_set_paths长度大于0
    assert len(feature_set_paths) > 0
    # 返回一个OpenWebTextDataset对象,传入feature_set_paths和n_tensors_per_file参数
    return OpenWebTextDataset(feature_set_paths, n_tensors_per_file=n_tensors_per_file)


# 定义一个函数wrap_example_builder,接收dataset、vocab和max_length三个参数
def wrap_example_builder(dataset, vocab, max_length):
    # 返回一个ExampleBuilderDataset对象,传入循环生成器cycle(iter(dataset))和ExampleBuilder(vocab, max_length)参数
    return ExampleBuilderDataset(cycle(iter(dataset)), ExampleBuilder(vocab, max_length))

.\lucidrains\electra-pytorch\pretraining\openwebtext\preprocess.py

import logging
import logging
import math
import multiprocessing
import os
import random
import tarfile
from dataclasses import dataclass
from itertools import chain
from functools import partial
from pathlib import Path

import numpy as np

import torch
import torch.utils.data

from pretraining.openwebtext import arg
from pretraining.openwebtext import tokenization


logger = logging.getLogger(__name__)


def parse_tokenizer(tokenizer, text):
    return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))


def create_tokenizer(vocab_file, do_lower_case=True):
    tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case)
    return partial(parse_tokenizer, tokenizer)


def preprocess_owt(tokenizer, src_dir, tmp_dir, trg_dir, n_dataset_building_processes, n_tensors_per_file, max_seq_length=None):
    # Preamble
    logger.info(f'Writing features to {trg_dir}.')
    os.makedirs(trg_dir, exist_ok=False)

    # Crunch files
    trg_dir = Path(trg_dir)
    src_dir = Path(src_dir)
    tmp_dir = Path(tmp_dir)
    archives = os.listdir(src_dir)
    n_archives_per_job = math.ceil(len(archives) / n_dataset_building_processes)
    job_archives = [
        archives[i * n_archives_per_job : (i + 1) * n_archives_per_job]
        for i in range(n_dataset_building_processes)
    ]

    logger.info(f'Processing {len(archives)} archives.')
    assert len(archives) > 0

    if n_dataset_building_processes == 1:
        feature_set_paths = preprocess_owt_job(tokenizer, src_dir, tmp_dir, trg_dir, job_archives, n_tensors_per_file, max_seq_length, job_id=0)
    else:
        pool = multiprocessing.Pool(processes=n_dataset_building_processes)
        preprocess_owt_job_partial = partial(preprocess_owt_job, tokenizer, src_dir, tmp_dir, trg_dir, job_archives, n_tensors_per_file, max_seq_length)
        feature_sets = pool.map(preprocess_owt_job_partial, range(n_dataset_building_processes))
        feature_set_paths = [file_path for feature_set in feature_sets for file_path in feature_set]

    return feature_set_paths


def preprocess_owt_job(tokenizer, src_dir, tmp_dir, trg_dir, job_archives, n_tensors_per_file, max_seq_length, job_id=0):
    '''
    OpenWebText is saved under the following format:
    openwebtext.zip
        |-> archive_xxx.zip
            |-> file_xxx.txt
            |-> file_xxz.txt
            ...
        |-> archive_xxz.zip
            |-> file_xxy.txt
            ...
        ...
    '''

    # Preamble
    os.makedirs(tmp_dir, exist_ok=True)

    # Process
    feature_index = 0
    feature_set_paths = []
    features = []
    for archive_id, archive in enumerate(job_archives[job_id]):
        if os.path.isdir(src_dir / archive):
            logger.info(f'Ignoring rogue directory {src_dir / archive}.')
            continue

        logger.info(f'Job {job_id}: Processing {archive_id}/{len(job_archives[job_id])} {src_dir / archive}.')

        with tarfile.open(src_dir / archive) as t:
            extracted_archive = tmp_dir / f'{archive}-extracted'
            t.extractall(extracted_archive)

        for file in os.listdir(extracted_archive):
            file_path = extracted_archive / file

            with open(file_path, 'r') as f:
                for line in f.readlines():
                    line = line.strip()
                    if len(line) > 2:
                        encoding = tokenizer(line)
                        features.append(torch.tensor(encoding))

        while len(features) > n_tensors_per_file:
            feature_set_path = trg_dir / f'feature_set_{job_id}_{feature_index}.pt'
            torch.save(features[:n_tensors_per_file], feature_set_path)
            features = features[n_tensors_per_file:]
            feature_index += 1
            feature_set_paths.append(feature_set_path)

    # Serialize
    # 如果特征列表不为空
    if len(features) > 0:
        # 构建特征集路径,包含作业ID和特征索引
        feature_set_path = trg_dir / f'feature_set_{job_id}_{feature_index}.pt'
        # 使用torch保存特征到指定路径
        torch.save(features, feature_set_path)
        # 将特征集路径添加到列表中
        feature_set_paths.append(feature_set_path)

    # 返回特征集路径列表
    return feature_set_paths
# 使用 dataclass 装饰器创建一个不可变的参数类 Args,包含默认参数值
@dataclass(frozen=True)
class Args:
    src_dir: arg.Str = 'data/openwebtext'  # 源目录路径参数,默认值为'data/openwebtext'
    trg_dir: arg.Str = 'data/openwebtext_features'  # 目标目录路径参数,默认值为'data/openwebtext_features'
    tmp_dir: arg.Str = '/tmp/owt'  # 临时目录路径参数,默认值为'/tmp/owt'
    vocab_file: arg.Str = 'data/vocab.txt'  # 词汇表文件路径参数,默认值为'data/vocab.txt'
    n_dataset_building_processes: arg.Int = 32  # 数据集构建进程数参数,默认值为32
    n_tensors_per_file: arg.Int = 2048  # 每个文件的张量数参数,默认值为2048
    max_seq_length: arg.Int = 128  # 最大序列长度参数,默认值为128

# 主函数
def main():
    # 解析参数并赋值给 args
    args = arg.parse_to(Args)

    # 配置日志记录器,设置日志格式和级别
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO
    )

    # 创建分词器对象
    tokenizer = create_tokenizer(args.vocab_file)
    # 预处理 openwebtext 数据集
    preprocess_owt(tokenizer=tokenizer, src_dir=args.src_dir, tmp_dir=args.tmp_dir, trg_dir=args.trg_dir, n_dataset_building_processes=args.n_dataset_building_processes, n_tensors_per_file=args.n_tensors_per_file, max_seq_length=args.max_seq_length)

# 如果当前脚本作为主程序运行,则调用主函数
if __name__ == '__main__':
    main()

.\lucidrains\electra-pytorch\pretraining\openwebtext\pretrain.py

# 导入必要的库
import os
import sys

# 获取当前文件所在目录的绝对路径
dir_path = os.path.dirname(os.path.realpath(__file__))
# 获取当前文件所在目录的父目录的绝对路径
parent_dir_path = os.path.abspath(os.path.join(dir_path, os.pardir))
# 将父目录的路径插入到系统路径中
sys.path.insert(0, parent_dir_path)

# 导入其他必要的库
import random
import logging
from time import time
from dataclasses import dataclass

import numpy as np

import torch
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data.dataloader import DataLoader

from electra_pytorch import Electra

from openwebtext import arg
from openwebtext.dataset import load_owt, new_tokenizer, wrap_example_builder

logger = logging.getLogger(__name__)

########################################################################################################
## args

# 定义参数类
@dataclass
class Args:
    data_dir: arg.Str = 'data/openwebtext_features'
    data_vocab_file: arg.Str = 'data/vocab.txt'
    data_n_tensors_per_file: arg.Int = 2048
    data_max_seq_length: arg.Int = 128

    gpu: arg.Int = 0
    gpu_enabled: arg.Bool = True
    gpu_deterministic: arg.Bool = False
    gpu_mixed_precision: arg.Bool = False
    distributed_port: arg.Int = 8888
    distributed_enabled: arg.Bool = True
    distributed_world_size: arg.Int = 4

    model_generator: arg.Str = 'pretraining/openwebtext/small_generator.json'
    model_discriminator: arg.Str = 'pretraining/openwebtext/small_discriminator.json'
    model_mask_prob: arg.Float = 0.15

    opt_lr: arg.Float = 5e-4
    opt_batch_size: arg.Int = 128 // (distributed_world_size if distributed_enabled else 1)
    opt_warmup_steps: arg.Int = 10_000
    opt_num_training_steps: arg.Int = 200_000

    step_log: arg.Int = 10
    step_ckpt: arg.Int = 10_000


########################################################################################################
## train

# 定义训练函数
def train(rank, args):

    #######################
    ## distributed

    # 如果启用分布式训练,则初始化进程组
    if args.distributed_enabled:
        torch.distributed.init_process_group(
            backend='nccl',
            init_method='env://',
            world_size=args.distributed_world_size,
            rank=rank)
    # 如果启用 GPU,则选择对应的设备
    if args.gpu_enabled:
        device = torch.device('cuda:{}'.format(rank))
    else:
        device = torch.device('cpu')

    # 判断当前进程是否为主进程
    is_master = True if not args.distributed_enabled else args.distributed_enabled and rank == 0


    #######################
    ## preamble

    # 设置 GPU
    set_gpus(rank)
    # 设置随机种子
    set_seed(rank)
    # 设置 CUDA
    set_cuda(deterministic=args.gpu_deterministic)

    # 创建输出目录
    output_dir = f'{args.output_dir}/{rank}'
    os.makedirs(output_dir, exist_ok=False)

    # 设置日志记录
    setup_logging(filename=f'{output_dir}/output.log', console=is_master)


    #######################
    ## dataset

    # 创建分词器
    tokenizer = new_tokenizer(vocab_file=args.data_vocab_file)
    vocab_size = len(tokenizer.vocab)
    # 加载数据集
    ds_train = wrap_example_builder(dataset=load_owt(owt_dir=args.data_dir, n_tensors_per_file=args.data_n_tensors_per_file), vocab=tokenizer.vocab, max_length=args.data_max_seq_length)

    # 获取特殊标记的 ID
    pad_token_id = tokenizer.vocab['[PAD]']
    mask_token_id = tokenizer.vocab['[MASK]']
    cls_token_id = tokenizer.vocab['[CLS]']
    sep_token_id = tokenizer.vocab['[SEP]']

    # 断言特殊标记的 ID 符合预期
    assert pad_token_id == 0
    assert cls_token_id == 101
    assert sep_token_id == 102
    assert mask_token_id == 103

    # 定义数据加载函数
    def collate_batch(examples):
        input_ids = torch.nn.utils.rnn.pad_sequence([example['input_ids'] for example in examples], batch_first=True, padding_value=pad_token_id)
        input_mask = torch.nn.utils.rnn.pad_sequence([example['input_mask'] for example in examples], batch_first=True, padding_value=pad_token_id)
        segment_ids = torch.nn.utils.rnn.pad_sequence([example['segment_ids'] for example in examples], batch_first=True, padding_value=pad_token_id)
        return input_ids, input_mask, segment_ids

    # 定义数据集加载器
    def cycle(iterable):
        while True:
            for x in iterable:
                yield x

    ds_train_loader = iter(cycle(DataLoader(ds_train, batch_size=args.opt_batch_size, collate_fn=collate_batch)))


    #######################
    ## model
    # 如果分布式模式未启用,则返回原始模型;否则返回使用分布式数据并行的模型
    def to_distributed_model(model):
        return model if not args.distributed_enabled else torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], find_unused_parameters=True)

    # 将生成器和鉴别器的权重绑定在一起
    def tie_weights(generator, discriminator):
        generator.electra.embeddings.word_embeddings = discriminator.electra.embeddings.word_embeddings
        generator.electra.embeddings.position_embeddings = discriminator.electra.embeddings.position_embeddings
        generator.electra.embeddings.token_type_embeddings = discriminator.electra.embeddings.token_type_embeddings

    # 定义一个适配器类,用于调整模型输出的格式
    class LogitsAdapter(torch.nn.Module):
        def __init__(self, adaptee):
            super().__init__()
            self.adaptee = adaptee

        def forward(self, *args, **kwargs):
            return self.adaptee(*args, **kwargs)[0]

    # 导入所需的库和模型配置
    from transformers import AutoConfig, ElectraForMaskedLM, ElectraForPreTraining

    # 创建生成器和鉴别器模型
    generator = ElectraForMaskedLM(AutoConfig.from_pretrained(args.model_generator))
    discriminator = ElectraForPreTraining(AutoConfig.from_pretrained(args.model_discriminator))

    # 将生成器和鉴别器的权重绑定在一起
    tie_weights(generator, discriminator)

    # 创建分布式模型,并设置相关参数
    model = to_distributed_model(Electra(
        LogitsAdapter(generator),
        LogitsAdapter(discriminator),
        num_tokens = vocab_size,
        mask_token_id = mask_token_id,
        pad_token_id = pad_token_id,
        mask_prob = args.model_mask_prob,
        mask_ignore_token_ids = [tokenizer.vocab['[CLS]'], tokenizer.vocab['[SEP]'],
        random_token_prob = 0.0).to(device))

    #######################
    ## optimizer

    # 定义一个带有热身阶段的线性学习率调度器
    def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
        def lr_lambda(current_step):
            learning_rate = max(0.0, 1. - (float(current_step) / float(num_training_steps)))
            learning_rate *= min(1.0, float(current_step) / float(num_warmup_steps))
            return learning_rate
        return LambdaLR(optimizer, lr_lambda, last_epoch)

    # 获取不需要权重衰减的参数
    def get_params_without_weight_decay_ln(named_params, weight_decay):
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {
                'params': [p for n, p in named_params if not any(nd in n for nd in no_decay)],
                'weight_decay': weight_decay,
            },
            {
                'params': [p for n, p in named_params if any(nd in n for nd in no_decay)],
                'weight_decay': 0.0,
            },
        ]
        return optimizer_grouped_parameters

    # 创建优化器和学习率调度器
    optimizer = torch.optim.AdamW(get_params_without_weight_decay_ln(model.named_parameters(), weight_decay=0.1), lr=args.opt_lr, betas=(0.9, 0.999), eps=1e-08)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.opt_warmup_steps, num_training_steps=args.opt_num_training_steps)
    scaler = torch.cuda.amp.GradScaler(enabled=args.gpu_mixed_precision)

    #######################
    ## train

    # 记录训练开始时间,步长速度和预计完成时间
    t, steps_s, eta_m = time(), 0., 0
    # 循环执行训练步骤,包括优化器更新、梯度裁剪、学习率调整等
    for step in range(args.opt_num_training_steps+1):
        # 从训练数据加载下一个批次的输入数据
        input_ids, input_mask, segment_ids = next(ds_train_loader)

        # 将输入数据移动到指定设备上
        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)
        segment_ids = segment_ids.to(device)

        # 断言输入数据的序列长度不超过设定的最大长度
        assert input_ids.shape[1] <= args.data_max_seq_length

        # 梯度清零
        optimizer.zero_grad()

        # 使用混合精度训练,计算损失和准确率
        with torch.cuda.amp.autocast(enabled=args.gpu_mixed_precision):
            loss, loss_mlm, loss_disc, acc_gen, acc_disc, disc_labels, disc_pred = model(input_ids, attention_mask=input_mask, token_type_ids=segment_ids)

        # 反向传播并调整优化器参数
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        # 记录训练指标
        metrics = {
            'step': (step, '{:8d}'),
            'loss': (loss.item(), '{:8.5f}'),
            'loss_mlm': (loss_mlm.item(), '{:8.5f}'),
            'loss_disc': (loss_disc.item(), '{:8.5f}'),
            'acc_gen': (acc_gen.item(), '{:5.3f}'),
            'acc_disc': (acc_disc.item(), '{:5.3f}'),
            'lr': (scheduler.get_last_lr()[0], '{:8.7f}'),
            'steps': (steps_s, '{:4.1f}/s'),
            'eta': (eta_m, '{:4d}m'),
        }

        # 每隔一定步数打印训练指标信息
        if step % args.step_log == 0:
            sep = ' ' * 2
            logger.info(sep.join([f'{k}: {v[1].format(v[0])}' for (k, v) in metrics.items()])

        # 每隔一定步数计算训练速度和预计剩余时间
        if step > 0 and step % 100 == 0:
            t2 = time()
            steps_s = 100. / (t2 - t)
            eta_m = int(((args.opt_num_training_steps - step) / steps_s) // 60)
            t = t2

        # 每隔一定步数打印部分标签和预测结果
        if step % 200 == 0:
            logger.info(np.array2string(disc_labels[0].cpu().numpy(), threshold=sys.maxsize, max_line_width=sys.maxsize))
            logger.info(np.array2string(disc_pred[0].cpu().numpy(), threshold=sys.maxsize, max_line_width=sys.maxsize))

        # 每隔一定步数保存模型检查点
        if step > 0 and step % args.step_ckpt == 0 and is_master:
            discriminator.electra.save_pretrained(f'{args.output_dir}/ckpt/{step}')
# 设置程序在哪块 GPU 上运行
def set_gpus(gpu):
    torch.cuda.set_device(gpu)

# 设置随机种子
def set_seed(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    如果 CUDA 可用,设置 CUDA 随机种子
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

# 设置 CUDA 是否确定性
def set_cuda(deterministic=True):
    如果 CUDA 可用,设置 CUDA 是否确定性
    if torch.cuda.is_available():
        torch.backends.cudnn.deterministic = deterministic
        torch.backends.cudnn.benchmark = not deterministic

# 获取实验 ID
def get_exp_id(file):
    返回文件名的基本名称(不包含扩展名)
    return os.path.splitext(os.path.basename(file))[0]

# 获取输出目录
def get_output_dir(exp_id):
    导入 datetime 模块
    t = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    创建输出目录路径
    output_dir = os.path.join('output/' + exp_id, t)
    如果输出目录不存在,则创建
    os.makedirs(output_dir, exist_ok=True)
    返回输出目录路径
    return output_dir

# 设置日志记录
def setup_logging(filename, console=True):
    设置日志格式
    log_format = logging.Formatter("%(asctime)s : %(message)s")
    获取日志记录器
    logger = logging.getLogger()
    清空日志记录器的处理器
    logger.handlers = []
    创建文件处理器
    file_handler = logging.FileHandler(filename)
    设置文件处理器的格式
    file_handler.setFormatter(log_format)
    添加文件处理器到日志记录器
    logger.addHandler(file_handler)
    如果需要在控制台输出日志
    if console:
        创建控制台处理器
        console_handler = logging.StreamHandler(sys.stdout)
        设置控制台处理器的格式
        console_handler.setFormatter(log_format)
        添加控制台处理器到日志记录器
        logger.addHandler(console_handler)
        设置日志记录器的日志级别为 INFO
        logger.setLevel(logging.INFO)
    返回日志记录器
    return logger

# 复制源文件到输出目录
def copy_source(file, output_dir):
    导入 shutil 模块
    复制源文件到输出目录
    shutil.copyfile(file, os.path.join(output_dir, os.path.basename(file)))

# 主函数
def main():

    # preamble
    获取实验 ID
    exp_id = get_exp_id(__file__)
    获取输出目录
    output_dir = get_output_dir(exp_id)
    如果输出目录不存在,则创建
    os.makedirs(output_dir, exist_ok=True)
    创建检查点目录
    os.makedirs(f'{output_dir}/ckpt', exist_ok=False)
    复制源文件到输出目录
    copy_source(__file__, output_dir)

    # args
    解析命令行参数
    args = arg.parse_to(Args)
    设置输出目录和实验 ID
    args.output_dir = output_dir
    args.exp_id = exp_id

    # distributed
    如果启用分布式训练
    if args.distributed_enabled:
        设置主地址和端口
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = str(args.distributed_port)
        使用多进程方式启动训练
        torch.multiprocessing.spawn(train, nprocs=args.distributed_world_size, args=(args,))
    否则
    else:
        单机训练
        train(rank=args.gpu, args=args)

# 如果当前脚本作为主程序运行,则调用主函数
if __name__ == '__main__':
    main()

.\lucidrains\electra-pytorch\pretraining\openwebtext\tokenization.py

# 设置文件编码为 utf-8
# 版权声明
# 根据 Apache 许可证 2.0 版本授权
# 获取许可证的链接
# 根据适用法律或书面同意,软件按"原样"分发,不提供任何明示或暗示的担保或条件
# 请查看许可证以获取特定语言的权限和限制

"""Tokenization classes, the same as used for BERT."""

# 导入必要的库
import collections
import unicodedata

# 将输入文本转换为 Unicode 编码(如果尚未转换),假定输入为 utf-8
def convert_to_unicode(text):
    if isinstance(text, str):
        return text
    elif isinstance(text, bytes):
        return text.decode("utf-8", "ignore")
    else:
        raise ValueError("Unsupported string type: %s" % (type(text)))

# 返回适合打印的文本编码方式
def printable_text(text):
    if isinstance(text, str):
        return text
    elif isinstance(text, bytes):
        return text.decode("utf-8", "ignore")
    else:
        raise ValueError("Unsupported string type: %s" % (type(text)))

# 加载词汇表文件到字典中
def load_vocab(vocab_file):
    vocab = collections.OrderedDict()
    index = 0
    with open(vocab_file, "r") as reader:
        while True:
            token = convert_to_unicode(reader.readline())
            if not token:
                break
            token = token.strip()
            vocab[token] = index
            index += 1
    return vocab

# 使用词汇表将序列 [tokens|ids] 转换
def convert_by_vocab(vocab, items):
    output = []
    for item in items:
        output.append(vocab[item])
    return output

# 将 tokens 转换为 ids
def convert_tokens_to_ids(vocab, tokens):
    return convert_by_vocab(vocab, tokens)

# 将 ids 转换为 tokens
def convert_ids_to_tokens(inv_vocab, ids):
    return convert_by_vocab(inv_vocab, ids)

# 基本的空格分词函数
def whitespace_tokenize(text):
    text = text.strip()
    if not text:
        return []
    tokens = text.split()
    return tokens

# 完整的分词器类
class FullTokenizer(object):
    def __init__(self, vocab_file, do_lower_case=True):
        self.vocab = load_vocab(vocab_file)
        self.inv_vocab = {v: k for k, v in self.vocab.items()}
        self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
        self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)

    def tokenize(self, text):
        split_tokens = []
        for token in self.basic_tokenizer.tokenize(text):
            for sub_token in self.wordpiece_tokenizer.tokenize(token):
                split_tokens.append(sub_token)
        return split_tokens

    def convert_tokens_to_ids(self, tokens):
        return convert_by_vocab(self.vocab, tokens)

    def convert_ids_to_tokens(self, ids):
        return convert_by_vocab(self.inv_vocab, ids)

# 基本的分词器类
class BasicTokenizer(object):
    def __init__(self, do_lower_case=True):
        self.do_lower_case = do_lower_case
    def tokenize(self, text):
        """Tokenizes a piece of text."""
        # 将文本转换为 Unicode 格式
        text = convert_to_unicode(text)
        # 清理文本数据
        text = self._clean_text(text)

        # 为多语言和中文模型添加的功能,对中文字符进行处理
        text = self._tokenize_chinese_chars(text)

        # 使用空格分隔文本,得到原始 token 列表
        orig_tokens = whitespace_tokenize(text)
        split_tokens = []
        for token in orig_tokens:
            # 如果需要转换为小写,则将 token 转换为小写
            if self.do_lower_case:
                token = token.lower()
                # 去除 token 中的重音符号
                token = self._run_strip_accents(token)
            # 根据标点符号分割 token
            split_tokens.extend(self._run_split_on_punc(token))

        # 使用空格分隔 token 列表,得到最终的输出 token 列表
        output_tokens = whitespace_tokenize(" ".join(split_tokens))
        return output_tokens

    def _run_strip_accents(self, text):
        """Strips accents from a piece of text."""
        # 将文本中的重音符号去除
        text = unicodedata.normalize("NFD", text)
        output = []
        for char in text:
            cat = unicodedata.category(char)
            if cat == "Mn":
                continue
            output.append(char)
        return "".join(output)

    def _run_split_on_punc(self, text):
        """Splits punctuation on a piece of text."""
        chars = list(text)
        i = 0
        start_new_word = True
        output = []
        while i < len(chars):
            char = chars[i]
            if _is_punctuation(char):
                output.append([char])
                start_new_word = True
            else:
                if start_new_word:
                    output.append([])
                start_new_word = False
                output[-1].append(char)
            i += 1

        return ["".join(x) for x in output]

    def _tokenize_chinese_chars(self, text):
        """Adds whitespace around any CJK character."""
        output = []
        for char in text:
            cp = ord(char)
            if self._is_chinese_char(cp):
                output.append(" ")
                output.append(char)
                output.append(" ")
            else:
                output.append(char)
        return "".join(output)

    def _is_chinese_char(self, cp):
        """Checks whether CP is the codepoint of a CJK character."""
        # 判断字符是否为中文字符
        if ((cp >= 0x4E00 and cp <= 0x9FFF) or
                (cp >= 0x3400 and cp <= 0x4DBF) or
                (cp >= 0x20000 and cp <= 0x2A6DF) or
                (cp >= 0x2A700 and cp <= 0x2B73F) or
                (cp >= 0x2B740 and cp <= 0x2B81F) or
                (cp >= 0x2B820 and cp <= 0x2CEAF) or
                (cp >= 0xF900 and cp <= 0xFAFF) or
                (cp >= 0x2F800 and cp <= 0x2FA1F)):
            return True

        return False

    def _clean_text(self, text):
        """Performs invalid character removal and whitespace cleanup on text."""
        output = []
        for char in text:
            cp = ord(char)
            # 移除无效字符和空白字符
            if cp == 0 or cp == 0xfffd or _is_control(char):
                continue
            if _is_whitespace(char):
                output.append(" ")
            else:
                output.append(char)
        return "".join(output)
class WordpieceTokenizer(object):
    """Runs WordPiece tokenziation."""

    def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
        # 初始化 WordpieceTokenizer 类,设置词汇表、未知标记和单词最大字符数
        self.vocab = vocab
        self.unk_token = unk_token
        self.max_input_chars_per_word = max_input_chars_per_word

    def tokenize(self, text):
        """Tokenizes a piece of text into its word pieces.

        This uses a greedy longest-match-first algorithm to perform tokenization
        using the given vocabulary.

        For example:
            input = "unaffable"
            output = ["un", "##aff", "##able"]

        Args:
            text: A single token or whitespace separated tokens. This should have
                already been passed through `BasicTokenizer.

        Returns:
            A list of wordpiece tokens.
        """

        text = convert_to_unicode(text)

        output_tokens = []
        for token in whitespace_tokenize(text):
            chars = list(token)
            if len(chars) > self.max_input_chars_per_word:
                output_tokens.append(self.unk_token)
                continue

            is_bad = False
            start = 0
            sub_tokens = []
            while start < len(chars):
                end = len(chars)
                cur_substr = None
                while start < end:
                    substr = "".join(chars[start:end])
                    if start > 0:
                        substr = "##" + substr
                    if substr in self.vocab:
                        cur_substr = substr
                        break
                    end -= 1
                if cur_substr is None:
                    is_bad = True
                    break
                sub_tokens.append(cur_substr)
                start = end

            if is_bad:
                output_tokens.append(self.unk_token)
            else:
                output_tokens.extend(sub_tokens)
        return output_tokens


def _is_whitespace(char):
    """Checks whether `chars` is a whitespace character."""
    # \t, \n, and \r are technically contorl characters but we treat them
    # as whitespace since they are generally considered as such.
    if char == " " or char == "\t" or char == "\n" or char == "\r":
        return True
    cat = unicodedata.category(char)
    if cat == "Zs":
        return True
    return False


def _is_control(char):
    """Checks whether `chars` is a control character."""
    # These are technically control characters but we count them as whitespace
    # characters.
    if char == "\t" or char == "\n" or char == "\r":
        return False
    cat = unicodedata.category(char)
    if cat.startswith("C"):
        return True
    return False


def _is_punctuation(char):
    """Checks whether `chars` is a punctuation character."""
    cp = ord(char)
    # We treat all non-letter/number ASCII as punctuation.
    # Characters such as "^", "$", and "`" are not in the Unicode
    # Punctuation class but we treat them as punctuation anyways, for
    # consistency.
    if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
            (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
        return True
    cat = unicodedata.category(char)
    if cat.startswith("P"):
        return True
    return False