nn.Sequential - DonghoonPark12/ssd.pytorch GitHub Wiki

nn.Sequential은 λͺ¨λ“ˆμ˜ μ»¨ν…Œμ΄λ„ˆ. 순차적으둜 μ²˜λ¦¬ν•΄μ•Ό ν•˜λŠ” λͺ¨λ“ˆμ΄ μžˆμ„ λ•Œ μ•„λž˜μ™€ 같이 μ‚¬μš©ν•œλ‹€.

import torch.nn as nn

model = nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

μ•„λž˜μ™€ 같이 ν•˜λ‚˜μ˜ block1,2둜 묢기도 ν•œλ‹€.

class MyCNNClassifier(nn.Module):
    def __init__(self, in_c, n_classes):
        super().__init__()
        self.conv_block1 = nn.Sequential(
            nn.Conv2d(in_c, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        
        self.conv_block2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(32 * 28 * 28, 1024),
            nn.Sigmoid(),
            nn.Linear(1024, n_classes)
        )

        
    def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)

        x = x.view(x.size(0), -1) # flat
        
        x = self.decoder(x)
        
        return x

conv_block1,2κ°€ 맀우 μœ μ‚¬ν•¨μ„ μƒκΈ°ν•΄λ³΄μž. μ•„λž˜μ˜ ν•¨μˆ˜λ₯Ό 두면

def conv_block(in_f, out_f, *args, **kwargs):
    return nn.Sequential(
        nn.Conv2d(in_f, out_f, *args, **kwargs),
        nn.BatchNorm2d(out_f),
        nn.ReLU()
    )

μ•„λž˜ 처럼 ν‘œν˜„λ„ κ°€λŠ₯ν•˜λ‹€.

class MyCNNClassifier(nn.Module):
    def __init__(self, in_c, n_classes):
        super().__init__()
        self.conv_block1 = conv_block(in_c, 32, kernel_size=3, padding=1)
        
        self.conv_block2 = conv_block(32, 64, kernel_size=3, padding=1)
        
        self.decoder = nn.Sequential(
            nn.Linear(32 * 28 * 28, 1024),
            nn.Sigmoid(),
            nn.Linear(1024, n_classes)
        )

     def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)

        x = x.view(x.size(0), -1) # flat
        
        x = self.decoder(x)
        
        return x