[NLP] ๋ฌธ์žฅ ์Œ ๋ถ„๋ฅ˜ ๋ชจ๋ธ ํ•™์Šตํ•˜๊ธฐ

โ€ขAI/NLP

์ž์—ฐ์–ด์ฒ˜๋ฆฌ์˜ ์˜ˆ์ œ๋ฅผ ํ•™์Šตํ•˜์—ฌ ๋ณด์ž.
๋‹ค์Œ์€ ์ด์ „ ๊ธ€์—์„œ ์„ค๋ช…ํ•˜์˜€๋˜ ๋ฌธ์žฅ ์Œ ๋ถ„๋ฅ˜ ๋ชจ๋ธ์„ ๊ตฌํ˜„ํ•œ ๊ฒƒ์ด๋‹ค.

๋ณธ ํŒŒ์ผ์€ ์ด๊ธฐ์ฐฝ๋‹˜์˜ 'Do it! ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ'์— ๊ธฐ์ดˆํ•˜์—ฌ ์ž‘์„ฑ๋˜์—ˆ๋‹ค. :)

๋ฌธ์žฅ ์Œ ๋ถ„๋ฅ˜ ๋ชจ๋ธ ํ•™์Šตํ•˜๊ธฐ

์ „์ œ์™€ ๊ฐ€์„ค์„ ๊ฒ€์ฆํ•˜๋Š” ์ž์—ฐ์–ด ์ถ”๋ก  ๋ชจ๋ธ ๋งŒ๋“ค๊ธฐ

1. ๊ฐ์ข… ์„ค์ •ํ•˜๊ธฐ

TPU ๊ด€๋ จ ํŒจํ‚ค์ง€ ์„ค์น˜

์ฝ”๋žฉ ๋…ธํŠธ๋ถ ์ดˆ๊ธฐํ™” ๊ณผ์ •์—์„œ ํ•˜๋“œ์›จ์–ด ๊ฐ€์†๊ธฐ๋กœ TPU๋ฅผ ์„ ํƒํ–ˆ๋‹ค๋ฉด ๋‹ค์Œ ์ฝ”๋“œ๋ฅผ ์‹คํ–‰ํ•˜๊ณ , GPU๋ฅผ ์„ ํƒํ–ˆ๋‹ค๋ฉด ์‹คํ–‰ํ•˜์ง€ ์•Š๋Š”๋‹ค.

code 3-0

!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl

์˜์กด์„ฑ ํŒจํ‚ค์ง€ ์„ค์น˜

code 3-1์„ ์‹คํ–‰ํ•ด TPU ์ด์™ธ์˜ ์˜์กด์„ฑ ์žˆ๋Š” ํŒจํ‚ค์ง€๋ฅผ ์„ค์น˜ํ•œ๋‹ค.

code 3-1

!pip install ratsnlp
โ–ถCode output
Requirement already satisfied: ratsnlp in /usr/local/lib/python3.7/dist-packages (1.0.1)
Requirement already satisfied: pytorch-lightning==1.3.4 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (1.3.4)
Requirement already satisfied: torch>=1.9.0 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (1.10.0+cu111)
Requirement already satisfied: Korpora>=0.2.0 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (0.2.0)
Requirement already satisfied: flask>=1.1.4 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (1.1.4)
Requirement already satisfied: flask-cors>=3.0.10 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (3.0.10)
Requirement already satisfied: transformers==4.10.0 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (4.10.0)
Requirement already satisfied: flask-ngrok>=0.0.25 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (0.0.25)
Requirement already satisfied: numpy>=1.17.2 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (1.21.5)
Requirement already satisfied: fsspec[http]>=2021.4.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (2022.2.0)
Requirement already satisfied: tqdm>=4.41.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (4.62.3)
Requirement already satisfied: pyDeprecate==0.3.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (0.3.0)
Requirement already satisfied: torchmetrics>=0.2.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (0.7.2)
Requirement already satisfied: future>=0.17.1 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (0.18.2)
Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (21.3)
Requirement already satisfied: tensorboard!=2.5.0,>=2.2.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (2.8.0)
Requirement already satisfied: PyYAML<=5.4.1,>=5.1 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (5.4.1)
Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (0.10.3)
Requirement already satisfied: huggingface-hub>=0.0.12 in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (0.4.0)
Requirement already satisfied: sacremoses in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (0.0.47)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (4.11.1)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (2019.12.20)
Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (3.6.0)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (2.23.0)
Requirement already satisfied: click<8.0,>=5.1 in /usr/local/lib/python3.7/dist-packages (from flask>=1.1.4->ratsnlp) (7.1.2)
Requirement already satisfied: itsdangerous<2.0,>=0.24 in /usr/local/lib/python3.7/dist-packages (from flask>=1.1.4->ratsnlp) (1.1.0)
Requirement already satisfied: Werkzeug<2.0,>=0.15 in /usr/local/lib/python3.7/dist-packages (from flask>=1.1.4->ratsnlp) (1.0.1)
Requirement already satisfied: Jinja2<3.0,>=2.10.1 in /usr/local/lib/python3.7/dist-packages (from flask>=1.1.4->ratsnlp) (2.11.3)
Requirement already satisfied: Six in /usr/local/lib/python3.7/dist-packages (from flask-cors>=3.0.10->ratsnlp) (1.15.0)
Requirement already satisfied: aiohttp in /usr/local/lib/python3.7/dist-packages (from fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (3.8.1)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub>=0.0.12->transformers==4.10.0->ratsnlp) (3.10.0.2)
Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from Jinja2<3.0,>=2.10.1->flask>=1.1.4->ratsnlp) (2.0.1)
Requirement already satisfied: xlrd>=1.2.0 in /usr/local/lib/python3.7/dist-packages (from Korpora>=0.2.0->ratsnlp) (2.0.1)
Requirement already satisfied: dataclasses>=0.6 in /usr/local/lib/python3.7/dist-packages (from Korpora>=0.2.0->ratsnlp) (0.6)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->pytorch-lightning==1.3.4->ratsnlp) (3.0.7)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.10.0->ratsnlp) (1.24.3)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.10.0->ratsnlp) (2.10)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.10.0->ratsnlp) (2021.10.8)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.10.0->ratsnlp) (3.0.4)
Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.35.0)
Requirement already satisfied: protobuf>=3.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (3.17.3)
Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (57.4.0)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.4.6)
Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.0.0)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.8.1)
Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.37.1)
Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.44.0)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (3.3.6)
Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.6.1)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (4.2.4)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.2.8)
Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (4.8)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.3.1)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers==4.10.0->ratsnlp) (3.7.0)
Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.4.8)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (3.2.0)
Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (6.0.2)
Requirement already satisfied: asynctest==0.13.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (0.13.0)
Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (2.0.12)
Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (1.2.0)
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (21.4.0)
Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (1.3.0)
Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (1.7.2)
Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (4.0.2)
Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.10.0->ratsnlp) (1.1.0)

๊ตฌ๊ธ€ ๋“œ๋ผ์ด๋ธŒ์™€ ์—ฐ๊ฒฐ

์ฝ”๋žฉ ๋…ธํŠธ๋ถ์€ ์ผ์ •์‹œ๊ฐ„ ์‚ฌ์šฉํ•˜์ง€ ์•Š์œผ๋ฉด ๋‹น์‹œ๊นŒ์ง€์˜ ๋ชจ๋“  ๊ฒฐ๊ณผ๋ฌผ์ด ๋‚ ์•„๊ฐˆ ์ˆ˜ ์žˆ๋‹ค. ๋ชจ๋ธ ์ฒดํฌํฌ์ธํŠธ ๋“ฑ์„ ์ €์žฅํ•ด ์ฃผ๊ธฐ ์œ„ํ•ด ์ž์‹ ์˜ ๊ตฌ๊ธ€ ๋“œ๋ผ์ด๋ธŒ๋ฅผ ์ฝ”๋žฉ ๋…ธํŠธ๋ถ๊ณผ ์—ฐ๊ฒฐํ•œ๋‹ค.

code 3-2

from google.colab import drive
drive.mount('/gdrive', force_remount=True)
โ–ถCode output
Mounted at /gdrive

๋ชจ๋ธ ํ™˜๊ฒฝ ์„ค์ •

kcbert-base๋ชจ๋ธ์„ ์ธ๊ณต์ง€๋Šฅ ๊ธฐ์—… ์—…์Šคํ…Œ์ด์ง€๊ฐ€ ๊ณต๊ฐœํ•œ KLUE-NLI๋ฐ์ดํ„ฐ* ๋กœ ํŒŒ์ธํŠœ๋‹

*klue-benchmark.com/tasks/68/data/description

code 3-3

import torch
from ratsnlp.nlpbook.classification import ClassificationTrainArguments
args = ClassificationTrainArguments(
    pretrained_model_name="beomi/kcbert-base",
    downstream_task_name="pair-classification",
    downstream_corpus_name="klue-nli",
    downstream_model_dir="/gdrive/My Drive/nlpbook/checkpoint-paircls",
    batch_size=32 if torch.cuda.is_available() else 4,
    learning_rate=5e-5,
    max_seq_length=64,
    epochs=5,
    tpu_cores=0 if torch.cuda.is_available() else 8,
    seed=7,
)

๋žœ๋ค ์‹œ๋“œ ๊ณ ์ •

๋žœ๋ค ์‹œ๋“œ๋ฅผ ์„ค์ •

code 3-4๋Š” args์— ์ง€์ •๋œ ์‹œ๋“œ๋กœ ๊ณ ์ •ํ•˜๋Š” ์—ญํ• ์„ ํ•œ๋‹ค.

code 3-4

from ratsnlp import nlpbook
nlpbook.set_seed(args)
โ–ถCode output
set seed: 7

๋กœ๊ฑฐ ์„ค์ •

๊ฐ์ข… ๋กœ๊ทธ๋ฅผ ์ถœ๋ ฅํ•˜๋Š” ๋กœ๊ฑฐ๋ฅผ ์„ค์ •

code 3-5

nlpbook.set_logger(args)
โ–ถCode output
INFO:ratsnlp:Training/evaluation parameters ClassificationTrainArguments(pretrained_model_name='beomi/kcbert-base', downstream_task_name='pair-classification', downstream_corpus_name='klue-nli', downstream_corpus_root_dir='/content/Korpora', downstream_model_dir='/gdrive/My Drive/nlpbook/checkpoint-paircls', max_seq_length=64, save_top_k=1, monitor='min val_loss', seed=7, overwrite_cache=False, force_download=False, test_mode=False, learning_rate=5e-05, epochs=5, batch_size=32, cpu_workers=2, fp16=False, tpu_cores=0)

2. ๋ง๋ญ‰์น˜ ๋‚ด๋ ค๋ฐ›๊ธฐ

๋ง๋ญ‰์น˜ ๋‚ด๋ ค๋ฐ›๊ธฐ

KLUE-NLI ๋ฐ์ดํ„ฐ๋ฅผ ๋‚ด๋ ค๋ฐ›๋Š”๋‹ค. corpus_name์— ํ•ด๋‹นํ•˜๋Š” ๋ง๋ญ‰์น˜(klue_nli)๋ฅผ downstream_corpus_root_dir์•„๋ž˜(/root/Korpora)์— ์ €์žฅํ•ด๋‘”๋‹ค.

code 3-6

nlpbook.download_downstream_dataset(args)
โ–ถCode output
Downloading: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 12.3M/12.3M [00:00<00:00, 42.3MB/s]
Downloading: 100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1.47M/1.47M [00:00<00:00, 35.6MB/s]

3. ํ† ํฌ๋‚˜์ด์ € ์ค€๋น„ํ•˜๊ธฐ

ํ† ํฌ๋‚˜์ด์ € ์ค€๋น„

code 3-7์„ ์‹คํ–‰ํ•ด pretrained_model_name์— ํ•ด๋‹นํ•˜๋Š” ๋ชจ๋ธ(kcbert-base)์ด ์‚ฌ์šฉํ•˜๋Š” ํ† ํฌ๋‚˜์ด์ €๋ฅผ ์„ ์–ธํ•œ๋‹ค.

code 3-7

from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(
    args.pretrained_model_name,
    do_lower_case=False,
)
โ–ถCode output
Downloading:   0%|          | 0.00/250k [00:00<?, ?B/s]



Downloading:   0%|          | 0.00/49.0 [00:00<?, ?B/s]



Downloading:   0%|          | 0.00/619 [00:00<?, ?B/s]

4. ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌํ•˜๊ธฐ

ํ•™์Šต ๋ฐ์ดํ„ฐ์…‹ ๊ตฌ์ถ•

code 3-8์„ ์ˆ˜ํ–‰ํ•˜๋ฉด ํ•™์Šต ๋ฐ์ดํ„ฐ์…‹ ์„ ๋งŒ๋“ค ์ˆ˜ ์žˆ๋‹ค. KlueNLICorpus ํด๋ž˜์Šค๋Š” JSON ํŒŒ์ผ ํ˜•์‹์˜ KLUE-NLI ๋ฐ์ดํ„ฐ๋ฅผ ๋ฌธ์žฅ(์ „์ œ + ๊ฐ€์„ค)๊ณผ ๋ ˆ์ด๋ธ”(์ฐธ, ๊ฑฐ์ง“, ์ค‘๋ฆฝ)๋กœ ์ฝ์–ด๋“ค์ธ๋‹ค. KlueNLICorpus๋Š” ClassificationDataset์ด ์š”๊ตฌํ•˜๋ฉด ์ด ๋ฌธ์žฅ๊ณผ ๋ ˆ์ด๋ธ”์„ ClassificationDataset์— ์ œ๊ณตํ•œ๋‹ค.

code 3-8

from ratsnlp.nlpbook.paircls import KlueNLICorpus
from ratsnlp.nlpbook.classification import ClassificationDataset
corpus = KlueNLICorpus()
train_dataset = ClassificationDataset(
    args=args,
    corpus=corpus,
    tokenizer=tokenizer,
    mode="train",
)
โ–ถCode output
INFO:ratsnlp:Creating features from dataset file at /content/Korpora/klue-nli
INFO:ratsnlp:loading train data... LOOKING AT /content/Korpora/klue-nli/klue_nli_train.json
INFO:ratsnlp:tokenize sentences, it could take a lot of time...
INFO:ratsnlp:tokenize sentences [took 15.747 s]
INFO:ratsnlp:*** Example ***
INFO:ratsnlp:sentence A, B: 100๋ถ„๊ฐ„ ์ž˜๊ป„ ๊ทธ๋ž˜๋„ ์†Œ๋‹‰๋ถ๋•œ์— 2์ ์ค€๋‹ค + 100๋ถ„๊ฐ„ ์žค๋‹ค.
INFO:ratsnlp:tokens: [CLS] 100 ##๋ถ„๊ฐ„ ์ž˜ ##๊ป„ ๊ทธ๋ž˜๋„ ์†Œ ##๋‹‰ ##๋ถ ##๋•œ์— 2 ##์  ##์ค€๋‹ค [SEP] 100 ##๋ถ„๊ฐ„ ์žค ##๋‹ค . [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
INFO:ratsnlp:label: contradiction
INFO:ratsnlp:features: ClassificationFeatures(input_ids=[2, 8327, 15760, 2483, 4260, 8446, 1895, 5623, 5969, 10319, 21, 4213, 10172, 3, 8327, 15760, 2491, 4020, 17, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], attention_mask=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], token_type_ids=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], label=1)
INFO:ratsnlp:*** Example ***
INFO:ratsnlp:sentence A, B: 100๋ถ„๊ฐ„ ์ž˜๊ป„ ๊ทธ๋ž˜๋„ ์†Œ๋‹‰๋ถ๋•œ์— 2์ ์ค€๋‹ค + ์†Œ๋‹‰๋ถ์ด ์ •๋ง ๋ฉ‹์žˆ์—ˆ๋‹ค.
INFO:ratsnlp:tokens: [CLS] 100 ##๋ถ„๊ฐ„ ์ž˜ ##๊ป„ ๊ทธ๋ž˜๋„ ์†Œ ##๋‹‰ ##๋ถ ##๋•œ์— 2 ##์  ##์ค€๋‹ค [SEP] ์†Œ ##๋‹‰ ##๋ถ ##์ด ์ •๋ง ๋ฉ‹ ##์žˆ ##์—ˆ๋‹ค . [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
INFO:ratsnlp:label: neutral
INFO:ratsnlp:features: ClassificationFeatures(input_ids=[2, 8327, 15760, 2483, 4260, 8446, 1895, 5623, 5969, 10319, 21, 4213, 10172, 3, 1895, 5623, 5969, 4017, 8050, 1348, 4188, 8217, 17, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], attention_mask=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], token_type_ids=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], label=2)
INFO:ratsnlp:*** Example ***
INFO:ratsnlp:sentence A, B: 100๋ถ„๊ฐ„ ์ž˜๊ป„ ๊ทธ๋ž˜๋„ ์†Œ๋‹‰๋ถ๋•œ์— 2์ ์ค€๋‹ค + 100๋ถ„๊ฐ„ ์ž๋Š”๊ฒŒ ๋” ๋‚˜์•˜์„ ๊ฒƒ ๊ฐ™๋‹ค.
INFO:ratsnlp:tokens: [CLS] 100 ##๋ถ„๊ฐ„ ์ž˜ ##๊ป„ ๊ทธ๋ž˜๋„ ์†Œ ##๋‹‰ ##๋ถ ##๋•œ์— 2 ##์  ##์ค€๋‹ค [SEP] 100 ##๋ถ„๊ฐ„ ์ž๋Š” ##๊ฒŒ ๋” ๋‚˜ ##์•˜์„ ๊ฒƒ ๊ฐ™๋‹ค . [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
INFO:ratsnlp:label: neutral
INFO:ratsnlp:features: ClassificationFeatures(input_ids=[2, 8327, 15760, 2483, 4260, 8446, 1895, 5623, 5969, 10319, 21, 4213, 10172, 3, 8327, 15760, 15095, 4199, 832, 587, 25331, 258, 8604, 17, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], attention_mask=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], token_type_ids=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], label=2)
INFO:ratsnlp:*** Example ***
INFO:ratsnlp:sentence A, B: 101๋นŒ๋”ฉ ๊ทผ์ฒ˜์— ๋‚˜๋ฆ„ ์ฆ๊ธธ๊ฑฐ๋ฆฌ๊ฐ€ ๋งŽ์Šต๋‹ˆ๋‹ค. + 101๋นŒ๋”ฉ ๊ทผ์ฒ˜์—์„œ ์ฆ๊ธธ๊ฑฐ๋ฆฌ ์ฐพ๊ธฐ๋Š” ์–ด๋ ต์Šต๋‹ˆ๋‹ค.
INFO:ratsnlp:tokens: [CLS] 10 ##1 ##๋นŒ ##๋”ฉ ๊ทผ์ฒ˜์— ๋‚˜๋ฆ„ ์ฆ ##๊ธธ ##๊ฑฐ๋ฆฌ๊ฐ€ ๋งŽ์Šต๋‹ˆ๋‹ค . [SEP] 10 ##1 ##๋นŒ ##๋”ฉ ๊ทผ์ฒ˜์— ##์„œ ์ฆ ##๊ธธ ##๊ฑฐ๋ฆฌ ์ฐพ ##๊ธฐ๋Š” ์–ด๋ ต ##์Šต๋‹ˆ๋‹ค . [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
INFO:ratsnlp:label: contradiction
INFO:ratsnlp:features: ClassificationFeatures(input_ids=[2, 8240, 4068, 4647, 4389, 29671, 13715, 2676, 4583, 14516, 14617, 17, 3, 8240, 4068, 4647, 4389, 29671, 4072, 2676, 4583, 8181, 2851, 8189, 9775, 8046, 17, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], attention_mask=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], token_type_ids=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], label=1)
INFO:ratsnlp:*** Example ***
INFO:ratsnlp:sentence A, B: 101๋นŒ๋”ฉ ๊ทผ์ฒ˜์— ๋‚˜๋ฆ„ ์ฆ๊ธธ๊ฑฐ๋ฆฌ๊ฐ€ ๋งŽ์Šต๋‹ˆ๋‹ค. + 101๋นŒ๋”ฉ ์ฃผ๋ณ€์— ์ Š์€์ด๋“ค์ด ์ฆ๊ธธ๊ฑฐ๋ฆฌ๊ฐ€ ๋งŽ์Šต๋‹ˆ๋‹ค.
INFO:ratsnlp:tokens: [CLS] 10 ##1 ##๋นŒ ##๋”ฉ ๊ทผ์ฒ˜์— ๋‚˜๋ฆ„ ์ฆ ##๊ธธ ##๊ฑฐ๋ฆฌ๊ฐ€ ๋งŽ์Šต๋‹ˆ๋‹ค . [SEP] 10 ##1 ##๋นŒ ##๋”ฉ ์ฃผ๋ณ€์— ์ Š์€์ด๋“ค์ด ์ฆ ##๊ธธ ##๊ฑฐ๋ฆฌ๊ฐ€ ๋งŽ์Šต๋‹ˆ๋‹ค . [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
INFO:ratsnlp:label: neutral
INFO:ratsnlp:features: ClassificationFeatures(input_ids=[2, 8240, 4068, 4647, 4389, 29671, 13715, 2676, 4583, 14516, 14617, 17, 3, 8240, 4068, 4647, 4389, 12298, 22790, 2676, 4583, 14516, 14617, 17, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], attention_mask=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], token_type_ids=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], label=2)
INFO:ratsnlp:Saving features into cached file, it could take a lot of time...
INFO:ratsnlp:Saving features into cached file /content/Korpora/klue-nli/cached_train_BertTokenizer_64_klue-nli_pair-classification [took 1.934 s]

ClassificationDataset ํด๋ž˜์Šค๊ฐ€ ํ•˜๋Š” ์—ญํ• 

์ด ํด๋ž˜์Šค๋Š” KlueNLICorpus์™€ code 3-7์—์„œ ์„ ์–ธํ•ด ๋‘” ํ† ํฌ๋‚˜์ด์ €๋ฅผ ํ’ˆ๊ณ  ์žˆ๋‹ค.

ClassificationDataset์€ ์ œ๊ณต๋ฐ›์€ ๋ฌธ์žฅ๊ณผ ๋ ˆ์ด๋ธ” ๊ฐ๊ฐ์„ tokenizer๋ฅผ ํ™œ์šฉํ•ด ๋ชจ๋ธ์ด ํ•™์Šตํ•  ์ˆ˜ ์žˆ๋Š” ํ˜•ํƒœ(ClassificationFeature)๋กœ ๊ฐ€๊ณตํ•œ๋‹ค.
๋‹ค์‹œ ๋งํ•ด, ์ „์ œ์™€ ๊ฐ€์„ค 2๊ฐœ ๋ฌธ์žฅ์„ ๊ฐ๊ฐ ํ† ํฐํ™”ํ•˜๊ณ  ์ด๋ฅผ ์ธ๋ฑ์Šค๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ํ•œํŽธ, ๋ ˆ์ด๋ธ” ์—ญ์‹œ ์ •์ˆ˜๋กœ ๋ฐ”๊ฟ”์ฃผ๋Š” ์—ญํ• ์„ ํ•œ๋‹ค.

(entailment: 0, contradiction: 1, neutral: 2)

KlueNLICorpus์™€ classificationDataset์˜ ์—ญํ• ๊ณผ ์ž์„ธํ•œ ๊ตฌํ˜„ ๋‚ด์šฉ์€ ์•„๋ž˜์˜ ๋งํฌ๋ฅผ ์ฐธ๊ณ ํ•˜์ž!
(ํ˜„์žฌ๋Š” ๊ต์žฌ๋งํฌ๋ฅผ ์˜ฌ๋ ค๋‘์ง€๋งŒ, ์ถ”ํ›„ ๋ณธ์ธ์˜ ๊นƒํ—ˆ๋ธŒ์— ๊ตฌํ˜„ ์˜ˆ์ •)

  • ratsgo.github.io/nlpbook/docs/pair_cls/detail

ํ•™์Šต ๋ฐ์ดํ„ฐ ๋กœ๋” ๊ตฌ์ถ•

code 3-9๋ฅผ ์‹คํ–‰ํ•˜๋ฉด ํ•™์Šตํ•  ๋•Œ ์“ฐ์ด๋Š” ๋ฐ์ดํ„ฐ ๋กœ๋”๋ฅผ ๋งŒ๋“ค ์ˆ˜ ์žˆ๋‹ค. ํ•™์Šต์šฉ ๋ฐ์ดํ„ฐ ๋กœ๋”๋Š” ClassificationDataset ํด๋ž˜์Šค๊ฐ€ ๋“ค๊ณ  ์žˆ๋Š” ์ „์ฒด ์ธ์Šคํ„ด์Šค ๊ฐ€์šด๋ฐ ๋ฐฐํฌ ํฌ๊ธฐ(code 3-3 ์—์„œ ์ •์˜ํ•œ args์˜ batch_size)๋งŒํผ์˜ ์ธ์Šคํ„ด์Šค๋“ค์„ ๋น„๋ณต์›(replacement=False)๋žœ๋ค ์ถ”์ถœ(RandomSampler)ํ•œ ๋’ค ์ด๋ฅผ ๋ฐฐ์น˜ ํ˜•ํƒœ๋กœ ๊ฐ€๊ณต(nlpbook.data_collator)ํ•ด ๋ชจ๋ธ์— ๊ณต๊ธ‰ํ•˜๋Š” ์—ญํ• ์„ ์ˆ˜ํ–‰ํ•œ๋‹ค.

code 3-9

from torch.utils.data import DataLoader, RandomSampler
train_dataloader = DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    sampler=RandomSampler(train_dataset, replacement=False),
    collate_fn=nlpbook.data_collator,
    drop_last=False,
    num_workers=args.cpu_workers,
)

ํ‰๊ฐ€์šฉ ๋ฐ์ดํ„ฐ ๋กœ๋” ๊ตฌ์ถ•

code 3-10์„ ์‹คํ–‰ํ•˜๋ฉด ํ‰๊ฐ€์šฉ ๋ฐ์ดํ„ฐ ๋กœ๋”๋ฅผ ๊ตฌ์ถ•ํ•  ์ˆ˜ ์žˆ๋‹ค. ํ‰๊ฐ€์šฉ ๋ฐ์ดํ„ฐ ๋กœ๋”๋Š” ๋ฐฐ์น˜ ํฌ๊ธฐ(code 3-3์—์„œ ์ •์˜ํ•œ args์˜ batch_size)๋งŒํผ์˜ ์ธ์Šคํ„ด์Šค๋ฅผ ์ˆœ์„œ๋Œ€๋กœ ์ถ”์ถœ(Sequential Sampler)ํ•œ ํ›„ ์ด๋ฅผ ๋ฐฐ์น˜ ํ˜•ํƒœ๋กœ ๊ฐ€๊ณต(nlpbook.data_collator)ํ•ด ๋ชจ๋ธ์— ๊ณต๊ธ‰ํ•œ๋‹ค.

code 3-10

from torch.utils.data import SequentialSampler
val_dataset = ClassificationDataset(
    args=args,
    corpus=corpus,
    tokenizer=tokenizer,
    mode="test",
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=args.batch_size,
    sampler=SequentialSampler(val_dataset),
    collate_fn=nlpbook.data_collator,
    drop_last=False,
    num_workers=args.cpu_workers,
)
โ–ถCode output
INFO:ratsnlp:Loading features from cached file /content/Korpora/klue-nli/cached_test_BertTokenizer_64_klue-nli_pair-classification [took 0.116 s]

5. ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

๋ชจ๋ธ ์ดˆ๊ธฐํ™”

code 3-11์„ ์ˆ˜ํ–‰ํ•ด ๋ชจ๋ธ์„ ์ดˆ๊ธฐํ™” ํ•œ๋‹ค. ํ”„๋ฆฌํŠธ๋ ˆ์ธ์„ ๋งˆ์นœ BERT๋กœ kcbert-base๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค. code 3-3์—์„œ pretrained_model_name์„ beomi/kcber-base๋กœ ์ง€์ •ํ–ˆ๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค. ๋ฌผ๋ก  ํ—ˆ๊น…ํŽ˜์ด์Šค ๋ชจ๋ธ ํ—ˆ๋ธŒ์— ๋“ฑ๋ก๋œ ๋ชจ๋ธ์ด๋ผ๋ฉด ๋‹ค๋ฅธ ๋ชจ๋ธ ์—ญ์‹œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค.

BertForSequenceClassification์€ ํ”„๋ฆฌํŠธ๋ ˆ์ธ์„ ๋งˆ์นœ BERT๋ชจ๋ธ ์œ„์— ๋ฌธ์„œ ๋ถ„๋ฅ˜์šฉ ํƒœ์Šคํฌ ๋ชจ๋“ˆ์„ ๋ง๋ถ™์ธ ํ˜•ํƒœ์˜ ๋ชจ๋ธ ํด๋ž˜์Šค์ด๋‹ค. ์ด ํด๋ž˜์Šค๋Š” ๋ฌธ์„œ ๋ถ„๋ฅ˜ ๋ชจ๋ธ์—์„œ ์‚ฌ์šฉํ•œ ๊ฒƒ๊ณผ ๋™์ผํ•˜๋‹ค.

code 3-11

from transformers import BertConfig, BertForSequenceClassification
pretrained_model_config = BertConfig.from_pretrained(
    args.pretrained_model_name,
    num_labels=corpus.num_labels,
)
model = BertForSequenceClassification.from_pretrained(
    args.pretrained_model_name,
    config=pretrained_model_config,
)
โ–ถCode output
Downloading:   0%|          | 0.00/438M [00:00<?, ?B/s]


Some weights of the model checkpoint at beomi/kcbert-base were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at beomi/kcbert-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

6. ๋ชจ๋ธ ํ•™์Šต์‹œํ‚ค๊ธฐ

code 3-12๋ฅผ ์‹คํ–‰ํ•˜๋ฉด ๋ฌธ์žฅ ์Œ ๋ถ„๋ฅ˜์šฉ ํƒœ์Šคํฌ๋ฅผ ์ •์˜ํ•  ์ˆ˜ ์žˆ๋‹ค. ๋ชจ๋ธ์€ code 3-11์—์„œ ์ค€๋น„ํ•œ ๋ชจ๋ธ ํด๋ž˜์Šค๋ฅผ ClassificationTask์— ํฌํ•จํ•œ๋‹ค. ClassificationTask ํด๋ž˜์Šค์—๋Š” ์˜ตํ‹ฐ๋งˆ์ด์ €, ๋Ÿฌ๋‹ ๋ ˆ์ดํŠธ ์Šค์ผ€์ค„๋Ÿฌ๊ฐ€ ์ •์˜ ๋˜ ์žˆ๋Š”๋ฐ, ์˜ตํ‹ฐ๋งˆ์ด์ €๋กœ๋Š” ์•„๋‹ด(Adam), ๋Ÿฌ๋‹ ๋ ˆ์ดํŠธ ์Šค์ผ€์ค„๋Ÿฌ๋กœ๋Š” ExponentialLR์„ ์‚ฌ์šฉํ•œ๋‹ค.

ํƒœ์Šคํฌ ์ •์˜

code 3-12

from ratsnlp.nlpbook.classification import ClassificationTask
task = ClassificationTask(model, args)

ํŠธ๋ ˆ์ด๋„ˆ ์ •์˜

code 3-13์„ ์‹คํ–‰ํ•˜๋ฉด ํŠธ๋ ˆ์ด๋„ˆ๋ฅผ ์ •์˜ํ•  ์ˆ˜ ์žˆ๋‹ค. ์ด ํŠธ๋ ˆ์ด๋„ˆ๋Š” ํŒŒ์ดํ† ์น˜ ๋ผ์ดํŠธ๋‹ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์˜ ๋„์›€์„ ๋ฐ›์•„ GPU/TPU ์„ค์ •, ๋กœ๊ทธ ๋ฐ ์ฒดํฌํฌ์ธํŠธ ๋“ฑ ๊ท€์ฐฎ์€ ์„ค์ •๋“ค์„ ์•Œ์•„์„œ ํ•ด์ค€๋‹ค.

code 3-13

trainer = nlpbook.get_trainer(args)
โ–ถCode output
GPU available: True, used: True
TPU available: False, using: 0 TPU cores

ํ•™์Šต ๊ฐœ์‹œ

code 3-14์™€ ๊ฐ™์ด ํŠธ๋ ˆ์ด๋„ˆ์˜ fit()ํ•จ์ˆ˜๋ฅผ ํ˜ธ์ถœํ•˜๋ฉด ํ•™์Šต์„ ์‹œ์ž‘ํ•œ๋‹ค.

code 3-14

trainer.fit(
    task,
    train_dataloader=train_dataloader,
    val_dataloaders=val_dataloader,
)
โ–ถCode output
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                          | Params
--------------------------------------------------------
0 | model | BertForSequenceClassification | 108 M 
--------------------------------------------------------
108 M     Trainable params
0         Non-trainable params
108 M     Total params
435.683   Total estimated model params size (MB)



Training: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]

๋ฌธ์žฅ ์Œ ๋ถ„๋ฅ˜๋Š” ๋ฌธ์„œ ๋ถ„๋ฅ˜ ๊ณผ์ œ์™€ ํƒœ์Šคํฌ ๋ชจ๋“ˆ ๊ตฌ์กฐ ๋“ฑ์—์„œ ๋ณธ์งˆ์ ์œผ๋กœ ๋‹ค๋ฅด์ง€ ์•Š๋‹ค. ์ž…๋ ฅ๋ฌธ์„œ๊ฐ€ 1๊ฐœ๋ƒ(๋ฌธ์„œ๋ถ„๋ฅ˜), 2๊ฐœ๋ƒ(๋ฌธ์žฅ ์Œ ๋ถ„๋ฅ˜)์˜ ์ฐจ์ด๊ฐ€ ์žˆ์„ ๋ฟ์ด๋‹ค.

Share