[NLP] ๋ฌธ์žฅ ์Œ ๋ถ„๋ฅ˜ ๋ชจ๋ธ ์‹ค์ „ ํˆฌ์ž…

โ€ขAI/NLP

์ž, ๊ทธ๋Ÿผ ํ•™์Šต์„ ๋งˆ์นœ ๋ชจ๋ธ์„ ์–ด๋–ป๊ฒŒ ์‚ฌ์šฉํ• ๊นŒ?

๋ณธ ํŒŒ์ผ์€ ์ด๊ธฐ์ฐฝ๋‹˜์˜ 'Do it! ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ'์— ๊ธฐ์ดˆํ•˜์—ฌ ์ž‘์„ฑ๋˜์—ˆ์Œ์„ ๋ฏธ๋ฆฌ ์•Œ๋ ค๋“œ๋ฆฝ๋‹ˆ๋‹ค! :)

ํ•™์Šต ๋งˆ์นœ ๋ชจ๋ธ์„ ์‹ค์ „ ํˆฌ์ž…ํ•˜๊ธฐ

ํ•™์Šต์„ ๋งˆ์นœ ๋ฌธ์žฅ ์Œ ๋ถ„๋ฅ˜ ๋ชจ๋ธ์„ ์ธํผ๋Ÿฐ์Šคํ•˜๋Š” ๊ณผ์ •์„ ์‹ค์Šตํ•ด๋ณธ๋‹ค. ์ด๋ฒˆ ์‹ค์Šต์—์„œ ๋งŒ๋“œ๋Š” ์›น ์„œ๋น„์Šค์˜ ๊ฐœ๋…๋„๋Š” ์•„๋ž˜ ๊ทธ๋ฆผ 1๊ณผ ๊ฐ™๋‹ค.

pair_classification_map
๊ทธ๋ฆผ 1. ๋ฌธ์žฅ ์Œ ๋ถ„๋ฅ˜ ์›น ์„œ๋น„์Šค

์ „์ œ์™€ ๊ฐ€์„ค ๋ฌธ์žฅ์„ ๋ฐ›์•„ ๋‹ต๋ณ€ํ•˜๋Š” ์›น ์„œ๋น„์Šค์ด๋‹ค. ์ „์ œ์™€ ๊ฐ€์„ค ๊ฐ๊ฐ์„ ํ† ํฐํ™”, ์ธ๋ฑ์‹ฑํ•œ ๋’ค ๋ชจ๋ธ ์ž…๋ ฅ๊ฐ’์œผ๋กœ ๋งŒ๋“ค๊ณ  ์ด๋ฅผ ๋ชจ๋ธ์— ๋„ฃ์–ด

[์ „์ œ์— ๋Œ€ํ•ด ๊ฐ€์„ค์ด ์ฐธ์ผ ํ™•๋ฅ , ์ „์ œ์— ๋Œ€ํ•ด ๊ฐ€์„ค์ด ๊ฑฐ์ง“์ผ ํ™•๋ฅ , ์ „์ œ์— ๋Œ€ํ•ด ๊ฐ€์„ค์ด ์ค‘๋ฆฝ์ผ ํ™•๋ฅ ]
์„ ๊ณ„์‚ฐํ•œ๋‹ค.

์ดํ›„ ์•ฝ๊ฐ„์˜ ํ›„์ฒ˜๋ฆฌ ๊ณผ์ •์„ ๊ฑฐ์ณ ์‘๋‹ตํ•˜๋Š” ๋ฐฉ์‹์ด๋‹ค.

์ „์ œ์™€ ๊ฐ€์„ค์„ ๊ฒ€์ฆํ•˜๋Š” ์›น ์„œ๋น„์Šค ๋งŒ๋“ค๊ธฐ

1. ํ™˜๊ฒฝ ์„ค์ •ํ•˜๊ธฐ

์˜์กด์„ฑ ํŒจํ‚ค์ง€ ์„ค์น˜

pip ๋ช…๋ น์–ด๋ฅผ ํ†ตํ•ด ์˜์กด์„ฑ์žˆ๋Š” ํŒจํ‚ค์ง€๋ฅผ ์„ค์น˜ํ•œ๋‹ค.

code 4-0

!pip install ratsnlp
โ–ถCode output
Requirement already satisfied: ratsnlp in /usr/local/lib/python3.7/dist-packages (1.0.1)
Requirement already satisfied: transformers==4.10.0 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (4.10.0)
Requirement already satisfied: Korpora>=0.2.0 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (0.2.0)
Requirement already satisfied: flask>=1.1.4 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (1.1.4)
Requirement already satisfied: flask-cors>=3.0.10 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (3.0.10)
Requirement already satisfied: torch>=1.9.0 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (1.10.0+cu111)
Requirement already satisfied: flask-ngrok>=0.0.25 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (0.0.25)
Requirement already satisfied: pytorch-lightning==1.3.4 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (1.3.4)
Requirement already satisfied: future>=0.17.1 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (0.18.2)
Requirement already satisfied: tensorboard!=2.5.0,>=2.2.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (2.8.0)
Requirement already satisfied: fsspec[http]>=2021.4.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (2022.2.0)
Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (21.3)
Requirement already satisfied: torchmetrics>=0.2.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (0.7.2)
Requirement already satisfied: numpy>=1.17.2 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (1.21.5)
Requirement already satisfied: pyDeprecate==0.3.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (0.3.0)
Requirement already satisfied: tqdm>=4.41.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (4.63.0)
Requirement already satisfied: PyYAML<=5.4.1,>=5.1 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (5.4.1)
Requirement already satisfied: huggingface-hub>=0.0.12 in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (0.4.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (3.6.0)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (4.11.2)
Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (0.10.3)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (2.23.0)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (2019.12.20)
Requirement already satisfied: sacremoses in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (0.0.47)
Requirement already satisfied: Jinja2<3.0,>=2.10.1 in /usr/local/lib/python3.7/dist-packages (from flask>=1.1.4->ratsnlp) (2.11.3)
Requirement already satisfied: Werkzeug<2.0,>=0.15 in /usr/local/lib/python3.7/dist-packages (from flask>=1.1.4->ratsnlp) (1.0.1)
Requirement already satisfied: itsdangerous<2.0,>=0.24 in /usr/local/lib/python3.7/dist-packages (from flask>=1.1.4->ratsnlp) (1.1.0)
Requirement already satisfied: click<8.0,>=5.1 in /usr/local/lib/python3.7/dist-packages (from flask>=1.1.4->ratsnlp) (7.1.2)
Requirement already satisfied: Six in /usr/local/lib/python3.7/dist-packages (from flask-cors>=3.0.10->ratsnlp) (1.15.0)
Requirement already satisfied: aiohttp in /usr/local/lib/python3.7/dist-packages (from fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (3.8.1)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub>=0.0.12->transformers==4.10.0->ratsnlp) (3.10.0.2)
Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from Jinja2<3.0,>=2.10.1->flask>=1.1.4->ratsnlp) (2.0.1)
Requirement already satisfied: dataclasses>=0.6 in /usr/local/lib/python3.7/dist-packages (from Korpora>=0.2.0->ratsnlp) (0.6)
Requirement already satisfied: xlrd>=1.2.0 in /usr/local/lib/python3.7/dist-packages (from Korpora>=0.2.0->ratsnlp) (2.0.1)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->pytorch-lightning==1.3.4->ratsnlp) (3.0.7)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.10.0->ratsnlp) (2.10)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.10.0->ratsnlp) (1.24.3)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.10.0->ratsnlp) (2021.10.8)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.10.0->ratsnlp) (3.0.4)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.4.6)
Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.6.1)
Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (57.4.0)
Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.44.0)
Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.35.0)
Requirement already satisfied: protobuf>=3.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (3.17.3)
Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.0.0)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.8.1)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (3.3.6)
Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.37.1)
Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (4.8)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.2.8)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (4.2.4)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.3.1)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers==4.10.0->ratsnlp) (3.7.0)
Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.4.8)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (3.2.0)
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (21.4.0)
Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (1.3.0)
Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (1.7.2)
Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (2.0.12)
Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (1.2.0)
Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (4.0.2)
Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (6.0.2)
Requirement already satisfied: asynctest==0.13.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (0.13.0)
Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.10.0->ratsnlp) (1.1.0)

๊ตฌ๊ธ€ ๋“œ๋ผ์ด๋ธŒ ์—ฐ๋™

ํ•™์Šตํ•œ ๋ชจ๋ธ์˜ ์ฒดํฌํฌ์ธํŠธ๋Š” ๊ตฌ๊ธ€ ๋“œ๋ผ์ด๋ธŒ์— ์ €์žฅํ•ด ๋‘์—ˆ์œผ๋ฏ€๋กœ, code 4-1์„ ์‹คํ–‰ํ•˜์—ฌ ์ฝ”๋žฉ ๋…ธํŠธ๋ถ๊ณผ ์ž์‹ ์˜ ๊ตฌ๊ธ€ ๋“œ๋ผ์ด๋ธŒ๋ฅผ ์—ฐ๋™ํ•œ๋‹ค.

code 4-1

from google.colab import drive
drive.mount('/gdrive', force_remount=True)
โ–ถCode output
Mounted at /gdrive

์ธํผ๋Ÿฐ์Šค ์„ค์ •

๊ฐ์ข… ์ธํผ๋Ÿฐ์Šค ์„ค์ •์„ ์ˆ˜ํ–‰ํ•œ๋‹ค. pretrained_model_name๊ณผ max_seq_length, downstream_model_dir ๋ชจ๋‘ ์•ž ํŠธ๋ ˆ์ธ์—์„œ ์ ์šฉํ•œ ๊ทธ๋Œ€๋กœ ์ž…๋ ฅํ•˜์—ฌ์•ผ ํ•œ๋‹ค.

code 4-2

from ratsnlp.nlpbook.classification import ClassificationDeployArguments
args = ClassificationDeployArguments(
    pretrained_model_name="beomi/kcbert-base",
    downstream_model_dir="/gdrive/My Drive/nlpbook/checkpoint-paircls",
    max_seq_length=64,
)
โ–ถCode output
downstream_model_checkpoint_fpath: /gdrive/My Drive/nlpbook/checkpoint-paircls/epoch=1-val_loss=0.82.ckpt

2. ํ† ํฌ๋‚˜์ด์ € ๋ฐ ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ

code 4-3์„ ์‹คํ–‰ํ•ด ํ† ํฌ๋‚˜์ด์ €๋ฅผ ์ดˆ๊ธฐํ™”ํ•œ๋‹ค.

code 4-3

from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(
    args.pretrained_model_name,
    do_lower_case=False,
)

์ฒดํฌํฌ์ธํŠธ ๋กœ๋“œ

code 4-4๋Š” pair_classification_train.ipynb์—์„œ ํŒŒ์ธํŠœ๋‹ํ•œ ๋ชจ๋ธ์˜ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ฝ์–ด ๋“ค์ธ๋‹ค.

code 4-4

import torch
fine_tuned_model_ckpt = torch.load(
    args.downstream_model_checkpoint_fpath,
    map_location=torch.device("cpu"),
)

BERT ์„ค์ • ๋กœ๋“œ ๋ฐ BERT ๋ชจ๋ธ ์ดˆ๊ธฐํ™”

code 4-5๋Š” pair_classification_train.ipynb์˜ ํŒŒ์ธํŠœ๋‹ ๋•Œ ์‚ฌ์šฉํ•œ pretrained_model_name์— ํ•ด๋‹นํ•˜๋Š” ๋ชจ๋ธ์˜ ์„ค์ •๊ฐ’๋“ค์„ ์ฝ์–ด๋“ค์ด๋ฉฐ, code 4-6์„ ์‹คํ–‰ํ•˜๋ฉด ํ•ด๋‹น ๊ฐ’๋Œ€๋กœ BERT ๋ชจ๋ธ์„ ์ดˆ๊ธฐํ™” ํ•œ๋‹ค.

code 4-5

from transformers import BertConfig
pretrained_model_config = BertConfig.from_pretrained(
    args.pretrained_model_name,
    num_labels=fine_tuned_model_ckpt['state_dict']['model.classifier.bias'].shape.numel(),
)

code 4-6

from transformers import BertForSequenceClassification
model = BertForSequenceClassification(pretrained_model_config)

์ฒดํฌํฌ์ธํŠธ ์ฃผ์ž…ํ•˜๊ธฐ

code 4-7์€ ์ดˆ๊ธฐํ™”ํ•œ BERT๋ชจ๋ธ์— code 4-4์˜ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ฃผ์ž…ํ•œ๋‹ค

code 4-7

model.load_state_dict({k.replace("model.",""): v for k, v in fine_tuned_model_ckpt['state_dict'].items()})
โ–ถCode output
<All keys matched successfully>

ํ‰๊ฐ€ ๋ชจ๋“œ๋กœ ์ „ํ™˜

์ด์–ด์„œ code 4-8์„ ์‹คํ–‰ํ•˜๋ฉด ๋ชจ๋ธ์ด ํ‰๊ฐ€๋ชจ๋“œ๋กœ ์ „ํ™˜๋˜๊ฒŒ ๋œ๋‹ค. ๋“œ๋กญ์•„์›ƒ ๋“ฑ ํ•™์Šต ๋•Œ๋งŒ ์‚ฌ์šฉํ•˜๋Š” ๊ธฐ๋ฒ•๋“ค์„ ๋ฌดํšจํ™”ํ•˜๋Š” ์—ญํ• ์„ ํ•œ๋‹ค.

code 4-8

model.eval()
โ–ถCode output
BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30000, 768, padding_idx=0)
      (position_embeddings): Embedding(300, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (1): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (2): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (3): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (4): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (5): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (6): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (7): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (8): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (9): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (10): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (11): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (pooler): BertPooler(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (activation): Tanh()
    )
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (classifier): Linear(in_features=768, out_features=3, bias=True)
)

3. ๋ชจ๋ธ ์ถœ๋ ฅ๊ฐ’ ๋งŒ๋“ค๊ณ  ํ›„์ฒ˜๋ฆฌ ํ•˜๊ธฐ

code 4-9๋Š” ์ธํผ๋Ÿฐ์Šค ๊ณผ์ •์„ ์ •์˜ํ•œ ํ•จ์ˆ˜์ด๋‹ค. ์ „์ œ(premise)์™€ ๊ฐ€์„ค(hypothesis)์„ ์ž…๋ ฅ๋ฐ›์•„ ๊ฐ๊ฐ ํ† ํฐํ™”, ์ธ๋ฑ์‹ฑ์„ ์ˆ˜ํ–‰ํ•œ ๋’ค input_ids, attention_mask, token_type_ids๋ฅผ ๋งŒ๋“ ๋‹ค. ์ด๋“ค ์••๋ ฅ๊ฐ’์„ ํŒŒ์ดํ† ์น˜ ํ…์„œ ์ž๋ฃŒํ˜•์œผ๋กœ ๋ณ€ํ™˜ํ•œ ๋’ค ๋ชจ๋ธ์— ์ž…๋ ฅํ•œ๋‹ค.

์ธํผ๋Ÿฐ์Šค ํ•จ์ˆ˜

code 4-9

def inference_fn(premise, hypothesis):
  # ์ „์ œ์™€ ๊ฐ€์„ค์„ ๋ชจ๋ธ ์ž…๋ ฅ๊ฐ’์œผ๋กœ ๋งŒ๋“ค๊ธฐ
  inputs = tokenizer(
      [(premise, hypothesis)],
      max_length=args.max_seq_length,
      padding="max_length",
      truncation=True,
  )
  with torch.no_grad():
    # ๋ชจ๋ธ ๊ณ„์‚ฐํ•˜๊ธฐ
    outputs = model(**{k: torch.tensor(v) for k, v in inputs.items()})  # {}์•ˆ = inputs๋ฅผ ํŒŒ์ดํ† ์น˜ ํ…์„œ๋กœ ๋ฐ”๊พธ๊ธฐ

    # ๋กœ์ง“์— ์†Œํ”„ํŠธ๋งฅ์Šค ์ทจํ•˜๊ธฐ
    prob = outputs.logits.softmax(dim=1)

    # ํ™•๋ฅ ์„ ์†Œ์ˆ˜์  ๋‘ ์ž๋ฆฌ์—์„œ ๋ฐ˜์˜ฌ๋ฆผ
    entailment_prob = round(prob[0][0].item(), 2)
    contradiction_prob = round(prob[0][1].item(), 2)
    neutral_prob = round(prob[0][2].item(), 2)

    # ์˜ˆ์ธก ํ™•๋ฅ ์˜ ์ตœ๋Œ“๊ฐ’ ์œ„์น˜์— ๋”ฐ๋ผ pred ๋งŒ๋“ค๊ธฐ
    if torch.argmax(prob) == 0:
      pred = "์ฐธ (entailment)"
    elif torch.argmax(prob) == 1:
      pred = "๊ฑฐ์ง“ (contradiction)"
    else:
      pred = "์ค‘๋ฆฝ (neutral)"
  
  return {
      'premise': premise,
      'hypothesis': hypothesis,
      'prediction': pred,
      'entailment_data': f"์ฐธ {entailment_prob}",
      'contradiction_data': f"๊ฑฐ์ง“ {contradiction_prob}",
      'neutral_data': f"์ค‘๋ฆฝ {neutral_prob}",
      'entailment_width': f"{entailment_prob * 100}%",
      'contradiction_width': f"{contradiction_prob * 100}%",
      'neutral_width': f"{neutral_prob * 100}%"
  }

**๋ชจ๋ธ ์ถœ๋ ฅ๊ฐ’(output.logits)**์€ ์†Œํ”„ํŠธ๋งฅ์Šค ํ•จ์ˆ˜ ์ ์šฉ ์ด์ „์˜ ๋กœ์ง“ ํ˜•ํƒœ์ด๋‹ค. ์—ฌ๊ธฐ์— ์†Œํ”„ํŠธ๋งฅ์Šค ํ•จ์ˆ˜๋ฅผ ์จ์„œ ๋ชจ๋ธ ์ถœ๋ ฅ์„ ํ™•๋ฅ  ํ˜•ํƒœ๋กœ ๋ฐ”๊พผ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ์•ฝ๊ฐ„ ํ›„์ฒ˜๋ฆฌํ•˜์—ฌ ์˜ˆ์ธก ํ™•๋ฅ ์˜ ์ตœ๋Œ“๊ฐ’์ด ์ฐธ ์œ„์น˜(0)์ผ ๊ฒฝ์šฐ ํ•ด๋‹น ๋ฌธ์žฅ์ด '์ฐธ (entailment)', ๊ฑฐ์ง“ ์œ„์น˜(1)์ผ ๊ฒฝ์šฐ '๊ฑฐ์ง“ (contradiction)', ์ค‘๋ฆฝ ์œ„์น˜(2)์ผ ๊ฒฝ์šฐ '์ค‘๋ฆฝ (neutral)'์ด ๋˜๋„๋ก pred ๊ฐ’์„ ๋งŒ๋“ ๋‹ค.

code 4-9์—์„œ entailment_width, contradiction_width, neutral_width๋Š” ์›น ํŽ˜์ด์ง€์—์„œ ์ฐธ, ๊ฑฐ์ง“, ์ค‘๋ฆฝ ๋ง‰๋Œ€ ๊ธธ์ด๋ฅผ ์กฐ์ •ํ•˜๋Š” ์ •๋ณด์ด๋ฏ€๋กœ ํฌ๊ฒŒ ์‹ ๊ฒฝ ์“ฐ์ง€ ์•Š์•„๋„ ๋œ๋‹ค.

4. ์›น ์„œ๋น„์Šค ์‹œ์ž‘ํ•˜๊ธฐ

์›น ์„œ๋น„์Šค ๋งŒ๋“ค๊ธฐ ์ค€๋น„

ngrok์€ ์ฝ”๋žฉ ๋กœ์ปฌ์—์„œ ์‹คํ–‰ ์ค‘์ธ ์›น์„œ๋น„์Šค๋ฅผ ์•ˆ์ „ํ•˜๊ฒŒ ์™ธ๋ถ€์—์„œ ์ ‘๊ทผ ๊ฐ€๋Šฅํ•˜๋„๋ก ํ•ด์ฃผ๋Š” ๋„๊ตฌ์ด๋‹ค. ngrok์„ ์‹คํ–‰ํ•˜๋ ค๋ฉด ํšŒ์›๊ฐ€์ž… ํ›„ ๋กœ๊ทธ์ธ์„ ํ•œ ๋’ค ์ด๊ณณ์— ์ ‘์†ํ•ด ์ธ์ฆํ† ํฐ(authtoken)์„ ํ™•์ธํ•ด์•ผ ํ•œ๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด ํ™•์ธ๋œ authtoken์ด test123์ด๋ผ๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์‹คํ–‰ ๋œ๋‹ค.

** !mkdir /root/.ngrok2 && echo "authtoken: test123" > /root/.ngrok2/ngrok.yml**

code 4-10

!mkdir /root/.ngrok2 && echo "authtoken: (์—ฌ๊ธฐ ์ฑ„์šฐ์„ธ์š”)" > /root/.ngrok2/ngrok.yml
โ–ถCode output
mkdir: cannot create directory โ€˜/root/.ngrok2โ€™: File exists

์›น ์„œ๋น„์Šค ์‹œ์ž‘ํ•˜๊ธฐ

code 4-9์—์„œ ์ •์˜ํ•œ ์ธํผ๋Ÿฐ์Šค ํ•จ์ˆ˜ inference_fn์„ ๊ฐ€์ง€๊ณ  code 4-11์„ ์‹คํ–‰ํ•˜๋ฉด ์›น ์„œ๋น„์Šค๋ฅผ ๋„์šธ ์ˆ˜ ์žˆ๋‹ค. ํŒŒ์ด์ฌ์˜ ํ”Œ๋ผ์Šคํฌ๋ฅผ ํ™œ์šฉํ•œ ์•ฑ์ด๋‹ค.

code 4-11

from ratsnlp.nlpbook.paircls import get_web_service_app
app = get_web_service_app(inference_fn)
app.run()
โ–ถCode output
 * Serving Flask app "ratsnlp.nlpbook.paircls.deploy" (lazy loading)
 * Environment: production
   WARNING: This is a development server. Do not use it in a production deployment.
   Use a production WSGI server instead.
 * Debug mode: off


 * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)


 * Running on http://0163-35-238-180-140.ngrok.io
 * Traffic stats available on http://127.0.0.1:4040


127.0.0.1 - - [04/Mar/2022 09:14:48] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [04/Mar/2022 09:14:49] "GET /favicon.ico HTTP/1.1" 404 -
127.0.0.1 - - [04/Mar/2022 09:14:49] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [04/Mar/2022 09:15:01] "POST /api HTTP/1.1" 200 -

์›น์‚ฌ์ดํŠธ ํ˜•ํƒœ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

pair_classification

Share