Grad-CAM을 활용한 이미지 분류 모델 시각화
이번 포스팅에서는 Grad-CAM을 활용하여 이미지 분류 모델의 시각화를 구현한 사례를 소개합니다. Grad-CAM(Gradient-weighted Class Activation Mapping)은 신경망이 이미지를 분류할 때 주목한 부분을 시각화할 수 있는 방법입니다. 특히 복잡한 CNN(Convolutional Neural Network) 모델에서, 모델이 어떤 부분을 중점적으로 보았는지 확인할 수 있어 모델 해석에 큰 도움이 됩니다.
Grad-CAM을 적용한 전체 코드 개요
아래 코드에서는 ResNet101 모델을 사용하여 이미지 분류를 진행한 후, Grad-CAM을 통해 모델이 주목한 이미지 영역을 시각화하는 과정을 설명합니다. Augmentation이 적용된 데이터를 사용해 Grad-CAM을 출력하고, 각 이미지에 대한 원본 이미지, Grad-CAM 결과 이미지, 오버레이된 이미지를 시각화합니다.
1. 필요한 라이브러리 및 클래스 정의
Grad-CAM을 사용하려면 PyTorch, Albumentations, TorchCAM, 그리고 OpenCV와 같은 라이브러리들이 필요합니다. 먼저 기본적인 데이터 로딩, Augmentation 적용, 모델 선택 등을 위한 클래스를 정의합니다.
Custom Dataset 및 Albumentations Transform
CustomDataset 클래스는 이미지 경로와 타겟 정보를 포함하는 DataFrame을 통해 데이터를 로딩하는 역할을 합니다. 여기서는 Albumentations를 사용하여 다양한 이미지 변환을 적용합니다.
class CustomDataset(Dataset):
def __init__(self, root_dir: str, info_df: pd.DataFrame, transform: Callable, is_inference: bool = False):
self.root_dir = root_dir
self.transform = transform
self.is_inference = is_inference
self.image_paths = info_df['image_path'].tolist()
if not self.is_inference:
self.targets = info_df['target'].tolist()
def __len__(self) -> int:
return len(self.image_paths)
def __getitem__(self, index: int) -> Union[Tuple[torch.Tensor, int], torch.Tensor]:
img_path = os.path.join(self.root_dir, self.image_paths[index])
image = cv2.imread(img_path, cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = self.transform(image)
if self.is_inference:
return image
else:
target = self.targets[index]
return image, target
Albumentations Transform
AlbumentationsTransform 클래스는 학습 또는 테스트에 사용할 변환(Transform)을 정의합니다. 학습 시에는 데이터 증강(Augmentation)을 적용하고, 테스트 시에는 리사이즈와 정규화만 수행합니다.
class AlbumentationsTransform:
def __init__(self, is_train: bool = True):
common_transforms = [
A.Resize(224, 224),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
ToTensorV2()
]
if is_train:
self.transform = A.Compose(
[
A.OneOf([A.RandomCrop(64, 64), A.Resize(224, 224)], p=0.5),
A.HorizontalFlip(p=0.5),
A.Rotate(limit=30),
A.RandomBrightnessContrast(brightness_limit=(-0.2, -0.2), contrast_limit=0, p=1),
A.Sharpen(alpha=(0.2, 0.5), lightness=(0.7, 1.3), p=1),
A.CoarseDropout(max_holes=8, max_height=32, max_width=32, p=0.5),
] + common_transforms
)
else:
self.transform = A.Compose(common_transforms)
def __call__(self, image) -> torch.Tensor:
transformed = self.transform(image=image)
return transformed['image']
2. Grad-CAM을 통한 모델 시각화 함수
Grad-CAM을 적용하기 위해 visualize_gradcam 함수를 정의합니다. 이 함수는 Grad-CAM을 사용하여 모델이 주목한 영역을 시각화하고, 결과 이미지를 출력합니다.
def visualize_gradcam(model: torch.nn.Module, device: torch.device, dataloader: DataLoader, target_layer: str, num_images: int, random_indices: list = None, num_classes: int = 500):
cam_extractor = GradCAM(model, target_layer)
model.eval()
total_images = len(dataloader.dataset)
if random_indices is None:
random_indices = random.sample(range(total_images), num_images)
else:
random_indices = random_indices[:num_images]
fig, axes = plt.subplots(num_images, 3, figsize=(18, 6 * num_images))
current_index = 0
selected_indices = set(random_indices)
for inputs in dataloader:
inputs = inputs.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
batch_size = inputs.size(0)
for j in range(batch_size):
if current_index in selected_indices:
class_idx = preds[j].item()
if class_idx >= num_classes:
class_idx = random.randint(0, num_classes - 1)
cam = cam_extractor(class_idx=class_idx, scores=outputs)[0]
cam = cam.mean(dim=0).cpu().numpy()
cam = cv2.resize(cam, (inputs[j].shape[2], inputs[j].shape[1]))
cam = (cam - cam.min()) / (cam.max() - cam.min())
cam = np.uint8(255 * cam)
cam = cv2.applyColorMap(cam, cv2.COLORMAP_JET)
cam = cv2.cvtColor(cam, cv2.COLOR_BGR2RGB)
input_image = inputs[j].cpu().numpy().transpose((1, 2, 0))
input_image = (input_image - input_image.min()) / (input_image.max() - input_image.min())
input_image = (input_image * 255).astype(np.uint8)
img_idx = list(selected_indices).index(current_index)
ax_row = axes[img_idx] if num_images > 1 else axes
ax_row[0].imshow(input_image)
ax_row[0].set_title(f"Original Image {current_index}")
ax_row[0].axis('off')
ax_row[1].imshow(cam)
ax_row[1].set_title("Grad-CAM Image")
ax_row[1].axis('off')
overlay = cv2.addWeighted(input_image, 0.5, cam, 0.5, 0)
ax_row[2].imshow(overlay)
ax_row[2].set_title("Overlay Image")
ax_row[2].axis('off')
current_index += 1
if current_index >= total_images:
break
plt.tight_layout()
plt.show()
3. Grad-CAM 시각화 실행
위에서 정의한 visualize_gradcam 함수를 통해 모델이 예측에 사용한 이미지의 주요 영역을 시각화합니다. 레이어를 선택하여 각 단계에서 모델이 주목한 부분을 확인할 수 있습니다.
layers_to_visualize = [
'layer1.0.conv1', # 초기 레이어
'layer2.1.conv2', # 중간 레이어
'layer3.2.conv3', # 더 깊은 레이어
'layer4.2.act2' # 마지막 레이어
]
for layer in layers_to_visualize:
print(f"Visualizing Grad-CAM for layer: {layer}")
visualize_gradcam(model.model, device, test_loader, target_layer=layer, num_images=3)
Insight
Grad-CAM을 활용한 시각화는 CNN 모델이 이미지를 분류할 때 주목하는 영역을 직관적으로 파악할 수 있어 모델 해석에 큰 도움을 준다. 특히, Augmentation이 적용된 데이터를 활용한 Grad-CAM 시각화는 모델이 다양한 데이터 변형에도 어떻게 반응하는지 확인할 수 있어 학습 성능을 개선하는 데 중요한 인사이트를 제공한다.
Grad-CAM을 통해 모델이 예측할 때 집중하는 부분을 명확하게 파악할 수 있었고, 이를 바탕으로 모델의 약점과 개선할 방향을 찾아낼 수 있었다. 예를 들어, 모델이 주목해야 할 핵심 영역이 아닌 다른 부분에 집중한다면, 데이터 증강 기법이나 아키텍처 수정이 필요할 수 있음을 시사한다. 이러한 피드백을 통해 모델을 보다 정확하고 효율적으로 발전시킬 수 있으며, 궁극적으로 더 높은 성능을 가진 딥러닝 모델로 개선할 수 있다.
Grad-CAM은 단순한 시각화 도구를 넘어, 모델 개선과 최적화를 위한 전략적 방향을 제시하는 중요한 도구로 활용 될 수 있다는 점이 인상깊었다.
감사합니다.
'AI > 딥러닝 프로젝트' 카테고리의 다른 글
[02 Object Detection] - MMdetection 설치 및 기본사용법 (6) | 2024.10.26 |
---|---|
[01 - ImageNet Sketch 데이터] - Augmentation [2] (2) | 2024.09.26 |
[01 - ImageNet Sketch 데이터] - Training (resnet101) [1] (16) | 2024.09.24 |
StratifiedKFold란? (1) | 2024.09.16 |
[01 - ImageNet Sketch 데이터] - Torchvision 와 Albumentations Trasnformer 비교 (0) | 2024.09.13 |
댓글