[NLP] ๋ฌธ์ ๋ถ๋ฅ ๋ชจ๋ธ ์ค์ ํฌ์
์, ๊ทธ๋ผ ํ์ต์ ๋ง์น ๋ชจ๋ธ์ ์ด๋ป๊ฒ ์ฌ์ฉํ ๊น?
๋ณธ ํ์ผ์ ์ด๊ธฐ์ฐฝ๋์ 'Do it! ์์ฐ์ด ์ฒ๋ฆฌ'์ ๊ธฐ์ดํ์ฌ ์์ฑ๋์๋ค! :)
ํ์ต ๋ง์น ๋ชจ๋ธ์ ์ค์ ํฌ์ ํ๊ธฐ
์ด๋ฒ ์ค์ต์์๋ ํ์ต์ ๋ง์น ๋ฌธ์ ๋ถ๋ฅ ๋ชจ๋ธ์ ๊ฐ์ง๊ณ ์น ์๋น์ค๋ฅผ ๋ง๋ ๋ค.
๋ฌธ์ฅ์ ๋ฐ์ ํด๋น ๋ฌธ์ฅ์ด ๊ธ์ ์ธ์ง ๋ถ์ ์ธ์ง ๋ต๋ณํ๋ ์น ์๋น์ค๋ก, ๋ฌธ์ฅ์ ํ ํฐํํ ๋ค ๋ชจ๋ธ ์ ๋ ฅ๊ฐ์ผ๋ก ๋ง๋ค๊ณ ์ด๋ฅผ ๋ชจ๋ธ์ ์ ๋ ฅํด [ํด๋น ๋ฌธ์ฅ์ด ๊ธ์ ์ผ ํ๋ฅ , ํด๋น ๋ฌธ์ฅ์ด ๋ถ์ ์ผ ํ๋ฅ ]์ ๊ณ์ฐํ๊ฒ ๋ง๋ ๋ค. ์ดํ ์ฝ๊ฐ์ ํ์ฒ๋ฆฌ ๊ณผ์ ์ ๊ฑฐ์ณ ์๋ตํ๊ฒ ๋ง๋๋ ๋ฐฉ์์ด๋ค.
์น ์๋น์ค๋ ๋คํธ์ํฌ์์ ์ปดํจํฐ ๊ฐ์ ์ํธ์์ฉ์ ํ๊ธฐ ์ํด ๋ง๋ค์ด์ง ์ํํธ์จ์ด ์์คํ ์ด๋ค. ๋ณธ ๋ ธํธ์์๋ ์๊ฒฉ ์ฌ์ฉ์๊ฐ ๋ณด๋ธ ๋ฌธ์ฅ์ ์์ ํด ํด๋น ๋ฌธ์ฅ์ด ๊ธ์ ์ธ์ง ๋ถ์ ์ธ์ง ์๋ต์ ๋ง๋ค๊ณ ์ด ์๋ต์ ์๊ฒฉ ์ฌ์ฉ์์๊ฒ ์ ๋ฌํ๋ ์น ์๋น์ค๋ฅผ ๋ง๋๋ ๊ฒ์ด๋ค.
1. ํ๊ฒฝ ์ค์ ํ๊ธฐ
์์กด์ฑ ํจํค์ง ์ค์น
pip ๋ช ๋ น์ด๋ฅผ ํตํด ์์กด์ฑ ์๋ ํจํค์ง๋ฅผ ์ค์นํ๋ค.
code 2-0
!pip install ratsnlp
โถCode output
Requirement already satisfied: ratsnlp in /usr/local/lib/python3.7/dist-packages (1.0.1)
Requirement already satisfied: transformers==4.10.0 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (4.10.0)
Requirement already satisfied: Korpora>=0.2.0 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (0.2.0)
Requirement already satisfied: pytorch-lightning==1.3.4 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (1.3.4)
Requirement already satisfied: flask-ngrok>=0.0.25 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (0.0.25)
Requirement already satisfied: torch>=1.9.0 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (1.10.0+cu111)
Requirement already satisfied: flask>=1.1.4 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (1.1.4)
Requirement already satisfied: flask-cors>=3.0.10 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (3.0.10)
Requirement already satisfied: fsspec[http]>=2021.4.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (2022.2.0)
Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (21.3)
Requirement already satisfied: pyDeprecate==0.3.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (0.3.0)
Requirement already satisfied: tensorboard!=2.5.0,>=2.2.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (2.8.0)
Requirement already satisfied: PyYAML<=5.4.1,>=5.1 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (5.4.1)
Requirement already satisfied: torchmetrics>=0.2.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (0.7.2)
Requirement already satisfied: tqdm>=4.41.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (4.62.3)
Requirement already satisfied: future>=0.17.1 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (0.18.2)
Requirement already satisfied: numpy>=1.17.2 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (1.21.5)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (2.23.0)
Requirement already satisfied: huggingface-hub>=0.0.12 in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (0.4.0)
Requirement already satisfied: sacremoses in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (0.0.47)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (4.11.1)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (2019.12.20)
Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (3.6.0)
Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (0.10.3)
Requirement already satisfied: click<8.0,>=5.1 in /usr/local/lib/python3.7/dist-packages (from flask>=1.1.4->ratsnlp) (7.1.2)
Requirement already satisfied: itsdangerous<2.0,>=0.24 in /usr/local/lib/python3.7/dist-packages (from flask>=1.1.4->ratsnlp) (1.1.0)
Requirement already satisfied: Jinja2<3.0,>=2.10.1 in /usr/local/lib/python3.7/dist-packages (from flask>=1.1.4->ratsnlp) (2.11.3)
Requirement already satisfied: Werkzeug<2.0,>=0.15 in /usr/local/lib/python3.7/dist-packages (from flask>=1.1.4->ratsnlp) (1.0.1)
Requirement already satisfied: Six in /usr/local/lib/python3.7/dist-packages (from flask-cors>=3.0.10->ratsnlp) (1.15.0)
Requirement already satisfied: aiohttp in /usr/local/lib/python3.7/dist-packages (from fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (3.8.1)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub>=0.0.12->transformers==4.10.0->ratsnlp) (3.10.0.2)
Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from Jinja2<3.0,>=2.10.1->flask>=1.1.4->ratsnlp) (2.0.1)
Requirement already satisfied: dataclasses>=0.6 in /usr/local/lib/python3.7/dist-packages (from Korpora>=0.2.0->ratsnlp) (0.6)
Requirement already satisfied: xlrd>=1.2.0 in /usr/local/lib/python3.7/dist-packages (from Korpora>=0.2.0->ratsnlp) (2.0.1)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->pytorch-lightning==1.3.4->ratsnlp) (3.0.7)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.10.0->ratsnlp) (2021.10.8)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.10.0->ratsnlp) (3.0.4)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.10.0->ratsnlp) (2.10)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.10.0->ratsnlp) (1.24.3)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.4.6)
Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.35.0)
Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.43.0)
Requirement already satisfied: protobuf>=3.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (3.17.3)
Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.37.1)
Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (57.4.0)
Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.0.0)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (3.3.6)
Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.6.1)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.8.1)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (4.2.4)
Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (4.8)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.2.8)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.3.1)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers==4.10.0->ratsnlp) (3.7.0)
Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.4.8)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (3.2.0)
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (21.4.0)
Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (1.2.0)
Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (1.7.2)
Requirement already satisfied: asynctest==0.13.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (0.13.0)
Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (6.0.2)
Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (1.3.0)
Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (2.0.12)
Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (4.0.2)
Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.10.0->ratsnlp) (1.1.0)
๊ตฌ๊ธ ๋๋ผ์ด๋ธ ์ฐ๋
๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ๋ ๊ตฌ๊ธ ๋๋ผ์ด๋ธ์ ์ ์ฅํด ๋์๋ค. ์ฝ๋ฉ ๋ ธํธ๋ถ๊ณผ ์์ ์ ๊ตฌ๊ธ ๋๋ผ์ด๋ธ๋ฅผ ์ฐ๋ํ๋ค.
code 2-1
from google.colab import drive
drive.mount('/gdrive', force_remount=True)
โถCode output
Mounted at /gdrive
์ธํผ๋ฐ์ค ์ค์
๊ฐ์ข ์ธ์( ๋ชจ๋ธ ํ์ดํผํ๋ผ๋ฏธํฐ(hyperparameter)์ ์ ์ฅ ์์น ๋ฑ )๋ฅผ ์ค์ ํ๋ค.
code 2-2
from ratsnlp.nlpbook.classification import ClassificationDeployArguments
args = ClassificationDeployArguments(
pretrained_model_name="beomi/kcbert-base",
downstream_model_dir="/gdrive/My Drive/nlpbook/checkpoint-doccls",
max_seq_length=128,
)
โถCode output
downstream_model_checkpoint_fpath: /gdrive/My Drive/nlpbook/checkpoint-doccls/epoch=1-val_loss=0.28.ckpt
๊ฐ ์ธ์์ ์ญํ ๊ณผ ๋ด์ฉ์ ๋ค์๊ณผ ๊ฐ๋ค.
- pretrained_model_name:
training_section.ipynb์์ ์ ์ฉํpretrained_model_name(๋จ, ํด๋น ๋ชจ๋ธ์ ํ๊น ํ์ด์ค ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๋ฑ๋ก๋์ด ์์ด์ผ ํ๋ค.) - downstream_model_dir:
training_section.ipynb์์ ํ์ธํ๋ํ ๋ชจ๋ธ์ ์ฒดํฌํฌ์ธํธ ์ ์ฅ ์์น(ํ์ฅ์๊ฐckpt์ธ ํ์ผ์ด ํ๋ ์ด์ ์์ด์ผ ํ๋ค.) - max_seq_length: ํ ํฐ ๊ธฐ์ค ์ ๋ ฅ ๋ฌธ์ฅ ์ต๋ ๊ธธ์ด. ์๋ฌด๊ฒ๋ ์ ๋ ฅํ์ง ์์ผ๋ฉด 128.
2. ํ ํฌ๋์ด์ ๋ฐ ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ
ํ ํฌ๋์ด์ ๋ก๋
code 2-3์ ์คํํด ํ ํฌ๋์ด์ ๋ฅผ ์ด๊ธฐํ ํ๋ค.
code 2-3
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(
args.pretrained_model_name,
do_lower_case=False,
)
โถCode output
Downloading: 0%| | 0.00/250k [00:00<?, ?B/s]
Downloading: 0%| | 0.00/49.0 [00:00<?, ?B/s]
Downloading: 0%| | 0.00/619 [00:00<?, ?B/s]
์ฒดํฌํฌ์ธํธ ๋ก๋
code 2-4๋ training_section.ipynb์์ ํ์ธํ๋ํ ๋ชจ๋ธ์ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ฝ์ด๋ค์ธ๋ค.
code 2-4
import torch
fine_tuned_model_ckpt = torch.load(
args.downstream_model_checkpoint_fpath,
map_location=torch.device("cpu"),
)
BERT ์ค์ ๋ก๋ ๋ฐ BERT ๋ชจ๋ธ ์ด๊ธฐํ
code 2-5๋ training_section.ipynb์ ํ์ธํ๋ ๋ ์ฌ์ฉํ pretrained_model_name์ ํด๋นํ๋ ๋ชจ๋ธ์ ์ค์ ๊ฐ๋ค์ ์ฝ์ด๋ค์ธ๋ค.
์ด์ด์ code 2-6์ ์คํํ๋ฉด ํด๋น ์ค์ ๊ฐ๋๋ก BERT ๋ชจ๋ธ์ ์ด๊ธฐํ ํ๋ค.
code 2-5
from transformers import BertConfig
pretrained_model_config = BertConfig.from_pretrained(
args.pretrained_model_name,
num_labels=fine_tuned_model_ckpt["state_dict"]["model.classifier.bias"].shape.numel(),
)
code 2-6
from transformers import BertForSequenceClassification
model = BertForSequenceClassification(pretrained_model_config)
์ฒดํฌํฌ์ธํธ ์ฃผ์ ํ๊ธฐ
code 2-7์ ์ด๊ธฐํํ BERT๋ชจ๋ธ์ ์ฒดํฌํฌ์ธํธ(fine_tuned_model_ckpt)๋ฅผ ์ฃผ์ ํ๋ค.
code 2-7
model.load_state_dict({k.replace("model.",""): v for k, v in fine_tuned_model_ckpt['state_dict'].items()})
โถCode output
<All keys matched successfully>
ํ๊ฐ๋ชจ๋๋ก ์ ํ
์ด์ด์ code 2-8์ ์คํํ๋ฉด ๋ชจ๋ธ์ด ํ๊ฐ๋ชจ๋๋ก ์ ํ๋๊ฒ ๋๋ค. ๋๋กญ์์ ๋ฑ ํ์ต ๋๋ง ์ฌ์ฉํ๋ ๊ธฐ๋ฒ๋ค์ ๋ฌดํจํํ๋ ์ญํ ์ ํ๋ค.
code 2-8
model.eval()
โถCode output
BertForSequenceClassification(
(bert): BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(30000, 768, padding_idx=0)
(position_embeddings): Embedding(300, 768)
(token_type_embeddings): Embedding(2, 768)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(1): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(2): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(3): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(4): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(5): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(6): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(7): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(8): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(9): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(10): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(11): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
(dropout): Dropout(p=0.1, inplace=False)
(classifier): Linear(in_features=768, out_features=2, bias=True)
)
3. ๋ชจ๋ธ ์ถ๋ ฅ๊ฐ ๋ง๋ค๊ณ ํ์ฒ๋ฆฌ ํ๊ธฐ
code 2-9๋ ์ธํผ๋ฐ์ค ๊ณผ์ ์ ์ ์ํ ํจ์์ด๋ค. ๋ฌธ์ฅ์ ํ ํฐํ๋ฅผ ์ํํ ๋ค input_ids, attention_mask, token_type_ids๋ฅผ ๋ง๋ ๋ค. ์ด๋ค ์
๋ ฅ๊ฐ์ ํ์ดํ ์น์ ํ
์ ์๋ฃํ์ผ๋ก ๋ณํํ ๋ค ๋ชจ๋ธ์ ์
๋ ฅํ๋ค. ๋ชจ๋ธ ์ถ๋ ฅ๊ฐ(outputs.logits)์ ์ํํธ๋งฅ์ค ํจ์ ์ ์ฉ ์ด์ ์ ๋ก์ง(logit)ํํ์ธ๋ฐ, ์ฌ๊ธฐ์ ์ํํธ๋งฅ์ค ํจ์๋ฅผ ์จ์ ๋ชจ๋ธ ์ถ๋ ฅ์ '[๋ถ์ ์ผ ํ๋ฅ , ๊ธ์ ์ผ ํ๋ฅ ]'๋ก ๋ฐ๊พผ๋ค.
๋ง์ง๋ง์ผ๋ก ๋ชจ๋ธ ์ถ๋ ฅ์ ์ฝ๊ฐ ํ์ฒ๋ฆฌ ํ์ฌ ์์ธก ํ๋ฅ ์ ์ต๋๊ฐ์ด ๋ถ์ ์์น์ผ ๋ ํด๋น ๋ฌธ์ฅ์ด '๋ถ์ (negative)', ๋ฐ๋์ผ ๋๋ '๊ธ์ (positive)'์ด ๋๋๋ก pred๊ฐ์ ๋ง๋ ๋ค.
code 2-9
def inference_fn(sentence):
# ๋ฌธ์ฅ์ ํ ํฐํํ ๋ค input_id, attention_masks, token_type_ids ๋ง๋ค๊ธฐ
inputs = tokenizer(
[sentence],
max_lenght=args.max_seq_length,
padding="max_length",
truncation=True,
)
with torch.no_grad():
# ๋ชจ๋ธ ๊ณ์ฐํ๊ธฐ
outputs = model(**{k: torch.tensor(v) for k, v in inputs.items()}) # {}์ = inputs๋ฅผ ํ์ดํ ์น ํ
์๋ก ๋ฐ๊พธ๊ธฐ
# ๋ก์ง์ ์ํํธ ๋งฅ์ค ์ทจํ๊ธฐ
prob = outputs.logits.softmax(dim=1)
# ๊ธ์ /๋ถ์ ํ๋ฅ ์ ์์์ 4์๋ฆฌ๋ก ๋ฐ์ฌ๋ฆผ
positive_prob = round(prob[0][1].item(), 4)
negative_prob = round(prob[0][0].item(), 4)
# ์์ธก ํ๋ฅ ์ ์ต๋๊ฐ ์์น์ ๋ฐ๋ผ pred ๋ง๋ค๊ธฐ
pred = "๊ธ์ (positive)" if torch.argmax(prob) == 1 else "๋ถ์ (negative)"
return {
'sentence' : sentence,
'prediction': pred,
'positive_data': f"๊ธ์ {positive_prob}",
'negative_data': f"๋ถ์ {negative_prob}",
'positive_width': f"{positive_prob * 100}%",
'negative_width': f"{negative_prob * 100}%",
}
code 2-9์์ positive_width, negative_width๋ ์น ํ์ด์ง์์ ๊ธ์ /๋ถ์ ๋ง๋์ ๊ธธ์ด๋ฅผ ์กฐ์ ํ๋ ค๋ ๊ฒ์ด๋ฏ๋ก ํฌ๊ฒ ์ ๊ฒฝ์ฐ์ง ์์๋ ๋๋ค.
4. ์น ์๋น์ค ์์ํ๊ธฐ
์น ์๋น์ค ๋ง๋ค๊ธฐ ์ค๋น
ngrok์ ์ฝ๋ฉ ๋ก์ปฌ์์ ์คํ ์ค์ธ ์น์๋น์ค๋ฅผ ์์ ํ๊ฒ ์ธ๋ถ์์ ์ ๊ทผ ๊ฐ๋ฅํ๋๋ก ํด์ฃผ๋ ๋๊ตฌ์ด๋ค. ngrok์ ์คํํ๋ ค๋ฉด ํ์๊ฐ์
ํ ๋ก๊ทธ์ธ์ ํ ๋ค ์ด๊ณณ์ ์ ์ํด ์ธ์ฆํ ํฐ(authtoken)์ ํ์ธํด์ผ ํ๋ค.
์๋ฅผ ๋ค์ด ํ์ธ๋ authtoken์ด test123์ด๋ผ๋ฉด ๋ค์๊ณผ ๊ฐ์ด ์คํ ๋๋ค.
** !mkdir /root/.ngrok2 && echo "authtoken: test123" > /root/.ngrok2/ngrok.yml**
code 2-10
!mkdir /root/.ngrok2 && echo "authtoken: (์ฌ๊ธฐ ์ฑ์ฐ์ธ์))" > /root/.ngrok2/ngrok.yml
์น ์๋น์ค ์์ํ๊ธฐ
code 2-9์์ ์ ์ํ ์ธํผ๋ฐ์ค ํจ์ inference_fn์ ๊ฐ์ง๊ณ code 2-11์ ์คํํ๋ฉด ํ๋ผ์คํฌ(flask)๋ผ๋ ํ์ด์ฌ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๋์์ ๋ฐ์ ์น ์๋น์ค๋ฅผ ๋์ธ ์ ์๋ค.
code 2-11
from ratsnlp.nlpbook.classification import get_web_service_app
app = get_web_service_app(inference_fn)
app.run()
์น ์ฌ์ดํธ์ ํํ๋ ๋ค์๊ณผ ๊ฐ๋ค.

