【NLP】漏洞类情报类别分类 -- 训练代码阅读以及模型训练

162 阅读9分钟

image.png


持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第9天,点击查看活动详情

前言

之前的文章中,描述了TextCNN的模型结构以及模型代码,今天将记录模型训练代码阅读以及模型训练。

train

import datetime
import os
import time
import numpy as np
import tensorflow as tf
from data_helpers import batch_iter
from text_cnn import TextCNN
import pickle
import codecs

writer = codecs.open('dev_res.txt','w','UTF-8')

def pickle_reader(inputs):
    f = open(inputs, 'rb')
    lines = pickle.load(f)
    f.close()
    print("Finish load {}".format(inputs))
    return lines


def pickle_writer(inputs, name):
    output = open(name, 'wb')
    pickle.dump(inputs, output, protocol=2)
    output.close()
    print("Finish save {}".format(name))

tf.flags.DEFINE_string("filter_sizes", "3,4,5", "Comma-separated filter sizes (default: '3,4,5')")
tf.flags.DEFINE_integer("num_filters", 128, "Number of filters per filter size (default: 128)")
tf.flags.DEFINE_float("dropout_keep_prob", 0.5, "Dropout keep probability (default: 0.5)")
tf.flags.DEFINE_float("l2_reg_lambda", 0.0, "L2 regularization lambda (default: 0.0)")

# Training parameters
tf.flags.DEFINE_integer("batch_size", 32, "Batch Size (default: 64)")
tf.flags.DEFINE_integer("num_epochs", 40, "Number of training epochs (default: 200)")
tf.flags.DEFINE_integer("evaluate_every", 2000, "Evaluate model on dev set after this many steps (default: 100)")
tf.flags.DEFINE_integer("checkpoint_every", 1000, "Save model after this many steps (default: 100)")
tf.flags.DEFINE_integer("num_checkpoints", 2, "Number of checkpoints to store (default: 5)")
# Misc Parameters
tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")

FLAGS = tf.flags.FLAGS

# FLAGS._parse_flags()
print("\nParameters:")
for attr, value in sorted(FLAGS.__flags.items()):
    print("{}={}".format(attr.upper(), value))
print("")


# Data Preparation
# ==================================================
# train_data = codecs.open(r'F:\TextCNN_4_URL\data\train_data.txt','r','UTF-8').readlines()
train_data = codecs.open(r'E:\TextCNN\data\VlunData\train.txt','r','UTF-8').readlines()
np.random.shuffle(train_data)
urls = []
labels = []
index = int(0.8 * float(len(train_data)))
max_len = 0
for line in train_data:
    url = line.strip().split('<split>')[0].split()
    if len(url) > max_len:
        max_len = len(url)
    label = line.strip().split('<split>')[1]
    urls.append(url)
    labels.append(label)

print('文本最大长度为 {}'.format(max_len))
vocabulary = {'UNK':0,'PAD':1}
for url in urls:
    for word in url:
        if word in vocabulary:
            pass
        else:
            vocabulary[word] = len(vocabulary)
pickle_writer(vocabulary,'vocabulary.pkl')

def convert2id(inputs):
    res = []
    for line in inputs:
        ids = [1 for _ in range(max_len)]
        for index, word in enumerate(line):
            if word in vocabulary:
                ids[index]  = vocabulary[word]
            else:
                ids[index] = 0
        res.append(ids)
    return res


def convert2label(inputs):
    res = []
    for index in inputs:
        if index == '0':
            res.append([0,1])
        else:
            res.append([1,0])
    return res

train_url, train_label = convert2id(urls[:index]), convert2label(labels[:index])
dev_url, dev_label = convert2id(urls[index:]), convert2label(labels[index:])
# Training
# ==================================================

with tf.Graph().as_default():
    session_conf = tf.ConfigProto(
      allow_soft_placement=FLAGS.allow_soft_placement,
      log_device_placement=FLAGS.log_device_placement)
    sess = tf.Session(config=session_conf)
    with sess.as_default():
        cnn = TextCNN(sequence_length=max_len,
            num_classes=2,
            vocab_size=len(vocabulary),
            embedding_size=300,
            filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
            num_filters=FLAGS.num_filters,
            l2_reg_lambda=FLAGS.l2_reg_lambda)

        # Define Training procedure
        global_step = tf.Variable(0, name="global_step", trainable=False)
        optimizer = tf.train.AdamOptimizer(1e-3)
        grads_and_vars = optimizer.compute_gradients(cnn.loss)
        train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)

        # Keep track of gradient values and sparsity (optional)
        grad_summaries = []
        for g, v in grads_and_vars:
            if g is not None:
                grad_hist_summary = tf.summary.histogram("{}/grad/hist".format(v.name), g)
                sparsity_summary = tf.summary.scalar("{}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
                grad_summaries.append(grad_hist_summary)
                grad_summaries.append(sparsity_summary)
        grad_summaries_merged = tf.summary.merge(grad_summaries)

        # Output directory for models and summaries
        timestamp = str(int(time.time()))
        out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
        print("Writing to {}\n".format(out_dir))

        # Summaries for loss and accuracy
        loss_summary = tf.summary.scalar("loss", cnn.loss)
        acc_summary = tf.summary.scalar("accuracy", cnn.accuracy)

        # Train Summaries
        train_summary_op = tf.summary.merge([loss_summary, acc_summary, grad_summaries_merged])
        train_summary_dir = os.path.join(out_dir, "summaries", "train")
        train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)

        # Dev summaries
        dev_summary_op = tf.summary.merge([loss_summary, acc_summary])
        dev_summary_dir = os.path.join(out_dir, "summaries", "dev")
        dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph)

        # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
        checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
        checkpoint_prefix = os.path.join(checkpoint_dir, "model")
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints)

        # Write vocabulary
        # Initialize all variables
        sess.run(tf.global_variables_initializer())

        def train_step(x_batch, y_batch):
            """
            A single training step
            """
            feed_dict = {
              cnn.input_x: x_batch,
              cnn.input_y: y_batch,
              cnn.dropout_keep_prob: FLAGS.dropout_keep_prob
            }
            _, step, summaries, loss, accuracy = sess.run(
                [train_op, global_step, train_summary_op, cnn.loss, cnn.accuracy],
                feed_dict)
            time_str = datetime.datetime.now().isoformat()
            print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))
            train_summary_writer.add_summary(summaries, step)

        def dev_step(x_batch, y_batch, writer=None):
            """
            Evaluates model on a dev set
            """
            feed_dict = {
              cnn.input_x: x_batch,
              cnn.input_y: y_batch,
              cnn.dropout_keep_prob: 1.0
            }
            step, summaries, loss, accuracy = sess.run(
                [global_step, dev_summary_op, cnn.loss, cnn.accuracy],
                feed_dict)
            time_str = datetime.datetime.now().isoformat()
            if writer:
                writer.add_summary(summaries, step)
            return loss,accuracy

        # Generate batches
        batches = batch_iter(
            list(zip(train_url, train_label)), FLAGS.batch_size, FLAGS.num_epochs)
        # Training loop. For each batch...
        for batch in batches:
            x_batch, y_batch = zip(*batch)
            train_step(x_batch, y_batch)
            current_step = tf.train.global_step(sess, global_step)
            if current_step % FLAGS.evaluate_every == 0:
                print("\nEvaluation:")
                t_cost = 0.0
                t_acc = 0.0
                count = 0
                dev_batches = batch_iter(list(zip(dev_url, dev_label)), FLAGS.batch_size, 1)
                for dev_batch in dev_batches:
                    x_dev_batch, y_dev_batch = zip(*dev_batch)
                    loss, accuracy = dev_step(x_dev_batch, y_dev_batch, writer=dev_summary_writer)
                    t_cost += loss
                    t_acc += accuracy
                    count += 1
                print("loss {:g}, acc {:g}".format(t_cost/count, t_acc/count)+'\n')
                writer.write("loss {:g}, acc {:g}".format(t_cost/count, t_acc/count)+'\n')
                writer.flush()
            if current_step % FLAGS.checkpoint_every == 0:
                path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                print("Saved model checkpoint to {}\n".format(path))

参数解读:

  • dev_sample_percentage 测试集比例
  • positive_data_file 正样本文件
  • negative_data_file 负样本文件
  • embedding_dim 词向量维度
  • filter_sizes filter维度
  • num_filters filter个数
  • dropout_keep_prob 随机失活比例
  • l2_reg_lambda l2正则化系数
  • batch_size batch 批次大小
  • num_epochs 训练轮数
  • evaluate_every 每隔多少step进行测试
  • checkpoint_every 每隔多少step保存一个checkpoint
  • num_checkpoints 一共保存多少个checkpoint
  • allow_soft_placement 允许动态分配资源
  • log_device_placement=True,让我们可以看到我们的tensor、op是在哪台设备、哪颗CPU上运行
train_data = codecs.open(r'E:\TextCNN\data\VlunData\train.txt','r','UTF-8').readlines()
np.random.shuffle(train_data)
urls = []
labels = []
index = int(0.8 * float(len(train_data)))
max_len = 0
for line in train_data:
    url = line.strip().split('<split>')[0].split()
    if len(url) > max_len:
        max_len = len(url)
    label = line.strip().split('<split>')[1]
    urls.append(url)
    labels.append(label)

print('文本最大长度为 {}'.format(max_len))
vocabulary = {'UNK':0,'PAD':1}
for url in urls:
    for word in url:
        if word in vocabulary:
            pass
        else:
            vocabulary[word] = len(vocabulary)
pickle_writer(vocabulary,'vocabulary.pkl')

train_data 加载训练数据,即dataPreprocessing数据处理之后的训练数据,格式如下:

apple   os   x 是 一款 苹果 分发 的 基于 bsd 的 操作系统 。<split>0
cpanel 是 美国 cpanel 公司 的 一套 基于 web 的 自动化 主机 托管 平台 。 该 平台 主要 用于 自动化 管理 网站 和 服务器 。<split>0
mozilla   firefox 是 美国 mozilla 基金会 的 一款 开源 web 浏览器 。<split>1
iscripts   autohoster 是 一款 基于 php 的 web 应用程序 。<split>0
dataease   v1.11 . 1   was   discovered   to   contain   a   sql   injection   vulnerability   via   the   parameter   datasourceid .<split>0
d - bus 是 一款 进程 间通信 ( ipc ) 实现 。 用于 在 应用程序 间 发送 消息 。<split>1
chromium - browser 是 一个 开放源码 的 web 浏览器 项目 , 由谷歌 启动 , 为 专有 的 谷歌 浏览器 浏览器 提供 源代码 。<split>1
在 固件 版本 为 1012 的 insteon   hub   2245 - 222 设备 上 , 从 pubnub 服务 收到 的 特别 精心制作 的 回复 可能 会 在 覆盖 任意 数据 的 全局 区域 上 导致 缓冲区 溢出 。 攻击者 应该 模拟 pubnub 并 响应 https   get 请求 来 触发 此 漏洞 。 strcpy 溢出 了 insteon _ pubnub 缓冲区 。 channel _ ad _ r , 它 的 大小 为 16 字节 。 攻击者 可以 发送 任意 长 的 “ ad _ r ” 参数 来 利用 这个 漏洞 。<split>1
mattermost   server 是 美国 mattermost 公司 的 一套 开源 的 消息传递 平台 。<split>0
malwarebytes   premium 是 美国 malwarebytes 公司 的 一套 反 恶意 间谍 软件 。 该软件 支持 删除 蠕虫 、 拨号 程序 、 木马 、 rootkit 、 间谍 软件 、 漏洞 、 僵尸 和 其他 恶意软件 等 。<split>0

np.random.shuffle(train_data),对训练数据进展随机打乱,

index = int(0.8 * float(len(train_data))): 用于切分训练测试集

max_len = 0
for line in train_data:
    url = line.strip().split('<split>')[0].split()
    if len(url) > max_len:
        max_len = len(url)
    label = line.strip().split('<split>')[1]
    urls.append(url)
    labels.append(label)

获取文本最大长度,以及将文本与标签分别存于不同的list,用于模型训练。

print('文本最大长度为 {}'.format(max_len))
vocabulary = {'UNK':0,'PAD':1}
for url in urls:
    for word in url:
        if word in vocabulary:
            pass
        else:
            vocabulary[word] = len(vocabulary)
pickle_writer(vocabulary,'vocabulary.pkl')

def convert2id(inputs):
    res = []
    for line in inputs:
        ids = [1 for _ in range(max_len)]
        for index, word in enumerate(line):
            if word in vocabulary:
                ids[index]  = vocabulary[word]
            else:
                ids[index] = 0
        res.append(ids)
    return res


def convert2label(inputs):
    res = []
    for index in inputs:
        if index == '0':
            res.append([0,1])
        else:
            res.append([1,0])
    return res

生成词典,用于测试的时候使用,并且将token_list和label_list转为对应的index,作为模型input。 词典方面,增加pad和unk字符,分别为补长字符和未登录词

train_url, train_label = convert2id(urls[:index]), convert2label(labels[:index])
dev_url, dev_label = convert2id(urls[index:]), convert2label(labels[index:])

切分训练集和测试机,index为上面生成的int类型。

with tf.Graph().as_default():
    session_conf = tf.ConfigProto(
      allow_soft_placement=FLAGS.allow_soft_placement,
      log_device_placement=FLAGS.log_device_placement)
    sess = tf.Session(config=session_conf)
    with sess.as_default():
        cnn = TextCNN(sequence_length=max_len,
            num_classes=2,
            vocab_size=len(vocabulary),
            embedding_size=300,
            filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
            num_filters=FLAGS.num_filters,
            l2_reg_lambda=FLAGS.l2_reg_lambda)

        # Define Training procedure
        global_step = tf.Variable(0, name="global_step", trainable=False)
        optimizer = tf.train.AdamOptimizer(1e-3)
        grads_and_vars = optimizer.compute_gradients(cnn.loss)
        train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)

        # Keep track of gradient values and sparsity (optional)
        grad_summaries = []
        for g, v in grads_and_vars:
            if g is not None:
                grad_hist_summary = tf.summary.histogram("{}/grad/hist".format(v.name), g)
                sparsity_summary = tf.summary.scalar("{}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
                grad_summaries.append(grad_hist_summary)
                grad_summaries.append(sparsity_summary)
        grad_summaries_merged = tf.summary.merge(grad_summaries)

        # Output directory for models and summaries
        timestamp = str(int(time.time()))
        out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
        print("Writing to {}\n".format(out_dir))

        # Summaries for loss and accuracy
        loss_summary = tf.summary.scalar("loss", cnn.loss)
        acc_summary = tf.summary.scalar("accuracy", cnn.accuracy)

        # Train Summaries
        train_summary_op = tf.summary.merge([loss_summary, acc_summary, grad_summaries_merged])
        train_summary_dir = os.path.join(out_dir, "summaries", "train")
        train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)

        # Dev summaries
        dev_summary_op = tf.summary.merge([loss_summary, acc_summary])
        dev_summary_dir = os.path.join(out_dir, "summaries", "dev")
        dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph)

        # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
        checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
        checkpoint_prefix = os.path.join(checkpoint_dir, "model")
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints)

        # Write vocabulary
        # Initialize all variables
        sess.run(tf.global_variables_initializer())

初始化模型,增加summary,用于记录模型训练过程。

def train_step(x_batch, y_batch):
    """
    A single training step
    """
    feed_dict = {
      cnn.input_x: x_batch,
      cnn.input_y: y_batch,
      cnn.dropout_keep_prob: FLAGS.dropout_keep_prob
    }
    _, step, summaries, loss, accuracy = sess.run(
        [train_op, global_step, train_summary_op, cnn.loss, cnn.accuracy],
        feed_dict)
    time_str = datetime.datetime.now().isoformat()
    print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))
    train_summary_writer.add_summary(summaries, step)

def dev_step(x_batch, y_batch, writer=None):
    """
    Evaluates model on a dev set
    """
    feed_dict = {
      cnn.input_x: x_batch,
      cnn.input_y: y_batch,
      cnn.dropout_keep_prob: 1.0
    }
    step, summaries, loss, accuracy = sess.run(
        [global_step, dev_summary_op, cnn.loss, cnn.accuracy],
        feed_dict)
    time_str = datetime.datetime.now().isoformat()
    if writer:
        writer.add_summary(summaries, step)
    return loss,accuracy

# Generate batches
batches = batch_iter(
    list(zip(train_url, train_label)), FLAGS.batch_size, FLAGS.num_epochs)
# Training loop. For each batch...
for batch in batches:
    x_batch, y_batch = zip(*batch)
    train_step(x_batch, y_batch)
    current_step = tf.train.global_step(sess, global_step)
    if current_step % FLAGS.evaluate_every == 0:
        print("\nEvaluation:")
        t_cost = 0.0
        t_acc = 0.0
        count = 0
        dev_batches = batch_iter(list(zip(dev_url, dev_label)), FLAGS.batch_size, 1)
        for dev_batch in dev_batches:
            x_dev_batch, y_dev_batch = zip(*dev_batch)
            loss, accuracy = dev_step(x_dev_batch, y_dev_batch, writer=dev_summary_writer)
            t_cost += loss
            t_acc += accuracy
            count += 1
        print("loss {:g}, acc {:g}".format(t_cost/count, t_acc/count)+'\n')
        writer.write("loss {:g}, acc {:g}".format(t_cost/count, t_acc/count)+'\n')
        writer.flush()
    if current_step % FLAGS.checkpoint_every == 0:
        path = saver.save(sess, checkpoint_prefix, global_step=current_step)
        print("Saved model checkpoint to {}\n".format(path))

模型的训练过程feed_dict部分,值得一提的是,程序之前的代码是将测试集全部送入模型,会导致OOM,因此也是用batch的方法,计算累计loss和acc,在计算平均值。

训练过程

image.png

image.png

由于时间问题,训练到近拟合就停止了,模型基本收敛

image.png

最终训练过程准确率0.91,测试集准确率0.88,作为任务的baseline。

模型的计算graph如下:

image.png

模型全部代码如下:

dataPreprocessing.py



import codecs
import jieba
from tqdm import tqdm

writer = codecs.open("train.txt",'w','UTF-8')
lines = codecs.open("aliyunSpider.txt","r",'UTF-8').readlines()
lines = list(set(lines))
print(len(lines))
count0 = 0
count1 = 0
for line in tqdm(lines):
    try:
        content = eval(line.strip())["content"].strip()
        if len(content) > 20 and count0 <= 50000:
            if eval(line.strip())["type"] == "系统":
                count0 += 1
                print(count0)
                writer.write("{}<split>1".format(' '.join(list(jieba.cut(content)))))
                writer.write("\n")
            elif eval(line.strip())["type"] == "应用" and  count1 <= 50000:
                count1 += 1
                print(count1)
                writer.write("{}<split>0".format(' '.join(list(jieba.cut(content)))))
                writer.write("\n")
    except Exception as e:
        pass

text_cnn.py



import codecs
import jieba
from tqdm import tqdm

writer = codecs.open("train.txt",'w','UTF-8')
lines = codecs.open("aliyunSpider.txt","r",'UTF-8').readlines()
lines = list(set(lines))
print(len(lines))
count0 = 0
count1 = 0
for line in tqdm(lines):
    try:
        content = eval(line.strip())["content"].strip()
        if len(content) > 20 and count0 <= 50000:
            if eval(line.strip())["type"] == "系统":
                count0 += 1
                print(count0)
                writer.write("{}<split>1".format(' '.join(list(jieba.cut(content)))))
                writer.write("\n")
            elif eval(line.strip())["type"] == "应用" and  count1 <= 50000:
                count1 += 1
                print(count1)
                writer.write("{}<split>0".format(' '.join(list(jieba.cut(content)))))
                writer.write("\n")
    except Exception as e:
        pass

data_helpers.py

import numpy as np
import re
import itertools
from collections import Counter


def clean_str(string):
    """
    Tokenization/string cleaning for all datasets except for SST.
    Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
    """
    string = re.sub(r"[^A-Za-z0-9(),!?'`]", " ", string)
    string = re.sub(r"'s", " 's", string)
    string = re.sub(r"'ve", " 've", string)
    string = re.sub(r"n't", " n't", string)
    string = re.sub(r"'re", " 're", string)
    string = re.sub(r"'d", " 'd", string)
    string = re.sub(r"'ll", " 'll", string)
    string = re.sub(r",", " , ", string)
    string = re.sub(r"!", " ! ", string)
    string = re.sub(r"(", " ( ", string)
    string = re.sub(r")", " ) ", string)
    string = re.sub(r"?", " ? ", string)
    string = re.sub(r"\s{2,}", " ", string)
    return string.strip().lower()


def load_data_and_labels(positive_data_file, negative_data_file):
    """
    Loads MR polarity data from files, splits the data into words and generates labels.
    Returns split sentences and labels.
    """
    # Load data from files
    positive_examples = list(open(positive_data_file, "r",encoding='UTF-8').readlines()[:60000])
    positive_examples = [s.strip() for s in positive_examples]
    negative_examples = list(open(negative_data_file, "r",encoding='utf-8').readlines()[:60000])
    negative_examples = [s.strip() for s in negative_examples]
    # Split by words
    x_text = positive_examples + negative_examples
    x_text = [clean_str(sent) for sent in x_text]
    # Generate labels
    positive_labels = [[0, 1] for _ in positive_examples]
    negative_labels = [[1, 0] for _ in negative_examples]
    y = np.concatenate([positive_labels, negative_labels], 0)
    return [x_text, y]


def batch_iter(data, batch_size, num_epochs, shuffle=True):
    """
    Generates a batch iterator for a dataset.
    """
    data = np.array(data)
    data_size = len(data)
    num_batches_per_epoch = int((len(data)-1)/batch_size) + 1
    for epoch in range(num_epochs):
        # Shuffle the data at each epoch
        if shuffle:
            shuffle_indices = np.random.permutation(np.arange(data_size))
            shuffled_data = data[shuffle_indices]
        else:
            shuffled_data = data
        for batch_num in range(num_batches_per_epoch):
            start_index = batch_num * batch_size
            end_index = min((batch_num + 1) * batch_size, data_size)
            yield shuffled_data[start_index:end_index]

train.py

import datetime
import os
import time
import numpy as np
import tensorflow as tf
from data_helpers import batch_iter
from text_cnn import TextCNN
import pickle
import codecs

writer = codecs.open('dev_res.txt','w','UTF-8')

def pickle_reader(inputs):
    f = open(inputs, 'rb')
    lines = pickle.load(f)
    f.close()
    print("Finish load {}".format(inputs))
    return lines


def pickle_writer(inputs, name):
    output = open(name, 'wb')
    pickle.dump(inputs, output, protocol=2)
    output.close()
    print("Finish save {}".format(name))

tf.flags.DEFINE_string("filter_sizes", "3,4,5", "Comma-separated filter sizes (default: '3,4,5')")
tf.flags.DEFINE_integer("num_filters", 128, "Number of filters per filter size (default: 128)")
tf.flags.DEFINE_float("dropout_keep_prob", 0.5, "Dropout keep probability (default: 0.5)")
tf.flags.DEFINE_float("l2_reg_lambda", 0.0, "L2 regularization lambda (default: 0.0)")

# Training parameters
tf.flags.DEFINE_integer("batch_size", 32, "Batch Size (default: 64)")
tf.flags.DEFINE_integer("num_epochs", 40, "Number of training epochs (default: 200)")
tf.flags.DEFINE_integer("evaluate_every", 2000, "Evaluate model on dev set after this many steps (default: 100)")
tf.flags.DEFINE_integer("checkpoint_every", 1000, "Save model after this many steps (default: 100)")
tf.flags.DEFINE_integer("num_checkpoints", 2, "Number of checkpoints to store (default: 5)")
# Misc Parameters
tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")

FLAGS = tf.flags.FLAGS

# FLAGS._parse_flags()
print("\nParameters:")
for attr, value in sorted(FLAGS.__flags.items()):
    print("{}={}".format(attr.upper(), value))
print("")


# Data Preparation
# ==================================================
train_data = codecs.open(r'E:\TextCNN\data\VlunData\train.txt','r','UTF-8').readlines()
np.random.shuffle(train_data)
urls = []
labels = []
index = int(0.8 * float(len(train_data)))
max_len = 0
for line in train_data:
    url = line.strip().split('<split>')[0].split()
    if len(url) > max_len:
        max_len = len(url)
    label = line.strip().split('<split>')[1]
    urls.append(url)
    labels.append(label)

print('文本最大长度为 {}'.format(max_len))
vocabulary = {'UNK':0,'PAD':1}
for url in urls:
    for word in url:
        if word in vocabulary:
            pass
        else:
            vocabulary[word] = len(vocabulary)
pickle_writer(vocabulary,'vocabulary.pkl')

def convert2id(inputs):
    res = []
    for line in inputs:
        ids = [1 for _ in range(max_len)]
        for index, word in enumerate(line):
            if word in vocabulary:
                ids[index]  = vocabulary[word]
            else:
                ids[index] = 0
        res.append(ids)
    return res


def convert2label(inputs):
    res = []
    for index in inputs:
        if index == '0':
            res.append([0,1])
        else:
            res.append([1,0])
    return res

train_url, train_label = convert2id(urls[:index]), convert2label(labels[:index])
dev_url, dev_label = convert2id(urls[index:]), convert2label(labels[index:])
# Training
# ==================================================

with tf.Graph().as_default():
    session_conf = tf.ConfigProto(
      allow_soft_placement=FLAGS.allow_soft_placement,
      log_device_placement=FLAGS.log_device_placement)
    sess = tf.Session(config=session_conf)
    with sess.as_default():
        cnn = TextCNN(sequence_length=max_len,
            num_classes=2,
            vocab_size=len(vocabulary),
            embedding_size=300,
            filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))),
            num_filters=FLAGS.num_filters,
            l2_reg_lambda=FLAGS.l2_reg_lambda)

        # Define Training procedure
        global_step = tf.Variable(0, name="global_step", trainable=False)
        optimizer = tf.train.AdamOptimizer(1e-3)
        grads_and_vars = optimizer.compute_gradients(cnn.loss)
        train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)

        # Keep track of gradient values and sparsity (optional)
        grad_summaries = []
        for g, v in grads_and_vars:
            if g is not None:
                grad_hist_summary = tf.summary.histogram("{}/grad/hist".format(v.name), g)
                sparsity_summary = tf.summary.scalar("{}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
                grad_summaries.append(grad_hist_summary)
                grad_summaries.append(sparsity_summary)
        grad_summaries_merged = tf.summary.merge(grad_summaries)

        # Output directory for models and summaries
        timestamp = str(int(time.time()))
        out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
        print("Writing to {}\n".format(out_dir))

        # Summaries for loss and accuracy
        loss_summary = tf.summary.scalar("loss", cnn.loss)
        acc_summary = tf.summary.scalar("accuracy", cnn.accuracy)

        # Train Summaries
        train_summary_op = tf.summary.merge([loss_summary, acc_summary, grad_summaries_merged])
        train_summary_dir = os.path.join(out_dir, "summaries", "train")
        train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)

        # Dev summaries
        dev_summary_op = tf.summary.merge([loss_summary, acc_summary])
        dev_summary_dir = os.path.join(out_dir, "summaries", "dev")
        dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph)

        # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
        checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
        checkpoint_prefix = os.path.join(checkpoint_dir, "model")
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints)

        # Write vocabulary
        # Initialize all variables
        sess.run(tf.global_variables_initializer())

        def train_step(x_batch, y_batch):
            """
            A single training step
            """
            feed_dict = {
              cnn.input_x: x_batch,
              cnn.input_y: y_batch,
              cnn.dropout_keep_prob: FLAGS.dropout_keep_prob
            }
            _, step, summaries, loss, accuracy = sess.run(
                [train_op, global_step, train_summary_op, cnn.loss, cnn.accuracy],
                feed_dict)
            time_str = datetime.datetime.now().isoformat()
            print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))
            train_summary_writer.add_summary(summaries, step)

        def dev_step(x_batch, y_batch, writer=None):
            """
            Evaluates model on a dev set
            """
            feed_dict = {
              cnn.input_x: x_batch,
              cnn.input_y: y_batch,
              cnn.dropout_keep_prob: 1.0
            }
            step, summaries, loss, accuracy = sess.run(
                [global_step, dev_summary_op, cnn.loss, cnn.accuracy],
                feed_dict)
            time_str = datetime.datetime.now().isoformat()
            if writer:
                writer.add_summary(summaries, step)
            return loss,accuracy

        # Generate batches
        batches = batch_iter(
            list(zip(train_url, train_label)), FLAGS.batch_size, FLAGS.num_epochs)
        # Training loop. For each batch...
        for batch in batches:
            x_batch, y_batch = zip(*batch)
            train_step(x_batch, y_batch)
            current_step = tf.train.global_step(sess, global_step)
            if current_step % FLAGS.evaluate_every == 0:
                print("\nEvaluation:")
                t_cost = 0.0
                t_acc = 0.0
                count = 0
                dev_batches = batch_iter(list(zip(dev_url, dev_label)), FLAGS.batch_size, 1)
                for dev_batch in dev_batches:
                    x_dev_batch, y_dev_batch = zip(*dev_batch)
                    loss, accuracy = dev_step(x_dev_batch, y_dev_batch, writer=dev_summary_writer)
                    t_cost += loss
                    t_acc += accuracy
                    count += 1
                print("loss {:g}, acc {:g}".format(t_cost/count, t_acc/count)+'\n')
                writer.write("loss {:g}, acc {:g}".format(t_cost/count, t_acc/count)+'\n')
                writer.flush()
            if current_step % FLAGS.checkpoint_every == 0:
                path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                print("Saved model checkpoint to {}\n".format(path))
`;-.          ___,
  `.`\_...._/`.-"`
    \        /      ,
    /()   () \    .' `-._
   |)  .    ()\  /   _.'
   \  -'-     ,; '. <
    ;.__     ,;|   > \
   / ,    / ,  |.-'.-'
  (_/    (_/ ,;|.<`
    \    ,     ;-`
     >   \    /
    (_,-'`> .'
         (_,'
         
蟹蟹~