(pytorch) data class - beyondnlp/nlp GitHub Wiki

  • Dataset class๋ฅผ ์ƒ์† ( nsmc data )

    • len๊ณผ getitem์„ ๊ตฌํ˜„ํ•œ๋‹ค.
    • len์€ ์ „์ฒด ๋ฌธ์„œ์ˆ˜
    • getitem์€ slice์—ฐ์‚ฐ์„ ์ง€์›ํ•˜๊ธฐ ์œ„ํ•œ ์šฉ๋„์ด๋‹ค.
  • nsmc sample

#id     document        label
9976970 ์•„ ๋”๋น™.. ์ง„์งœ ์งœ์ฆ๋‚˜๋„ค์š” ๋ชฉ์†Œ๋ฆฌ        0
3819312 ํ ...ํฌ์Šคํ„ฐ๋ณด๊ณ  ์ดˆ๋”ฉ์˜ํ™”์ค„....์˜ค๋ฒ„์—ฐ๊ธฐ์กฐ์ฐจ ๊ฐ€๋ณ์ง€ ์•Š๊ตฌ๋‚˜        1
10265843        ๋„ˆ๋ฌด์žฌ๋ฐ“์—ˆ๋‹ค๊ทธ๋ž˜์„œ๋ณด๋Š”๊ฒƒ์„์ถ”์ฒœํ•œ๋‹ค      0
9045019 ๊ต๋„์†Œ ์ด์•ผ๊ธฐ๊ตฌ๋จผ ..์†”์งํžˆ ์žฌ๋ฏธ๋Š” ์—†๋‹ค..ํ‰์  ์กฐ์ •       0
6483659 ์‚ฌ์ด๋ชฌํŽ˜๊ทธ์˜ ์ต์‚ด์Šค๋Ÿฐ ์—ฐ๊ธฐ๊ฐ€ ๋‹๋ณด์˜€๋˜ ์˜ํ™”!์ŠคํŒŒ์ด๋”๋งจ์—์„œ ๋Š™์–ด๋ณด์ด๊ธฐ๋งŒ ํ–ˆ๋˜ ์ปค์Šคํ‹ด ๋˜์ŠคํŠธ๊ฐ€ ๋„ˆ๋ฌด๋‚˜๋„ ์ด๋ป๋ณด์˜€๋‹ค 1
5403919 ๋ง‰ ๊ฑธ์Œ๋งˆ ๋—€ 3์„ธ๋ถ€ํ„ฐ ์ดˆ๋“ฑํ•™๊ต 1ํ•™๋…„์ƒ์ธ 8์‚ด์šฉ์˜ํ™”.ใ…‹ใ…‹ใ…‹...๋ณ„๋ฐ˜๊ฐœ๋„ ์•„๊นŒ์›€.     0
7797314 ์›์ž‘์˜ ๊ธด์žฅ๊ฐ์„ ์ œ๋Œ€๋กœ ์‚ด๋ ค๋‚ด์ง€๋ชปํ–ˆ๋‹ค.  0
9443947 ๋ณ„ ๋ฐ˜๊ฐœ๋„ ์•„๊น๋‹ค ์š•๋‚˜์˜จ๋‹ค ์ด์‘๊ฒฝ ๊ธธ์šฉ์šฐ ์—ฐ๊ธฐ์ƒํ™œ์ด๋ช‡๋…„์ธ์ง€..์ •๋ง ๋ฐœ๋กœํ•ด๋„ ๊ทธ๊ฒƒ๋ณด๋‹จ ๋‚ซ๊ฒŸ๋‹ค ๋‚ฉ์น˜.๊ฐ๊ธˆ๋งŒ๋ฐ˜๋ณต๋ฐ˜๋ณต..์ด๋“œ๋ผ๋งˆ๋Š” ๊ฐ€์กฑ๋„์—†๋‹ค ์—ฐ๊ธฐ๋ชปํ•˜๋Š”์‚ฌ๋žŒ๋งŒ๋ชจ์—ฟ๋„ค     0
7156791 ์•ก์…˜์ด ์—†๋Š”๋ฐ๋„ ์žฌ๋ฏธ ์žˆ๋Š” ๋ช‡์•ˆ๋˜๋Š” ์˜ํ™” 1

  • Dataset class ์ƒ์†
class MovieDataSet(torch.utils.data.Dataset):
  • vocab : ํ‚ค์›Œ๋“œ๋ฅผ ์ €์žฅํ•  dict ๋ณ€์ˆ˜
  • infile : ํ•™์Šต,ํ…Œ์ŠคํŠธ์— ์‚ฌ์šฉํ•  ํŒŒ์ผ๋ช…
    def __init__(self, vocab, infile):
        self.vocab = vocab
        self.labels = []
        self.sentences = []

        line_cnt = 0
        with open(infile, "r") as f:
            for line in f:
                line_cnt += 1

        with open(infile, "r") as f:
            for line in f.readlines():
                if line[0] == '#' : continue
                line = line.rstrip()
                term=line.split("\t");
                ids = vocab.encode_as_ids(term[1])
                label=[int(term[2])]

                self.sentences.append( ids )
                self.labels.append( label  )
  • len : ์ „์ฒด ๊ฑด์ˆ˜๋ฅผ ๊ฐ€์ ธ์˜ค๊ธฐ ์œ„ํ•œ ํ•จ์ˆ˜
    def __len__(self):
        assert len(self.labels) == len(self.sentences)
        return len(self.labels)
  • getitem_ : slice์—ฐ์‚ฐ์„ ์ง€์›ํ•˜๊ธฐ ์œ„ํ•œ ํ•จ์ˆ˜
  • "[" "]"๋ฅผ ๊ตฌํ˜„
    def __getitem__(self, item):
        label = torch.tensor( self.labels[item]    )
        sent  = torch.tensor( self.sentences[item] )
        return ( label, sent )
  • batch ๋‹จ์œ„๋กœ ๋ฐ์ดํ„ฐ๋ฅผ ๊บผ๋‚ด์˜ค๊ธฐ ์œ„ํ•œ ํ•จ์ˆ˜
def movie_collate_fn(inputs):
    labels, enc_inputs, dec_inputs = list(zip(*inputs))

    enc_inputs = torch.nn.utils.rnn.pad_sequence(enc_inputs, batch_first=True, padding_value=0)
    dec_inputs = torch.nn.utils.rnn.pad_sequence(dec_inputs, batch_first=True, padding_value=0)

    batch = [
        torch.stack(labels, dim=0),
        enc_inputs,
        dec_inputs,
    ]
    return batch
  • ๊ฐœ์ฒด์„ ์–ธ
train_dataset = MovieDataSet( vocab, train_file )
test_dataset  = MovieDataSet( vocab, test_file )
  • Dataloader์— ๊ฒฐํ•ฉ
train_loader  = \
torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=movie_collate_fn)

test_loader   = \
torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=movie_collate_fn)
  • ์‚ฌ์šฉ ( train_loader )
def train_epoch(config, epoch, model, criterion, optimizer, train_loader):

    for i, value in enumerate(train_loader):