이전 글에서 LSTM 이론적인 부분을 설명했습니다.
LSTM 이란? [2-1편]
기본적인 순환신경망인 Vanilla RNN에 대해서 1편에서 설명 했습니다.하지만 현재는 Vanilla RNN 이 사용되고 있지 않습니다.어떠한 이유때문에 사용되고 있지 않는지, 그리고 더 발전된 RNN에 대해서
ai-bt.tistory.com
이번에는 코드를 분석해서 간단한 학습까지 진행하겠습니다.
1. LSTM 셀 코드
# 1. LSTM 셀 직접 구현
class LSTMCell(nn.Module):
def __init__(self, input_size, hidden_size):
super(LSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
# 입력 x_t와 hidden h_t-1를 concat해서 Linear 연산 (가중치 통합)
self.W_f = nn.Linear(input_size + hidden_size, hidden_size) # Forget Gate
self.W_i = nn.Linear(input_size + hidden_size, hidden_size) # Input Gate
self.W_c = nn.Linear(input_size + hidden_size, hidden_size) # Cell State
self.W_o = nn.Linear(input_size + hidden_size, hidden_size) # Output Gate
def forward(self, x_t, h_prev, c_prev):
# 1. input과 h_prev를 concat
combined = torch.cat((x_t, h_prev), dim=1) # (batch_size, input_size + hidden_size)
# 2. 각 gate 계산
f_t = torch.sigmoid(self.W_f(combined)) # Forget gate
i_t = torch.sigmoid(self.W_i(combined)) # Input gate
g_t = torch.tanh(self.W_c(combined)) # Cell state
o_t = torch.sigmoid(self.W_o(combined)) # Output gate
# 3. cell state 업데이트
c_t = f_t * c_prev + i_t * g_t
# 4. hidden state 업데이트
h_t = o_t * torch.tanh(c_t)
return h_t, c_t
1. combined = torch.cat((x_t, h_prev), dim=1)
무엇을 하는가?
- 현재 입력값 x_t와, 직전 timestep의 hidden state h_prev를 붙인다(concatenation).
- 왜? ➔ Gate 계산할 때 입력 정보(x)와 기존 상태 정보(h)를 같이 사용해야 하니까.
(한 문장 요약)
지금 timestep의 정보(x)와 과거 기억(h)을 합쳐서, "어떻게 업데이트할지" 계산 준비한다.
2. 각 Gate 계산
이제 붙인 combined을 각 Gate로 보낸다. (LSTM은 4개의 Gate가 있다.)
(1) f_t = torch.sigmoid(self.W_f(combined)) — Forget Gate
무엇을 하는가?
- "이전 기억(c_prev) 중에서 무엇을 버리고 무엇을 남길까?"를 결정하는 Gate.
- Sigmoid 함수를 거쳐서 0~1 사이 값을 만든다.
- 0이면 완전히 잊고
- 1이면 완전히 기억 유지
한 문장 요약
이전 cell memory 중 어떤 걸 "Forget(잊을지)" 결정한다.
(2) i_t = torch.sigmoid(self.W_i(combined)) — Input Gate
무엇을 하는가?
- "지금 timestep 입력 x_t를 가지고, 뭘 새로 기억할까?"를 결정.
- 역시 Sigmoid로 0~1 스케일로 만든다.
- 0이면 입력 무시
- 1이면 입력 강하게 반영
한 문장 요약
지금 timestep 입력을 보고 "무엇을 새로 기억할지" 결정한다.
(3) g_t = torch.tanh(self.W_c(combined)) — Candidate Cell State
무엇을 하는가?
- 새로 추가될 정보 후보를 만든다.
- Tanh 함수를 거쳐서 -1~1 값으로 나온다.
- (완전히 부정하거나 완전히 긍정할 수 있음)
한 문장 요약
새롭게 기억할 후보를 만든다. (기억할 "내용" 자체)
(4) o_t = torch.sigmoid(self.W_o(combined)) — Output Gate
(무엇을 하는가?)
- "cell state를 외부로 얼마나 드러낼까?"를 결정하는 Gate.
- Hidden state(h_t)를 계산할 때 쓰인다.
(한 문장 요약)
cell memory를 얼마나 보여줄지 결정하는 게이트.
3. c_t = f_t * c_prev + i_t * g_t
무엇을 하는가?
- cell state 업데이트하는 공식.
수식으로 보면:
\[
c_t = f_t \times c_{t-1} + i_t \times \tilde{c}_t
\]
해석하면:
\[
f_t \times c_{t-1} \quad \rightarrow \quad \text{"과거 기억 중 잊지 않고 남긴 것"}
\]
\[
i_t \times \tilde{c}_t \quad \rightarrow \quad \text{"새로 추가할 기억"}
\]
이 둘을 더해서 최신 cell state (c_t)를 만든다!
한 문장 요약
기존 기억은 일부 버리고, 새 기억을 추가해서 새로운 cell memory를 만든다.
2. Full LSTM 모델
# 2. Full LSTM 모델 (Seq2Seq 스타일)
class MyLSTM(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MyLSTM, self).__init__()
self.hidden_size = hidden_size
self.lstm_cell = LSTMCell(input_size, hidden_size)
self.fc = nn.Linear(hidden_size, output_size) # hidden state를 output으로 변환
def forward(self, inputs):
batch_size, seq_len, _ = inputs.size()
h_t = torch.zeros(batch_size, self.hidden_size, device=inputs.device)
c_t = torch.zeros(batch_size, self.hidden_size, device=inputs.device)
outputs = []
for t in range(seq_len):
x_t = inputs[:, t, :]
h_t, c_t = self.lstm_cell(x_t, h_t, c_t) # timestep마다 업데이트
output = self.fc(h_t) # h_t를 output으로 변환
outputs.append(output.unsqueeze(1)) # (batch, 1, output_size)
outputs = torch.cat(outputs, dim=1) # (batch, seq_len, output_size)
return outputs
MyLSTM 클래스는
➡️ 우리가 직접 만든 LSTMCell을 가지고
➡️ 시퀀스 전체를 처리하기 위해 만든 "하나의 묶음 모델" 이다.
쉽게 말하면,
LSTMCell | 시퀀스의 한 timestep만 처리하는 셀 (xₜ, hₜ₋₁, cₜ₋₁ 받아서 hₜ, cₜ 계산) |
MyLSTM | 여러 timestep 전체를 순차적으로 처리하는 모델 (t=0,1,2,3... 순서대로 hₜ, cₜ 업데이트) |
1) __init__ (초기화)
def __init__(self, input_size, hidden_size, output_size):
super(MyLSTM, self).__init__()
self.hidden_size = hidden_size
self.lstm_cell = LSTMCell(input_size, hidden_size)
self.fc = nn.Linear(hidden_size, output_size) # hidden state를 output으로 변환
- input_size: 입력 차원 (ex: 단어 하나당 100차원 벡터면 100)
- hidden_size: hidden state 차원 (메모리 용량)
- output_size: 최종 출력 차원
단어 하나가 "100차원 벡터"라는 뜻
✨ 단어는 컴퓨터 입장에서는 "숫자"로 표현돼야 해
우리가 보는 "apple" 같은 단어는 컴퓨터가 바로 이해할 수 없어.
그래서 단어를 숫자 벡터로 변환해서 컴퓨터가 다룰 수 있게 만들어.
✨ 예를 들면
"apple" 이라는 단어를:
[0.25, 0.01, 0.98, 0.42, ..., 0.03] (총 100개 숫자)
처럼 100개의 숫자로 표현하는 거야.
이런 걸 "100차원 벡터" 라고 해.
- "100차원" = 숫자가 100개 있음
- "벡터" = 이 숫자들이 하나로 묶여 있음
✨ 왜 이렇게 하냐?
- 단어를 숫자 벡터로 바꿔야
- LSTM, Transformer 같은 수학 계산 모델이 다룰 수 있어.
(단순히 'apple'이라는 문자열로는 계산을 못하거든.)
궁금한 부분...
"단어를 어떤 숫자들로 바꿀까?"
"차원 수가 크면 무조건 좋다고 할 수 있을까??"
✅ self.lstm_cell:
- 직접 만든 LSTMCell을 하나 선언한다.
- (한 timestep씩 업데이트하는 작은 셀)
✅ self.fc:
- hidden state를 output으로 변환하는 Linear 레이어.
2) forward (실제 데이터 들어왔을 때 흐름)
def forward(self, inputs):
batch_size, seq_len, _ = inputs.size()
h_t = torch.zeros(batch_size, self.hidden_size, device=inputs.device)
c_t = torch.zeros(batch_size, self.hidden_size, device=inputs.device)
✅ 입력 데이터는 (batch_size, seq_len, input_size) 형태.
✅ 초기 hidden state(h₀), cell state(c₀)를 전부 0으로 초기화한다. (시작은 항상 아무 기억 없는 상태)
3) 시퀀스 처리 (for loop)
outputs = []
for t in range(seq_len):
x_t = inputs[:, t, :]
h_t, c_t = self.lstm_cell(x_t, h_t, c_t) # timestep마다 업데이트
output = self.fc(h_t) # h_t를 output으로 변환
outputs.append(output.unsqueeze(1)) # (batch, 1, output_size)
✅ 시퀀스 길이만큼 매 timestep마다
- 입력값 x_t를 꺼낸다.
- 이전의 hidden, cell state(hₜ₋₁, cₜ₋₁)를 넣고 LSTMCell을 돌린다.
- 나온 hidden state를 fully connected layer를 통과시켜 output으로 변환한다.
- 모든 output을 리스트에 모은다.
4) 최종 Output 만들기
outputs = torch.cat(outputs, dim=1) # (batch, seq_len, output_size)
return outputs
✅ 시간축으로 쭉 모은 후 반환한다.
3. 학습 및 예측 코드
1) 간단한 데이터 생성 (generate_data)
def generate_data(seq_length=5, num_samples=1000):
X = []
Y = []
for _ in range(num_samples):
start = torch.randint(0, 10, (1,)).item()
seq = torch.arange(start, start + seq_length)
X.append(seq[:-1].unsqueeze(-1).float()) # (n-1 steps)
Y.append(seq[1:].unsqueeze(-1).float()) # (next n-1 steps)
return torch.stack(X), torch.stack(Y)
설명
- 랜덤한 시작 숫자(start)를 뽑는다.
- start부터 5개 연속된 숫자를 만든다. (예: 2 → [2, 3, 4, 5, 6])
- 입력(X)과 정답(Y)을 만드는 방법:
- X: [2, 3, 4, 5] (마지막 숫자 빼고)
- Y: [3, 4, 5, 6] (첫 숫자 빼고)
즉,
"현재 숫자"를 입력하면 "다음 숫자"를 예측하는 문제를 만든 것.
포인트
- sequence 예측은 미래를 예측하는 문제이기 때문에 이렇게 하나씩 밀어서 학습 데이터를 만든다!
2) 학습 준비
input_size = 1
hidden_size = 16
output_size = 1
seq_length = 5
model = MyLSTM(input_size, hidden_size, output_size)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
X_train, Y_train = generate_data(seq_length=seq_length, num_samples=1000)
설명
- 입력은 숫자 하나니까 input_size = 1
- hidden state 크기는 16으로 설정 (hidden_size=16)
- 출력도 숫자 하나를 예측할 거니까 output_size = 1
✅ 모델 구성:
- 직접 만든 MyLSTM 클래스를 사용한다.
- 손실 함수는 MSELoss (Mean Squared Error) ➔ 수치 예측이니까 적합.
- 옵티마이저는 Adam ➔ 학습을 빠르게 안정화시킴.
3) 학습
epochs = 200
losses = []
for epoch in range(epochs):
model.train()
optimizer.zero_grad()
outputs = model(X_train)
loss = criterion(outputs, Y_train)
loss.backward() # 역전파
optimizer.step()
losses.append(loss.item())
if (epoch+1) % 5 == 0:
print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")
설명
- 200 epoch 동안 학습을 진행한다.
- 매 epoch마다:
- 모델을 학습 모드로 설정
- 입력(X_train)을 넣어 예측(output)을 만든다
- 예측과 실제(Y_train) 사이의 오차(loss)를 계산
- loss.backward()를 통해 역전파(Backpropagation) 를 수행한다
- optimizer.step()으로 모델 파라미터를 업데이트한다.
✅ 5번마다 현재 epoch과 loss를 출력해서 학습 과정을 모니터링할 수 있다.
포인트
- LSTM 구조 특성상 시간에 따라 쌓이는 정보를 학습해야 하기 때문에, epoch 수를 충분히 주는 것이 중요하다.
4) 테스트
epochs = 200
losses = []
for epoch in range(epochs):
model.train()
optimizer.zero_grad()
outputs = model(X_train)
loss = criterion(outputs, Y_train)
loss.backward() # 역전파
optimizer.step()
losses.append(loss.item())
if (epoch+1) % 5 == 0:
print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")
설명
- 학습이 끝난 모델을 평가하기 위해 eval 모드로 설정.
- torch.no_grad()로 gradient 계산을 막는다 (속도 빠르고 메모리 절약).
- [5,6,7,8]이라는 숫자 시퀀스를 입력으로 넣고, 모델이 그 다음 숫자들을 예측하게 한다.
- 결과 출력:
- Input Sequence: [5, 6, 7, 8]
- Predicted Next Sequence: 모델이 예측한 [6, 7, 8, 9] 비슷한 결과
포인트
- 이 모델은 "n번째 숫자 → n+1번째 숫자" 관계를 배워야 하는 문제를 학습했기 때문에,
- [5,6,7,8] 입력 시 [6,7,8,9]을 예측해야 정답이다!
결과
숫자를 잘 예측하는 것을 볼 수 있다.
전체코드
import torch
import torch.nn as nn
import torch.optim as optim
# 1. LSTM 셀 직접 구현
class LSTMCell(nn.Module):
def __init__(self, input_size, hidden_size):
super(LSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
# 입력 x_t와 hidden h_t-1를 concat해서 Linear 연산 (가중치 통합)
self.W_f = nn.Linear(input_size + hidden_size, hidden_size) # Forget Gate
self.W_i = nn.Linear(input_size + hidden_size, hidden_size) # Input Gate
self.W_c = nn.Linear(input_size + hidden_size, hidden_size) # Cell State
self.W_o = nn.Linear(input_size + hidden_size, hidden_size) # Output Gate
def forward(self, x_t, h_prev, c_prev):
# 1. input과 h_prev를 concat
combined = torch.cat((x_t, h_prev), dim=1) # (batch_size, input_size + hidden_size)
# 2. 각 gate 계산
f_t = torch.sigmoid(self.W_f(combined)) # Forget gate
i_t = torch.sigmoid(self.W_i(combined)) # Input gate
g_t = torch.tanh(self.W_c(combined)) # Cell state
o_t = torch.sigmoid(self.W_o(combined)) # Output gate
# 3. cell state 업데이트
c_t = f_t * c_prev + i_t * g_t
# 4. hidden state 업데이트
h_t = o_t * torch.tanh(c_t)
return h_t, c_t
# 2. Full LSTM 모델 (Seq2Seq 스타일)
class MyLSTM(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MyLSTM, self).__init__()
self.hidden_size = hidden_size
self.lstm_cell = LSTMCell(input_size, hidden_size)
self.fc = nn.Linear(hidden_size, output_size) # hidden state를 output으로 변환
def forward(self, inputs):
batch_size, seq_len, _ = inputs.size()
h_t = torch.zeros(batch_size, self.hidden_size, device=inputs.device)
c_t = torch.zeros(batch_size, self.hidden_size, device=inputs.device)
outputs = []
for t in range(seq_len):
x_t = inputs[:, t, :]
h_t, c_t = self.lstm_cell(x_t, h_t, c_t) # timestep마다 업데이트
output = self.fc(h_t) # h_t를 output으로 변환
outputs.append(output.unsqueeze(1)) # (batch, 1, output_size)
outputs = torch.cat(outputs, dim=1) # (batch, seq_len, output_size)
return outputs
# 3. 간단한 데이터 생성 (RNN 때랑 비슷하게)
def generate_data(seq_length=5, num_samples=1000):
X = []
Y = []
for _ in range(num_samples):
start = torch.randint(0, 10, (1,)).item()
seq = torch.arange(start, start + seq_length)
X.append(seq[:-1].unsqueeze(-1).float()) # (n-1 steps)
Y.append(seq[1:].unsqueeze(-1).float()) # (next n-1 steps)
return torch.stack(X), torch.stack(Y)
# 4. 학습 준비
input_size = 1
hidden_size = 16
output_size = 1
seq_length = 5
model = MyLSTM(input_size, hidden_size, output_size)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
X_train, Y_train = generate_data(seq_length=seq_length, num_samples=1000)
# 5. 학습
epochs = 200
losses = []
for epoch in range(epochs):
model.train()
optimizer.zero_grad()
outputs = model(X_train)
loss = criterion(outputs, Y_train)
loss.backward()
optimizer.step()
losses.append(loss.item())
if (epoch+1) % 5 == 0:
print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")
# 6. 테스트
model.eval()
with torch.no_grad():
test_input = torch.tensor([[[5], [6], [7], [8]]]).float()
prediction = model(test_input)
print("Input Sequence: ", test_input.squeeze(-1).squeeze(0).tolist())
print("Predicted Next Sequence: ", prediction.squeeze(-1).squeeze(0).tolist())
이상으로 LSTM 코드 관련 내용을 마치겠습니다.
감사합니다.
'딥러닝 (Deep Learning) > [05] - 논문 리뷰' 카테고리의 다른 글
[05] - Receptive Field 를 확장시킨 모델 (DeepLab v1, DeepLab v2) (6) | 2024.11.19 |
---|---|
[04] - FC DenseNet 이란? (1) | 2024.11.15 |
[03] - 빠르면서도 정확한 SegNet (0) | 2024.11.14 |
[02] - FCN 한계점을 극복한 DeconvNet 이란?? (3) | 2024.11.13 |
[01] - FCN32s, FCN16s, FCN8s 이란? (0) | 2024.11.12 |