굴러가는 분석가의 일상

[딥러닝 모델] CNN im2col 이해하기 본문

Computer Vision

[딥러닝 모델] CNN im2col 이해하기

G3LU 2024. 2. 23. 04:39

※ 본 게시물에서는 합성곱 연산을 효율적으로 수행하기 위한 im2col에 대해 알아보도록 하겠습니다. 

💡 im2col 이해하기

CNN은 3차원의 데이터 (주로 이미지)를 학습시켜 특징을 추출하는데 특화되어 있는 신경망입니다. 이에 Spatial 영역의 정보를 잃지 않기 위해 합성곱 연산을 여러 개의 for문과 같이 사용하게 됩니다. 이에 수만 건의 데이터를 처리해야하는 CNN에 적합하지도 않고 효율적이지도 않습니다. 또한 Numpy를 통해서 원소에 접근할 때 for 문을 사용하면, 성능이 떨어지는 단점도 있습니다. 

 

그렇다면, 이러한 문제점을 가진 합성곱 연산은 어떻게 해결해야 할까요? 이에 대한 해결책이 바로 본 게시물의 주제인 im2col입니다. 

 

 

im2col의 동작 원리

 

im2col은 다차원의 데이터를 행렬로 변환하여 행렬 연산을 하도록 해주는 함수이며, 입력 데이터를 필터링(가중치 계산)하기 좋게 전개하는 함수입니다. 위의 그림과 같이 3차원 입력 데이터에 im2col 함수를 적용하면 2차원 행렬로 바꿔주는 것을 보실 수 있습니다. 

 

조금 더 자세하게 알아보기 위해 예시를 들어보도록 하겠습니다. 예를 들어, 3X3 입력 데이터, 2X2 필터, 스트라이드 값이 1인 데이터가 있다고 아래와 같이 가정을 해보겠습니다.

 

im2col를 적용하게 된다면 아래의 사진과 같이 입력 데이터와 필터의 단일 곱셉-누산을 통해 flatten 된 것을 확인할 수 있습니다.

 

이제 flatten된 (4,4) 입력 데이터와 (4,1) 필터의 행렬 내적으로 아래와 같이 표현할 수 있습니다. 

합성곱 신경망(CNN)은 매 Layer에서 4차원 데이터를 처리하기 때문에 2차원 데이터를 N(데이터 개수) x OH(데이터 높이) x OW(데이터 폭) x FN (필터 개수) 와 같은 4차원으로 변형시켜줘야 합니다. 이는 reshape이라고 합니다. 

 

 

조금 더 직관적인 설명이 필요하시다면, 아래의 그림을 참조 부탁드립니다.