[NLP] ๋ฌธ์„œ ๋ถ„๋ฅ˜ ๋ชจ๋ธ ์‹ค์ „ ํˆฌ์ž…

โ€ขAI/NLP

์ž, ๊ทธ๋Ÿผ ํ•™์Šต์„ ๋งˆ์นœ ๋ชจ๋ธ์„ ์–ด๋–ป๊ฒŒ ์‚ฌ์šฉํ• ๊นŒ?

๋ณธ ํŒŒ์ผ์€ ์ด๊ธฐ์ฐฝ๋‹˜์˜ 'Do it! ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ'์— ๊ธฐ์ดˆํ•˜์—ฌ ์ž‘์„ฑ๋˜์—ˆ๋‹ค! :)

ํ•™์Šต ๋งˆ์นœ ๋ชจ๋ธ์„ ์‹ค์ „ ํˆฌ์ž…ํ•˜๊ธฐ

์ด๋ฒˆ ์‹ค์Šต์—์„œ๋Š” ํ•™์Šต์„ ๋งˆ์นœ ๋ฌธ์„œ ๋ถ„๋ฅ˜ ๋ชจ๋ธ์„ ๊ฐ€์ง€๊ณ  ์›น ์„œ๋น„์Šค๋ฅผ ๋งŒ๋“ ๋‹ค.

๋ฌธ์žฅ์„ ๋ฐ›์•„ ํ•ด๋‹น ๋ฌธ์žฅ์ด ๊ธ์ •์ธ์ง€ ๋ถ€์ •์ธ์ง€ ๋‹ต๋ณ€ํ•˜๋Š” ์›น ์„œ๋น„์Šค๋กœ, ๋ฌธ์žฅ์„ ํ† ํฐํ™”ํ•œ ๋’ค ๋ชจ๋ธ ์ž…๋ ฅ๊ฐ’์œผ๋กœ ๋งŒ๋“ค๊ณ  ์ด๋ฅผ ๋ชจ๋ธ์— ์ž…๋ ฅํ•ด [ํ•ด๋‹น ๋ฌธ์žฅ์ด ๊ธ์ •์ผ ํ™•๋ฅ , ํ•ด๋‹น ๋ฌธ์žฅ์ด ๋ถ€์ •์ผ ํ™•๋ฅ  ]์„ ๊ณ„์‚ฐํ•˜๊ฒŒ ๋งŒ๋“ ๋‹ค. ์ดํ›„ ์•ฝ๊ฐ„์˜ ํ›„์ฒ˜๋ฆฌ ๊ณผ์ •์„ ๊ฑฐ์ณ ์‘๋‹ตํ•˜๊ฒŒ ๋งŒ๋“œ๋Š” ๋ฐฉ์‹์ด๋‹ค.

์›น ์„œ๋น„์Šค๋ž€ ๋„คํŠธ์›Œํฌ์—์„œ ์ปดํ“จํ„ฐ ๊ฐ„์— ์ƒํ˜ธ์ž‘์šฉ์„ ํ•˜๊ธฐ ์œ„ํ•ด ๋งŒ๋“ค์–ด์ง„ ์†Œํ”„ํŠธ์›จ์–ด ์‹œ์Šคํ…œ์ด๋‹ค. ๋ณธ ๋…ธํŠธ์—์„œ๋Š” ์›๊ฒฉ ์‚ฌ์šฉ์ž๊ฐ€ ๋ณด๋‚ธ ๋ฌธ์žฅ์„ ์ˆ˜์‹ ํ•ด ํ•ด๋‹น ๋ฌธ์žฅ์ด ๊ธ์ •์ธ์ง€ ๋ถ€์ •์ธ์ง€ ์‘๋‹ต์„ ๋งŒ๋“ค๊ณ  ์ด ์‘๋‹ต์„ ์›๊ฒฉ ์‚ฌ์šฉ์ž์—๊ฒŒ ์ „๋‹ฌํ•˜๋Š” ์›น ์„œ๋น„์Šค๋ฅผ ๋งŒ๋“œ๋Š” ๊ฒƒ์ด๋‹ค.

1. ํ™˜๊ฒฝ ์„ค์ •ํ•˜๊ธฐ

์˜์กด์„ฑ ํŒจํ‚ค์ง€ ์„ค์น˜

pip ๋ช…๋ น์–ด๋ฅผ ํ†ตํ•ด ์˜์กด์„ฑ ์žˆ๋Š” ํŒจํ‚ค์ง€๋ฅผ ์„ค์น˜ํ•œ๋‹ค.

code 2-0

!pip install ratsnlp
โ–ถCode output
Requirement already satisfied: ratsnlp in /usr/local/lib/python3.7/dist-packages (1.0.1)
Requirement already satisfied: transformers==4.10.0 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (4.10.0)
Requirement already satisfied: Korpora>=0.2.0 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (0.2.0)
Requirement already satisfied: pytorch-lightning==1.3.4 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (1.3.4)
Requirement already satisfied: flask-ngrok>=0.0.25 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (0.0.25)
Requirement already satisfied: torch>=1.9.0 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (1.10.0+cu111)
Requirement already satisfied: flask>=1.1.4 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (1.1.4)
Requirement already satisfied: flask-cors>=3.0.10 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (3.0.10)
Requirement already satisfied: fsspec[http]>=2021.4.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (2022.2.0)
Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (21.3)
Requirement already satisfied: pyDeprecate==0.3.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (0.3.0)
Requirement already satisfied: tensorboard!=2.5.0,>=2.2.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (2.8.0)
Requirement already satisfied: PyYAML<=5.4.1,>=5.1 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (5.4.1)
Requirement already satisfied: torchmetrics>=0.2.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (0.7.2)
Requirement already satisfied: tqdm>=4.41.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (4.62.3)
Requirement already satisfied: future>=0.17.1 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (0.18.2)
Requirement already satisfied: numpy>=1.17.2 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (1.21.5)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (2.23.0)
Requirement already satisfied: huggingface-hub>=0.0.12 in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (0.4.0)
Requirement already satisfied: sacremoses in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (0.0.47)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (4.11.1)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (2019.12.20)
Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (3.6.0)
Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (0.10.3)
Requirement already satisfied: click<8.0,>=5.1 in /usr/local/lib/python3.7/dist-packages (from flask>=1.1.4->ratsnlp) (7.1.2)
Requirement already satisfied: itsdangerous<2.0,>=0.24 in /usr/local/lib/python3.7/dist-packages (from flask>=1.1.4->ratsnlp) (1.1.0)
Requirement already satisfied: Jinja2<3.0,>=2.10.1 in /usr/local/lib/python3.7/dist-packages (from flask>=1.1.4->ratsnlp) (2.11.3)
Requirement already satisfied: Werkzeug<2.0,>=0.15 in /usr/local/lib/python3.7/dist-packages (from flask>=1.1.4->ratsnlp) (1.0.1)
Requirement already satisfied: Six in /usr/local/lib/python3.7/dist-packages (from flask-cors>=3.0.10->ratsnlp) (1.15.0)
Requirement already satisfied: aiohttp in /usr/local/lib/python3.7/dist-packages (from fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (3.8.1)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub>=0.0.12->transformers==4.10.0->ratsnlp) (3.10.0.2)
Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from Jinja2<3.0,>=2.10.1->flask>=1.1.4->ratsnlp) (2.0.1)
Requirement already satisfied: dataclasses>=0.6 in /usr/local/lib/python3.7/dist-packages (from Korpora>=0.2.0->ratsnlp) (0.6)
Requirement already satisfied: xlrd>=1.2.0 in /usr/local/lib/python3.7/dist-packages (from Korpora>=0.2.0->ratsnlp) (2.0.1)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->pytorch-lightning==1.3.4->ratsnlp) (3.0.7)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.10.0->ratsnlp) (2021.10.8)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.10.0->ratsnlp) (3.0.4)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.10.0->ratsnlp) (2.10)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.10.0->ratsnlp) (1.24.3)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.4.6)
Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.35.0)
Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.43.0)
Requirement already satisfied: protobuf>=3.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (3.17.3)
Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.37.1)
Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (57.4.0)
Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.0.0)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (3.3.6)
Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.6.1)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.8.1)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (4.2.4)
Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (4.8)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.2.8)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.3.1)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers==4.10.0->ratsnlp) (3.7.0)
Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.4.8)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (3.2.0)
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (21.4.0)
Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (1.2.0)
Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (1.7.2)
Requirement already satisfied: asynctest==0.13.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (0.13.0)
Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (6.0.2)
Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (1.3.0)
Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (2.0.12)
Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (4.0.2)
Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.10.0->ratsnlp) (1.1.0)

๊ตฌ๊ธ€ ๋“œ๋ผ์ด๋ธŒ ์—ฐ๋™

๋ชจ๋ธ ์ฒดํฌํฌ์ธํŠธ๋Š” ๊ตฌ๊ธ€ ๋“œ๋ผ์ด๋ธŒ์— ์ €์žฅํ•ด ๋‘์—ˆ๋‹ค. ์ฝ”๋žฉ ๋…ธํŠธ๋ถ๊ณผ ์ž์‹ ์˜ ๊ตฌ๊ธ€ ๋“œ๋ผ์ด๋ธŒ๋ฅผ ์—ฐ๋™ํ•œ๋‹ค.

code 2-1

from google.colab import drive
drive.mount('/gdrive', force_remount=True)
โ–ถCode output
Mounted at /gdrive

์ธํผ๋Ÿฐ์Šค ์„ค์ •

๊ฐ์ข… ์ธ์ž( ๋ชจ๋ธ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ(hyperparameter)์™€ ์ €์žฅ ์œ„์น˜ ๋“ฑ )๋ฅผ ์„ค์ •ํ•œ๋‹ค.

code 2-2

from ratsnlp.nlpbook.classification import ClassificationDeployArguments
args = ClassificationDeployArguments(
    pretrained_model_name="beomi/kcbert-base",
    downstream_model_dir="/gdrive/My Drive/nlpbook/checkpoint-doccls",
    max_seq_length=128,
)
โ–ถCode output
downstream_model_checkpoint_fpath: /gdrive/My Drive/nlpbook/checkpoint-doccls/epoch=1-val_loss=0.28.ckpt

๊ฐ ์ธ์ž์˜ ์—ญํ• ๊ณผ ๋‚ด์šฉ์€ ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

  • pretrained_model_name: training_section.ipynb์—์„œ ์ ์šฉํ•œ pretrained_model_name(๋‹จ, ํ•ด๋‹น ๋ชจ๋ธ์€ ํ—ˆ๊น…ํŽ˜์ด์Šค ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์— ๋“ฑ๋ก๋˜์–ด ์žˆ์–ด์•ผ ํ•œ๋‹ค.)
  • downstream_model_dir: training_section.ipynb์—์„œ ํŒŒ์ธํŠœ๋‹ํ•œ ๋ชจ๋ธ์˜ ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ ์œ„์น˜(ํ™•์žฅ์ž๊ฐ€ ckpt์ธ ํŒŒ์ผ์ด ํ•˜๋‚˜ ์ด์ƒ ์žˆ์–ด์•ผ ํ•œ๋‹ค.)
  • max_seq_length: ํ† ํฐ ๊ธฐ์ค€ ์ž…๋ ฅ ๋ฌธ์žฅ ์ตœ๋Œ€ ๊ธธ์ด. ์•„๋ฌด๊ฒƒ๋„ ์ž…๋ ฅํ•˜์ง€ ์•Š์œผ๋ฉด 128.

2. ํ† ํฌ๋‚˜์ด์ € ๋ฐ ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ

code 2-3์„ ์‹คํ–‰ํ•ด ํ† ํฌ๋‚˜์ด์ €๋ฅผ ์ดˆ๊ธฐํ™” ํ•œ๋‹ค.

code 2-3

from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(
    args.pretrained_model_name,
    do_lower_case=False,
)
โ–ถCode output
Downloading:   0%|          | 0.00/250k [00:00<?, ?B/s]



Downloading:   0%|          | 0.00/49.0 [00:00<?, ?B/s]



Downloading:   0%|          | 0.00/619 [00:00<?, ?B/s]

์ฒดํฌํฌ์ธํŠธ ๋กœ๋“œ

code 2-4๋Š” training_section.ipynb์—์„œ ํŒŒ์ธํŠœ๋‹ํ•œ ๋ชจ๋ธ์˜ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ฝ์–ด๋“ค์ธ๋‹ค.

code 2-4

import torch
fine_tuned_model_ckpt = torch.load(
    args.downstream_model_checkpoint_fpath,
    map_location=torch.device("cpu"),
)

BERT ์„ค์ • ๋กœ๋“œ ๋ฐ BERT ๋ชจ๋ธ ์ดˆ๊ธฐํ™”

code 2-5๋Š” training_section.ipynb์˜ ํŒŒ์ธํŠœ๋‹ ๋•Œ ์‚ฌ์šฉํ•œ pretrained_model_name์— ํ•ด๋‹นํ•˜๋Š” ๋ชจ๋ธ์˜ ์„ค์ •๊ฐ’๋“ค์„ ์ฝ์–ด๋“ค์ธ๋‹ค.

์ด์–ด์„œ code 2-6์„ ์‹คํ–‰ํ•˜๋ฉด ํ•ด๋‹น ์„ค์ •๊ฐ’๋Œ€๋กœ BERT ๋ชจ๋ธ์„ ์ดˆ๊ธฐํ™” ํ•œ๋‹ค.

code 2-5

from transformers import BertConfig
pretrained_model_config = BertConfig.from_pretrained(
    args.pretrained_model_name,
    num_labels=fine_tuned_model_ckpt["state_dict"]["model.classifier.bias"].shape.numel(),
)

code 2-6

from transformers import BertForSequenceClassification
model = BertForSequenceClassification(pretrained_model_config)

์ฒดํฌํฌ์ธํŠธ ์ฃผ์ž…ํ•˜๊ธฐ

code 2-7์€ ์ดˆ๊ธฐํ™”ํ•œ BERT๋ชจ๋ธ์— ์ฒดํฌํฌ์ธํŠธ(fine_tuned_model_ckpt)๋ฅผ ์ฃผ์ž…ํ•œ๋‹ค.

code 2-7

model.load_state_dict({k.replace("model.",""): v for k, v in fine_tuned_model_ckpt['state_dict'].items()})
โ–ถCode output
<All keys matched successfully>

ํ‰๊ฐ€๋ชจ๋“œ๋กœ ์ „ํ™˜

์ด์–ด์„œ code 2-8์„ ์‹คํ–‰ํ•˜๋ฉด ๋ชจ๋ธ์ด ํ‰๊ฐ€๋ชจ๋“œ๋กœ ์ „ํ™˜๋˜๊ฒŒ ๋œ๋‹ค. ๋“œ๋กญ์•„์›ƒ ๋“ฑ ํ•™์Šต ๋•Œ๋งŒ ์‚ฌ์šฉํ•˜๋Š” ๊ธฐ๋ฒ•๋“ค์„ ๋ฌดํšจํ™”ํ•˜๋Š” ์—ญํ• ์„ ํ•œ๋‹ค.

code 2-8

model.eval()
โ–ถCode output
BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30000, 768, padding_idx=0)
      (position_embeddings): Embedding(300, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (1): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (2): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (3): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (4): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (5): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (6): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (7): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (8): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (9): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (10): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (11): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (pooler): BertPooler(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (activation): Tanh()
    )
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (classifier): Linear(in_features=768, out_features=2, bias=True)
)

3. ๋ชจ๋ธ ์ถœ๋ ฅ๊ฐ’ ๋งŒ๋“ค๊ณ  ํ›„์ฒ˜๋ฆฌ ํ•˜๊ธฐ

code 2-9๋Š” ์ธํผ๋Ÿฐ์Šค ๊ณผ์ •์„ ์ •์˜ํ•œ ํ•จ์ˆ˜์ด๋‹ค. ๋ฌธ์žฅ์— ํ† ํฐํ™”๋ฅผ ์ˆ˜ํ–‰ํ•œ ๋’ค input_ids, attention_mask, token_type_ids๋ฅผ ๋งŒ๋“ ๋‹ค. ์ด๋“ค ์ž…๋ ฅ๊ฐ’์„ ํŒŒ์ดํ† ์น˜์˜ ํ…์„œ ์ž๋ฃŒํ˜•์œผ๋กœ ๋ณ€ํ™˜ํ•œ ๋’ค ๋ชจ๋ธ์— ์ž…๋ ฅํ•œ๋‹ค. ๋ชจ๋ธ ์ถœ๋ ฅ๊ฐ’(outputs.logits)์€ ์†Œํ”„ํŠธ๋งฅ์Šค ํ•จ์ˆ˜ ์ ์šฉ ์ด์ „์˜ ๋กœ์ง“(logit)ํ˜•ํƒœ์ธ๋ฐ, ์—ฌ๊ธฐ์— ์†Œํ”„ํŠธ๋งฅ์Šค ํ•จ์ˆ˜๋ฅผ ์จ์„œ ๋ชจ๋ธ ์ถœ๋ ฅ์„ '[๋ถ€์ •์ผ ํ™•๋ฅ , ๊ธ์ •์ผ ํ™•๋ฅ ]'๋กœ ๋ฐ”๊พผ๋‹ค.

๋งˆ์ง€๋ง‰์œผ๋กœ ๋ชจ๋ธ ์ถœ๋ ฅ์„ ์•ฝ๊ฐ„ ํ›„์ฒ˜๋ฆฌ ํ•˜์—ฌ ์˜ˆ์ธก ํ™•๋ฅ ์˜ ์ตœ๋Œ“๊ฐ’์ด ๋ถ€์ • ์œ„์น˜์ผ ๋•Œ ํ•ด๋‹น ๋ฌธ์žฅ์ด '๋ถ€์ •(negative)', ๋ฐ˜๋Œ€์ผ ๋•Œ๋Š” '๊ธ์ •(positive)'์ด ๋˜๋„๋ก pred๊ฐ’์„ ๋งŒ๋“ ๋‹ค.

code 2-9

def inference_fn(sentence):
  # ๋ฌธ์žฅ์„ ํ† ํฐํ™”ํ•œ ๋’ค input_id, attention_masks, token_type_ids ๋งŒ๋“ค๊ธฐ
  inputs = tokenizer(
      [sentence],
      max_lenght=args.max_seq_length,
      padding="max_length",
      truncation=True,
  )
  with torch.no_grad():
    # ๋ชจ๋ธ ๊ณ„์‚ฐํ•˜๊ธฐ
    outputs = model(**{k: torch.tensor(v) for k, v in inputs.items()})  # {}์•ˆ = inputs๋ฅผ ํŒŒ์ดํ† ์น˜ ํ…์„œ๋กœ ๋ฐ”๊พธ๊ธฐ
    
    # ๋กœ์ง“์— ์†Œํ”„ํŠธ ๋งฅ์Šค ์ทจํ•˜๊ธฐ
    prob = outputs.logits.softmax(dim=1)

    # ๊ธ์ •/๋ถ€์ • ํ™•๋ฅ ์„ ์†Œ์ˆ˜์  4์ž๋ฆฌ๋กœ ๋ฐ˜์˜ฌ๋ฆผ
    positive_prob = round(prob[0][1].item(), 4)
    negative_prob = round(prob[0][0].item(), 4)

    # ์˜ˆ์ธก ํ™•๋ฅ ์˜ ์ตœ๋Œ“๊ฐ’ ์œ„์น˜์— ๋”ฐ๋ผ pred ๋งŒ๋“ค๊ธฐ
    pred = "๊ธ์ • (positive)" if torch.argmax(prob) == 1 else "๋ถ€์ • (negative)"
  return {
      'sentence' : sentence,
      'prediction': pred,
      'positive_data': f"๊ธ์ • {positive_prob}",
      'negative_data': f"๋ถ€์ • {negative_prob}",
      'positive_width': f"{positive_prob * 100}%",
      'negative_width': f"{negative_prob * 100}%",
  }

code 2-9์—์„œ positive_width, negative_width๋Š” ์›น ํŽ˜์ด์ง€์—์„œ ๊ธ์ •/๋ถ€์ • ๋ง‰๋Œ€์˜ ๊ธธ์ด๋ฅผ ์กฐ์ •ํ•˜๋ ค๋Š” ๊ฒƒ์ด๋ฏ€๋กœ ํฌ๊ฒŒ ์‹ ๊ฒฝ์“ฐ์ง€ ์•Š์•„๋„ ๋œ๋‹ค.

4. ์›น ์„œ๋น„์Šค ์‹œ์ž‘ํ•˜๊ธฐ

์›น ์„œ๋น„์Šค ๋งŒ๋“ค๊ธฐ ์ค€๋น„

ngrok์€ ์ฝ”๋žฉ ๋กœ์ปฌ์—์„œ ์‹คํ–‰ ์ค‘์ธ ์›น์„œ๋น„์Šค๋ฅผ ์•ˆ์ „ํ•˜๊ฒŒ ์™ธ๋ถ€์—์„œ ์ ‘๊ทผ ๊ฐ€๋Šฅํ•˜๋„๋ก ํ•ด์ฃผ๋Š” ๋„๊ตฌ์ด๋‹ค. ngrok์„ ์‹คํ–‰ํ•˜๋ ค๋ฉด ํšŒ์›๊ฐ€์ž… ํ›„ ๋กœ๊ทธ์ธ์„ ํ•œ ๋’ค ์ด๊ณณ์— ์ ‘์†ํ•ด ์ธ์ฆํ† ํฐ(authtoken)์„ ํ™•์ธํ•ด์•ผ ํ•œ๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด ํ™•์ธ๋œ authtoken์ด test123์ด๋ผ๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์‹คํ–‰ ๋œ๋‹ค.

** !mkdir /root/.ngrok2 && echo "authtoken: test123" > /root/.ngrok2/ngrok.yml**

code 2-10

!mkdir /root/.ngrok2 && echo "authtoken: (์—ฌ๊ธฐ ์ฑ„์šฐ์„ธ์š”))" > /root/.ngrok2/ngrok.yml

์›น ์„œ๋น„์Šค ์‹œ์ž‘ํ•˜๊ธฐ

code 2-9์—์„œ ์ •์˜ํ•œ ์ธํผ๋Ÿฐ์Šค ํ•จ์ˆ˜ inference_fn์„ ๊ฐ€์ง€๊ณ  code 2-11์„ ์‹คํ–‰ํ•˜๋ฉด ํ”Œ๋ผ์Šคํฌ(flask)๋ผ๋Š” ํŒŒ์ด์ฌ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์˜ ๋„์›€์„ ๋ฐ›์•„ ์›น ์„œ๋น„์Šค๋ฅผ ๋„์šธ ์ˆ˜ ์žˆ๋‹ค.

code 2-11

from ratsnlp.nlpbook.classification import get_web_service_app
app = get_web_service_app(inference_fn)
app.run()

์›น ์‚ฌ์ดํŠธ์˜ ํ˜•ํƒœ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

model_inference

Share