LSTM 코드 분석 [2-2편]

2025. 4. 9. 11:33·딥러닝 (Deep Learning)/[05] - 논문 리뷰
728x90
반응형

이전 글에서 LSTM 이론적인 부분을 설명했습니다.

 

 

LSTM 이란? [2-1편]

기본적인 순환신경망인 Vanilla RNN에 대해서 1편에서 설명 했습니다.하지만 현재는 Vanilla RNN 이 사용되고 있지 않습니다.어떠한 이유때문에 사용되고 있지 않는지, 그리고 더 발전된 RNN에 대해서

ai-bt.tistory.com

 

이번에는 코드를 분석해서 간단한 학습까지 진행하겠습니다.

 


1. LSTM 셀 코드

# 1. LSTM 셀 직접 구현
class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        # 입력 x_t와 hidden h_t-1를 concat해서 Linear 연산 (가중치 통합)
        self.W_f = nn.Linear(input_size + hidden_size, hidden_size)  # Forget Gate
        self.W_i = nn.Linear(input_size + hidden_size, hidden_size)  # Input Gate
        self.W_c = nn.Linear(input_size + hidden_size, hidden_size)  # Cell State
        self.W_o = nn.Linear(input_size + hidden_size, hidden_size)  # Output Gate

    def forward(self, x_t, h_prev, c_prev):

        # 1. input과 h_prev를 concat
        combined = torch.cat((x_t, h_prev), dim=1)  # (batch_size, input_size + hidden_size)

        # 2. 각 gate 계산
        f_t = torch.sigmoid(self.W_f(combined))   # Forget gate
        i_t = torch.sigmoid(self.W_i(combined))   # Input gate
        g_t = torch.tanh(self.W_c(combined))      # Cell state
        o_t = torch.sigmoid(self.W_o(combined))   # Output gate

        # 3. cell state 업데이트
        c_t = f_t * c_prev + i_t * g_t

        # 4. hidden state 업데이트
        h_t = o_t * torch.tanh(c_t)

        return h_t, c_t

 

1. combined = torch.cat((x_t, h_prev), dim=1)

무엇을 하는가?

  • 현재 입력값 x_t와, 직전 timestep의 hidden state h_prev를 붙인다(concatenation).
  • 왜? ➔ Gate 계산할 때 입력 정보(x)와 기존 상태 정보(h)를 같이 사용해야 하니까.

(한 문장 요약)

지금 timestep의 정보(x)와 과거 기억(h)을 합쳐서, "어떻게 업데이트할지" 계산 준비한다.

 


2. 각 Gate 계산

이제 붙인 combined을 각 Gate로 보낸다. (LSTM은 4개의 Gate가 있다.)

(1) f_t = torch.sigmoid(self.W_f(combined)) — Forget Gate

무엇을 하는가?

  • "이전 기억(c_prev) 중에서 무엇을 버리고 무엇을 남길까?"를 결정하는 Gate.
  • Sigmoid 함수를 거쳐서 0~1 사이 값을 만든다.
    • 0이면 완전히 잊고
    • 1이면 완전히 기억 유지

한 문장 요약

이전 cell memory 중 어떤 걸 "Forget(잊을지)" 결정한다.

 

(2) i_t = torch.sigmoid(self.W_i(combined)) — Input Gate

무엇을 하는가?

  • "지금 timestep 입력 x_t를 가지고, 뭘 새로 기억할까?"를 결정.
  • 역시 Sigmoid로 0~1 스케일로 만든다.
    • 0이면 입력 무시
    • 1이면 입력 강하게 반영

한 문장 요약

지금 timestep 입력을 보고 "무엇을 새로 기억할지" 결정한다.

 

 

(3) g_t = torch.tanh(self.W_c(combined)) — Candidate Cell State

무엇을 하는가?

  • 새로 추가될 정보 후보를 만든다.
  • Tanh 함수를 거쳐서 -1~1 값으로 나온다.
    • (완전히 부정하거나 완전히 긍정할 수 있음)

한 문장 요약

새롭게 기억할 후보를 만든다. (기억할 "내용" 자체)

 

 

(4) o_t = torch.sigmoid(self.W_o(combined)) — Output Gate

(무엇을 하는가?)

  • "cell state를 외부로 얼마나 드러낼까?"를 결정하는 Gate.
  • Hidden state(h_t)를 계산할 때 쓰인다.

(한 문장 요약)

cell memory를 얼마나 보여줄지 결정하는 게이트.

 


 

3. c_t = f_t * c_prev + i_t * g_t

무엇을 하는가?

  • cell state 업데이트하는 공식.

수식으로 보면:

\[
c_t = f_t \times c_{t-1} + i_t \times \tilde{c}_t
\]

​

해석하면:

\[
f_t \times c_{t-1} \quad \rightarrow \quad \text{"과거 기억 중 잊지 않고 남긴 것"}
\]
\[
i_t \times \tilde{c}_t \quad \rightarrow \quad \text{"새로 추가할 기억"}
\]

 

 

이 둘을 더해서 최신 cell state (c_t)를 만든다!

 

한 문장 요약

기존 기억은 일부 버리고, 새 기억을 추가해서 새로운 cell memory를 만든다.


2. Full LSTM 모델

# 2. Full LSTM 모델 (Seq2Seq 스타일)
class MyLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MyLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.lstm_cell = LSTMCell(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)  # hidden state를 output으로 변환

    def forward(self, inputs):
        batch_size, seq_len, _ = inputs.size()

        h_t = torch.zeros(batch_size, self.hidden_size, device=inputs.device)
        c_t = torch.zeros(batch_size, self.hidden_size, device=inputs.device)

        outputs = []
        for t in range(seq_len):
            x_t = inputs[:, t, :]
            h_t, c_t = self.lstm_cell(x_t, h_t, c_t)  # timestep마다 업데이트
            output = self.fc(h_t)  # h_t를 output으로 변환
            outputs.append(output.unsqueeze(1))  # (batch, 1, output_size)

        outputs = torch.cat(outputs, dim=1)  # (batch, seq_len, output_size)
        return outputs

 

MyLSTM 클래스는
➡️ 우리가 직접 만든 LSTMCell을 가지고
➡️ 시퀀스 전체를 처리하기 위해 만든 "하나의 묶음 모델" 이다.


쉽게 말하면,

LSTMCell 시퀀스의 한 timestep만 처리하는 셀 (xₜ, hₜ₋₁, cₜ₋₁ 받아서 hₜ, cₜ 계산)
MyLSTM 여러 timestep 전체를 순차적으로 처리하는 모델 (t=0,1,2,3... 순서대로 hₜ, cₜ 업데이트)

1) __init__ (초기화)

def __init__(self, input_size, hidden_size, output_size):
    super(MyLSTM, self).__init__()
    self.hidden_size = hidden_size
    self.lstm_cell = LSTMCell(input_size, hidden_size)
    self.fc = nn.Linear(hidden_size, output_size)  # hidden state를 output으로 변환
  • input_size: 입력 차원 (ex: 단어 하나당 100차원 벡터면 100)
  • hidden_size: hidden state 차원 (메모리 용량)
  • output_size: 최종 출력 차원

단어 하나가 "100차원 벡터"라는 뜻

더보기

✨ 단어는 컴퓨터 입장에서는 "숫자"로 표현돼야 해

우리가 보는 "apple" 같은 단어는 컴퓨터가 바로 이해할 수 없어.

그래서 단어를 숫자 벡터로 변환해서 컴퓨터가 다룰 수 있게 만들어.

 

✨ 예를 들면

"apple" 이라는 단어를:

[0.25, 0.01, 0.98, 0.42, ..., 0.03]  (총 100개 숫자)
 

처럼 100개의 숫자로 표현하는 거야.

이런 걸 "100차원 벡터" 라고 해.

  • "100차원" = 숫자가 100개 있음
  • "벡터" = 이 숫자들이 하나로 묶여 있음

✨ 왜 이렇게 하냐?

  • 단어를 숫자 벡터로 바꿔야
  • LSTM, Transformer 같은 수학 계산 모델이 다룰 수 있어.

(단순히 'apple'이라는 문자열로는 계산을 못하거든.)

 

궁금한 부분...
"단어를 어떤 숫자들로 바꿀까?"

"차원 수가 크면 무조건 좋다고 할 수 있을까??"

 

✅ self.lstm_cell:

  • 직접 만든 LSTMCell을 하나 선언한다.
  • (한 timestep씩 업데이트하는 작은 셀)

✅ self.fc:

  • hidden state를 output으로 변환하는 Linear 레이어.

2) forward (실제 데이터 들어왔을 때 흐름)

def forward(self, inputs):
    batch_size, seq_len, _ = inputs.size()

    h_t = torch.zeros(batch_size, self.hidden_size, device=inputs.device)
    c_t = torch.zeros(batch_size, self.hidden_size, device=inputs.device)

 

✅ 입력 데이터는 (batch_size, seq_len, input_size) 형태.

✅ 초기 hidden state(h₀), cell state(c₀)를 전부 0으로 초기화한다. (시작은 항상 아무 기억 없는 상태)


3) 시퀀스 처리 (for loop)

outputs = []
for t in range(seq_len):
    x_t = inputs[:, t, :]
    h_t, c_t = self.lstm_cell(x_t, h_t, c_t)  # timestep마다 업데이트
    output = self.fc(h_t)  # h_t를 output으로 변환
    outputs.append(output.unsqueeze(1))  # (batch, 1, output_size)

 

✅ 시퀀스 길이만큼 매 timestep마다

  • 입력값 x_t를 꺼낸다.
  • 이전의 hidden, cell state(hₜ₋₁, cₜ₋₁)를 넣고 LSTMCell을 돌린다.
  • 나온 hidden state를 fully connected layer를 통과시켜 output으로 변환한다.
  • 모든 output을 리스트에 모은다.

4) 최종 Output 만들기

outputs = torch.cat(outputs, dim=1)  # (batch, seq_len, output_size)
return outputs

✅ 시간축으로 쭉 모은 후 반환한다.


3. 학습 및 예측 코드

1) 간단한 데이터 생성 (generate_data)

def generate_data(seq_length=5, num_samples=1000):
    X = []
    Y = []
    for _ in range(num_samples):
        start = torch.randint(0, 10, (1,)).item()
        seq = torch.arange(start, start + seq_length)
        X.append(seq[:-1].unsqueeze(-1).float())  # (n-1 steps)
        Y.append(seq[1:].unsqueeze(-1).float())   # (next n-1 steps)
    return torch.stack(X), torch.stack(Y)

설명

  • 랜덤한 시작 숫자(start)를 뽑는다.
  • start부터 5개 연속된 숫자를 만든다. (예: 2 → [2, 3, 4, 5, 6])
  • 입력(X)과 정답(Y)을 만드는 방법:
    • X: [2, 3, 4, 5] (마지막 숫자 빼고)
    • Y: [3, 4, 5, 6] (첫 숫자 빼고)

즉,

"현재 숫자"를 입력하면 "다음 숫자"를 예측하는 문제를 만든 것.

 

포인트

  • sequence 예측은 미래를 예측하는 문제이기 때문에 이렇게 하나씩 밀어서 학습 데이터를 만든다!

2) 학습 준비

input_size = 1
hidden_size = 16
output_size = 1
seq_length = 5

model = MyLSTM(input_size, hidden_size, output_size)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

X_train, Y_train = generate_data(seq_length=seq_length, num_samples=1000)
 
 

설명

  • 입력은 숫자 하나니까 input_size = 1
  • hidden state 크기는 16으로 설정 (hidden_size=16)
  • 출력도 숫자 하나를 예측할 거니까 output_size = 1

✅ 모델 구성:

  • 직접 만든 MyLSTM 클래스를 사용한다.
  • 손실 함수는 MSELoss (Mean Squared Error) ➔ 수치 예측이니까 적합.
  • 옵티마이저는 Adam ➔ 학습을 빠르게 안정화시킴.

3) 학습

epochs = 200
losses = []

for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(X_train)
    loss = criterion(outputs, Y_train)
    loss.backward()  # 역전파
    optimizer.step()

    losses.append(loss.item())
    if (epoch+1) % 5 == 0:
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")

설명

  • 200 epoch 동안 학습을 진행한다.
  • 매 epoch마다:
    • 모델을 학습 모드로 설정
    • 입력(X_train)을 넣어 예측(output)을 만든다
    • 예측과 실제(Y_train) 사이의 오차(loss)를 계산
    • loss.backward()를 통해 역전파(Backpropagation) 를 수행한다
    • optimizer.step()으로 모델 파라미터를 업데이트한다.

✅ 5번마다 현재 epoch과 loss를 출력해서 학습 과정을 모니터링할 수 있다.

 

포인트

  • LSTM 구조 특성상 시간에 따라 쌓이는 정보를 학습해야 하기 때문에, epoch 수를 충분히 주는 것이 중요하다.

4) 테스트

epochs = 200
losses = []

for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(X_train)
    loss = criterion(outputs, Y_train)
    loss.backward()  # 역전파
    optimizer.step()

    losses.append(loss.item())
    if (epoch+1) % 5 == 0:
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")

설명

  • 학습이 끝난 모델을 평가하기 위해 eval 모드로 설정.
  • torch.no_grad()로 gradient 계산을 막는다 (속도 빠르고 메모리 절약).
  • [5,6,7,8]이라는 숫자 시퀀스를 입력으로 넣고, 모델이 그 다음 숫자들을 예측하게 한다.
  • 결과 출력:
    • Input Sequence: [5, 6, 7, 8]
    • Predicted Next Sequence: 모델이 예측한 [6, 7, 8, 9] 비슷한 결과

포인트

  • 이 모델은 "n번째 숫자 → n+1번째 숫자" 관계를 배워야 하는 문제를 학습했기 때문에,
  • [5,6,7,8] 입력 시 [6,7,8,9]을 예측해야 정답이다!

 

결과

 

숫자를 잘 예측하는 것을 볼 수 있다.


 

전체코드

import torch
import torch.nn as nn
import torch.optim as optim

# 1. LSTM 셀 직접 구현
class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        # 입력 x_t와 hidden h_t-1를 concat해서 Linear 연산 (가중치 통합)
        self.W_f = nn.Linear(input_size + hidden_size, hidden_size)  # Forget Gate
        self.W_i = nn.Linear(input_size + hidden_size, hidden_size)  # Input Gate
        self.W_c = nn.Linear(input_size + hidden_size, hidden_size)  # Cell State
        self.W_o = nn.Linear(input_size + hidden_size, hidden_size)  # Output Gate

    def forward(self, x_t, h_prev, c_prev):

        # 1. input과 h_prev를 concat
        combined = torch.cat((x_t, h_prev), dim=1)  # (batch_size, input_size + hidden_size)

        # 2. 각 gate 계산
        f_t = torch.sigmoid(self.W_f(combined))   # Forget gate
        i_t = torch.sigmoid(self.W_i(combined))   # Input gate
        g_t = torch.tanh(self.W_c(combined))      # Cell state
        o_t = torch.sigmoid(self.W_o(combined))   # Output gate

        # 3. cell state 업데이트
        c_t = f_t * c_prev + i_t * g_t

        # 4. hidden state 업데이트
        h_t = o_t * torch.tanh(c_t)

        return h_t, c_t

# 2. Full LSTM 모델 (Seq2Seq 스타일)
class MyLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MyLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.lstm_cell = LSTMCell(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)  # hidden state를 output으로 변환

    def forward(self, inputs):
        batch_size, seq_len, _ = inputs.size()

        h_t = torch.zeros(batch_size, self.hidden_size, device=inputs.device)
        c_t = torch.zeros(batch_size, self.hidden_size, device=inputs.device)

        outputs = []
        for t in range(seq_len):
            x_t = inputs[:, t, :]
            h_t, c_t = self.lstm_cell(x_t, h_t, c_t)  # timestep마다 업데이트
            output = self.fc(h_t)  # h_t를 output으로 변환
            outputs.append(output.unsqueeze(1))  # (batch, 1, output_size)

        outputs = torch.cat(outputs, dim=1)  # (batch, seq_len, output_size)
        return outputs

# 3. 간단한 데이터 생성 (RNN 때랑 비슷하게)
def generate_data(seq_length=5, num_samples=1000):
    X = []
    Y = []
    for _ in range(num_samples):
        start = torch.randint(0, 10, (1,)).item()
        seq = torch.arange(start, start + seq_length)
        X.append(seq[:-1].unsqueeze(-1).float())  # (n-1 steps)
        Y.append(seq[1:].unsqueeze(-1).float())   # (next n-1 steps)
    return torch.stack(X), torch.stack(Y)

# 4. 학습 준비
input_size = 1
hidden_size = 16
output_size = 1
seq_length = 5

model = MyLSTM(input_size, hidden_size, output_size)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

X_train, Y_train = generate_data(seq_length=seq_length, num_samples=1000)

# 5. 학습
epochs = 200
losses = []

for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(X_train)
    loss = criterion(outputs, Y_train)
    loss.backward()
    optimizer.step()

    losses.append(loss.item())
    if (epoch+1) % 5 == 0:
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")

# 6. 테스트
model.eval()
with torch.no_grad():
    test_input = torch.tensor([[[5], [6], [7], [8]]]).float()
    prediction = model(test_input)
    print("Input Sequence: ", test_input.squeeze(-1).squeeze(0).tolist())
    print("Predicted Next Sequence: ", prediction.squeeze(-1).squeeze(0).tolist())

 

이상으로 LSTM 코드 관련 내용을 마치겠습니다.

감사합니다.

728x90
반응형
저작자표시 비영리 변경금지 (새창열림)

'딥러닝 (Deep Learning) > [05] - 논문 리뷰' 카테고리의 다른 글

[05] - Receptive Field 를 확장시킨 모델 (DeepLab v1, DeepLab v2)  (6) 2024.11.19
[04] - FC DenseNet 이란?  (1) 2024.11.15
[03] - 빠르면서도 정확한 SegNet  (0) 2024.11.14
[02] - FCN 한계점을 극복한 DeconvNet 이란??  (3) 2024.11.13
[01] - FCN32s, FCN16s, FCN8s 이란?  (0) 2024.11.12
'딥러닝 (Deep Learning)/[05] - 논문 리뷰' 카테고리의 다른 글
  • [05] - Receptive Field 를 확장시킨 모델 (DeepLab v1, DeepLab v2)
  • [04] - FC DenseNet 이란?
  • [03] - 빠르면서도 정확한 SegNet
  • [02] - FCN 한계점을 극복한 DeconvNet 이란??
AI-BT
AI-BT
인공지능 (AI)과 블록체인에 관심있는 블로그
  • AI-BT
    AI-BLACK-TIGER
    AI-BT
  • 전체
    오늘
    어제
    • 분류 전체보기 (133)
      • 딥러닝 (Deep Learning) (81)
        • [01] - 딥러닝 이란? (5)
        • [02] - 데이터 (4)
        • [03] - 모델 (17)
        • [04] - 학습 및 최적화 (14)
        • [05] - 논문 리뷰 (17)
        • [06] - 평가 및 결과 분석 (4)
        • [07] - Serving (6)
        • [08] - 프로젝트 (14)
      • 머신러닝 & 딥러닝 개념 (0)
        • 머신러닝 (0)
        • 딥러닝 (0)
      • Quant 투자 (12)
        • 경제 (9)
        • 퀀트 알고리즘 & 전략 개요 (3)
      • 딥러닝 Math (4)
      • AI Naver boost camp (22)
        • 회고 (19)
        • CV 프로젝트 가이드 (3)
      • Python (1)
      • 개발 및 IT 용어 (6)
        • IT 용어 (2)
        • VS Code (1)
      • 코인 정보 (7)
  • 인기 글

  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.3
AI-BT
LSTM 코드 분석 [2-2편]
상단으로

티스토리툴바