使用PaddleNLP语义预训练模型ERNIE优化情感分析

159 阅读25分钟

使用PaddleNLP语义预训练模型ERNIE优化情感分析

注意

本项目代码需要使用GPU环境来运行:

在2017年之前,工业界和学术界对NLP文本处理依赖于序列模型Recurrent Neural Network (RNN).

图1:RNN示意图

这篇 paddlenlp.seq2vec是什么? 瞧瞧它怎么完成情感分析 教程介绍了如何使用paddlenlp.seq2vec表征文本语义,其中介绍了基本的BOW网络,经典的RNN/CNN网络等。

近年来随着深度学习的发展,模型参数数量飞速增长,为了训练这些参数,需要更大的数据集来避免过拟合。然而,对于大部分NLP任务来说,构建大规模的标注数据集成本过高,非常困难,特别是对于句法和语义相关的任务。相比之下,大规模的未标注语料库的构建则相对容易。最近的研究表明,基于大规模未标注语料库的预训练模型(Pretrained Models, PTM) 能够习得通用的语言表示,将预训练模型Fine-tune到下游任务,能够获得出色的表现。另外,预训练模型能够避免从零开始训练模型。

图2:预训练模型一览,图片来源:github.com/thunlp/PLMp…

本示例展示了以ERNIE(Enhanced Representation through Knowledge Integration)为代表的预训练模型如何Finetune完成中文情感分析任务。

AI Studio平台默认安装了Paddle和PaddleNLP,并定期更新版本。 如需手动更新Paddle,可参考飞桨安装说明,安装相应环境下最新版飞桨框架。

使用如下命令确保安装最新版PaddleNLP:

In [2]

!pip install --upgrade paddlenlp==3.0.0b0
Looking in indexes: https://mirror.baidu.com/pypi/simple/, https://mirrors.aliyun.com/pypi/simple/
WARNING: Skipping page https://mirror.baidu.com/pypi/simple/paddlenlp/ because the GET request got Content-Type: application/octet-stream. The only supported Content-Types are application/vnd.pypi.simple.v1+json, application/vnd.pypi.simple.v1+html, and text/html
Collecting paddlenlp==3.0.0b0
  Downloading https://mirrors.aliyun.com/pypi/packages/da/3f/47ab4185d6b78702fc1770d3aa76c617469db1104583824136a697d53e03/paddlenlp-3.0.0b0-py3-none-any.whl (2.6 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.6/2.6 MB 381.4 kB/s eta 0:00:0000:0100:01
Requirement already satisfied: jieba in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from paddlenlp==3.0.0b0) (0.42.1)
Requirement already satisfied: colorlog in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from paddlenlp==3.0.0b0) (6.8.2)
Requirement already satisfied: colorama in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from paddlenlp==3.0.0b0) (0.4.6)
Requirement already satisfied: seqeval in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from paddlenlp==3.0.0b0) (1.2.2)
Requirement already satisfied: dill<0.3.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from paddlenlp==3.0.0b0) (0.3.4)
Requirement already satisfied: multiprocess<=0.70.12.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from paddlenlp==3.0.0b0) (0.70.12.2)
Requirement already satisfied: datasets>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from paddlenlp==3.0.0b0) (2.20.0)
Requirement already satisfied: tqdm in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from paddlenlp==3.0.0b0) (4.66.4)
Requirement already satisfied: paddlefsl in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from paddlenlp==3.0.0b0) (1.1.0)
Requirement already satisfied: sentencepiece in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from paddlenlp==3.0.0b0) (0.2.0)
Requirement already satisfied: huggingface-hub>=0.19.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from paddlenlp==3.0.0b0) (0.23.4)
Requirement already satisfied: onnx>=1.10.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from paddlenlp==3.0.0b0) (1.16.1)
Requirement already satisfied: paddle2onnx in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from paddlenlp==3.0.0b0) (1.2.4)
Requirement already satisfied: Flask-Babel in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from paddlenlp==3.0.0b0) (4.0.0)
Requirement already satisfied: visualdl in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from paddlenlp==3.0.0b0) (2.5.3)
Requirement already satisfied: fastapi in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from paddlenlp==3.0.0b0) (0.111.0)
Requirement already satisfied: uvicorn in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from paddlenlp==3.0.0b0) (0.30.1)
Requirement already satisfied: typer in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from paddlenlp==3.0.0b0) (0.12.3)
Requirement already satisfied: rich in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from paddlenlp==3.0.0b0) (13.7.1)
Requirement already satisfied: safetensors in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from paddlenlp==3.0.0b0) (0.4.3)
Requirement already satisfied: aistudio-sdk>=0.1.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from paddlenlp==3.0.0b0) (0.2.4)
Requirement already satisfied: jinja2 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from paddlenlp==3.0.0b0) (3.1.4)
Requirement already satisfied: regex in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from paddlenlp==3.0.0b0) (2024.5.15)
Requirement already satisfied: numpy<=1.26.4 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from paddlenlp==3.0.0b0) (1.26.4)
Requirement already satisfied: protobuf>=3.20.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from paddlenlp==3.0.0b0) (3.20.3)
Requirement already satisfied: tool-helpers in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from paddlenlp==3.0.0b0) (0.1.1)
Requirement already satisfied: requests in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from aistudio-sdk>=0.1.3->paddlenlp==3.0.0b0) (2.32.3)
Requirement already satisfied: bce-python-sdk in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from aistudio-sdk>=0.1.3->paddlenlp==3.0.0b0) (0.9.17)
Requirement already satisfied: prettytable in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from aistudio-sdk>=0.1.3->paddlenlp==3.0.0b0) (3.10.0)
Requirement already satisfied: click in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from aistudio-sdk>=0.1.3->paddlenlp==3.0.0b0) (8.1.7)
Requirement already satisfied: filelock in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from datasets>=2.0.0->paddlenlp==3.0.0b0) (3.15.4)
Requirement already satisfied: pyarrow>=15.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from datasets>=2.0.0->paddlenlp==3.0.0b0) (16.1.0)
Requirement already satisfied: pyarrow-hotfix in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from datasets>=2.0.0->paddlenlp==3.0.0b0) (0.6)
Requirement already satisfied: pandas in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from datasets>=2.0.0->paddlenlp==3.0.0b0) (2.2.2)
Requirement already satisfied: xxhash in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from datasets>=2.0.0->paddlenlp==3.0.0b0) (3.4.1)
Requirement already satisfied: fsspec<=2024.5.0,>=2023.1.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from fsspec[http]<=2024.5.0,>=2023.1.0->datasets>=2.0.0->paddlenlp==3.0.0b0) (2024.5.0)
Requirement already satisfied: aiohttp in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from datasets>=2.0.0->paddlenlp==3.0.0b0) (3.9.5)
Requirement already satisfied: packaging in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from datasets>=2.0.0->paddlenlp==3.0.0b0) (24.1)
Requirement already satisfied: pyyaml>=5.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from datasets>=2.0.0->paddlenlp==3.0.0b0) (6.0.1)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from huggingface-hub>=0.19.2->paddlenlp==3.0.0b0) (4.12.2)
Requirement already satisfied: starlette<0.38.0,>=0.37.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from fastapi->paddlenlp==3.0.0b0) (0.37.2)
Requirement already satisfied: pydantic!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,!=2.1.0,<3.0.0,>=1.7.4 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from fastapi->paddlenlp==3.0.0b0) (2.8.2)
Requirement already satisfied: fastapi-cli>=0.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from fastapi->paddlenlp==3.0.0b0) (0.0.4)
Requirement already satisfied: httpx>=0.23.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from fastapi->paddlenlp==3.0.0b0) (0.27.0)
Requirement already satisfied: python-multipart>=0.0.7 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from fastapi->paddlenlp==3.0.0b0) (0.0.9)
Requirement already satisfied: ujson!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0,>=4.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from fastapi->paddlenlp==3.0.0b0) (5.10.0)
Requirement already satisfied: orjson>=3.2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from fastapi->paddlenlp==3.0.0b0) (3.10.6)
Requirement already satisfied: email_validator>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from fastapi->paddlenlp==3.0.0b0) (2.2.0)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from jinja2->paddlenlp==3.0.0b0) (2.1.5)
Requirement already satisfied: h11>=0.8 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from uvicorn->paddlenlp==3.0.0b0) (0.14.0)
Requirement already satisfied: Babel>=2.12 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from Flask-Babel->paddlenlp==3.0.0b0) (2.15.0)
Requirement already satisfied: Flask>=2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from Flask-Babel->paddlenlp==3.0.0b0) (3.0.3)
Requirement already satisfied: pytz>=2022.7 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from Flask-Babel->paddlenlp==3.0.0b0) (2024.1)
Requirement already satisfied: onnxruntime>=1.10.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from paddle2onnx->paddlenlp==3.0.0b0) (1.18.1)
Requirement already satisfied: markdown-it-py>=2.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from rich->paddlenlp==3.0.0b0) (2.2.0)
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from rich->paddlenlp==3.0.0b0) (2.18.0)
Requirement already satisfied: scikit-learn>=0.21.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from seqeval->paddlenlp==3.0.0b0) (1.5.1)
Requirement already satisfied: pybind11 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from tool-helpers->paddlenlp==3.0.0b0) (2.13.1)
Requirement already satisfied: shellingham>=1.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from typer->paddlenlp==3.0.0b0) (1.5.4)
Requirement already satisfied: Pillow>=7.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from visualdl->paddlenlp==3.0.0b0) (10.4.0)
Requirement already satisfied: six>=1.14.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from visualdl->paddlenlp==3.0.0b0) (1.16.0)
Requirement already satisfied: matplotlib in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from visualdl->paddlenlp==3.0.0b0) (3.9.1)
Requirement already satisfied: rarfile in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from visualdl->paddlenlp==3.0.0b0) (4.2)
Requirement already satisfied: psutil in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from visualdl->paddlenlp==3.0.0b0) (6.0.0)
Requirement already satisfied: dnspython>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from email_validator>=2.0.0->fastapi->paddlenlp==3.0.0b0) (2.6.1)
Requirement already satisfied: idna>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from email_validator>=2.0.0->fastapi->paddlenlp==3.0.0b0) (3.7)
Requirement already satisfied: Werkzeug>=3.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from Flask>=2.0->Flask-Babel->paddlenlp==3.0.0b0) (3.0.3)
Requirement already satisfied: itsdangerous>=2.1.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from Flask>=2.0->Flask-Babel->paddlenlp==3.0.0b0) (2.2.0)
Requirement already satisfied: blinker>=1.6.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from Flask>=2.0->Flask-Babel->paddlenlp==3.0.0b0) (1.8.2)
Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from aiohttp->datasets>=2.0.0->paddlenlp==3.0.0b0) (1.3.1)
Requirement already satisfied: attrs>=17.3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from aiohttp->datasets>=2.0.0->paddlenlp==3.0.0b0) (23.2.0)
Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from aiohttp->datasets>=2.0.0->paddlenlp==3.0.0b0) (1.4.1)
Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from aiohttp->datasets>=2.0.0->paddlenlp==3.0.0b0) (6.0.5)
Requirement already satisfied: yarl<2.0,>=1.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from aiohttp->datasets>=2.0.0->paddlenlp==3.0.0b0) (1.9.4)
Requirement already satisfied: async-timeout<5.0,>=4.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from aiohttp->datasets>=2.0.0->paddlenlp==3.0.0b0) (4.0.3)
Requirement already satisfied: anyio in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from httpx>=0.23.0->fastapi->paddlenlp==3.0.0b0) (4.4.0)
Requirement already satisfied: certifi in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from httpx>=0.23.0->fastapi->paddlenlp==3.0.0b0) (2024.7.4)
Requirement already satisfied: httpcore==1.* in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from httpx>=0.23.0->fastapi->paddlenlp==3.0.0b0) (1.0.5)
Requirement already satisfied: sniffio in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from httpx>=0.23.0->fastapi->paddlenlp==3.0.0b0) (1.3.1)
Requirement already satisfied: mdurl~=0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich->paddlenlp==3.0.0b0) (0.1.2)
Requirement already satisfied: coloredlogs in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from onnxruntime>=1.10.0->paddle2onnx->paddlenlp==3.0.0b0) (15.0.1)
Requirement already satisfied: flatbuffers in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from onnxruntime>=1.10.0->paddle2onnx->paddlenlp==3.0.0b0) (24.3.25)
Requirement already satisfied: sympy in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from onnxruntime>=1.10.0->paddle2onnx->paddlenlp==3.0.0b0) (1.12.1)
Requirement already satisfied: annotated-types>=0.4.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from pydantic!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,!=2.1.0,<3.0.0,>=1.7.4->fastapi->paddlenlp==3.0.0b0) (0.7.0)
Requirement already satisfied: pydantic-core==2.20.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from pydantic!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,!=2.1.0,<3.0.0,>=1.7.4->fastapi->paddlenlp==3.0.0b0) (2.20.1)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from requests->aistudio-sdk>=0.1.3->paddlenlp==3.0.0b0) (3.3.2)
Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from requests->aistudio-sdk>=0.1.3->paddlenlp==3.0.0b0) (2.2.2)
Requirement already satisfied: scipy>=1.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from scikit-learn>=0.21.3->seqeval->paddlenlp==3.0.0b0) (1.14.0)
Requirement already satisfied: joblib>=1.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from scikit-learn>=0.21.3->seqeval->paddlenlp==3.0.0b0) (1.4.2)
Requirement already satisfied: threadpoolctl>=3.1.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from scikit-learn>=0.21.3->seqeval->paddlenlp==3.0.0b0) (3.5.0)
Requirement already satisfied: httptools>=0.5.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from uvicorn[standard]>=0.12.0->fastapi->paddlenlp==3.0.0b0) (0.6.1)
Requirement already satisfied: python-dotenv>=0.13 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from uvicorn[standard]>=0.12.0->fastapi->paddlenlp==3.0.0b0) (1.0.1)
Requirement already satisfied: uvloop!=0.15.0,!=0.15.1,>=0.14.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from uvicorn[standard]>=0.12.0->fastapi->paddlenlp==3.0.0b0) (0.19.0)
Requirement already satisfied: watchfiles>=0.13 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from uvicorn[standard]>=0.12.0->fastapi->paddlenlp==3.0.0b0) (0.22.0)
Requirement already satisfied: websockets>=10.4 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from uvicorn[standard]>=0.12.0->fastapi->paddlenlp==3.0.0b0) (11.0.3)
Requirement already satisfied: pycryptodome>=3.8.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from bce-python-sdk->aistudio-sdk>=0.1.3->paddlenlp==3.0.0b0) (3.20.0)
Requirement already satisfied: future>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from bce-python-sdk->aistudio-sdk>=0.1.3->paddlenlp==3.0.0b0) (1.0.0)
Requirement already satisfied: contourpy>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from matplotlib->visualdl->paddlenlp==3.0.0b0) (1.2.1)
Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from matplotlib->visualdl->paddlenlp==3.0.0b0) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from matplotlib->visualdl->paddlenlp==3.0.0b0) (4.53.0)
Requirement already satisfied: kiwisolver>=1.3.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from matplotlib->visualdl->paddlenlp==3.0.0b0) (1.4.5)
Requirement already satisfied: pyparsing>=2.3.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from matplotlib->visualdl->paddlenlp==3.0.0b0) (3.1.2)
Requirement already satisfied: python-dateutil>=2.7 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from matplotlib->visualdl->paddlenlp==3.0.0b0) (2.9.0.post0)
Requirement already satisfied: tzdata>=2022.7 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from pandas->datasets>=2.0.0->paddlenlp==3.0.0b0) (2024.1)
Requirement already satisfied: wcwidth in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from prettytable->aistudio-sdk>=0.1.3->paddlenlp==3.0.0b0) (0.2.13)
Requirement already satisfied: exceptiongroup>=1.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from anyio->httpx>=0.23.0->fastapi->paddlenlp==3.0.0b0) (1.2.1)
Requirement already satisfied: humanfriendly>=9.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from coloredlogs->onnxruntime>=1.10.0->paddle2onnx->paddlenlp==3.0.0b0) (10.0)
Requirement already satisfied: mpmath<1.4.0,>=1.1.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages (from sympy->onnxruntime>=1.10.0->paddle2onnx->paddlenlp==3.0.0b0) (1.3.0)
Installing collected packages: paddlenlp
  Attempting uninstall: paddlenlp
    Found existing installation: paddlenlp 2.8.1
    Uninstalling paddlenlp-2.8.1:
      Successfully uninstalled paddlenlp-2.8.1
Successfully installed paddlenlp-3.0.0b0

加载数据集

以公开中文情感分析数据集ChnSenticorp为例。PaddleNLP已经内置该数据集,一键即可加载。

In [3]

import paddlenlp as ppnlp
from paddlenlp.datasets import load_dataset

train_ds, dev_ds, test_ds = load_dataset(
    "chnsenticorp", splits=["train", "dev", "test"])

print(train_ds.label_list)

for data in train_ds.data[:5]:
    print(data)
/opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
/opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/_distutils_hack/__init__.py:26: UserWarning: Setuptools is replacing distutils.
  warnings.warn("Setuptools is replacing distutils.")
[2024-07-26 17:26:07,651] [ WARNING] - if you run ring_flash_attention.py, please ensure you install the paddlenlp_ops by following the instructions provided at https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md
100%|██████████| 1909/1909 [00:00<00:00, 36619.16it/s]
['0', '1']
{'text': '选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般', 'label': 1, 'qid': ''}
{'text': '15.4寸笔记本的键盘确实爽,基本跟台式机差不多了,蛮喜欢数字小键盘,输数字特方便,样子也很美观,做工也相当不错', 'label': 1, 'qid': ''}
{'text': '房间太小。其他的都一般。。。。。。。。。', 'label': 0, 'qid': ''}
{'text': '1.接电源没有几分钟,电源适配器热的不行. 2.摄像头用不起来. 3.机盖的钢琴漆,手不能摸,一摸一个印. 4.硬盘分区不好办.', 'label': 0, 'qid': ''}
{'text': '今天才知道这书还有第6卷,真有点郁闷:为什么同一套书有两种版本呢?当当网是不是该跟出版社商量商量,单独出个第6卷,让我们的孩子不会有所遗憾。', 'label': 1, 'qid': ''}

每条数据包含一句评论和对应的标签,0或1。0代表负向评论,1代表正向评论。

之后,还需要对输入句子进行数据处理,如切词,映射词表id等。

调用ppnlp.transformers.ErnieTokenizer进行数据处理

预训练模型ERNIE对中文数据的处理是以字为单位。PaddleNLP对于各种预训练模型已经内置了相应的tokenizer。指定想要使用的模型名字即可加载对应的tokenizer。

tokenizer作用为将原始输入文本转化成模型model可以接受的输入数据形式。

图3:ERNIE模型框架示意图

In [4]

# 设置想要使用模型的名称
MODEL_NAME = "ernie-1.0"

tokenizer = ppnlp.transformers.ErnieTokenizer.from_pretrained(MODEL_NAME)
ernie_model = ppnlp.transformers.ErnieModel.from_pretrained(MODEL_NAME)
(…)enlp/models/transformers/ernie/vocab.txt: 100%|██████████| 91.6k/91.6k [00:00<00:00, 7.74MB/s]
[2024-07-26 17:26:17,644] [    INFO] - tokenizer config file saved in /home/aistudio/.paddlenlp/models/ernie-1.0/tokenizer_config.json
[2024-07-26 17:26:17,645] [    INFO] - Special tokens file saved in /home/aistudio/.paddlenlp/models/ernie-1.0/special_tokens_map.json
(…)formers/ernie/ernie_v1_chn_base.pdparams: 100%|██████████| 402M/402M [00:12<00:00, 31.5MB/s] 
[2024-07-26 17:26:30,442] [    INFO] - Loading weights file from cache at /home/aistudio/.paddlenlp/models/ernie-1.0/model_state.pdparams
[2024-07-26 17:26:30,904] [    INFO] - Loaded weights file from disk, setting weights to model.
W0726 17:26:30.910248   260 gpu_resources.cc:119] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 12.0, Runtime API Version: 11.8
W0726 17:26:30.912489   260 gpu_resources.cc:164] device: 0, cuDNN Version: 8.9.
[2024-07-26 17:26:36,480] [ WARNING] - Some weights of the model checkpoint at ernie-1.0 were not used when initializing ErnieModel: ['cls.predictions.decoder_bias', 'cls.predictions.layer_norm.bias', 'cls.predictions.layer_norm.weight', 'cls.predictions.transform.bias', 'cls.predictions.transform.weight']
- This IS expected if you are initializing ErnieModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ErnieModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
[2024-07-26 17:26:36,481] [    INFO] - All the weights of ErnieModel were initialized from the model checkpoint at ernie-1.0.
If your task is similar to the task the model of the checkpoint was trained on, you can already use ErnieModel for predictions without further training.

In [5]

import paddle

# 将原始输入文本切分token,
tokens = tokenizer._tokenize("请输入测试样例")
print("Tokens: {}".format(tokens))

# token映射为对应token id
tokens_ids = tokenizer.convert_tokens_to_ids(tokens)
print("Tokens id: {}".format(tokens_ids))


# 拼接上预训练模型对应的特殊token ,如[CLS]、[SEP]
tokens_ids = tokenizer.build_inputs_with_special_tokens(tokens_ids)

# 转化成paddle框架数据格式
tokens_pd = paddle.to_tensor([tokens_ids])
print("Tokens : {}".format(tokens_pd))

# 此时即可输入ERNIE模型中得到相应输出
sequence_output, pooled_output = ernie_model(tokens_pd)
print("Token wise output: {}, Pooled output: {}".format(sequence_output.shape, pooled_output.shape))
Tokens: ['请', '输', '入', '测', '试', '样', '例']
Tokens id: [647, 789, 109, 558, 525, 314, 656]
Tokens : Tensor(shape=[1, 9], dtype=int64, place=Place(gpu:0), stop_gradient=True,
       [[1  , 647, 789, 109, 558, 525, 314, 656, 2  ]])
Token wise output: [1, 9, 768], Pooled output: [1, 768]

从以上代码可以看出,ERNIE模型输出有2个tensor。

  • sequence_output是对应每个输入token的语义特征表示,shape为(1, num_tokens, hidden_size)。其一般用于序列标注、问答等任务。
  • pooled_output是对应整个句子的语义特征表示,shape为(1, hidden_size)。其一般用于文本分类、信息检索等任务。

NOTE:

如需使用ernie-tiny预训练模型,则对应的tokenizer应该使用paddlenlp.transformers.ErnieTinyTokenizer.from_pretrained('ernie-tiny')

以上代码示例展示了使用Transformer类预训练模型所需的数据处理步骤。为了更方便地使用,PaddleNLP同时提供了更加高阶API,一键即可返回模型所需数据格式。

In [6]

# 一行代码完成切分token,映射token ID以及拼接特殊token
encoded_text = tokenizer(text="请输入测试样例")
for key, value in encoded_text.items():
    print("{}:\n\t{}".format(key, value))

# 转化成paddle框架数据格式
input_ids = paddle.to_tensor([encoded_text['input_ids']])
print("input_ids : {}".format(input_ids))
segment_ids = paddle.to_tensor([encoded_text['token_type_ids']])
print("token_type_ids : {}".format(segment_ids))

# 此时即可输入ERNIE模型中得到相应输出
sequence_output, pooled_output = ernie_model(input_ids, segment_ids)
print("Token wise output: {}, Pooled output: {}".format(sequence_output.shape, pooled_output.shape))
input_ids:
	[1, 647, 789, 109, 558, 525, 314, 656, 2]
token_type_ids:
	[0, 0, 0, 0, 0, 0, 0, 0, 0]
input_ids : Tensor(shape=[1, 9], dtype=int64, place=Place(gpu:0), stop_gradient=True,
       [[1  , 647, 789, 109, 558, 525, 314, 656, 2  ]])
token_type_ids : Tensor(shape=[1, 9], dtype=int64, place=Place(gpu:0), stop_gradient=True,
       [[0, 0, 0, 0, 0, 0, 0, 0, 0]])
Token wise output: [1, 9, 768], Pooled output: [1, 768]

由以上代码可以见,tokenizer提供了一种非常便利的方式生成模型所需的数据格式。

以上,

  • input_ids: 表示输入文本的token ID。
  • segment_ids: 表示对应的token属于输入的第一个句子还是第二个句子。(Transformer类预训练模型支持单句以及句对输入。)详细参见左侧utils.py文件中convert_example()函数解释。
  • seq_len: 表示输入句子的token个数。
  • input_mask:表示对应的token是否一个padding token。由于一个batch中的输入句子长度不同,所以需要将不同长度的句子padding到统一固定长度。1表示真实输入,0表示对应token为padding token。
  • position_ids: 表示对应token在整个输入序列中的位置。

In [7]

# 单句输入
single_seg_input = tokenizer(text="请输入测试样例")
# 句对输入
multi_seg_input = tokenizer(text="请输入测试样例1", text_pair="请输入测试样例2")

print("单句输入token (str): {}".format(tokenizer.convert_ids_to_tokens(single_seg_input['input_ids'])))
print("单句输入token (int): {}".format(single_seg_input['input_ids']))
print("单句输入segment ids : {}".format(single_seg_input['token_type_ids']))

print()
print("句对输入token (str): {}".format(tokenizer.convert_ids_to_tokens(multi_seg_input['input_ids'])))
print("句对输入token (int): {}".format(multi_seg_input['input_ids']))
print("句对输入segment ids : {}".format(multi_seg_input['token_type_ids']))
单句输入token (str): ['[CLS]', '请', '输', '入', '测', '试', '样', '例', '[SEP]']
单句输入token (int): [1, 647, 789, 109, 558, 525, 314, 656, 2]
单句输入segment ids : [0, 0, 0, 0, 0, 0, 0, 0, 0]

句对输入token (str): ['[CLS]', '请', '输', '入', '测', '试', '样', '例', '1', '[SEP]', '请', '输', '入', '测', '试', '样', '例', '2', '[SEP]']
句对输入token (int): [1, 647, 789, 109, 558, 525, 314, 656, 208, 2, 647, 789, 109, 558, 525, 314, 656, 249, 2]
句对输入segment ids : [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]

In [8]

# Highlight: padding到统一长度
encoded_text = tokenizer(text="请输入测试样例",  max_seq_len=15)

for key, value in encoded_text.items():
    print("{}:\n\t{}".format(key, value))
input_ids:
	[1, 647, 789, 109, 558, 525, 314, 656, 2]
token_type_ids:
	[0, 0, 0, 0, 0, 0, 0, 0, 0]
/opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/paddlenlp/transformers/tokenizer_utils_base.py:2331: FutureWarning: The `max_seq_len` argument is deprecated and will be removed in a future version, please use `max_length` instead.
  warnings.warn(
/opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/paddlenlp/transformers/tokenizer_utils_base.py:1903: UserWarning: Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
  warnings.warn(

以上代码示例详细介绍了tokenizer的用法。

接下来使用tokenzier处理ChnSentiCorp数据集。

数据读入

使用paddle.io.DataLoader接口多线程异步加载数据。

In [9]

from functools import partial
from paddlenlp.data import Stack, Tuple, Pad
from utils import  convert_example, create_dataloader

# 模型运行批处理大小
batch_size = 32
max_seq_length = 128

trans_func = partial(
    convert_example,
    tokenizer=tokenizer,
    max_seq_length=max_seq_length)
batchify_fn = lambda samples, fn=Tuple(
    Pad(axis=0, pad_val=tokenizer.pad_token_id),  # input
    Pad(axis=0, pad_val=tokenizer.pad_token_type_id),  # segment
    Stack(dtype="int64")  # label
): [data for data in fn(samples)]
train_data_loader = create_dataloader(
    train_ds,
    mode='train',
    batch_size=batch_size,
    batchify_fn=batchify_fn,
    trans_fn=trans_func)
dev_data_loader = create_dataloader(
    dev_ds,
    mode='dev',
    batch_size=batch_size,
    batchify_fn=batchify_fn,
    trans_fn=trans_func)

PaddleNLP一键加载预训练模型

情感分析本质是一个文本分类任务,PaddleNLP对于各种预训练模型已经内置了对于下游任务-文本分类的Fine-tune网络。以下教程ERNIE为例,介绍如何将预训练模型Fine-tune完成文本分类任务。

paddlenlp.transformers.ErnieModel()一行代码即可加载预训练模型ERNIE。

paddlenlp.transformers.ErnieForSequenceClassification()一行代码即可加载预训练模型ERNIE用于文本分类任务的Fine-tune网络。

其在ERNIE模型后拼接上一个全连接网络(Full Connected)进行分类。

paddlenlp.transformers.ErnieForSequenceClassification.from_pretrained() 只需指定想要使用的模型名称和文本分类的类别数即可完成网络定义。

PaddleNLP不仅支持ERNIE预训练模型,还支持BERT、RoBERTa、Electra等预训练模型,可跳转到文末了解更多。

In [10]

ernie_model = ppnlp.transformers.ErnieModel.from_pretrained(MODEL_NAME)

model = ppnlp.transformers.ErnieForSequenceClassification.from_pretrained(MODEL_NAME, num_classes=len(train_ds.label_list))
[2024-07-26 17:27:05,480] [    INFO] - Loading weights file from cache at /home/aistudio/.paddlenlp/models/ernie-1.0/model_state.pdparams
[2024-07-26 17:27:05,958] [    INFO] - Loaded weights file from disk, setting weights to model.
[2024-07-26 17:27:06,481] [ WARNING] - Some weights of the model checkpoint at ernie-1.0 were not used when initializing ErnieModel: ['cls.predictions.decoder_bias', 'cls.predictions.layer_norm.bias', 'cls.predictions.layer_norm.weight', 'cls.predictions.transform.bias', 'cls.predictions.transform.weight']
- This IS expected if you are initializing ErnieModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ErnieModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
[2024-07-26 17:27:06,483] [    INFO] - All the weights of ErnieModel were initialized from the model checkpoint at ernie-1.0.
If your task is similar to the task the model of the checkpoint was trained on, you can already use ErnieModel for predictions without further training.
[2024-07-26 17:27:06,492] [    INFO] - Loading weights file from cache at /home/aistudio/.paddlenlp/models/ernie-1.0/model_state.pdparams
[2024-07-26 17:27:06,886] [    INFO] - Loaded weights file from disk, setting weights to model.
[2024-07-26 17:27:07,400] [ WARNING] - Some weights of the model checkpoint at ernie-1.0 were not used when initializing ErnieForSequenceClassification: ['cls.predictions.decoder_bias', 'cls.predictions.layer_norm.bias', 'cls.predictions.layer_norm.weight', 'cls.predictions.transform.bias', 'cls.predictions.transform.weight']
- This IS expected if you are initializing ErnieForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ErnieForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
[2024-07-26 17:27:07,402] [ WARNING] - Some weights of ErnieForSequenceClassification were not initialized from the model checkpoint at ernie-1.0 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

设置Fine-Tune优化策略,接入评价指标

适用于ERNIE/BERT这类Transformer模型的学习率为warmup的动态学习率。

图4:动态学习率示意图

In [11]

from paddlenlp.transformers import LinearDecayWithWarmup

# 训练过程中的最大学习率
learning_rate = 5e-5 
# 训练轮次
epochs = 1 #3
# 学习率预热比例
warmup_proportion = 0.1
# 权重衰减系数,类似模型正则项策略,避免模型过拟合
weight_decay = 0.01

num_training_steps = len(train_data_loader) * epochs
lr_scheduler = LinearDecayWithWarmup(learning_rate, num_training_steps, warmup_proportion)
optimizer = paddle.optimizer.AdamW(
    learning_rate=lr_scheduler,
    parameters=model.parameters(),
    weight_decay=weight_decay,
    apply_decay_param_fun=lambda x: x in [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ])

criterion = paddle.nn.loss.CrossEntropyLoss()
metric = paddle.metric.Accuracy()

模型训练与评估

模型训练的过程通常有以下步骤:

  1. 从dataloader中取出一个batch data
  2. 将batch data喂给model,做前向计算
  3. 将前向计算结果传给损失函数,计算loss。将前向计算结果传给评价方法,计算评价指标。
  4. loss反向回传,更新梯度。重复以上步骤。

每训练一个epoch时,程序将会评估一次,评估当前模型训练的效果。

In [12]

# checkpoint文件夹用于保存训练模型
!mkdir /home/aistudio/checkpoint
mkdir: cannot create directory '/home/aistudio/checkpoint': File exists

In [13]

import paddle.nn.functional as F
from utils import evaluate

global_step = 0
for epoch in range(1, epochs + 1):
    for step, batch in enumerate(train_data_loader, start=1):
        input_ids, segment_ids, labels = batch
        logits = model(input_ids, segment_ids)
        loss = criterion(logits, labels)
        probs = F.softmax(logits, axis=1)
        correct = metric.compute(probs, labels)
        metric.update(correct)
        acc = metric.accumulate()

        global_step += 1
        if global_step % 10 == 0 :
            print("global step %d, epoch: %d, batch: %d, loss: %.5f, acc: %.5f" % (global_step, epoch, step, loss, acc))
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.clear_grad()
    evaluate(model, criterion, metric, dev_data_loader)

model.save_pretrained('/home/aistudio/checkpoint')
tokenizer.save_pretrained('/home/aistudio/checkpoint')
/opt/conda/envs/python35-paddle120-env/lib/python3.10/site-packages/paddlenlp/transformers/tokenizer_utils_base.py:2331: FutureWarning: The `max_seq_len` argument is deprecated and will be removed in a future version, please use `max_length` instead.
  warnings.warn(
global step 10, epoch: 1, batch: 10, loss: 0.63344, acc: 0.53438
global step 20, epoch: 1, batch: 20, loss: 0.70652, acc: 0.60625
global step 30, epoch: 1, batch: 30, loss: 0.36621, acc: 0.67188
global step 40, epoch: 1, batch: 40, loss: 0.18697, acc: 0.72734
global step 50, epoch: 1, batch: 50, loss: 0.18123, acc: 0.76125
global step 60, epoch: 1, batch: 60, loss: 0.32450, acc: 0.78333
global step 70, epoch: 1, batch: 70, loss: 0.38021, acc: 0.80089
global step 80, epoch: 1, batch: 80, loss: 0.37712, acc: 0.80898
global step 90, epoch: 1, batch: 90, loss: 0.26899, acc: 0.82083
global step 100, epoch: 1, batch: 100, loss: 0.09473, acc: 0.82937
global step 110, epoch: 1, batch: 110, loss: 0.24098, acc: 0.83381
global step 120, epoch: 1, batch: 120, loss: 0.20540, acc: 0.83828
global step 130, epoch: 1, batch: 130, loss: 0.20796, acc: 0.84447
global step 140, epoch: 1, batch: 140, loss: 0.07843, acc: 0.84821
global step 150, epoch: 1, batch: 150, loss: 0.13564, acc: 0.85396
global step 160, epoch: 1, batch: 160, loss: 0.29660, acc: 0.85801
global step 170, epoch: 1, batch: 170, loss: 0.12343, acc: 0.85993
global step 180, epoch: 1, batch: 180, loss: 0.04314, acc: 0.86319
global step 190, epoch: 1, batch: 190, loss: 0.25622, acc: 0.86530
global step 200, epoch: 1, batch: 200, loss: 0.41344, acc: 0.86687
global step 210, epoch: 1, batch: 210, loss: 0.26930, acc: 0.86979
global step 220, epoch: 1, batch: 220, loss: 0.20898, acc: 0.87273
global step 230, epoch: 1, batch: 230, loss: 0.29868, acc: 0.87554
global step 240, epoch: 1, batch: 240, loss: 0.33016, acc: 0.87695
global step 250, epoch: 1, batch: 250, loss: 0.35256, acc: 0.87800
global step 260, epoch: 1, batch: 260, loss: 0.19236, acc: 0.88005
global step 270, epoch: 1, batch: 270, loss: 0.12459, acc: 0.88241
global step 280, epoch: 1, batch: 280, loss: 0.19050, acc: 0.88382
global step 290, epoch: 1, batch: 290, loss: 0.13427, acc: 0.88556
global step 300, epoch: 1, batch: 300, loss: 0.11049, acc: 0.88698
[2024-07-26 17:28:26,747] [    INFO] - Configuration saved in /home/aistudio/checkpoint/config.json
eval loss: 0.19861, accu: 0.92250
[2024-07-26 17:28:30,408] [    INFO] - Model weights saved in /home/aistudio/checkpoint/model_state.pdparams
[2024-07-26 17:28:30,410] [    INFO] - tokenizer config file saved in /home/aistudio/checkpoint/tokenizer_config.json
[2024-07-26 17:28:30,412] [    INFO] - Special tokens file saved in /home/aistudio/checkpoint/special_tokens_map.json
('/home/aistudio/checkpoint/tokenizer_config.json',
 '/home/aistudio/checkpoint/special_tokens_map.json',
 '/home/aistudio/checkpoint/added_tokens.json')

模型预测

训练保存好的训练,即可用于预测。如以下示例代码自定义预测数据,调用predict()函数即可一键预测。

In [14]

from utils import predict

data = [
    {"text":'这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般'},
    {"text":'怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片'},
    {"text":'作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。'},
]
label_map = {0: 'negative', 1: 'positive'}

results = predict(
    model, data, tokenizer, label_map, batch_size=batch_size)
for idx, text in enumerate(data):
    print('Data: {} \t Lable: {}'.format(text, results[idx]))
Data: {'text': '这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般'} 	 Lable: negative
Data: {'text': '怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片'} 	 Lable: negative
Data: {'text': '作为老的四星酒店,房间依然很整洁,相当不错。机场接机服务很好,可以在车上办理入住手续,节省时间。'} 	 Lable: positive

PaddleNLP更多预训练模型

PaddleNLP不仅支持ERNIE预训练模型,还支持BERT、RoBERTa、Electra等预训练模型。 下表汇总了目前PaddleNLP支持的各类预训练模型。用户可以使用PaddleNLP提供的模型,完成问答、序列分类、token分类等任务。同时我们提供了22种预训练的参数权重供用户使用,其中包含了11种中文语言模型的预训练权重。

ModelTokenizerSupported TaskModel Name
ERNIEErnieTokenizer ErnieTinyTokenizerErnieModel ErnieForQuestionAnswering ErnieForSequenceClassification ErnieForTokenClassificationernie-1.0 ernie-tiny ernie-2.0-en ernie-2.0-large-en
BERTBertTokenizerBertModel BertForQuestionAnswering BertForSequenceClassification BertForTokenClassificationbert-base-uncased bert-large-uncased bert-base-multilingual-uncased bert-base-cased bert-base-chinese bert-base-multilingual-cased bert-large-cased bert-wwm-chinese bert-wwm-ext-chinese
RoBERTaRobertaTokenizerRobertaModel RobertaForQuestionAnswering RobertaForSequenceClassification RobertaForTokenClassificationroberta-wwm-ext roberta-wwm-ext-large rbt3 rbtl3
ELECTRAElectraTokenizerElectraModel ElectraForSequenceClassification ElectraForTokenClassificationelectra-small electra-base electra-large chinese-electra-small chinese-electra-base

注:其中中文的预训练模型有 bert-base-chinese, bert-wwm-chinese, bert-wwm-ext-chinese, ernie-1.0, ernie-tiny, roberta-wwm-ext, roberta-wwm-ext-large, rbt3, rbtl3, chinese-electra-base, chinese-electra-small 等。

更多预训练模型参考:github.com/PaddlePaddl… 更多预训练模型fine-tune下游任务使用方法,请参考examples

PaddleNLP更多教程