[NLP] ๋ฌธ์ฅ ์ ๋ถ๋ฅ ๋ชจ๋ธ ์ค์ ํฌ์
์, ๊ทธ๋ผ ํ์ต์ ๋ง์น ๋ชจ๋ธ์ ์ด๋ป๊ฒ ์ฌ์ฉํ ๊น?
๋ณธ ํ์ผ์ ์ด๊ธฐ์ฐฝ๋์ 'Do it! ์์ฐ์ด ์ฒ๋ฆฌ'์ ๊ธฐ์ดํ์ฌ ์์ฑ๋์์์ ๋ฏธ๋ฆฌ ์๋ ค๋๋ฆฝ๋๋ค! :)
ํ์ต ๋ง์น ๋ชจ๋ธ์ ์ค์ ํฌ์ ํ๊ธฐ
ํ์ต์ ๋ง์น ๋ฌธ์ฅ ์ ๋ถ๋ฅ ๋ชจ๋ธ์ ์ธํผ๋ฐ์คํ๋ ๊ณผ์ ์ ์ค์ตํด๋ณธ๋ค. ์ด๋ฒ ์ค์ต์์ ๋ง๋๋ ์น ์๋น์ค์ ๊ฐ๋ ๋๋ ์๋ ๊ทธ๋ฆผ 1๊ณผ ๊ฐ๋ค.

๊ทธ๋ฆผ 1. ๋ฌธ์ฅ ์ ๋ถ๋ฅ ์น ์๋น์ค
์ ์ ์ ๊ฐ์ค ๋ฌธ์ฅ์ ๋ฐ์ ๋ต๋ณํ๋ ์น ์๋น์ค์ด๋ค. ์ ์ ์ ๊ฐ์ค ๊ฐ๊ฐ์ ํ ํฐํ, ์ธ๋ฑ์ฑํ ๋ค ๋ชจ๋ธ ์ ๋ ฅ๊ฐ์ผ๋ก ๋ง๋ค๊ณ ์ด๋ฅผ ๋ชจ๋ธ์ ๋ฃ์ด
[์ ์ ์ ๋ํด ๊ฐ์ค์ด ์ฐธ์ผ ํ๋ฅ , ์ ์ ์ ๋ํด ๊ฐ์ค์ด ๊ฑฐ์ง์ผ ํ๋ฅ , ์ ์ ์ ๋ํด ๊ฐ์ค์ด ์ค๋ฆฝ์ผ ํ๋ฅ ]
์ ๊ณ์ฐํ๋ค.
์ดํ ์ฝ๊ฐ์ ํ์ฒ๋ฆฌ ๊ณผ์ ์ ๊ฑฐ์ณ ์๋ตํ๋ ๋ฐฉ์์ด๋ค.
์ ์ ์ ๊ฐ์ค์ ๊ฒ์ฆํ๋ ์น ์๋น์ค ๋ง๋ค๊ธฐ
1. ํ๊ฒฝ ์ค์ ํ๊ธฐ
์์กด์ฑ ํจํค์ง ์ค์น
pip ๋ช ๋ น์ด๋ฅผ ํตํด ์์กด์ฑ์๋ ํจํค์ง๋ฅผ ์ค์นํ๋ค.
code 4-0
!pip install ratsnlp
โถCode output
Requirement already satisfied: ratsnlp in /usr/local/lib/python3.7/dist-packages (1.0.1)
Requirement already satisfied: transformers==4.10.0 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (4.10.0)
Requirement already satisfied: Korpora>=0.2.0 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (0.2.0)
Requirement already satisfied: flask>=1.1.4 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (1.1.4)
Requirement already satisfied: flask-cors>=3.0.10 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (3.0.10)
Requirement already satisfied: torch>=1.9.0 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (1.10.0+cu111)
Requirement already satisfied: flask-ngrok>=0.0.25 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (0.0.25)
Requirement already satisfied: pytorch-lightning==1.3.4 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (1.3.4)
Requirement already satisfied: future>=0.17.1 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (0.18.2)
Requirement already satisfied: tensorboard!=2.5.0,>=2.2.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (2.8.0)
Requirement already satisfied: fsspec[http]>=2021.4.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (2022.2.0)
Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (21.3)
Requirement already satisfied: torchmetrics>=0.2.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (0.7.2)
Requirement already satisfied: numpy>=1.17.2 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (1.21.5)
Requirement already satisfied: pyDeprecate==0.3.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (0.3.0)
Requirement already satisfied: tqdm>=4.41.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (4.63.0)
Requirement already satisfied: PyYAML<=5.4.1,>=5.1 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (5.4.1)
Requirement already satisfied: huggingface-hub>=0.0.12 in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (0.4.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (3.6.0)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (4.11.2)
Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (0.10.3)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (2.23.0)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (2019.12.20)
Requirement already satisfied: sacremoses in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (0.0.47)
Requirement already satisfied: Jinja2<3.0,>=2.10.1 in /usr/local/lib/python3.7/dist-packages (from flask>=1.1.4->ratsnlp) (2.11.3)
Requirement already satisfied: Werkzeug<2.0,>=0.15 in /usr/local/lib/python3.7/dist-packages (from flask>=1.1.4->ratsnlp) (1.0.1)
Requirement already satisfied: itsdangerous<2.0,>=0.24 in /usr/local/lib/python3.7/dist-packages (from flask>=1.1.4->ratsnlp) (1.1.0)
Requirement already satisfied: click<8.0,>=5.1 in /usr/local/lib/python3.7/dist-packages (from flask>=1.1.4->ratsnlp) (7.1.2)
Requirement already satisfied: Six in /usr/local/lib/python3.7/dist-packages (from flask-cors>=3.0.10->ratsnlp) (1.15.0)
Requirement already satisfied: aiohttp in /usr/local/lib/python3.7/dist-packages (from fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (3.8.1)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub>=0.0.12->transformers==4.10.0->ratsnlp) (3.10.0.2)
Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from Jinja2<3.0,>=2.10.1->flask>=1.1.4->ratsnlp) (2.0.1)
Requirement already satisfied: dataclasses>=0.6 in /usr/local/lib/python3.7/dist-packages (from Korpora>=0.2.0->ratsnlp) (0.6)
Requirement already satisfied: xlrd>=1.2.0 in /usr/local/lib/python3.7/dist-packages (from Korpora>=0.2.0->ratsnlp) (2.0.1)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->pytorch-lightning==1.3.4->ratsnlp) (3.0.7)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.10.0->ratsnlp) (2.10)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.10.0->ratsnlp) (1.24.3)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.10.0->ratsnlp) (2021.10.8)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.10.0->ratsnlp) (3.0.4)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.4.6)
Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.6.1)
Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (57.4.0)
Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.44.0)
Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.35.0)
Requirement already satisfied: protobuf>=3.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (3.17.3)
Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.0.0)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.8.1)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (3.3.6)
Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.37.1)
Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (4.8)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.2.8)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (4.2.4)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.3.1)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers==4.10.0->ratsnlp) (3.7.0)
Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.4.8)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (3.2.0)
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (21.4.0)
Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (1.3.0)
Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (1.7.2)
Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (2.0.12)
Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (1.2.0)
Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (4.0.2)
Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (6.0.2)
Requirement already satisfied: asynctest==0.13.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (0.13.0)
Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.10.0->ratsnlp) (1.1.0)
๊ตฌ๊ธ ๋๋ผ์ด๋ธ ์ฐ๋
ํ์ตํ ๋ชจ๋ธ์ ์ฒดํฌํฌ์ธํธ๋ ๊ตฌ๊ธ ๋๋ผ์ด๋ธ์ ์ ์ฅํด ๋์์ผ๋ฏ๋ก, code 4-1์ ์คํํ์ฌ ์ฝ๋ฉ ๋ ธํธ๋ถ๊ณผ ์์ ์ ๊ตฌ๊ธ ๋๋ผ์ด๋ธ๋ฅผ ์ฐ๋ํ๋ค.
code 4-1
from google.colab import drive
drive.mount('/gdrive', force_remount=True)
โถCode output
Mounted at /gdrive
์ธํผ๋ฐ์ค ์ค์
๊ฐ์ข
์ธํผ๋ฐ์ค ์ค์ ์ ์ํํ๋ค. pretrained_model_name๊ณผ max_seq_length, downstream_model_dir ๋ชจ๋ ์ ํธ๋ ์ธ์์ ์ ์ฉํ ๊ทธ๋๋ก ์
๋ ฅํ์ฌ์ผ ํ๋ค.
code 4-2
from ratsnlp.nlpbook.classification import ClassificationDeployArguments
args = ClassificationDeployArguments(
pretrained_model_name="beomi/kcbert-base",
downstream_model_dir="/gdrive/My Drive/nlpbook/checkpoint-paircls",
max_seq_length=64,
)
โถCode output
downstream_model_checkpoint_fpath: /gdrive/My Drive/nlpbook/checkpoint-paircls/epoch=1-val_loss=0.82.ckpt
2. ํ ํฌ๋์ด์ ๋ฐ ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ
ํ ํฌ๋์ด์ ๋ก๋
code 4-3์ ์คํํด ํ ํฌ๋์ด์ ๋ฅผ ์ด๊ธฐํํ๋ค.
code 4-3
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(
args.pretrained_model_name,
do_lower_case=False,
)
์ฒดํฌํฌ์ธํธ ๋ก๋
code 4-4๋ pair_classification_train.ipynb์์ ํ์ธํ๋ํ ๋ชจ๋ธ์ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ฝ์ด ๋ค์ธ๋ค.
code 4-4
import torch
fine_tuned_model_ckpt = torch.load(
args.downstream_model_checkpoint_fpath,
map_location=torch.device("cpu"),
)
BERT ์ค์ ๋ก๋ ๋ฐ BERT ๋ชจ๋ธ ์ด๊ธฐํ
code 4-5๋ pair_classification_train.ipynb์ ํ์ธํ๋ ๋ ์ฌ์ฉํ pretrained_model_name์ ํด๋นํ๋ ๋ชจ๋ธ์ ์ค์ ๊ฐ๋ค์ ์ฝ์ด๋ค์ด๋ฉฐ, code 4-6์ ์คํํ๋ฉด ํด๋น ๊ฐ๋๋ก BERT ๋ชจ๋ธ์ ์ด๊ธฐํ ํ๋ค.
code 4-5
from transformers import BertConfig
pretrained_model_config = BertConfig.from_pretrained(
args.pretrained_model_name,
num_labels=fine_tuned_model_ckpt['state_dict']['model.classifier.bias'].shape.numel(),
)
code 4-6
from transformers import BertForSequenceClassification
model = BertForSequenceClassification(pretrained_model_config)
์ฒดํฌํฌ์ธํธ ์ฃผ์ ํ๊ธฐ
code 4-7์ ์ด๊ธฐํํ BERT๋ชจ๋ธ์ code 4-4์ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ฃผ์ ํ๋ค
code 4-7
model.load_state_dict({k.replace("model.",""): v for k, v in fine_tuned_model_ckpt['state_dict'].items()})
โถCode output
<All keys matched successfully>
ํ๊ฐ ๋ชจ๋๋ก ์ ํ
์ด์ด์ code 4-8์ ์คํํ๋ฉด ๋ชจ๋ธ์ด ํ๊ฐ๋ชจ๋๋ก ์ ํ๋๊ฒ ๋๋ค. ๋๋กญ์์ ๋ฑ ํ์ต ๋๋ง ์ฌ์ฉํ๋ ๊ธฐ๋ฒ๋ค์ ๋ฌดํจํํ๋ ์ญํ ์ ํ๋ค.
code 4-8
model.eval()
โถCode output
BertForSequenceClassification(
(bert): BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(30000, 768, padding_idx=0)
(position_embeddings): Embedding(300, 768)
(token_type_embeddings): Embedding(2, 768)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(1): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(2): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(3): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(4): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(5): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(6): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(7): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(8): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(9): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(10): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(11): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
(dropout): Dropout(p=0.1, inplace=False)
(classifier): Linear(in_features=768, out_features=3, bias=True)
)
3. ๋ชจ๋ธ ์ถ๋ ฅ๊ฐ ๋ง๋ค๊ณ ํ์ฒ๋ฆฌ ํ๊ธฐ
code 4-9๋ ์ธํผ๋ฐ์ค ๊ณผ์ ์ ์ ์ํ ํจ์์ด๋ค. ์ ์ (premise)์ ๊ฐ์ค(hypothesis)์ ์
๋ ฅ๋ฐ์ ๊ฐ๊ฐ ํ ํฐํ, ์ธ๋ฑ์ฑ์ ์ํํ ๋ค input_ids, attention_mask, token_type_ids๋ฅผ ๋ง๋ ๋ค. ์ด๋ค ์๋ ฅ๊ฐ์ ํ์ดํ ์น ํ
์ ์๋ฃํ์ผ๋ก ๋ณํํ ๋ค ๋ชจ๋ธ์ ์
๋ ฅํ๋ค.
์ธํผ๋ฐ์ค ํจ์
code 4-9
def inference_fn(premise, hypothesis):
# ์ ์ ์ ๊ฐ์ค์ ๋ชจ๋ธ ์
๋ ฅ๊ฐ์ผ๋ก ๋ง๋ค๊ธฐ
inputs = tokenizer(
[(premise, hypothesis)],
max_length=args.max_seq_length,
padding="max_length",
truncation=True,
)
with torch.no_grad():
# ๋ชจ๋ธ ๊ณ์ฐํ๊ธฐ
outputs = model(**{k: torch.tensor(v) for k, v in inputs.items()}) # {}์ = inputs๋ฅผ ํ์ดํ ์น ํ
์๋ก ๋ฐ๊พธ๊ธฐ
# ๋ก์ง์ ์ํํธ๋งฅ์ค ์ทจํ๊ธฐ
prob = outputs.logits.softmax(dim=1)
# ํ๋ฅ ์ ์์์ ๋ ์๋ฆฌ์์ ๋ฐ์ฌ๋ฆผ
entailment_prob = round(prob[0][0].item(), 2)
contradiction_prob = round(prob[0][1].item(), 2)
neutral_prob = round(prob[0][2].item(), 2)
# ์์ธก ํ๋ฅ ์ ์ต๋๊ฐ ์์น์ ๋ฐ๋ผ pred ๋ง๋ค๊ธฐ
if torch.argmax(prob) == 0:
pred = "์ฐธ (entailment)"
elif torch.argmax(prob) == 1:
pred = "๊ฑฐ์ง (contradiction)"
else:
pred = "์ค๋ฆฝ (neutral)"
return {
'premise': premise,
'hypothesis': hypothesis,
'prediction': pred,
'entailment_data': f"์ฐธ {entailment_prob}",
'contradiction_data': f"๊ฑฐ์ง {contradiction_prob}",
'neutral_data': f"์ค๋ฆฝ {neutral_prob}",
'entailment_width': f"{entailment_prob * 100}%",
'contradiction_width': f"{contradiction_prob * 100}%",
'neutral_width': f"{neutral_prob * 100}%"
}
**๋ชจ๋ธ ์ถ๋ ฅ๊ฐ(output.logits)**์ ์ํํธ๋งฅ์ค ํจ์ ์ ์ฉ ์ด์ ์ ๋ก์ง ํํ์ด๋ค. ์ฌ๊ธฐ์ ์ํํธ๋งฅ์ค ํจ์๋ฅผ ์จ์ ๋ชจ๋ธ ์ถ๋ ฅ์ ํ๋ฅ ํํ๋ก ๋ฐ๊พผ๋ค. ๊ทธ๋ฆฌ๊ณ ์ฝ๊ฐ ํ์ฒ๋ฆฌํ์ฌ ์์ธก ํ๋ฅ ์ ์ต๋๊ฐ์ด ์ฐธ ์์น(0)์ผ ๊ฒฝ์ฐ ํด๋น ๋ฌธ์ฅ์ด '์ฐธ (entailment)', ๊ฑฐ์ง ์์น(1)์ผ ๊ฒฝ์ฐ '๊ฑฐ์ง (contradiction)', ์ค๋ฆฝ ์์น(2)์ผ ๊ฒฝ์ฐ '์ค๋ฆฝ (neutral)'์ด ๋๋๋ก pred ๊ฐ์ ๋ง๋ ๋ค.
code 4-9์์ entailment_width, contradiction_width, neutral_width๋ ์น ํ์ด์ง์์ ์ฐธ, ๊ฑฐ์ง, ์ค๋ฆฝ ๋ง๋ ๊ธธ์ด๋ฅผ ์กฐ์ ํ๋ ์ ๋ณด์ด๋ฏ๋ก ํฌ๊ฒ ์ ๊ฒฝ ์ฐ์ง ์์๋ ๋๋ค.
4. ์น ์๋น์ค ์์ํ๊ธฐ
์น ์๋น์ค ๋ง๋ค๊ธฐ ์ค๋น
ngrok์ ์ฝ๋ฉ ๋ก์ปฌ์์ ์คํ ์ค์ธ ์น์๋น์ค๋ฅผ ์์ ํ๊ฒ ์ธ๋ถ์์ ์ ๊ทผ ๊ฐ๋ฅํ๋๋ก ํด์ฃผ๋ ๋๊ตฌ์ด๋ค. ngrok์ ์คํํ๋ ค๋ฉด ํ์๊ฐ์
ํ ๋ก๊ทธ์ธ์ ํ ๋ค ์ด๊ณณ์ ์ ์ํด ์ธ์ฆํ ํฐ(authtoken)์ ํ์ธํด์ผ ํ๋ค.
์๋ฅผ ๋ค์ด ํ์ธ๋ authtoken์ด test123์ด๋ผ๋ฉด ๋ค์๊ณผ ๊ฐ์ด ์คํ ๋๋ค.
** !mkdir /root/.ngrok2 && echo "authtoken: test123" > /root/.ngrok2/ngrok.yml**
code 4-10
!mkdir /root/.ngrok2 && echo "authtoken: (์ฌ๊ธฐ ์ฑ์ฐ์ธ์)" > /root/.ngrok2/ngrok.yml
โถCode output
mkdir: cannot create directory โ/root/.ngrok2โ: File exists
์น ์๋น์ค ์์ํ๊ธฐ
code 4-9์์ ์ ์ํ ์ธํผ๋ฐ์ค ํจ์ inference_fn์ ๊ฐ์ง๊ณ code 4-11์ ์คํํ๋ฉด ์น ์๋น์ค๋ฅผ ๋์ธ ์ ์๋ค. ํ์ด์ฌ์ ํ๋ผ์คํฌ๋ฅผ ํ์ฉํ ์ฑ์ด๋ค.
code 4-11
from ratsnlp.nlpbook.paircls import get_web_service_app
app = get_web_service_app(inference_fn)
app.run()
โถCode output
* Serving Flask app "ratsnlp.nlpbook.paircls.deploy" (lazy loading)
* Environment: production
[31m WARNING: This is a development server. Do not use it in a production deployment.[0m
[2m Use a production WSGI server instead.[0m
* Debug mode: off
* Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)
* Running on http://0163-35-238-180-140.ngrok.io
* Traffic stats available on http://127.0.0.1:4040
127.0.0.1 - - [04/Mar/2022 09:14:48] "[37mGET / HTTP/1.1[0m" 200 -
127.0.0.1 - - [04/Mar/2022 09:14:49] "[33mGET /favicon.ico HTTP/1.1[0m" 404 -
127.0.0.1 - - [04/Mar/2022 09:14:49] "[37mGET / HTTP/1.1[0m" 200 -
127.0.0.1 - - [04/Mar/2022 09:15:01] "[37mPOST /api HTTP/1.1[0m" 200 -
์น์ฌ์ดํธ ํํ๋ ๋ค์๊ณผ ๊ฐ๋ค.

