[NLP] ๋ฌธ์ ๋ถ๋ฅ ๋ชจ๋ธ training
์์ฐ์ด์ฒ๋ฆฌ์ ์์ ๋ฅผ ํ์ตํ์ฌ ๋ณด์.
๋ค์์ ์ด์ ๊ธ์์ ์ค๋ช
ํ์๋ ๋ฌธ์ ๋ถ๋ฅ ๋ชจ๋ธ์ ๊ตฌํํ ๊ฒ์ด๋ค.
๋ณธ ํ์ผ์ ์ด๊ธฐ์ฐฝ๋์ 'Do it! ์์ฐ์ด ์ฒ๋ฆฌ'์ ๊ธฐ์ดํ์ฌ ์์ฑ๋์๋ค! :)
๋ฌธ์๋ถ๋ฅ ๋ชจ๋ธ ํ์ตํ๊ธฐ
1. ๊ฐ์ข ์ค์ ํ๊ธฐ
TPU ๊ด๋ จ ํจํค์ง ์ค์น
์ฝ๋ฉ ๋
ธํธ๋ถ ์ด๊ธฐํ ๊ณผ์ ์์ ํ๋์จ์ด ๊ฐ์๊ธฐ๋ก TPU๋ฅผ ์ ํํ๋ค๋ฉด ๋ค์ ์ฝ๋๋ฅผ ์คํํ๋ฉด ๋๋ค.
๊ทธ๋ฌ๋ฉด TPU ๊ด๋ จ ๋ผ์ด๋ธ๋ฌ๋ฆฌ ๋ค์ ์ค์นํ๋ค.
(์ฐธ๊ณ ๋ก TPU ํ์ต์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ ์ง์ ๋ฑ์ด GPU๋ณด๋ค ๋ถ์์ ํ ํธ์ด๋ฏ๋ก ๋ ์ ์์ผ๋ฉด GPU๋ฅผ ์ฌ์ฉํ๊ธฐ๋ฅผ ๊ถํจ)
code 1-0
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl
โถCode output
Collecting torch-xla==1.9
Using cached https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl (149.9 MB)
Requirement already satisfied: cloud-tpu-client==0.10 in /usr/local/lib/python3.7/dist-packages (0.10)
Requirement already satisfied: google-api-python-client==1.8.0 in /usr/local/lib/python3.7/dist-packages (from cloud-tpu-client==0.10) (1.8.0)
Requirement already satisfied: oauth2client in /usr/local/lib/python3.7/dist-packages (from cloud-tpu-client==0.10) (4.1.3)
Requirement already satisfied: google-api-core<2dev,>=1.13.0 in /usr/local/lib/python3.7/dist-packages (from google-api-python-client==1.8.0->cloud-tpu-client==0.10) (1.26.3)
Requirement already satisfied: google-auth>=1.4.1 in /usr/local/lib/python3.7/dist-packages (from google-api-python-client==1.8.0->cloud-tpu-client==0.10) (1.35.0)
Requirement already satisfied: uritemplate<4dev,>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from google-api-python-client==1.8.0->cloud-tpu-client==0.10) (3.0.1)
Requirement already satisfied: six<2dev,>=1.6.1 in /usr/local/lib/python3.7/dist-packages (from google-api-python-client==1.8.0->cloud-tpu-client==0.10) (1.15.0)
Requirement already satisfied: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.7/dist-packages (from google-api-python-client==1.8.0->cloud-tpu-client==0.10) (0.0.4)
Requirement already satisfied: httplib2<1dev,>=0.9.2 in /usr/local/lib/python3.7/dist-packages (from google-api-python-client==1.8.0->cloud-tpu-client==0.10) (0.17.4)
Requirement already satisfied: googleapis-common-protos<2.0dev,>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (1.55.0)
Requirement already satisfied: pytz in /usr/local/lib/python3.7/dist-packages (from google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (2018.9)
Requirement already satisfied: setuptools>=40.3.0 in /usr/local/lib/python3.7/dist-packages (from google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (57.4.0)
Requirement already satisfied: protobuf>=3.12.0 in /usr/local/lib/python3.7/dist-packages (from google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (3.17.3)
Requirement already satisfied: packaging>=14.3 in /usr/local/lib/python3.7/dist-packages (from google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (21.3)
Requirement already satisfied: requests<3.0.0dev,>=2.18.0 in /usr/local/lib/python3.7/dist-packages (from google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (2.23.0)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth>=1.4.1->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (0.2.8)
Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth>=1.4.1->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (4.8)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth>=1.4.1->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (4.2.4)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=14.3->google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (3.0.7)
Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth>=1.4.1->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (0.4.8)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (3.0.4)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (1.24.3)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (2021.10.8)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core<2dev,>=1.13.0->google-api-python-client==1.8.0->cloud-tpu-client==0.10) (2.10)
์์กด์ฑ ํจํค์ง ์ค์น
๋ค์ ์ฝ๋๋ TPU์ด์ธ์ ์์กด์ฑ ์๋ ํจํค์ง๋ฅผ ์ค์นํ๋ค.
๋ช ๋ น์ด ๋งจ ์์ ๋ถ์ ๋๋ํ(!)๋ ์ฝ๋ฉ ํ๊ฒฝ์์ ํ์ด์ฌ์ด ์๋, ์ ธ(shell)๋ช ๋ น์ ์ํํ๋ค๋ ์๋ฏธ์ด๋ค.
code 1-1
!pip install ratsnlp
โถCode output
Requirement already satisfied: ratsnlp in /usr/local/lib/python3.7/dist-packages (1.0.1)
Requirement already satisfied: torch>=1.9.0 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (1.10.0+cu111)
Requirement already satisfied: transformers==4.10.0 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (4.10.0)
Requirement already satisfied: flask-ngrok>=0.0.25 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (0.0.25)
Requirement already satisfied: Korpora>=0.2.0 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (0.2.0)
Requirement already satisfied: flask-cors>=3.0.10 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (3.0.10)
Requirement already satisfied: pytorch-lightning==1.3.4 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (1.3.4)
Requirement already satisfied: flask>=1.1.4 in /usr/local/lib/python3.7/dist-packages (from ratsnlp) (1.1.4)
Requirement already satisfied: fsspec[http]>=2021.4.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (2022.2.0)
Requirement already satisfied: torchmetrics>=0.2.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (0.7.2)
Requirement already satisfied: PyYAML<=5.4.1,>=5.1 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (5.4.1)
Requirement already satisfied: tqdm>=4.41.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (4.62.3)
Requirement already satisfied: pyDeprecate==0.3.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (0.3.0)
Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (21.3)
Requirement already satisfied: future>=0.17.1 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (0.18.2)
Requirement already satisfied: numpy>=1.17.2 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (1.21.5)
Requirement already satisfied: tensorboard!=2.5.0,>=2.2.0 in /usr/local/lib/python3.7/dist-packages (from pytorch-lightning==1.3.4->ratsnlp) (2.8.0)
Requirement already satisfied: sacremoses in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (0.0.47)
Requirement already satisfied: huggingface-hub>=0.0.12 in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (0.4.0)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (2019.12.20)
Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (4.11.1)
Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (3.6.0)
Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (0.10.3)
Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers==4.10.0->ratsnlp) (2.23.0)
Requirement already satisfied: Jinja2<3.0,>=2.10.1 in /usr/local/lib/python3.7/dist-packages (from flask>=1.1.4->ratsnlp) (2.11.3)
Requirement already satisfied: Werkzeug<2.0,>=0.15 in /usr/local/lib/python3.7/dist-packages (from flask>=1.1.4->ratsnlp) (1.0.1)
Requirement already satisfied: click<8.0,>=5.1 in /usr/local/lib/python3.7/dist-packages (from flask>=1.1.4->ratsnlp) (7.1.2)
Requirement already satisfied: itsdangerous<2.0,>=0.24 in /usr/local/lib/python3.7/dist-packages (from flask>=1.1.4->ratsnlp) (1.1.0)
Requirement already satisfied: Six in /usr/local/lib/python3.7/dist-packages (from flask-cors>=3.0.10->ratsnlp) (1.15.0)
Requirement already satisfied: aiohttp in /usr/local/lib/python3.7/dist-packages (from fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (3.8.1)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub>=0.0.12->transformers==4.10.0->ratsnlp) (3.10.0.2)
Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from Jinja2<3.0,>=2.10.1->flask>=1.1.4->ratsnlp) (2.0.1)
Requirement already satisfied: xlrd>=1.2.0 in /usr/local/lib/python3.7/dist-packages (from Korpora>=0.2.0->ratsnlp) (2.0.1)
Requirement already satisfied: dataclasses>=0.6 in /usr/local/lib/python3.7/dist-packages (from Korpora>=0.2.0->ratsnlp) (0.6)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->pytorch-lightning==1.3.4->ratsnlp) (3.0.7)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.10.0->ratsnlp) (2021.10.8)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.10.0->ratsnlp) (2.10)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.10.0->ratsnlp) (3.0.4)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers==4.10.0->ratsnlp) (1.24.3)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.4.6)
Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.0.0)
Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (57.4.0)
Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.37.1)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (3.3.6)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.8.1)
Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.6.1)
Requirement already satisfied: protobuf>=3.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (3.17.3)
Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.35.0)
Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.44.0)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.2.8)
Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (4.8)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (4.2.4)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (1.3.1)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers==4.10.0->ratsnlp) (3.7.0)
Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (0.4.8)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard!=2.5.0,>=2.2.0->pytorch-lightning==1.3.4->ratsnlp) (3.2.0)
Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (1.3.0)
Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (1.2.0)
Requirement already satisfied: asynctest==0.13.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (0.13.0)
Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (6.0.2)
Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (4.0.2)
Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (1.7.2)
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (21.4.0)
Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->fsspec[http]>=2021.4.0->pytorch-lightning==1.3.4->ratsnlp) (2.0.12)
Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers==4.10.0->ratsnlp) (1.1.0)
๊ตฌ๊ธ ๋๋ผ์ด๋ธ์ ์ฐ๊ฒฐ
์ฝ๋ฉ ๋ ธํธ๋ถ์ ์ผ์ ์๊ฐ ์ฌ์ฉํ์ง ์์ผ๋ฉด ๋น์๊น์ง์ ๋ชจ๋ ๊ฒฐ๊ณผ๋ฌผ์ด ๋ ์๊ฐ ์ ์๋ค. ๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ ๋ฑ์ ์ ์ฅํด ๋๊ธฐ ์ํด ์์ ์ ๊ตฌ๊ธ ๋๋ผ์ด๋ธ๋ฅผ ์ฝ๋ฉ ๋ ธํธ๋ถ๊ณผ ์ฐ๊ฒฐํ๋ค.
code 1-2
from google.colab import drive
drive.mount('/gdrive', force_remount=True)
โถCode output
Mounted at /gdrive
๋ชจ๋ธ ํ๊ฒฝ ์ค์
kcbert-base ๋ชจ๋ธ์ NSMC๋ฐ์ดํฐ๋ก ํ์ธํ๋
code 1-3
import torch
from ratsnlp.nlpbook.classification import ClassificationTrainArguments
args = ClassificationTrainArguments(
pretrained_model_name="beomi/kcbert-base",
downstream_corpus_name="nsmc",
downstream_model_dir="/gdrive/My Drive/nlpbook/checkpoint-doccls",
batch_size=32 if torch.cuda.is_available() else 4,
learning_rate=5e-5,
max_seq_length=128,
epochs=3,
tpu_cores=0 if torch.cuda.is_available() else 8,
seed=7,
)
์ฐธ๊ณ ๋ก TrainArguments์ ๊ฐ ์ธ์๊ฐ ํ๋ ์ญํ ๊ณผ ์๋ฏธ๋ ๋ค์๊ณผ ๊ฐ๋ค.
pretrained_model_name
ํ๋ฆฌํธ๋ ์ธ ๋ง์น ์ธ์ด ๋ชจ๋ธ์ ์ด๋ฆ (๋จ, ํด๋น ๋ชจ๋ธ์ ํ๊น ํ์ด์ค ๋ชจ๋ธ ํ๋ธ์ ๋ฑ๋ก๋์ด ์์ด์ผ ํ๋ค.)
downstream_corpus_name
๋ค์ด์คํธ๋ฆผ ๋ฐ์ดํฐ์ ์ด๋ฆ
downstream_corpus_root_dir
๋ค์ด์คํธ๋ฆผ ๋ฐ์ดํฐ๋ฅผ ๋ด๋ ค๋ฐ์ ์์น. ์ ๋ ฅํ์ง ์์ผ๋ฉด /root/Korpora์ ์ ์ฅ๋๋ค.
downstream_model_dir
ํ์ธํ๋๋ ๋ชจ๋ธ์ ์ฒดํฌํฌ์ธํธ๊ฐ ์ ์ฅ๋ ์์น. gdrive/My Drive/nlpbook/checkpoint-doccs๋ก ์ง์ ํ๋ฉด ์์ ์ ๊ตฌ๊ธ ๋๋ผ์ด๋ธ [๋ด ํด๋] ์๋ npbook/checkpoint-doccls๋๋ ํฐ๋ฆฌ์ ์ ์ฅ๋๋ค.
batch_size
๋ฐฐ์น ํฌ๊ธฐ. ํ๋์จ์ด ๊ฐ์๊ธฐ๋ก GPU๋ฅผ ์ ํ(
torch.cuda.is_available() == True)ํ๋ค๋ฉด 32, TPU๋ผ๋ฉด(torch.cuda.is_available() == False) 4. ์ฝ๋ฉ ํ๊ฒฝ์์ TPU๋ ๋ณดํต 8๊ฐ์ ์ฝ์ด๊ฐ ํ ๋น๋๋๋ฐ batch_size๋ ์ฝ์ด๋ณ๋ก ์ ์ฉ๋๋ ๋ฐฐ์น ํฌ๊ธฐ์ด๋ฏ๋ก ์ด๋ ๊ฒ ์ค์ ํด ๋๋ค.
learning_rate
๋ฌ๋ ๋ ์ดํธ(๋ณดํญ). 1ํ ์คํ ์์ ๋ชจ๋ธ์ ์ผ๋ง๋ ์ ๋ฐ์ดํธํ ์ง์ ๊ดํ ํฌ๊ธฐ๋ฅผ ๊ฐ๋ฆฌํจ๋ค.
max_seq_length
ํ ํฐ ๊ธฐ์ค ์ ๋ ฅ ๋ฌธ์ฅ ์ต๋ ๊ธธ์ด. ์ด๋ณด๋ค ๊ธด ๋ฌธ์ฅ์
max_seq_length๋ก ์๋ฅด๊ณ , ์งง์ ๋ฌธ์ฅ์max_seq_length๊ฐ ๋๋๋ก ์คํ์ ํ ํฐ([PAD])๋ฅผ ๋ถ์ฌ์ค๋ค.
epochs
ํ์ต ์ํฌํฌ ์. 3์ด๋ผ๋ฉด ํ์ต ๋ฐ์ดํฐ ์ ์ฒด๋ฅผ 3ํ ๋ฐ๋ณต ํ์ตํฉ๋๋ค.
tpu_cores
TPU ์ฝ์ด ์. ํ๋์จ์ด ๊ฐ์๊ธฐ๋ก GPU๋ฅผ ์ ํํ๋ค๋ฉด 0, TPU๋ผ๋ฉด 8.
seed
๋๋ค ์๋(์ ์).
None์ ์ ๋ ฅํ๋ฉด ๋๋ค ์๋๋ฅผ ๊ณ ์ ํ์ง ์๋๋ค.
๋๋ค ์๋ ๊ณ ์
๋๋ค ์๋๋ฅผ ์ค์ .
code 1-4๋ args์ ์ง์ ๋ ์๋๋ก ๊ณ ์ ํ๋ ์ญํ ์ ํ๋ค.
code 1-4
from ratsnlp import nlpbook
nlpbook.set_seed(args)
โถCode output
set seed: 7
๋ก๊ฑฐ ์ค์
๊ฐ์ข ๋ก๊ทธ๋ฅผ ์ถ๋ ฅํ๋ ๋ก๊ฑฐ๋ฅผ ์ค์ .
code 1-5
nlpbook.set_logger(args)
โถCode output
INFO:ratsnlp:Training/evaluation parameters ClassificationTrainArguments(pretrained_model_name='beomi/kcbert-base', downstream_task_name='document-classification', downstream_corpus_name='nsmc', downstream_corpus_root_dir='/content/Korpora', downstream_model_dir='/gdrive/My Drive/nlpbook/checkpoint-doccls', max_seq_length=128, save_top_k=1, monitor='min val_loss', seed=7, overwrite_cache=False, force_download=False, test_mode=False, learning_rate=5e-05, epochs=3, batch_size=32, cpu_workers=2, fp16=False, tpu_cores=0)
2. ๋ง๋ญ์น ๋ด๋ ค๋ฐ๊ธฐ
๋ง๋ญ์น ๋ด๋ ค ๋ฐ๊ธฐ
NSMC ๋ฐ์ดํฐ๋ฅผ ๋ด๋ ค๋ฐ๋๋ค. ๋ฐ์ดํฐ๋ฅผ ๋ด๋ ค๋ฐ๋ ๋๊ตฌ๋ก ์ฝํฌ๋ผ(Korpora*)๋ผ๋ ํ์ด์ฌ ์คํ์์ค ํจํค์ง๋ฅผ ์ฌ์ฉํด, corpus_name(nsmc)์ ํด๋นํ๋ ๋ง๋ญ์น๋ฅผ root_dir(/root/Korpora) ์๋์ ์ ์ฅํด ๋๋ค.
- Korpora*: github.com/ko-nlp/korpora
code 1-6
from Korpora import Korpora
Korpora.fetch(
corpus_name=args.downstream_corpus_name,
root_dir=args.downstream_corpus_root_dir,
force_download=True,
)
โถCode output
[nsmc] download ratings_train.txt: 14.6MB [00:00, 75.6MB/s]
[nsmc] download ratings_test.txt: 4.90MB [00:00, 33.9MB/s]
3. ํ ํฌ๋์ด์ ์ค๋นํ๊ธฐ
ํ ํฌ๋์ด์ ์ค๋น
๋ณธ ํ๋ก์ ํธ์์ ๋ค๋ฃจ๋ ๋ฐ์ดํฐ์ ๊ธฐ๋ณธ ๋จ์๋ ํ ์คํธ ํํ์ ๋ฌธ์ฅ์ด๋ค. ํ ํฐํ๋ ๋ฌธ์ฅ์ ํ ํฐ ์ํ์ค๋ก ๋ถ์ ํ๋ ๊ณผ์ ์ ๊ฐ๋ฆฌํจ๋ค. ๋ณธ ์ค์ต์์ ์ฌ์ฉํ๋ ๋ชจ๋ธ์ ์์ฐ์ด ๋ฌธ์ฅ์ ๋ถ์ ํ ํ ํฐ ์ํ์ค๋ฅผ ์ ๋ ฅ ๋ฐ๋๋ค.
code 1-7๋ฅผ ์คํํด kcbert-base ๋ชจ๋ธ์ด ์ฌ์ฉํ๋ ํ ํฌ๋์ด์ ๋ฅผ ์ ์ธํ๋ค.
ํ ํฌ๋์ด์ ๋ ํ ํฐํ๋ฅผ ์ํํ๋ ํ๋ก๊ทธ๋จ์ด๋ผ๋ ๋ป์ด๋ค.
code 1-7
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(
args.pretrained_model_name,
do_lower_case=False,
)
โถCode output
Downloading: 0%| | 0.00/250k [00:00<?, ?B/s]
Downloading: 0%| | 0.00/49.0 [00:00<?, ?B/s]
Downloading: 0%| | 0.00/619 [00:00<?, ?B/s]
4. ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌํ๊ธฐ
๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ํ์ตํ๋ ค๋ฉด ํ์ต ๋ฐ์ดํฐ๋ฅผ ๋ฐฐ์น ๋จ์๋ก ๊ณ์ ๋ชจ๋ธ์ ๊ณต๊ธํด ์ฃผ์ด์ผ ํ๋ค. ํ์ดํ ์น์์๋ ์ด ์ญํ ์ ๋ฐ์ดํฐ ๋ก๋(dataloader)๊ฐ ์ํํ๋ค.
๋ฐ์ดํฐ ๋ก๋๋ ๋ฐ์ดํฐ์ (dataset)์ด ๋ณด์ ํ๊ณ ์๋ ์ธ์คํด์ค๋ฅผ ๋ฐฐ์น ํฌ๊ธฐ๋งํผ ๋ฝ์์ ์๋ฃํ, ๋ฐ์ดํฐ ๊ธธ์ด ๋ฑ ์ ํด์ง ํ์์ ๋ง์ถฐ ๋ฐฐ์น๋ฅผ ๋ง๋ค์ด ์ค๋ค.
ํ์ต ๋ฐ์ดํฐ์ ๊ตฌ์ถ
code 1-8์ ํตํด ClassificationDataset์ ๋ง๋ค ์ ์๋ค.
ClassificationDataset์ ๊ฐ์ฅ ํฐ ์ญํ ์ ๋ชจ๋ ์ธ์คํด์ค๋ฅผ ๊ฐ์ง๊ณ ์๋ค๊ฐ ๋ฐ์ดํฐ ๋ก๋๊ฐ ๋ฐฐ์น๋ฅผ ๋ง๋ค ๋ ์ธ์คํด์ค๋ฅผ ์ ๊ณตํ๋ ์ผ์ด๋ค.
code 1-8
from ratsnlp.nlpbook.classification import NsmcCorpus, ClassificationDataset
corpus = NsmcCorpus()
train_dataset = ClassificationDataset(
args=args,
corpus=corpus,
tokenizer=tokenizer,
mode="train",
)
โถCode output
INFO:ratsnlp:Creating features from dataset file at /content/Korpora/nsmc
INFO:ratsnlp:loading train data... LOOKING AT /content/Korpora/nsmc/ratings_train.txt
INFO:ratsnlp:tokenize sentences, it could take a lot of time...
INFO:ratsnlp:tokenize sentences [took 42.255 s]
INFO:ratsnlp:*** Example ***
INFO:ratsnlp:sentence: ์ ๋๋น.. ์ง์ง ์ง์ฆ๋๋ค์ ๋ชฉ์๋ฆฌ
INFO:ratsnlp:tokens: [CLS] ์ ๋ ##๋น . . ์ง์ง ์ง์ฆ๋๋ค ##์ ๋ชฉ์๋ฆฌ
INFO:ratsnlp:label: 0
INFO:ratsnlp:features: ClassificationFeatures(input_ids=[2, 2170, 832, 5045, 17, 17, 7992, 29734, 4040, 10720, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], attention_mask=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], token_type_ids=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], label=0)
INFO:ratsnlp:*** Example ***
INFO:ratsnlp:sentence: ํ ...ํฌ์คํฐ๋ณด๊ณ ์ด๋ฉ์ํ์ค....์ค๋ฒ์ฐ๊ธฐ์กฐ์ฐจ ๊ฐ๋ณ์ง ์๊ตฌ๋
INFO:ratsnlp:tokens: [CLS] ํ . . . ํฌ ##์คํฐ ##๋ณด๊ณ ์ด๋ฉ ##์ํ ##์ค . . . . ์ค๋ฒ ##์ฐ๊ธฐ ##์กฐ์ฐจ ๊ฐ๋ณ ##์ง ์ ##๊ตฌ๋ [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
INFO:ratsnlp:label: 1
INFO:ratsnlp:features: ClassificationFeatures(input_ids=[2, 3521, 17, 17, 17, 3294, 13069, 8190, 10635, 13796, 4006, 17, 17, 17, 17, 17613, 19625, 9790, 17775, 4102, 2175, 8030, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], attention_mask=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], token_type_ids=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], label=1)
INFO:ratsnlp:*** Example ***
INFO:ratsnlp:sentence: ๋๋ฌด์ฌ๋ฐ์๋ค๊ทธ๋์๋ณด๋๊ฒ์์ถ์ฒํ๋ค
INFO:ratsnlp:tokens: [CLS] ๋๋ฌด ##์ฌ ##๋ฐ ##์๋ค ##๊ทธ๋ ##์ ##๋ณด๋ ##๊ฒ์ ##์ถ ##์ฒ ##ํ๋ค
INFO:ratsnlp:label: 0
INFO:ratsnlp:features: ClassificationFeatures(input_ids=[2, 8069, 4089, 7847, 8217, 9791, 4072, 9136, 8750, 4142, 4244, 8008, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], attention_mask=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], token_type_ids=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], label=0)
INFO:ratsnlp:*** Example ***
INFO:ratsnlp:sentence: ๊ต๋์ ์ด์ผ๊ธฐ๊ตฌ๋จผ ..์์งํ ์ฌ๋ฏธ๋ ์๋ค..ํ์ ์กฐ์
INFO:ratsnlp:tokens: [CLS] ๊ต๋์ ์ด์ผ๊ธฐ ##๊ตฌ๋จผ . . ์์งํ ์ฌ๋ฏธ ##๋ ์๋ค . . ํ ##์ ์กฐ์
INFO:ratsnlp:label: 0
INFO:ratsnlp:features: ClassificationFeatures(input_ids=[2, 12164, 9089, 9828, 17, 17, 8876, 10827, 4008, 8131, 17, 17, 3288, 4213, 16612, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], attention_mask=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], token_type_ids=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], label=0)
INFO:ratsnlp:*** Example ***
INFO:ratsnlp:sentence: ์ฌ์ด๋ชฌํ๊ทธ์ ์ต์ด์ค๋ฐ ์ฐ๊ธฐ๊ฐ ๋๋ณด์๋ ์ํ!์คํ์ด๋๋งจ์์ ๋์ด๋ณด์ด๊ธฐ๋ง ํ๋ ์ปค์คํด ๋์คํธ๊ฐ ๋๋ฌด๋๋ ์ด๋ป๋ณด์๋ค
INFO:ratsnlp:tokens: [CLS] ์ฌ์ด ##๋ชฌ ##ํ ##๊ทธ ##์ ์ต ##์ด ##์ค๋ฐ ์ฐ๊ธฐ ##๊ฐ ๋ ##๋ณด ##์๋ ์ํ ! ์คํ์ด ##๋ ##๋งจ ##์์ ๋์ด ##๋ณด์ด ##๊ธฐ๋ง ํ๋ ์ปค ##์ค ##ํด ๋ ##์คํธ ##๊ฐ ๋๋ฌด๋๋ ์ด๋ป ##๋ณด ##์๋ค
INFO:ratsnlp:label: 1
INFO:ratsnlp:features: ClassificationFeatures(input_ids=[2, 8538, 4880, 4335, 4313, 4042, 2452, 4471, 10670, 11219, 4009, 870, 4010, 13043, 9376, 5, 24034, 4356, 4617, 7971, 22878, 11980, 9235, 10129, 3010, 4103, 4713, 834, 8795, 4009, 22110, 23997, 4010, 9827, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], attention_mask=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], token_type_ids=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], label=1)
INFO:ratsnlp:Saving features into cached file, it could take a lot of time...
INFO:ratsnlp:Saving features into cached file /content/Korpora/nsmc/cached_train_BertTokenizer_128_nsmc_document-classification [took 23.322 s]
ClassificationDataset ํด๋์ค๊ฐ ํ๋ ์ญํ
์ด ํด๋์ค๋ NsmcCorpus์ ์์์ ์ ์ธํด ๋ **ํ ํฌ๋์ด์ (tokenizer)**๋ฅผ ํ๊ณ ์๋ค.
NsmcCorpus๋ CSV ํ์ผ ํ์์ NSMC ๋ฐ์ดํฐ๋ฅผ ๋ฌธ์ฅ๊ณผ ๋ ์ด๋ธ*(์๋ฅผ ๋ค๋ฉด ์ํ ๋ฆฌ๋ทฐ์ ๊ธ์ ๋ฐ ๋ถ์ )*์ผ๋ก ์ฝ๋๋ค.
NsmcCorpus๋ ClassificationDataset์ด ์๊ตฌํ๋ฉด ์ด ๋ฌธ์ฅ๊ณผ ๋ ์ด๋ธ์ ClassificationDataset์ ์ ๊ณตํ๋ค.
ClassificationDataset์ ์ ๊ณต๋ฐ์ ๋ฌธ์ฅ๊ณผ ๋ ์ด๋ธ ๊ฐ๊ฐ์ tokenizer๋ฅผ ํ์ฉํด ๋ชจ๋ธ์ด ํ์ตํ ์ ์๋ ํํ(ClassificationFeature)๋ก ๊ฐ๊ณตํ๋ค.
ClassificationFeature๋ผ๋ ์๋ฃํ์๋ ์ด 4๊ฐ์ง์ ์ ๋ณด๊ฐ ์๋ค.
-
์ฒซ๋ฒ์งธ๋ **
input_id**์ด๋ค. ์ธ๋ฑ์ค๋ก ๋ณํ๋ ํ ํฐ ์ํ์ค์ด๋ค. -
๋๋ฒ์งธ๋ **
attention_mask**๋ก ํด๋น ํ ํฐ์ด ํจ๋ฉ ํ ํฐ์ธ์ง(0) ์๋์ง(1)๋ฅผ ๋ํ๋ธ๋ค. -
token_type_ids์ธ๊ทธ๋จผํธ ์ ๋ณด, **label**์ ์ ์๋ก ๋ฐ๋ ๋ ์ด๋ธ ์ ๋ณด ์ด๋ค.
ClassificationFeatures ๊ฐ ๊ตฌ์ฑ ์์์ ์๋ฃํ์ ๋ค์๊ณผ ๊ฐ๋ค.
input_ids:List[int]attention_mask:List[int]token_type_ids:Listh[int]label:int
ํ์ต ๋ฐ์ดํฐ ๋ก๋ ๊ตฌ์ถ
code 1-9๋ฅผ ํตํด ํ์ตํ ๋ ์ฐ์ด๋ ๋ฐ์ดํฐ ๋ก๋๋ฅผ ๋ง๋ค ์ ์๋ค. ๋ฐ์ดํฐ ๋ก๋๋ ClassificationDatasetํด๋์ค๊ฐ ๋ค๊ณ ์๋ ์ ์ฒด ์ธ์คํด์ค ๊ฐ์ด๋ฐ ๋ฐฐ์น ํฌ๊ธฐ(code 1-3์์ ์ ์ํ args์ batch_size)๋งํผ์ ๋ฝ์ ๋ฐฐ์น ํํ๋ก ๊ฐ๊ณต(nlpbook.data_collator)ํ๋ ์ญํ ์ ์ํํ๋ค.
code 1-9
from torch.utils.data import DataLoader, RandomSampler
train_dataloader = DataLoader(
train_dataset,
batch_size=args.batch_size,
sampler=RandomSampler(train_dataset, replacement=False),
collate_fn=nlpbook.data_collator,
drop_last=False,
num_workers=args.cpu_workers,
)
์ฝ๋๋ฅผ ์์ธํ ๋ณด๋ฉด sampler์ collate_fn์ด ๋์ ๋๋ค.
์ ์๋ ์ํ๋ง ๋ฐฉ์์ ์ ์ํ๋ค.
์ฌ๊ธฐ์ ๋ง๋ ๋ฐ์ดํฐ ๋ก๋๋ ๋ฐฐ์น๋ฅผ ๋ง๋ค ๋ ClassificationDataset์ด ๋ค๊ณ ์๋ ์ ์ฒด ์ธ์คํด์ค ๊ฐ์ด๋ฐ batch_size ๊ฐ์๋งํผ ๋น๋ณต์(replacement=False) ๋๋ค ์ถ์ถ(RandomSampler)ํ๋ค.
collate_fn์ ์ด๋ ๊ฒ ๋ฝ์ ์ธ์คํด์ค๋ค์ ๋ฐฐ์น๋ก ๋ง๋๋ ์ญํ ์ ํ๋ ํจ์์ด๋ค. nlpbook.data_collator๋ ๊ฐ์ ๋ฐฐ์น์์ ์ธ์คํด์ค๊ฐ ์ฌ๋ฟ์ผ ๋ ์ด๋ฅผ input_ids, attention_mask ๋ฑ ์ข
๋ฅ๋ณ๋ก ๋ชจ์ผ๊ณ ํ์ดํ ์น๊ฐ ์๊ตฌํ๋ ์๋ฃํ์ธ ํ
์(tensor)ํํ๋ก ๋ฐ๊พธ๋ ์ญํ ์ ์ํํ๋ค.
ํ๊ฐ์ฉ ๋ฐ์ดํฐ ๋ก๋ ๊ตฌ์ถ
ํ๊ฐ์ฉ ๋ฐ์ดํฐ ๋ก๋๋ ํ์ต์ฉ ๋ฐ์ดํฐ ๋ก๋์ ๋ฌ๋ฆฌ **SequentialSampler**๋ฅผ ์ฌ์ฉํ๋ค. SequentialSampler๋ ์ธ์คํด์ค๋ฅผ batch_size๋งํผ ์์๋๋ก ์ถ์ถํ๋ ์ญํ ์ ํ๋ค. ํ์ต ๋ ๋ฐฐ์น ๊ตฌ์ฑ์ ๋๋ค์ผ๋ก ํ๋ ๊ฒ์ด ์ข์๋ฐ, ํ๊ฐํ ๋๋ ํ๊ฐ์ฉ ๋ฐ์ดํฐ ์ ์ฒด๋ฅผ ์ฌ์ฉํ๋ฏ๋ก ๊ตณ์ด ๋๋ค์ผ๋ก ๊ตฌ์ฑํ ์ด์ ๊ฐ ์์ด SequentialSampler๋ฅผ ์ฌ์ฉํ๋ค.
code 1-10์ ํตํด ํ๊ฐ์ฉ ๋ฐ์ดํฐ ๋ก๋๋ฅผ ๊ตฌ์ถํ๋ค.
code 1-10
from torch.utils.data import SequentialSampler
val_dataset = ClassificationDataset(
args=args,
corpus=corpus,
tokenizer=tokenizer,
mode="test",
)
val_dataloader= DataLoader(
val_dataset,
batch_size=args.batch_size,
sampler=SequentialSampler(val_dataset),
collate_fn=nlpbook.data_collator,
drop_last=False,
num_workers=args.cpu_workers,
)
โถCode output
INFO:ratsnlp:Creating features from dataset file at /content/Korpora/nsmc
INFO:ratsnlp:loading test data... LOOKING AT /content/Korpora/nsmc/ratings_test.txt
INFO:ratsnlp:tokenize sentences, it could take a lot of time...
INFO:ratsnlp:tokenize sentences [took 14.198 s]
INFO:ratsnlp:*** Example ***
INFO:ratsnlp:sentence: ๊ตณ ใ
INFO:ratsnlp:tokens: [CLS] ๊ตณ ใ

INFO:ratsnlp:label: 1
INFO:ratsnlp:features: ClassificationFeatures(input_ids=[2, 352, 192, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], attention_mask=[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], token_type_ids=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], label=1)
INFO:ratsnlp:*** Example ***
INFO:ratsnlp:sentence: GDNTOPCLASSINTHECLUB
INFO:ratsnlp:tokens
INFO:ratsnlp:label: 0
INFO:ratsnlp:features: ClassificationFeatures(input_ids=[2, 40, 4452, 4581, 25144, 4579, 4881, 4450, 4580, 4985, 4985, 4506, 4581, 4850, 5121, 4451, 4881, 4450, 5167, 4756, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], attention_mask=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], token_type_ids=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], label=0)
INFO:ratsnlp:*** Example ***
INFO:ratsnlp:sentence: ๋ญ์ผ ์ด ํ์ ๋ค์.... ๋์์ง ์์ง๋ง 10์ ์ง๋ฆฌ๋ ๋๋์ฑ ์๋์์
INFO:ratsnlp:tokens: [CLS] ๋ญ์ผ ์ด ํ ##์ ##๋ค์ . . . . ๋์ ##์ง ์์ง๋ง 10 ##์ ์ง๋ฆฌ ##๋ ๋๋์ฑ ์๋์์
INFO:ratsnlp:label: 0
INFO:ratsnlp:features: ClassificationFeatures(input_ids=[2, 10691, 2451, 3288, 4213, 7977, 17, 17, 17, 17, 10476, 4153, 15426, 8240, 4213, 21394, 4008, 15616, 13439, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], attention_mask=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], token_type_ids=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], label=0)
INFO:ratsnlp:*** Example ***
INFO:ratsnlp:sentence: ์ง๋ฃจํ์ง๋ ์์๋ฐ ์์ ๋ง์ฅ์... ๋์ฃผ๊ณ ๋ณด๊ธฐ์๋....
INFO:ratsnlp:tokens: [CLS] ์ง ##๋ฃจ ##ํ์ง๋ ์์๋ฐ ์์ ๋ง์ฅ ##์ . . . ๋์ฃผ๊ณ ๋ณด๊ธฐ์ ##๋
INFO:ratsnlp:label: 0
INFO:ratsnlp:features: ClassificationFeatures(input_ids=[2, 2688, 4532, 16036, 20879, 8357, 15971, 4252, 17, 17, 17, 13900, 25253, 4008, 17, 17, 17, 17, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], attention_mask=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], token_type_ids=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], label=0)
INFO:ratsnlp:*** Example ***
INFO:ratsnlp:sentence: 3D๋ง ์๋์์ด๋ ๋ณ ๋ค์ฏ ๊ฐ ์คฌ์ํ
๋ฐ.. ์ 3D๋ก ๋์์ ์ ์ฌ๊ธฐ๋ฅผ ๋ถํธํ๊ฒ ํ์ฃ ??
INFO:ratsnlp:tokens: [CLS] 3 ##D ##๋ง ์๋์ ##์ด๋ ๋ณ ๋ค์ฏ ๊ฐ ์คฌ ##์ํ
๋ฐ . . ์ 3 ##D ##๋ก ๋์์ ์ ์ฌ ##๊ธฐ๋ฅผ ๋ถํธ ##ํ๊ฒ ํ์ฃ
INFO:ratsnlp:label: 0
INFO:ratsnlp:features: ClassificationFeatures(input_ids=[2, 22, 4452, 4049, 18851, 8194, 1558, 23887, 220, 2648, 9243, 17, 17, 2332, 22, 4452, 4091, 10045, 2545, 2015, 8313, 10588, 8007, 18566, 32, 32, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], attention_mask=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], token_type_ids=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], label=0)
INFO:ratsnlp:Saving features into cached file, it could take a lot of time...
INFO:ratsnlp:Saving features into cached file /content/Korpora/nsmc/cached_test_BertTokenizer_128_nsmc_document-classification [took 7.800 s]
5. ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ
๋ชจ๋ธ ์ด๊ธฐํ
code 1-3**(1. ๊ฐ์ข
์ค์ ํ๊ธฐ-๋ชจ๋ธ ํ๊ฒฝ ์ค์ )**์์ pretrained_model_name์ beomi/kebert-base๋ก ์ง์ ํ์ผ๋ฏ๋ก ํ๋ฆฌํธ๋ ์ธ์ ๋ง์น BERT๋ก kcbert-base๋ฅผ ์ฌ์ฉํ๋ค.
๋ชจ๋ธ์ ์ด๊ธฐํํ๋ ์ฝ๋์์ BertForSequenceClassification์ ํ๋ฆฌํธ๋ ์ธ์ ๋ง์น BERT ๋ชจ๋ธ ์์ ๋ฌธ์ ๋ถ๋ฅ์ฉ ํ์คํฌ ๋ชจ๋์ด ๋ง๋ถ์ฌ์ง ํํ์ ๋ชจ๋ธ ํด๋์ค์ด๋ค. ์ด ํด๋์ค๋ ํ๊น
ํ์ด์ค์์ ์ ๊ณตํ๋ transformers ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ํฌํจ๋์ด ์๋ค.
code 1-11
from transformers import BertConfig, BertForSequenceClassification
pretrained_model_config = BertConfig.from_pretrained(
args.pretrained_model_name,
num_labels=corpus.num_labels,
)
model = BertForSequenceClassification.from_pretrained(
args.pretrained_model_name,
config=pretrained_model_config,
)
โถCode output
Downloading: 0%| | 0.00/438M [00:00<?, ?B/s]
Some weights of the model checkpoint at beomi/kcbert-base were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at beomi/kcbert-base and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
6. ๋ชจ๋ธ ํ์ต์ํค๊ธฐ
ํ์ดํ ์น ๋ผ์ดํธ๋(pytorch lightning*)์ด ์ ๊ณตํ๋ LightningModule ํด๋์ค๋ฅผ ์์๋ฐ์ ํ์คํฌ(task)๋ฅผ ์ ์ํ๋ค. ํ์คํฌ์๋ ๋ชจ๋ธ๊ณผ ์ตํฐ๋ง์ด์ , ํ์ต ๊ณผ์ ๋ฑ์ด ์ ์๋์ด ์๋ค.
- pytorch lightning*: github.com/PyTorchLightning/pytorch-lightning
Task ์ ์
code 1-3**(1. ๊ฐ์ข
์ค์ ํ๊ธฐ-๋ชจ๋ธ ํ๊ฒฝ ์ค์ )์์ ๋ง๋ ํ์ต ์ค์ (args)๊ณผ code 1-11(5. ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ-๋ชจ๋ธ์ด๊ธฐํ)**์์ ์ค๋นํ ๋ชจ๋ธ(model)์ ClassificationTask์ ์ฃผ์
ํ๋ค. ClassificationTask์๋ ์ตํฐ๋ง์ด์ (optimizer), ๋ฌ๋ ๋ ์ดํธ ์ค์ผ์ค๋ฌ(learnig rate scheduler)๊ฐ ์ ์๋์ด ์๋ค. ์ตํฐ๋ง์ด์ ๋ก๋ ์๋ด(Adam), ๋ฌ๋ ๋ ์ดํธ ์ค์ผ์ค๋ฌ๋ก๋ ExponentialLR์ ์ฌ์ฉํ๋ค.
code1-12๋ฅผ ํตํด ๋ฌธ์ ๋ถ๋ฅ์ฉ ํ์คํฌ๋ฅผ ์ ์ํ ์ ์๋ค.
code 1-12
from ratsnlp.nlpbook.classification import ClassificationTask
task = ClassificationTask(model, args)
'Do it! ์์ฐ์ด ์ฒ๋ฆฌ'์ ์ ์ ์ด๊ธฐ์ฐฝ ๋์ ๋น์ ๋ฅผ ์ธ์ฉํ์๋ฉด,
๋ชจ๋ธ ํ์ต ๊ณผ์ ์ ๋์ ๊ฐ๋ฆฐ ์ํ์์ ์ฐ๋ฑ์ฑ์ด๋ฅผ ํ ๊ฑธ์์ฉ ๋ด๋ ค๊ฐ๋ ๊ณผ์ ๊ณผ ๊ฐ๋ค. ๋ฌ๋ ๋ ์ดํธ๋ ํ ๋ฒ ๋ด๋ ค๊ฐ ๋ ์ผ๋ง๋ ์ด๋ํ ์ง ๋ณดํญ์ ํด๋นํ๋ค.
ํ์ต์ด ์งํ๋๋ ๋์ ์ ์ฐจ ๋ฌ๋ ๋ ์ดํธ๋ฅผ ์ค์ฌ ์ธ๋ฐํ๊ฒ ํ์ํ๋ฉด ์ข ๋ ์ข์ ๋ชจ๋ธ์ ๋ง๋ค ์ ์๋ค. ์ด ์ญํ ์ ํ๋ ๊ฒ ๋ฐ๋ก ๋ฌ๋ ๋ ์ดํธ ์ค์ผ์ค๋ฌ์ด๋ค. ExponentialLR์ ํ์ฌ ์ํฌํฌ(epoch*)์ ๋ฌ๋ ๋ฉ์ดํธ๋ฅผ '์ด์ ์ํฌํฌ์ ๋ฌ๋ ๋ ์ดํธ x gamma'๋ก ์ค์ผ์ค๋ง ํ๋ค. ๋ณธ ์์ ์์๋ gamma๋ฅผ 0.9๋ก ์ค์ ํ์๋ค.
- ์ํฌํฌ: ๋ฐ์ดํฐ ์ ์ฒด๋ฅผ ํ์ตํ๋ ํ์. ๋ง์ผ ์ํฌํฌ๊ฐ 3์ด๋ผ๋ฉด ๋ฐ์ดํฐ๋ฅผ 3๋ฒ ๋ฐ๋ณต ํ์ตํ๋ค๋ ๋ป์ด๋ค.
ํธ๋ ์ด๋ ์ ์
code 1-13์ ํตํด ํธ๋ ์ด๋๋ฅผ ์ ์ํ ์ ์๋ค. ์ด ํธ๋ ์ด๋๋ ํ์ดํ ์น ๋ผ์ดํธ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๋์์ ๋ฐ์ GPU/TPU ์ค์ , ๋ก๊ทธ ๋ฐ ์ฒดํฌํฌ์ธํธ ๋ฑ ๊ท์ฐฎ์ ์ค์ ์ ์์์ ํด ์ค๋ค.
code 1-13
trainer = nlpbook.get_trainer(args)
โถCode output
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/distributed.py:69: UserWarning: Checkpoint directory /gdrive/My Drive/nlpbook/checkpoint-doccls exists and is not empty.
warnings.warn(*args, **kwargs)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
ํ์ต ๊ฐ์
code 1-14์ฒ๋ผ ํธ๋ ์ด๋์ fit()ํจ์๋ฅผ ํธ์ถํ๋ฉด ํ์ต์ ์์ํ๋ค. ํ์ต ์๊ฐ์ ์์ ์ ์ฝ๋ฉ ํ๊ฒฝ์ ๋ฐ๋ผ ๋ค๋ฅผ ์ ์์ผ๋ ๊ฝค ์ค๋ ๊ฑธ๋ฆด ์ ์๋ค. ํ์ต์ด ์งํ๋๋ ๋์ค ๋ธ๋ผ์ฐ์ ๋ฅผ ๋๋ฉด ๋ชจ๋ธ ํ์ต์ ๋น๋กฏํ ์ฝ๋ฉ ์คํ์ด ์ค๋จ๋๋ ์ฃผ์ํด์ผํ๋ค.
code 1-14
trainer.fit(
task,
train_dataloader=train_dataloader,
val_dataloaders=val_dataloader,
)
โถCode output
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
--------------------------------------------------------
0 | model | BertForSequenceClassification | 108 M
--------------------------------------------------------
108 M Trainable params
0 Non-trainable params
108 M Total params
435.680 Total estimated model params size (MB)
Training: 114it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
Validating: 0it [00:00, ?it/s]
