Lucidrains 系列项目源码解析(二十九)
.\lucidrains\electra-pytorch\examples\glue\run.py
""" 在 GLUE 上对库模型进行序列分类微调(Bert、XLM、XLNet、RoBERTa、Albert、XLM-RoBERTa)。"""
import argparse
import glob
import json
import logging
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from metrics import glue_compute_metrics as compute_metrics
from processors import glue_convert_examples_to_features as convert_examples_to_features
from processors import glue_output_modes as output_modes
from processors import glue_processors as processors
from processors import glue_tasks_num_labels as task_num_labels
logger = logging.getLogger(__name__)
class TokenizerAdapter:
def __init__(self, tokenizer, pad_token, cls_token="[CLS]", sep_token="[SEP]"):
self.tokenizer = tokenizer
self.pad_token = pad_token
self.cls_token = cls_token
self.sep_token = sep_token
def convert_tokens_to_ids(self, tokens):
return self.tokenizer.convert_tokens_to_ids(tokens)
def truncate_sequences(
self,
ids,
pair_ids,
num_tokens_to_remove,
truncation_strategy,
stride,
):
assert len(ids) > num_tokens_to_remove
window_len = min(len(ids), stride + num_tokens_to_remove)
overflowing_tokens = ids[-window_len:]
ids = ids[:-num_tokens_to_remove]
return (ids, pair_ids, overflowing_tokens)
def encode_plus(self, text, text_pair, add_special_tokens, max_length, return_token_type_ids):
token_ids_0 = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
len_ids = len(token_ids_0)
if text_pair:
token_ids_1 = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text_pair))
len_pair_ids = len(token_ids_1)
else:
token_ids_1 = None
len_pair_ids = 0
assert add_special_tokens
num_special_tokens_to_add = (2 if not text_pair else 3)
total_len = len_ids + len_pair_ids + num_special_tokens_to_add
if max_length and total_len > max_length:
token_ids_0, token_ids_1, overflowing_tokens = self.truncate_sequences(
token_ids_0,
pair_ids=token_ids_1,
num_tokens_to_remove=total_len - max_length,
truncation_strategy='only_first',
stride=0,
)
cls = [self.tokenizer.vocab[self.cls_token]]
sep = [self.tokenizer.vocab[self.sep_token]]
if not text_pair:
input_ids = cls + token_ids_0 + sep
token_type_ids = len(cls + token_ids_0 + sep) * [0]
else:
input_ids = cls + token_ids_0 + sep + token_ids_1 + sep
token_type_ids = len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
assert len(input_ids) <= max_length
return {"input_ids": input_ids, "token_type_ids": token_type_ids}
def __len__(self):
return len(self.tokenizer.vocab)
def save_pretrained(self, outputdir):
pass
def wrap_tokenizer(tokenizer, pad_token):
return TokenizerAdapter(tokenizer, pad_token)
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
""" Create a schedule with a learning rate that decreases linearly after
linearly increasing during a warmup period.
"""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(
0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
)
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)
def train(args, train_dataset, model, tokenizer):
""" Train the model """
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
if args.max_steps > 0:
t_total = args.max_steps
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
else:
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
)
if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
os.path.join(args.model_name_or_path, "scheduler.pt")
):
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))
if args.fp16:
try:
from apex import amp
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True,
)
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
logger.info(
" Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size
* args.gradient_accumulation_steps
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
)
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total)
global_step = 0
epochs_trained = 0
steps_trained_in_current_epoch = 0
if os.path.exists(args.model_name_or_path):
try:
global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
except ValueError:
global_step = 0
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", global_step)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
tr_loss, logging_loss = 0.0, 0.0
model.zero_grad()
train_iterator = trange(
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0],
)
set_seed(args)
return global_step, tr_loss / global_step
def evaluate(args, model, tokenizer, prefix=""):
eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
eval_outputs_dirs = (args.output_dir, args.output_dir + "-MM") if args.task_name == "mnli" else (args.output_dir,)
results = {}
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
os.makedirs(eval_output_dir)
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
model = torch.nn.DataParallel(model)
logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = %d", len(eval_dataset))
logger.info(" Batch size = %d", args.eval_batch_size)
eval_loss = 0.0
nb_eval_steps = 0
preds = None
out_label_ids = None
for batch in tqdm(eval_dataloader, desc="Evaluating"):
model.eval()
batch = tuple(t.to(args.device) for t in batch)
with torch.no_grad():
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
if args.model_type != "distilbert":
inputs["token_type_ids"] = (
batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None
)
outputs = model(**inputs)
tmp_eval_loss, logits = outputs[:2]
eval_loss += tmp_eval_loss.mean().item()
nb_eval_steps += 1
if preds is None:
preds = logits.detach().cpu().numpy()
out_label_ids = inputs["labels"].detach().cpu().numpy()
else:
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
eval_loss = eval_loss / nb_eval_steps
if args.output_mode == "classification":
preds = np.argmax(preds, axis=1)
print(preds)
elif args.output_mode == "regression":
preds = np.squeeze(preds)
result = compute_metrics(eval_task, preds, out_label_ids)
results.update(result)
output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(prefix))
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key]))
return results
def load_and_cache_examples(args, task, tokenizer, evaluate=False):
if args.local_rank not in [-1, 0] and not evaluate:
torch.distributed.barrier()
processor = processors[task]()
output_mode = output_modes[task]
cached_features_file = os.path.join(
args.data_dir,
"cached_{}_{}_{}_{}".format(
"dev" if evaluate else "train",
list(filter(None, args.model_name_or_path.split("/"))).pop(),
str(args.max_seq_length),
str(task),
),
)
if os.path.exists(cached_features_file) and not args.overwrite_cache:
logger.info("Loading features from cached file %s", cached_features_file)
features = torch.load(cached_features_file)
else:
logger.info("Creating features from dataset file at %s", args.data_dir)
label_list = processor.get_labels()
if task in ["mnli", "mnli-mm"] and args.model_type in ["roberta", "xlmroberta"]:
label_list[1], label_list[2] = label_list[2], label_list[1]
examples = (
processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
)
features = convert_examples_to_features(
examples,
tokenizer,
label_list=label_list,
max_length=args.max_seq_length,
output_mode=output_mode,
pad_on_left=False,
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
pad_token_segment_id=0,
)
if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)
if args.local_rank == 0 and not evaluate:
torch.distributed.barrier()
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
if output_mode == "classification":
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
elif output_mode == "regression":
all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
return dataset
def main(task='MRPC', seed=42, ckpt='output/pretrain/2020-08-28-02-41-37/ckpt/60000'):
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_dir",
default=f'data/glue_data/{task}',
type=str,
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
)
parser.add_argument(
"--model_type",
default="bert",
type=str,
)
parser.add_argument(
"--model_name_or_path",
default=ckpt,
type=str,
)
parser.add_argument(
"--vocab_path",
default='data/vocab.txt',
type=str,
)
parser.add_argument(
"--task_name",
default=task,
type=str,
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
)
parser.add_argument(
"--output_dir",
default='output/glue',
type=str,
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--cache_dir",
default="",
type=str,
help="Where do you want to store the pre-trained models downloaded from s3",
)
parser.add_argument(
"--max_seq_length",
default=128,
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.",
)
parser.add_argument("--do_train", default=True, help="Whether to run training.")
parser.add_argument("--do_eval", default=True, help="Whether to run eval on the dev set.")
parser.add_argument(
"--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step.",
)
parser.add_argument(
"--do_lower_case", default=True, help="Set this flag if you are using an uncased model.",
)
parser.add_argument(
"--per_gpu_train_batch_size", default=32, type=int, help="Batch size per GPU/CPU for training.",
)
parser.add_argument(
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument("--learning_rate", default=2e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument(
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.",
)
parser.add_argument(
"--max_steps",
default=-1,
type=int,
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
)
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
parser.add_argument(
"--eval_all_checkpoints",
action="store_true",
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
)
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
parser.add_argument(
"--overwrite_output_dir", default=True, help="Overwrite the content of the output directory",
)
parser.add_argument(
"--overwrite_cache", default=True, help="Overwrite the cached training and evaluation sets",
)
parser.add_argument("--seed", type=int, default=seed, help="random seed for initialization")
parser.add_argument(
"--fp16",
action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
)
parser.add_argument(
"--fp16_opt_level",
type=str,
default="O1",
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html",
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
args = parser.parse_args()
if (
os.path.exists(args.output_dir)
and os.listdir(args.output_dir)
and args.do_train
and not args.overwrite_output_dir
):
raise ValueError(
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
args.output_dir
)
)
if args.server_ip and args.server_port:
import ptvsd
print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach()
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = 1
args.device = device
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
args.local_rank,
device,
args.n_gpu,
bool(args.local_rank != -1),
args.fp16,
)
set_seed(args)
args.task_name = args.task_name.lower()
if args.task_name not in processors:
raise ValueError("Task not found: %s" % (args.task_name))
processor = processors[args.task_name]()
args.output_mode = output_modes[args.task_name]
label_list = processor.get_labels()
num_labels = len(label_list)
if args.local_rank not in [-1, 0]:
torch.distributed.barrier()
from transformers import AutoConfig, AutoModelForSequenceClassification
args.model_type = args.model_type.lower()
config = AutoConfig.from_pretrained(
args.model_name_or_path,
num_labels=num_labels,
finetuning_task=args.task_name,
cache_dir=args.cache_dir if args.cache_dir else None,
)
model = AutoModelForSequenceClassification.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
cache_dir=args.cache_dir if args.cache_dir else None,
)
from pretraining.openwebtext.dataset import new_tokenizer
tokenizer = wrap_tokenizer(new_tokenizer(args.vocab_path), pad_token='[PAD]')
if args.local_rank == 0:
torch.distributed.barrier()
model.to(args.device)
logger.info("Training/evaluation parameters %s", args)
if args.do_train:
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
os.makedirs(args.output_dir)
logger.info("Saving model checkpoint to %s", args.output_dir)
model_to_save = (
model.module if hasattr(model, "module") else model
)
model_to_save.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
model = model_to_save
model.to(args.device)
results = {}
if args.do_eval and args.local_rank in [-1, 0]:
checkpoints = [args.output_dir]
if args.eval_all_checkpoints:
checkpoints = list(
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
)
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)
logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints:
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
model.to(args.device)
result = evaluate(args, model, tokenizer, prefix=prefix)
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
results.update(result)
return results
if __name__ == "__main__":
main()
.\lucidrains\electra-pytorch\examples\glue\utils.py
import copy
import csv
import dataclasses
import json
import logging
from dataclasses import dataclass
from typing import Optional
is_torch_available = lambda: True
is_tf_available = lambda: False
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class InputExample:
"""
A single training/test example for simple sequence classification.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
guid: str
text_a: str
text_b: Optional[str] = None
label: Optional[str] = None
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"
class InputFeatures(object):
"""
A single set of features of data.
Args:
input_ids: Indices of input sequence tokens in the vocabulary.
attention_mask: Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
Usually ``1`` for tokens that are NOT MASKED, ``0`` for MASKED (padded) tokens.
token_type_ids: Segment token indices to indicate first and second portions of the inputs.
label: Label corresponding to the input
"""
def __init__(self, input_ids, attention_mask=None, token_type_ids=None, label=None):
self.input_ids = input_ids
self.attention_mask = attention_mask
self.token_type_ids = token_type_ids
self.label = label
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def get_example_from_tensor_dict(self, tensor_dict):
"""Gets an example from a dict with tensorflow tensors
Args:
tensor_dict: Keys and values should match the corresponding Glue
tensorflow_dataset examples.
"""
raise NotImplementedError()
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
def tfds_map(self, example):
"""Some tensorflow_datasets datasets are not formatted the same way the GLUE datasets are.
This method converts examples to the correct format."""
if len(self.get_labels()) > 1:
example.label = self.get_labels()[int(example.label)]
return example
@classmethod
def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
with open(input_file, "r", encoding="utf-8-sig") as f:
return list(csv.reader(f, delimiter="\t", quotechar=quotechar))
class SingleSentenceClassificationProcessor(DataProcessor):
""" Generic processor for a single sentence classification data set."""
def __init__(self, labels=None, examples=None, mode="classification", verbose=False):
self.labels = [] if labels is None else labels
self.examples = [] if examples is None else examples
self.mode = mode
self.verbose = verbose
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
if isinstance(idx, slice):
return SingleSentenceClassificationProcessor(labels=self.labels, examples=self.examples[idx])
return self.examples[idx]
@classmethod
def create_from_csv(
cls, file_name, split_name="", column_label=0, column_text=1, column_id=None, skip_first_row=False, **kwargs
):
processor = cls(**kwargs)
processor.add_examples_from_csv(
file_name,
split_name=split_name,
column_label=column_label,
column_text=column_text,
column_id=column_id,
skip_first_row=skip_first_row,
overwrite_labels=True,
overwrite_examples=True,
)
return processor
@classmethod
def create_from_examples(cls, texts_or_text_and_labels, labels=None, **kwargs):
processor = cls(**kwargs)
processor.add_examples(texts_or_text_and_labels, labels=labels)
return processor
def add_examples_from_csv(
self,
file_name,
split_name="",
column_label=0,
column_text=1,
column_id=None,
skip_first_row=False,
overwrite_labels=False,
overwrite_examples=False,
):
lines = self._read_tsv(file_name)
if skip_first_row:
lines = lines[1:]
texts = []
labels = []
ids = []
for (i, line) in enumerate(lines):
texts.append(line[column_text])
labels.append(line[column_label])
if column_id is not None:
ids.append(line[column_id])
else:
guid = "%s-%s" % (split_name, i) if split_name else "%s" % i
ids.append(guid)
return self.add_examples(
texts, labels, ids, overwrite_labels=overwrite_labels, overwrite_examples=overwrite_examples
)
def add_examples(
self, texts_or_text_and_labels, labels=None, ids=None, overwrite_labels=False, overwrite_examples=False
):
assert labels is None or len(texts_or_text_and_labels) == len(labels)
assert ids is None or len(texts_or_text_and_labels) == len(ids)
if ids is None:
ids = [None] * len(texts_or_text_and_labels)
if labels is None:
labels = [None] * len(texts_or_text_and_labels)
examples = []
added_labels = set()
for (text_or_text_and_label, label, guid) in zip(texts_or_text_and_labels, labels, ids):
if isinstance(text_or_text_and_label, (tuple, list)) and label is None:
text, label = text_or_text_and_label
else:
text = text_or_text_and_label
added_labels.add(label)
examples.append(InputExample(guid=guid, text_a=text, text_b=None, label=label))
if overwrite_examples:
self.examples = examples
else:
self.examples.extend(examples)
if overwrite_labels:
self.labels = list(added_labels)
else:
self.labels = list(set(self.labels).union(added_labels))
return self.examples
def get_features(
self,
tokenizer,
max_length=None,
pad_on_left=False,
pad_token=0,
mask_padding_with_zero=True,
return_tensors=None,
.\lucidrains\electra-pytorch\pretraining\openwebtext\arg.py
import argparse
import dataclasses
__all__ = ('Arg', 'Int', 'Float', 'Bool', 'Str', 'Choice', 'parse_to')
class Arg:
def __init__(self, **kwargs):
super().__init__()
self.kwargs = kwargs
class Int(Arg):
def __init__(self, **kwargs):
super().__init__(type=int, **kwargs)
class Float(Arg):
def __init__(self, **kwargs):
super().__init__(type=float, **kwargs)
class Bool(Arg):
def __init__(self, **kwargs):
super().__init__(type=bool, **kwargs)
class Str(Arg):
def __init__(self, **kwargs):
super().__init__(type=str, **kwargs)
class _MetaChoice(type):
def __getitem__(self, item):
return self(choices=list(item), type=item)
class Choice(Arg, metaclass=_MetaChoice):
def __init__(self, choices, **kwargs):
super().__init__(choices=choices, **kwargs)
def parse_to(container_class, **kwargs):
def mangle_name(name):
return '--' + name.replace('_', '-')
parser = argparse.ArgumentParser(description=container_class.__doc__)
for field in dataclasses.fields(container_class):
name = field.name
default = field.default
value_or_class = field.type
if isinstance(value_or_class, type):
value = value_or_class(default=default)
else:
value = value_or_class
value.kwargs['default'] = default
parser.add_argument(
mangle_name(name), **value.kwargs)
arg_dict = parser.parse_args(**kwargs)
return container_class(**vars(arg_dict))
.\lucidrains\electra-pytorch\pretraining\openwebtext\dataset.py
import math
import os
import random
from dataclasses import dataclass
from itertools import chain
from functools import partial
from pathlib import Path
import numpy as np
import torch
import torch.utils.data
from openwebtext import tokenization
class ExampleBuilder:
"""Given a stream of input text, creates pretraining examples."""
def __init__(self, vocab, max_length):
self._vocab = vocab
self._current_sentences = []
self._current_length = 0
self._max_length = max_length
self._target_length = max_length
def add_line(self, bert_tokids):
"""Adds a line of text to the current example being built."""
self._current_sentences.append(bert_tokids)
self._current_length += len(bert_tokids)
if self._current_length >= self._target_length:
return self._create_example()
return None
def _create_example(self):
"""Creates a pre-training example from the current list of sentences."""
if random.random() < 0.1:
first_segment_target_length = 100000
else:
first_segment_target_length = (self._target_length - 3) // 2
first_segment = []
second_segment = []
for sentence in self._current_sentences:
if (len(first_segment) == 0 or
len(first_segment) + len(sentence) < first_segment_target_length or
(len(second_segment) == 0 and
len(first_segment) < first_segment_target_length and
random.random() < 0.5)):
first_segment += sentence
else:
second_segment += sentence
first_segment = first_segment[:self._max_length - 2]
second_segment = second_segment[:max(0, self._max_length - len(first_segment) - 3)]
self._current_sentences = []
self._current_length = 0
if random.random() < 0.05:
self._target_length = random.randint(5, self._max_length)
else:
self._target_length = self._max_length
return self._make_tf_example(first_segment, second_segment)
def _make_tf_example(self, first_segment, second_segment):
"""将两个文本“段”转换为tf.train.Example。"""
vocab = self._vocab
input_ids = [vocab["[CLS]"]] + first_segment + [vocab["[SEP]"]
segment_ids = [0] * len(input_ids)
if second_segment:
input_ids += second_segment + [vocab["[SEP]"]]
segment_ids += [1] * (len(second_segment) + 1)
input_mask = [1] * len(input_ids)
input_ids += [0] * (self._max_length - len(input_ids))
input_mask += [0] * (self._max_length - len(input_mask))
segment_ids += [0] * (self._max_length - len(segment_ids)
def create_int_feature(tensors):
return torch.tensor(tensors)
tf_example = {
"input_ids": create_int_feature(input_ids),
"input_mask": create_int_feature(input_mask),
"segment_ids": create_int_feature(segment_ids)
}
return tf_example
class OpenWebTextDataset(torch.utils.data.IterableDataset):
def __init__(self, feature_set_paths, n_tensors_per_file):
self.feature_set_paths = feature_set_paths
@staticmethod
def parse_file(file_index):
try:
features = torch.load(str(file_index))
yield from features
except RuntimeError:
raise RuntimeError(f'Corrupted file {file_index}')
def __len__(self):
return len(self.feature_set_paths) * self.n_tensors_per_file
def __iter__(self):
return chain.from_iterable(map(self.parse_file, self.feature_set_paths))
class ExampleBuilderDataset(torch.utils.data.IterableDataset):
def __init__(self, dataset, builder):
self.dataset = dataset
def __len__(self):
return len(self.dataset)
def __iter__(self):
def create_example():
while True:
token_ids = list(next(self.dataset).cpu().numpy())
example = self.builder.add_line(token_ids)
if example:
return example
while True:
yield create_example()
def cycle(iterable):
while True:
for x in iterable:
yield x
def new_tokenizer(vocab_file, do_lower_case=True):
return tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case)
def parse_tokenizer(tokenizer, text):
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
def create_tokenizer(vocab_file, do_lower_case=True):
tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case)
return partial(parse_tokenizer, tokenizer)
def load_owt(owt_dir, n_tensors_per_file):
owt_dir_path = Path(owt_dir)
feature_set_paths = [owt_dir_path / feature_set_path for feature_set_path in os.listdir(owt_dir_path)]
np.random.shuffle(feature_set_paths)
assert len(feature_set_paths) > 0
return OpenWebTextDataset(feature_set_paths, n_tensors_per_file=n_tensors_per_file)
def wrap_example_builder(dataset, vocab, max_length):
return ExampleBuilderDataset(cycle(iter(dataset)), ExampleBuilder(vocab, max_length))
.\lucidrains\electra-pytorch\pretraining\openwebtext\preprocess.py
import logging
import logging
import math
import multiprocessing
import os
import random
import tarfile
from dataclasses import dataclass
from itertools import chain
from functools import partial
from pathlib import Path
import numpy as np
import torch
import torch.utils.data
from pretraining.openwebtext import arg
from pretraining.openwebtext import tokenization
logger = logging.getLogger(__name__)
def parse_tokenizer(tokenizer, text):
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
def create_tokenizer(vocab_file, do_lower_case=True):
tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case)
return partial(parse_tokenizer, tokenizer)
def preprocess_owt(tokenizer, src_dir, tmp_dir, trg_dir, n_dataset_building_processes, n_tensors_per_file, max_seq_length=None):
logger.info(f'Writing features to {trg_dir}.')
os.makedirs(trg_dir, exist_ok=False)
trg_dir = Path(trg_dir)
src_dir = Path(src_dir)
tmp_dir = Path(tmp_dir)
archives = os.listdir(src_dir)
n_archives_per_job = math.ceil(len(archives) / n_dataset_building_processes)
job_archives = [
archives[i * n_archives_per_job : (i + 1) * n_archives_per_job]
for i in range(n_dataset_building_processes)
]
logger.info(f'Processing {len(archives)} archives.')
assert len(archives) > 0
if n_dataset_building_processes == 1:
feature_set_paths = preprocess_owt_job(tokenizer, src_dir, tmp_dir, trg_dir, job_archives, n_tensors_per_file, max_seq_length, job_id=0)
else:
pool = multiprocessing.Pool(processes=n_dataset_building_processes)
preprocess_owt_job_partial = partial(preprocess_owt_job, tokenizer, src_dir, tmp_dir, trg_dir, job_archives, n_tensors_per_file, max_seq_length)
feature_sets = pool.map(preprocess_owt_job_partial, range(n_dataset_building_processes))
feature_set_paths = [file_path for feature_set in feature_sets for file_path in feature_set]
return feature_set_paths
def preprocess_owt_job(tokenizer, src_dir, tmp_dir, trg_dir, job_archives, n_tensors_per_file, max_seq_length, job_id=0):
'''
OpenWebText is saved under the following format:
openwebtext.zip
|-> archive_xxx.zip
|-> file_xxx.txt
|-> file_xxz.txt
...
|-> archive_xxz.zip
|-> file_xxy.txt
...
...
'''
os.makedirs(tmp_dir, exist_ok=True)
feature_index = 0
feature_set_paths = []
features = []
for archive_id, archive in enumerate(job_archives[job_id]):
if os.path.isdir(src_dir / archive):
logger.info(f'Ignoring rogue directory {src_dir / archive}.')
continue
logger.info(f'Job {job_id}: Processing {archive_id}/{len(job_archives[job_id])} {src_dir / archive}.')
with tarfile.open(src_dir / archive) as t:
extracted_archive = tmp_dir / f'{archive}-extracted'
t.extractall(extracted_archive)
for file in os.listdir(extracted_archive):
file_path = extracted_archive / file
with open(file_path, 'r') as f:
for line in f.readlines():
line = line.strip()
if len(line) > 2:
encoding = tokenizer(line)
features.append(torch.tensor(encoding))
while len(features) > n_tensors_per_file:
feature_set_path = trg_dir / f'feature_set_{job_id}_{feature_index}.pt'
torch.save(features[:n_tensors_per_file], feature_set_path)
features = features[n_tensors_per_file:]
feature_index += 1
feature_set_paths.append(feature_set_path)
if len(features) > 0:
feature_set_path = trg_dir / f'feature_set_{job_id}_{feature_index}.pt'
torch.save(features, feature_set_path)
feature_set_paths.append(feature_set_path)
return feature_set_paths
@dataclass(frozen=True)
class Args:
src_dir: arg.Str = 'data/openwebtext'
trg_dir: arg.Str = 'data/openwebtext_features'
tmp_dir: arg.Str = '/tmp/owt'
vocab_file: arg.Str = 'data/vocab.txt'
n_dataset_building_processes: arg.Int = 32
n_tensors_per_file: arg.Int = 2048
max_seq_length: arg.Int = 128
def main():
args = arg.parse_to(Args)
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO
)
tokenizer = create_tokenizer(args.vocab_file)
preprocess_owt(tokenizer=tokenizer, src_dir=args.src_dir, tmp_dir=args.tmp_dir, trg_dir=args.trg_dir, n_dataset_building_processes=args.n_dataset_building_processes, n_tensors_per_file=args.n_tensors_per_file, max_seq_length=args.max_seq_length)
if __name__ == '__main__':
main()
.\lucidrains\electra-pytorch\pretraining\openwebtext\pretrain.py
import os
import sys
dir_path = os.path.dirname(os.path.realpath(__file__))
parent_dir_path = os.path.abspath(os.path.join(dir_path, os.pardir))
sys.path.insert(0, parent_dir_path)
import random
import logging
from time import time
from dataclasses import dataclass
import numpy as np
import torch
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data.dataloader import DataLoader
from electra_pytorch import Electra
from openwebtext import arg
from openwebtext.dataset import load_owt, new_tokenizer, wrap_example_builder
logger = logging.getLogger(__name__)
@dataclass
class Args:
data_dir: arg.Str = 'data/openwebtext_features'
data_vocab_file: arg.Str = 'data/vocab.txt'
data_n_tensors_per_file: arg.Int = 2048
data_max_seq_length: arg.Int = 128
gpu: arg.Int = 0
gpu_enabled: arg.Bool = True
gpu_deterministic: arg.Bool = False
gpu_mixed_precision: arg.Bool = False
distributed_port: arg.Int = 8888
distributed_enabled: arg.Bool = True
distributed_world_size: arg.Int = 4
model_generator: arg.Str = 'pretraining/openwebtext/small_generator.json'
model_discriminator: arg.Str = 'pretraining/openwebtext/small_discriminator.json'
model_mask_prob: arg.Float = 0.15
opt_lr: arg.Float = 5e-4
opt_batch_size: arg.Int = 128 // (distributed_world_size if distributed_enabled else 1)
opt_warmup_steps: arg.Int = 10_000
opt_num_training_steps: arg.Int = 200_000
step_log: arg.Int = 10
step_ckpt: arg.Int = 10_000
def train(rank, args):
if args.distributed_enabled:
torch.distributed.init_process_group(
backend='nccl',
init_method='env://',
world_size=args.distributed_world_size,
rank=rank)
if args.gpu_enabled:
device = torch.device('cuda:{}'.format(rank))
else:
device = torch.device('cpu')
is_master = True if not args.distributed_enabled else args.distributed_enabled and rank == 0
set_gpus(rank)
set_seed(rank)
set_cuda(deterministic=args.gpu_deterministic)
output_dir = f'{args.output_dir}/{rank}'
os.makedirs(output_dir, exist_ok=False)
setup_logging(filename=f'{output_dir}/output.log', console=is_master)
tokenizer = new_tokenizer(vocab_file=args.data_vocab_file)
vocab_size = len(tokenizer.vocab)
ds_train = wrap_example_builder(dataset=load_owt(owt_dir=args.data_dir, n_tensors_per_file=args.data_n_tensors_per_file), vocab=tokenizer.vocab, max_length=args.data_max_seq_length)
pad_token_id = tokenizer.vocab['[PAD]']
mask_token_id = tokenizer.vocab['[MASK]']
cls_token_id = tokenizer.vocab['[CLS]']
sep_token_id = tokenizer.vocab['[SEP]']
assert pad_token_id == 0
assert cls_token_id == 101
assert sep_token_id == 102
assert mask_token_id == 103
def collate_batch(examples):
input_ids = torch.nn.utils.rnn.pad_sequence([example['input_ids'] for example in examples], batch_first=True, padding_value=pad_token_id)
input_mask = torch.nn.utils.rnn.pad_sequence([example['input_mask'] for example in examples], batch_first=True, padding_value=pad_token_id)
segment_ids = torch.nn.utils.rnn.pad_sequence([example['segment_ids'] for example in examples], batch_first=True, padding_value=pad_token_id)
return input_ids, input_mask, segment_ids
def cycle(iterable):
while True:
for x in iterable:
yield x
ds_train_loader = iter(cycle(DataLoader(ds_train, batch_size=args.opt_batch_size, collate_fn=collate_batch)))
def to_distributed_model(model):
return model if not args.distributed_enabled else torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], find_unused_parameters=True)
def tie_weights(generator, discriminator):
generator.electra.embeddings.word_embeddings = discriminator.electra.embeddings.word_embeddings
generator.electra.embeddings.position_embeddings = discriminator.electra.embeddings.position_embeddings
generator.electra.embeddings.token_type_embeddings = discriminator.electra.embeddings.token_type_embeddings
class LogitsAdapter(torch.nn.Module):
def __init__(self, adaptee):
super().__init__()
self.adaptee = adaptee
def forward(self, *args, **kwargs):
return self.adaptee(*args, **kwargs)[0]
from transformers import AutoConfig, ElectraForMaskedLM, ElectraForPreTraining
generator = ElectraForMaskedLM(AutoConfig.from_pretrained(args.model_generator))
discriminator = ElectraForPreTraining(AutoConfig.from_pretrained(args.model_discriminator))
tie_weights(generator, discriminator)
model = to_distributed_model(Electra(
LogitsAdapter(generator),
LogitsAdapter(discriminator),
num_tokens = vocab_size,
mask_token_id = mask_token_id,
pad_token_id = pad_token_id,
mask_prob = args.model_mask_prob,
mask_ignore_token_ids = [tokenizer.vocab['[CLS]'], tokenizer.vocab['[SEP]'],
random_token_prob = 0.0).to(device))
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
def lr_lambda(current_step):
learning_rate = max(0.0, 1. - (float(current_step) / float(num_training_steps)))
learning_rate *= min(1.0, float(current_step) / float(num_warmup_steps))
return learning_rate
return LambdaLR(optimizer, lr_lambda, last_epoch)
def get_params_without_weight_decay_ln(named_params, weight_decay):
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{
'params': [p for n, p in named_params if not any(nd in n for nd in no_decay)],
'weight_decay': weight_decay,
},
{
'params': [p for n, p in named_params if any(nd in n for nd in no_decay)],
'weight_decay': 0.0,
},
]
return optimizer_grouped_parameters
optimizer = torch.optim.AdamW(get_params_without_weight_decay_ln(model.named_parameters(), weight_decay=0.1), lr=args.opt_lr, betas=(0.9, 0.999), eps=1e-08)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.opt_warmup_steps, num_training_steps=args.opt_num_training_steps)
scaler = torch.cuda.amp.GradScaler(enabled=args.gpu_mixed_precision)
t, steps_s, eta_m = time(), 0., 0
for step in range(args.opt_num_training_steps+1):
input_ids, input_mask, segment_ids = next(ds_train_loader)
input_ids = input_ids.to(device)
input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device)
assert input_ids.shape[1] <= args.data_max_seq_length
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=args.gpu_mixed_precision):
loss, loss_mlm, loss_disc, acc_gen, acc_disc, disc_labels, disc_pred = model(input_ids, attention_mask=input_mask, token_type_ids=segment_ids)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
scheduler.step()
metrics = {
'step': (step, '{:8d}'),
'loss': (loss.item(), '{:8.5f}'),
'loss_mlm': (loss_mlm.item(), '{:8.5f}'),
'loss_disc': (loss_disc.item(), '{:8.5f}'),
'acc_gen': (acc_gen.item(), '{:5.3f}'),
'acc_disc': (acc_disc.item(), '{:5.3f}'),
'lr': (scheduler.get_last_lr()[0], '{:8.7f}'),
'steps': (steps_s, '{:4.1f}/s'),
'eta': (eta_m, '{:4d}m'),
}
if step % args.step_log == 0:
sep = ' ' * 2
logger.info(sep.join([f'{k}: {v[1].format(v[0])}' for (k, v) in metrics.items()])
if step > 0 and step % 100 == 0:
t2 = time()
steps_s = 100. / (t2 - t)
eta_m = int(((args.opt_num_training_steps - step) / steps_s) // 60)
t = t2
if step % 200 == 0:
logger.info(np.array2string(disc_labels[0].cpu().numpy(), threshold=sys.maxsize, max_line_width=sys.maxsize))
logger.info(np.array2string(disc_pred[0].cpu().numpy(), threshold=sys.maxsize, max_line_width=sys.maxsize))
if step > 0 and step % args.step_ckpt == 0 and is_master:
discriminator.electra.save_pretrained(f'{args.output_dir}/ckpt/{step}')
def set_gpus(gpu):
torch.cuda.set_device(gpu)
def set_seed(seed):
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
如果 CUDA 可用,设置 CUDA 随机种子
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def set_cuda(deterministic=True):
如果 CUDA 可用,设置 CUDA 是否确定性
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = deterministic
torch.backends.cudnn.benchmark = not deterministic
def get_exp_id(file):
返回文件名的基本名称(不包含扩展名)
return os.path.splitext(os.path.basename(file))[0]
def get_output_dir(exp_id):
导入 datetime 模块
t = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
创建输出目录路径
output_dir = os.path.join('output/' + exp_id, t)
如果输出目录不存在,则创建
os.makedirs(output_dir, exist_ok=True)
返回输出目录路径
return output_dir
def setup_logging(filename, console=True):
设置日志格式
log_format = logging.Formatter("%(asctime)s : %(message)s")
获取日志记录器
logger = logging.getLogger()
清空日志记录器的处理器
logger.handlers = []
创建文件处理器
file_handler = logging.FileHandler(filename)
设置文件处理器的格式
file_handler.setFormatter(log_format)
添加文件处理器到日志记录器
logger.addHandler(file_handler)
如果需要在控制台输出日志
if console:
创建控制台处理器
console_handler = logging.StreamHandler(sys.stdout)
设置控制台处理器的格式
console_handler.setFormatter(log_format)
添加控制台处理器到日志记录器
logger.addHandler(console_handler)
设置日志记录器的日志级别为 INFO
logger.setLevel(logging.INFO)
返回日志记录器
return logger
def copy_source(file, output_dir):
导入 shutil 模块
复制源文件到输出目录
shutil.copyfile(file, os.path.join(output_dir, os.path.basename(file)))
def main():
获取实验 ID
exp_id = get_exp_id(__file__)
获取输出目录
output_dir = get_output_dir(exp_id)
如果输出目录不存在,则创建
os.makedirs(output_dir, exist_ok=True)
创建检查点目录
os.makedirs(f'{output_dir}/ckpt', exist_ok=False)
复制源文件到输出目录
copy_source(__file__, output_dir)
解析命令行参数
args = arg.parse_to(Args)
设置输出目录和实验 ID
args.output_dir = output_dir
args.exp_id = exp_id
如果启用分布式训练
if args.distributed_enabled:
设置主地址和端口
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = str(args.distributed_port)
使用多进程方式启动训练
torch.multiprocessing.spawn(train, nprocs=args.distributed_world_size, args=(args,))
否则
else:
单机训练
train(rank=args.gpu, args=args)
if __name__ == '__main__':
main()
.\lucidrains\electra-pytorch\pretraining\openwebtext\tokenization.py
"""Tokenization classes, the same as used for BERT."""
import collections
import unicodedata
def convert_to_unicode(text):
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
def printable_text(text):
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
def load_vocab(vocab_file):
vocab = collections.OrderedDict()
index = 0
with open(vocab_file, "r") as reader:
while True:
token = convert_to_unicode(reader.readline())
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab
def convert_by_vocab(vocab, items):
output = []
for item in items:
output.append(vocab[item])
return output
def convert_tokens_to_ids(vocab, tokens):
return convert_by_vocab(vocab, tokens)
def convert_ids_to_tokens(inv_vocab, ids):
return convert_by_vocab(inv_vocab, ids)
def whitespace_tokenize(text):
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class FullTokenizer(object):
def __init__(self, vocab_file, do_lower_case=True):
self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
def tokenize(self, text):
split_tokens = []
for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
return split_tokens
def convert_tokens_to_ids(self, tokens):
return convert_by_vocab(self.vocab, tokens)
def convert_ids_to_tokens(self, ids):
return convert_by_vocab(self.inv_vocab, ids)
class BasicTokenizer(object):
def __init__(self, do_lower_case=True):
self.do_lower_case = do_lower_case
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = convert_to_unicode(text)
text = self._clean_text(text)
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
if ((cp >= 0x4E00 and cp <= 0x9FFF) or
(cp >= 0x3400 and cp <= 0x4DBF) or
(cp >= 0x20000 and cp <= 0x2A6DF) or
(cp >= 0x2A700 and cp <= 0x2B73F) or
(cp >= 0x2B740 and cp <= 0x2B81F) or
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or
(cp >= 0x2F800 and cp <= 0x2FA1F)):
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenziation."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
text = convert_to_unicode(text)
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False