1. Agumentation
이미지 분류 모델을 학습할 때, 데이터가 많지 않거나 특정한 패턴에 대해 과적합이 발생할 우려가 있다.
이런 상황을 개선하기 위해 데이터 증강(Augmentation) 기법을 사용하는데, 이 과정에서 이미지의 크기나 시각적 왜곡을 가하여 다양한 형태의 학습 데이터를 만들 수 있다.
class AlbumentationsTransform:
def __init__(self, is_train: bool = True):
# 공통 변환 설정
common_transforms = [
A.Resize(224, 224), # 224x224 resize
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # 정규화
ToTensorV2() # pyTorch 텐서 변환
]
if is_train:
self.transform = A.Compose(
[
A.HorizontalFlip(p=0.5), # 90% 확률로 이미지를 수평 뒤집기
A.Rotate(limit=30), # 최대 30도 회전
A.RandomBrightnessContrast(brightness_limit=(-0.2, -0.2), contrast_limit=0, p=1), # 20% 어둡게
A.GaussianBlur(blur_limit=(3, 5), p=0.6), # 약간의 블러 추가
A.CoarseDropout(max_holes=8, max_height=16, max_width=16, p=0.5),
A.GridDistortion(always_apply=False, p=1, num_steps=1, distort_limit=(-0.03, 0.05), interpolation=2, border_mode=0, value=(255, 255, 255), mask_value=None)
] + common_transforms
)
else:
self.transform = A.Compose(common_transforms)
def __call__(self, image) -> torch.Tensor:
# 이미지가 NumPy 배열인지 확인
if not isinstance(image, np.ndarray):
raise TypeError("Image should be a NumPy array (OpenCV format).")
# 이미지에 변환 적용 및 결과 반환
transformed = self.transform(image=image)
return transformed['image'] # 변환된 이미지의 텐서를 반환
ResNet101 모델 사용
ResNet101은 101개의 레이어를 가진 깊은 신경망으로, 이미지 분류 작업에서 뛰어난 성능을 보여주는 모델 중 하나이다. Residual Connection을 통해서 깊은 네트워크에서 발생할 수 있는 Gradient Vanishing 문제를 효과적으로 해결할 수 있으며, 여러 컴퓨터 비전 작업에서 우수한 성능을 보이고 있다.
Albumentations 라이브러리를 사용한 데이터 증강
이미지 데이터 증강을 위한 여러 라이브러리가 있지만, 그중 Albumentations는 간결한 코드와 효율적인 이미지 변환 기능을 제공하여 매우 유용했다. 특히, Albumentations를 사용하면 이미지를 Resize 할 때 Torchvision 대비 시각적으로 훨씬 자연스러운 결과물을 얻을 수 있다.
학습 시 데이터 증강 기법
- HorizontalFlip(p=0.5): 50% 확률로 이미지를 수평으로 뒤집습니다. 이는 이미지의 방향성에 의존하는 모델을 방지
- Rotate(limit=30): 이미지를 최대 30도까지 회전시켜, 다양한 각도의 이미지를 학습
- RandomBrightnessContrast: 이미지의 밝기를 20% 낮추는 방식으로 자연스러운 밝기 변화
- GaussianBlur: 약간의 흐림 효과를 주어 이미지를 더욱 다양하게 변형합니다.
- CoarseDropout: 이미지의 일부분을 무작위로 제거하여 모델이 이미지의 일부 정보가 사라져도 잘 동작하게 만듬.
- GridDistortion: 그리드 왜곡을 적용해 이미지의 픽셀을 약간 왜곡하여 다양한 시각적 특성을 학습
Insight
A.RandomBrightnessContrast(brightness_limit=(-0.2, -0.2), contrast_limit=0, p=1), # 20% 어둡게
데이터 증강(Augmentation)에서 가장 중요한 부분은 위의 코드이다.
동일한 학습 환경에서 RandomBrightnessContrast 변환을 적용한 경우와 적용하지 않은 경우를 비교했을 때,
10 epoch 기준으로 **정확도(accuracy)**에서 약 3% 이상의 차이가 발생했다.
이 결과는 RandomBrightnessContrast가 모델의 일반화 성능에 큰 영향을 미친다는 것을 보여준다.
2. Loss Class
최근 ImageNet Sketch 데이터셋을 활용한 모델 학습에서, Cross Entropy Loss를 대체해 Focal Loss를 적용한 결과 성능이 크게 개선되었다. 특히, 일반적인 Cross Entropy Loss 대비 Focal Loss를 적용했을 때 약 5% 이상의 성능 차이가 발생하며, 스케치 이미지의 특징을 더욱 잘 포착하는 것을 확인 했다.
Focal Loss는 일반적으로 불균형한 데이터셋에서 자주 사용되는 손실 함수 이다. Cross Entropy Loss는 모든 클래스에 대해 동일한 가중치를 부여하지만, Focal Loss는 학습이 어려운 클래스 또는 잘못 분류된 클래스에 더 큰 가중치를 부여하여 이러한 문제를 완화 한다.
Focal Loss는 수식
- 는 클래스별 가중치를 의미하며, 학습이 어려운 클래스에 더 큰 가중치를 줄 수 있다.
- gamma는 어려운 샘플에 대해 가중치를 높여주는 하이퍼파라미터로, 1−pt가 작을수록 즉, 모델이 샘플을 잘못 예측할수록 큰 영향을 준다.
class FocalLoss(nn.Module):
def __init__(self, alpha, gamma, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
# Compute the cross entropy loss
ce_loss = F.cross_entropy(outputs, targets, reduction='none')
# softmax
probs = torch.exp(-ce_loss)
# Apply the focal loss modification
focal_loss = self.alpha * (1 - probs) ** self.gamma * ce_loss
# Return loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
위 코드를 통해 Focal Loss를 적용한 결과, reduction='mean' 방식을 사용했을 때 일반 Cross Entropy Loss보다 약 5% 이상의 성능 차이를 확인할 수 있었습니다. 특히, 스케치 이미지의 특징을 더 잘 잡아내는 결과를 얻었으며, 이는 Focal Loss가 잘못 예측된 샘플에 대해 더 큰 가중치를 부여하기 때문에 발생한 결과로 판단됩니다.
반면, reduction='sum' 방식을 사용한 경우에는 평균을 취하는 방식보다 덜 효과적이었으며, 이는 각 샘플의 손실 값이 큰 편차를 보이는 경우 학습에 부정적인 영향을 미칠 수 있기 때문이다.
Insight
ImageNet Sketch와 같은 데이터셋에서는 Cross Entropy Loss보다는 Focal Loss를 사용하는 것이 더욱 유리한 것으로 나타났다. 특히, 모델이 어려운 샘플을 학습할 때 더 효과적으로 학습할 수 있도록 돕기 때문에 스케치 이미지처럼 복잡한 데이터셋에서 유의미한 성능 향상을 기대할 수 있다.
3. model 선언
ResNet101을 활용한 스케치 이미지 학습 모델 선정
최근 ImageNet Sketch 데이터셋을 학습하는 과정에서, 다양한 모델 중 ResNet101을 선택하여 성능을 최적화했습니다. 스케치 이미지와 같이 선과 모양이 중요한 데이터셋에서는 CNN 계열의 모델이 효과적이라는 점을 고려해, ResNet101을 학습 모델로 사용했다.
# 학습에 사용할 Model을 선언.
model_selector = ModelSelector(
model_type='timm',
num_classes=num_classes,
model_name='resnet101',
pretrained=True
)
model = model_selector.get_model()
model.to(device)
위 코드에서는 ResNet101을 선택하였으며, timm 라이브러리를 통해 모델을 불러왔다. 사전 학습된 모델(pretrained=True)을 사용하여 초기 가중치를 설정함으로써, 스케치 이미지와 같은 특이한 데이터셋에도 더 나은 성능을 기대할 수 있다.
CNN 계열 모델을 선택한 이유
스케치 이미지는 선(Line)과 모양(Shape)이 주요한 특징이다. 이러한 특징은 컨볼루션 레이어에서 지역적 패턴(Local Pattern)을 잘 포착하는 CNN 계열 모델이 매우 적합하다. CNN(Convolutional Neural Network)은 이미지 내에서 공간적 관계를 유지하면서도 다양한 필터를 통해 이미지의 엣지나 선 같은 저수준의 특징을 잘 잡아내기 때문에, 스케치 이미지의 세부 구조를 학습하기에 최적이다.
ResNet101 선택 이유
101개의 레이어로 구성된 이 모델은 더 깊은 레이어를 통해 복잡한 패턴을 더욱 잘 학습할 수 있는 능력이 있다. 이 과정에서 Residual Block이 사용되는데, 이는 모델이 깊어질수록 발생할 수 있는 기울기 소실 문제(Vanishing Gradient Problem)를 완화시켜준다. 이러한 특성 덕분에, 깊은 레이어에서도 성능 저하 없이 복잡한 특징을 학습할 수 있게 됩니다.
특히, 다음과 같은 이유로 ResNet101을 선택했다.
- 더 깊은 레이어에서의 학습 가능성: ResNet의 Residual Connection은 학습이 어려운 스케치 이미지와 같은 데이터셋에서 더 나은 결과를 제공
- 다른 ResNet 모델과의 비교: ResNet50이나 ResNet18 등의 얕은 모델에 비해, ResNet101은 더 복잡하고 다양한 패턴을 효과적으로 학습합니다. 다른 깊이의 모델들과 비교했을 때, ResNet101이 특히 스케치 이미지에서 더 나은 성능을 보여줬다.
- 사전 학습된 모델 사용: 사전 학습된 ResNet101을 사용함으로써, 일반적인 이미지에 대한 기본적인 시각적 특징을 이미 학습한 상태에서 스케치 이미지로 추가 학습을 할 수 있어 더 빠른 학습과 높은 성능을 기대할 수 있다.
4. optimizer
optimizer = optim.Adam(
model.parameters(),
lr=0.001
)
Adam Optimizer를 선택한 이유
Adam(Adaptive Moment Estimation)은 1차 모멘트(Gradient's Mean)와 2차 모멘트(Gradient's Variance*를 동시에 추적하여, 각 매개변수마다 적응형 학습률을 적용합니다. 이를 통해 학습이 느린 부분은 학습률을 높이고, 빠른 부분은 학습률을 줄여서 전반적인 학습 속도와 안정성을 크게 향상시킵니다.
Adam을 선택한 이유
- 빠른 수렴 속도: Adam은 SGD보다 빠르게 최적값에 수렴하는 경향이 있다. 이는 복잡한 이미지 패턴을 학습하는 과정에서 매우 중요한데, 특히 스케치 이미지와 같은 데이터셋에서는 빠르게 패턴을 학습하고 반복해서 미세 조정하는 것이 필요
- 적응형 학습률: Adam은 매개변수별로 서로 다른 학습률을 적용하므로, 모든 파라미터가 동일한 속도로 학습되는 것을 방지할 수 있다. 이는 학습 중 특정 파라미터가 너무 빠르게 또는 너무 느리게 업데이트되는 문제를 해결할 수 있어 효율적인 학습을 가능하게 함.
- 모멘트 추정: Adam은 모멘트 기반의 접근을 사용하여 기울기의 변화 패턴을 반영한다. 이를 통해 스케치 이미지에서 발생할 수 있는 노이즈나 복잡한 선 패턴에 대한 기울기 변화를 보다 정교하게 학습할 수 있다.
5. 스케줄러
scheduler_gamma = 0.1 # 학습률을 현재의 10%로 감소
steps_per_epoch = len(train_loader) # 한 epoch당 step 수 계산
# 10 epoch 마다 학습률을 감소시키는 스케줄러 선언
epochs_per_lr_decay = 10
scheduler_step_size = steps_per_epoch * epochs_per_lr_decay
scheduler = optim.lr_scheduler.StepLR(
optimizer,
step_size=scheduler_step_size,
gamma=scheduler_gamma
)
위의 코드는 학습률(lr)을 매 10 epochs마다 감소시키도록 설정되어 있습니다.
구체적으로, 학습률은 학습이 진행되면서 현재 학습률의 10%**로 감소된다.
이를 StepLR 스케줄러가 관리하는데, 학습이 진행됨에 따라 너무 빠르게 이동하지 않도록 학습 속도를 조절하는 것이다.
쉽게 설명하면,
이미지 수 = 15,000, Batch_size = 64 일때, 1 epoch에 약 234번의 배치가 처리된다.
step_size = 234 * 10 = 2340 이 되고, 10 epoch 마다 10% 감소된다.
학습률이 10% 감소한다는 표현은 현재 학습률의 10%가 남는다는 뜻이다.
즉, 학습률을 10%로 줄인다는 것은 **현재 학습률의 10%**가 남게 되는 것이므로, 학습률이 90% 감소한다고 생각하면 된다.
학습률이 0.001 × 0.1 = 0.0001로 줄어드는 것을 의미한다.
Insight
처음에는 큰 학습률로 빠르게 모델을 학습시키는 것이 효과적이지만, 학습이 진행됨에 따라 더 작은 학습률로 미세 조정하는 것이 필요하다. 학습률을 일정 주기마다 현재의 10%로 줄이는 것은 학습 후반부에 모델이 더 세밀하게 학습할 수 있도록 도와준다.
- 초기 학습: 높은 학습률로 빠르게 큰 변화들을 학습.
- 후반 학습: 작은 학습률로 세부적인 조정, 작은 패턴이나 복잡한 구조를 학습.
다음은 다른 학습 기법에 대해서 블로그 하겠습니다.
감사합니다.
'AI > 딥러닝 프로젝트' 카테고리의 다른 글
[02 Object Detection] - MMdetection 설치 및 기본사용법 (6) | 2024.10.26 |
---|---|
[01 - ImageNet Sketch 데이터] - Augmentation [2] (2) | 2024.09.26 |
[01 - ImageNet Sketch 데이터] - Grad Cam [3] (2) | 2024.09.26 |
StratifiedKFold란? (1) | 2024.09.16 |
[01 - ImageNet Sketch 데이터] - Torchvision 와 Albumentations Trasnformer 비교 (0) | 2024.09.13 |
댓글