코딩일기

parlai에 새로운 task 추가하기(feat. Adding a new task to parlai, facebook, blenderbot2, blenderbot1) 본문

Code/기타

parlai에 새로운 task 추가하기(feat. Adding a new task to parlai, facebook, blenderbot2, blenderbot1)

daje 2022. 7. 26. 15:25
728x90
반응형

안녕하십니까 다제입니다. 

 

요즘 open source를 디버깅하는 연습을 진행하고 있습니다. 

 

그중에서도 facebook의 parlai 프로젝트를 선택하여 진행을 하고 있는데요 

 

생각보다 쉽지 않고, 코드가 길고 복잡하게 패키징 되어 있어서 상당히 난항을 겪고 있지만 하나씩 풀어가보는 중입니다. 

 

오늘은 새로운 데이터셋을 추가하는 방법에 대해서 먼저 알아보려고 합니다. 

 

parlai는 다양한 데이터셋을 다운 받을 수 있도록 parlai api를 통해서 세팅을 해 놓았습니다. 

 

그래서 바로 다운을 받아서 실험하고 테스트 해볼 수 있는데요. 제가 테스트 하고 싶은데 데이터가 없는 경우가 생길 수 있습니다. 

 

이럴때 어떻게 추가해야하는지 공부를 진행하였고, 많은 분들이 궁금해하실 수 있기에 이렇게 포스팅을 진행하게 되었습니다. 

 

바로 코드를 먼저 공유드리고, 그 다음 설명 드리도록 하겠습니다. 

 

1. Adding Task 

 

새로운 .py 파일을 생성하고 아래 코드를 기재하고 실행합니다. 

import parlai.core.build_data as build_data
from parlai.core.build_data import DownloadableFile
import os

RESOURCES = [
    DownloadableFile(
        'https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json',
        'train-v1.1.json',
        '981b29407e0affa3b1b156f72073b945',
        zipped=False,
    ),
    DownloadableFile(
        'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json',
        'dev-v1.1.json',
        '3e85deb501d4e538b6bc56f786231552',
        zipped=False,
    ),
]


def build(opt):
    dpath = os.path.join(opt['datapath'], 'SQuAD')
    version = None

    if not build_data.built(dpath, version_string=version):
        print('[building data: ' + dpath + ']')
        if build_data.built(dpath):
            # An older version exists, so remove these outdated files.
            build_data.remove_dir(dpath)
        build_data.make_dir(dpath)

        # Download the data.
        for downloadable_file in RESOURCES[:2]:
            downloadable_file.download_file(dpath)

        # Mark the data as built.
        build_data.mark_done(dpath, version_string=version)

 

그리고 터미널에서 아래 코드를 실행하면 데이터가 다운받아지게 되고, 샘플로 몇개의 데이터를 터미널에 출력이 됩니다. 

parlai display_data --task SQuAD

 

2. Explaination

import parlai.core.build_data as build_data
from parlai.core.build_data import DownloadableFile
import os

RESOURCES = [
    DownloadableFile(
        'https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json',
        'train-v1.1.json',
        '981b29407e0affa3b1b156f72073b945',
        zipped=False,
    ),
    DownloadableFile(
        'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json',
        'dev-v1.1.json',
        '3e85deb501d4e538b6bc56f786231552',
        zipped=False,
    ),
]

DownloadableFile은 parlai에서 제공하는 함수인데, 인자 값으로 "다운받을 주소", "버전명", "file checksum(SHA256)", "압축여부" 를 받습니다. 

 

다운 받을 주소는 깃허브에 올려놓은 것도 되고, 구글 드라이브에 올려놓은 것도 됩니다. 만약, 구글 드라이브에 올려 놓은 경우, 추가적인 인자를 넣어주어야 하니, parlai.core.build_data에 가셔서 확인하셔야 합니다. 

 

파일은 .json, .txt 모두 가능하며, 모든 데이터들은 line by line으로 구성되어 있습니다. 

 

file checksum이라는 것을 처음 보시는 분이 계실 수 있는데요. 이것은 다운받은 파일이 내가 원하는 파일과 맞는지 확인하는 과정입니다. 

다운 받아올 파일을 미리 해당 링크에 가져서 SHA256으로 변경된 값을 받아오셔야 합니다. 여기서 주의해야할 점은 그냥 파일을 통채로 넣어서 받은 값을 사용해야한다는 것입니다. 

 

압축을 풀어야하는 경우라면 zipped=True로 변경해주시면 됩니다. 

 

이 외는 코드를 하나하나 실행보시면 쉽게 아실 수 있는 내용이라고 생각이 됩니다. 

그런 다음 데이터를 다운 받고 싶다면, display_data를 활용하여 다운 받을 수 있습니다. 

 

이렇게 복잡한 코드를 처음 보시는 분들은 어려울 수 있어서 이렇게 정리를 진행해보았습니다. 

 

감사합니다. 

728x90
반응형