본문 바로가기
AI/딥러닝

[Ch.02] 01 - CNN 모델 만들기 part 1

by AI-BT 2023. 1. 4.
728x90
반응형

챕터 1 에서는 딥러닝이 무엇인가

그리고 신경망 모델을 통해서 간단한 학습 코드를 했습니다.

챕터 2 에서는 유명한 CNN 모델에 대해서 이야기 할까 합니다.

실제 수많은 이미지를 분류하는 모델로써 현재 딥러닝 발전에 많은 영향을 미쳤습니다.

 

1. CNN 은 무엇인가?

만약 4K UHD (840만개 픽셀) 이미지를 인공신경망으로 처리하려면 

층마다 840만개 가중치가 필요합니다.

하지만 컴퓨터 계산하기에는 무리가 있고 이를 해결하고자 합성곱 (Convolution)을 만들게 됩니다.

 

합성곱 (Convolution) 을 간단하게 이야기 하면,

작은 필터를 이용해 이미지로 부터 특징을 추출해내는 방법 입니다.

 

예를 들어 아래 이미지를 학습한다고 했을 때,

이미지는 RGB 3개의 채널로 이루어져 있는데, 

여기서 특징으로 생각나는것은 "빨간색" 부분 이라고 가정하면

840만개 모두 학습하는 것이 아닌 필터로 빨간색 부분에 추출하여 가중치의 수량을 줄여줍니다.

이미지를 전체를 한칸씩 움직이면서 특직을 뽑아낸다고 생각하시면 됩니다.

필터 (혹은 커널)

이것이 CNN 핵심입니다. 즉 합성곱(Convolution) 을 사용하는 신경망 입니다.

학습할 가중치도 줄어들고, 특징도 파악할 수 있으니 더욱 좋은 모델이라고 할 수 있습니다.

유명한 CNN 모델로는 VGG, ResNet, Inception 등이 있습니다.


2. 데이터 확인 및 전처리

챕터 1에서 했던 코드와 같습니다.

다른 것은 데이터셋이 CIFAR10 10개의 클래스로 이루어진 데이터셋 입니다.

 

<데이터 확인>

# 1. 데이터 확인 및 전처리
import matplotlib.pyplot as plt
from torchvision.datasets.cifar import CIFAR10
from torchvision.transforms import ToTensor

# 데이터 셋 불러오기
train_data = CIFAR10(root='./', train=True, download=True, transform=ToTensor)
test_data = CIFAR10(root='./', train=False, download=True, transform=ToTensor)

# 데이터셋 확인
for i in range(9):
    plt.subplot(3, 3, i+1)
    plt.imshow(train_data.data[i])
plt.show()

출력 결과

데이터셋 확인

 

<데이터 전처리>

학습을 진행하기 위해서 충분한 데이터가 확보하기 위해서,

기존의 이미지를 변형해서 데이터 증강을 하기도 합니다. 

이러한 과정을 데이터 전처리라고 합니다.

예를 들어 좌우를 바꾸던거 위 아래에 조금 검은색으로 칠하던가 이미지 색을 변형하던가 등

원하는 모델 방향으로 데이터 전처리를 진행 합니다.

 

지금 아래 코드에서는 크롭핑(이미지 일부를 도려내는 기법), 좌우 대칭을 쓰겠습니다.

import torchvision.transforms as T
from torchvision.transforms import Compose
from torchvision.transforms import RandomHorizontalFlip, RandomCrop

# 데이터 전처리 함수
transforms = Compose([
    T.ToPILImage(),
    RandomCrop((32,32), padding=4), # 랜덤으로 이미지 일부 제거 후 패딩
    RandomHorizontalFlip(p=0.5), # y축 기준으로 대칭
])

# 데이터 셋 불러오기
train_data = CIFAR10(root='./', train=True, download=True, transform=transforms)
test_data = CIFAR10(root='./', train=False, download=True, transform=transforms)

# 데이터셋 확인
for i in range(9):
    plt.subplot(3, 3, i+1)
    plt.imshow(transforms(train_data.data[i]))
plt.show()

출력 결과

일부 이미지가 도려지고, 좌우가 바뀐 모습을 볼 수 있습니다.


3. 이미지 정규화

컬러 이미지 1 장을 데이터로 표현하려면 R,G,B 을 담담하는 3장의 이미지 데이터가 필요 합니다.

이미지별로 각각의 특징적인 색이 있기 때문에 빨강, 녹색, 파랑 한쪽에 색상으로 치우쳐 있을 확률이 높습니다.

데이터의 분포가 한쪽으로 너무 치우쳐 있으면 학습에 안 좋은 영향을 줍니다.

그래서 학습 전에 데이터의 편향을 계산해서 정규분포를 따르도록 하는게 좋습니다.

평군이 0, 표준편차가 1인 정규분포를 따르도록 하는 것을 정규화 라고 합니다.

from torchvision.transforms import Normalize

# <데이터 정규화>
transforms = Compose([
    T.ToPILImage(),
    RandomCrop((32,32), padding=4), # 랜덤으로 이미지 일부 제거 후 패딩
    RandomHorizontalFlip(p=0.5), # y축 기준으로 대칭
    T.ToTensor(),

    Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.243, 0.261)),
    T.ToPILImage()
])

# 데이터 셋 불러오기
train_data = CIFAR10(root='./', train=True, download=True, transform=transforms)
test_data = CIFAR10(root='./', train=False, download=True, transform=transforms)

# 데이터셋 확인
for i in range(9):
    plt.subplot(3, 3, i+1)
    plt.imshow(transforms(train_data.data[i]))
plt.show()

출력 결과

정규화 이미지

각각의 이미지를 정규화해야 한쪽으로 치우치는 학습을 하지 않도록

정규화 작업을 했습니다.

 

다음 포스팅에서 이제 CNN 모델을 가지고 학습을 진행 하도록 하겠습니다.

728x90
반응형

댓글