본문 바로가기
Tensorflow

[ANN] 09. Back Propagation of Neural Network

by 청양호박이 2020. 5. 16.

이제 ANN(인공신경망)에서 알아보아야 할 많은 부분을 정리해보았습니다. 순전파(Forward Propagation)부터 Loss Function을 통한 오차를 구하고, 오차를 최소화하기위해 수치미분을 통해 기울기를 구하고, 그것으로부터 Global Min 혹은 Local Min을 찾아 최소화하는 GDA(Gradient Descent Algorithm)을 통해 학습의 기본까지 확인했습니다.

 

이제 그 마지막 단계인 역전파(Back Propagation)를 통해서, 도출된 오류를 뒤의 Layer로 계속 전달하여 변수인 weight, bias를 자동으로 학습하는 내용에 대해서 알아보겠습니다. 오류 역전파는 이제 그 종착지라고 생각하면 마음이 편해질 것 입니다. 

 

  • 인공신경망(ANN)의 개요 및 퍼셉트론과의 차이점
  • 인공신경망(ANN)의 활성화함수(AF, Activation Function)
  • 인공신경망(ANN)의 출력층 설계 및 순전파(Forward Propagation)
  • 인공신경망(ANN)의 손실함수(Loss) 구현
  • 수치미분 및 편미분, 그리고 GDA
  • 인공신경망(ANN)의 역전파(Back Propagation) 개념
  • 인공신경망(ANN)의 역전파(Back Propagation) 구현

 

 

1. 역전파(Back Propagation) 란?


지금까지 계속 ANN에 대해서 이야기하면서 학습은 Loss Function의 수치미분을 통해서 무언가를 한다고 말했었습니다. 결국 이 미분결과를 모델의 Layer구성의 흐름의 반대방향... 즉 순전파의 반대방향으로 전달하며 학습하는 방법이 바로 역전파 입니다. 

 

신경망의 Layer 내부의 각 node들은 아래와 같이 구성되어 있고, 이 node의 그림은 그래프(Graph)로 그려지기 때문에 그 흐름으로 설명하고 이해하는것이 도움이 됩니다.

 

자 그럼 본격적으로 역전파에 대해서 알아보겠습니다. 쉬운 설명을 위해서 아래의 합과 곱으로 구성된 인공신경망처럼 생긴 Graph를 구성해 보겠습니다.

 

이 Graph는....

Y = ((X1 * X2) + X3) * X4 

 

에 해당하는 함수 f(x) 입니다. 또한 그래프의 좌측에서 우측으로 진행되는 순전파(Forward Propagation)입니다.

 

여기서 역전파의 개념은 X1의 변경이 최종 결과인 output Y에 미치는 영향을 말하는 것이고, 이는 돌려 돌려 표현을 한다면... 결국 dY/dX1 즉 output인 Y를 X1으로 미분한 결과라고 할 수 있습니다. 

 

최종적으로 input에 해당하는 X1, X2, X3, X4에 대해서 output인 Y에 미치는 영향을 다 파악하는 것이 역전파라고 할 수 있습니다. 아니... 그냥 다 각각 미분하면 되지 뭐하러 역전파니 계속 뒤로 뭘 전달한다느니 이런 말을 사용할까요??

 

결국 input layer에 해당하는 지점을 기준으로 미분을 하려고 하면 복잡한 Graph에서는 미분의 Chain Rule을 알아야 합니다. 이것이 역전파라는 용어가 성립될 수 있는 기반이 되기 때문이지요. 

 

[Chain Rule]

합성함수를 예를들자면,

f(x) = (x + y)^2

 

라는 함수가 있다고 하고 이를 미분한다고 합시다. f(x)를 Z라고 하고 Z를 x에 대해서 미분한다고 하면, dZ/dx를 위해서는 아래의 단계를 거치게 됩니다.

 

(1) x + y를 t라는 변수로 치환 : 결국 f(x) = t^2이 되고, dZ/dt를 우선 수행 (Z를 t에 대해서 미분) : 그 결과는 2t가 나옴

(2) t를 가지고 x에 대해서 미분 : 결국 t = x + y에서 dt/dx를 수행 : 그 결과는 1이 나옴

(3) chain rule에 의거해서 

     dZ/dx = dZ/dt * dt/dx = 2t * 1 = 2(x + y)

     가 결과적으로 나오게 됩니다. 

 

동일한지 확인해보기 위해서, f(x)를 전개한다음에 미분을 해보겠습니다.

 

f(x) = x^2 + 2xy + y^2

dZ/dx = 2x + 2y = 2(x + y)

 

이렇게 도출되기 때문에 정확하게 일치합니다.

 

그럼 한가지 의문이 들 수 있습니다. f(x)가 주어지면 그냥 전개한다음에 미분하면되지... 뭐하러 chain rule이라는 어려워 보이는 개념까지 도입해서 힘들게하냐...!!

 

그럼 위의 합과 곱의 구성그래프를 볼까요?? 

 

Y = ((X1 * X2) + X3) * X4 

 

이 경우에서, dY/dx1을 구하라고 한다면 어떨까요?? 아니면 더 복잡 무시한 그래프가 나온다면 어떨까요?? 쉽지않을 것입니다. 이러한 이유로 chain rule을 통해서 미분을 하게 되는 것입니다.

 

왜 사용하는지는 알았고... 그럼 Graph와 도데체 어떻게 연결되는지, 위의 예제를 통해서 알아보겠습니다.

 

f(x) = (x + y)^2

 

이 함수를 그래프로 표현하면 아래와 같습니다.

 

가장 오른쪽에 위치한 노드의 input값은 (x + y)와 2 입니다. 그럼 뒤쪽은 생각하지말고 해당 노드만 생각해 본다면, output인 f(x)에 대해서 input으로 미분한다면...

 

여기서 (x+y)입력을 t로 치환하면,

f(x) = t^2

df(x)/d(x+y) = df(x)/dt = 2t 

 

가 됩니다. 그럼 그 전의 노드도 이동해서... 동일한 방법으로 해당 노드만 보고 미분을 진행해 보겠습니다.

 

(x+y)가 t로 치환된 부분은 동일합니다. 그렇다면 input은 x와 y이고 output은 t가 됩니다. 그럼 output인 t에 대해서 input으로 미분한다면...

dt/dx = 1 

 

이 됩니다. 결국 이 과정을 합친것이 바로 Chain Rule과 동일한 절차가 되며, 각 단계를 거칠때마다의 결과를 서로 곱해주면 되는 것입니다.

df(x)/dt * dt/dx = 2t = 2(x + y)

 

결국 위에서 확인한 내용과 완벽하게 일치하게 됩니다.

 

이렇게 최종 결과의 Loss에 대해서 보정값을 각 Layer에 단계적으로 Chain Rule에 의해 내려보내는 것이 바로 역전파(Back Propagation)의 개념입니다.

 

참 멀게도 돌아왔습니다... ㅠㅠ

 

 

2. 역전파(Back Propagation) 기본


아래의 그림은 위에 주구장창 이야기 했던 역전파에 대해서 한개의 node내 1개의 연산을 기준으로 동작하는 원리를 간단하게 표현한 것 입니다.

순전파를 통해서 x -> y가 도출이 되고 Loss Function을 통해서 현재 미분한 값이 나오고, 학습할 수치가 정해집니다. 이 수치는 다시 역전파를 통해서 dL(x)/dy -> dL(x)/dy * dy/dx로 이동하는 로직이 되는 것 입니다.

 

그럼 좀더 세부적으로 python코드를 통해서 알아보기 위해서 (+), (*) 연산자가 들어가는 graph를 예로 들어보겠습니다.

 

우선 위에서 알아본 방법으로 손으로 계산을 해보도록 하겠습니다.

 

[순전파(Forward Propagation)]

(1) 첫번째 node : input은 x, y 이며, output은 x + y로 이는 t로 치환

(2) 두번째 node : input은 t, z 이며, output은 t * z 임

 

[Loss Function]

해당 부분은 따로 target도 없고, MSE를 적용할지 CEE를 적용할지 모르기 때문에 Loss Function을 통해서 도출한 수치미분 값을 1로 정의하겠습니다.

dL(x)/dout = 1

 

[역전파(Back Propagation)

(1) 첫번째 node : output은 t * z이고 이를 t로 미분하며, dout/dt = z

(2) 두번째 node : output은 t 즉 x + y이고 이를 x로 미분하며, dt/dx = 1

(3) 최종 input layer에 대한 결과는,

     df(x)/x = df(x)/dout * dout/dt * dt/dx = 1 * z * 1 = z

     df(x)/y =  df(x)/dout * dout/dt * dt/dy = 1 * z * 1 = z

     df(x)/z =  df(x)/dout * dout/dz = 1 * t = x + y

 

아주 깔끔하게 정리가 되었습니다. 이게 바로 역전파의 모든 것 입니다.

 

 

3. 역전파(Back Propagation) Tip


위에 연산자의 예를 (+), (*)로 들은 이유는 앞으로 알아볼 활성화함수(AF, Activation Function)이 대부분 이 연산자의 조합으로 구성이 되어있기 때문입니다. 

 

그래서 이 두가지가 동작하는 로직을 살펴보면 아래와 같습니다. 

(1) + 연산자는 역전파 시, 들어온 입력에 대해서 양쪽으로 입력값을 동일하게 전파합니다.

(2) * 연산자는 역전파 시, 들어온 입력에 대해서 상대방의 크기만큼 곱한 값으로 전파합니다.

 

 

4. 역전파(Back Propagation) 구현


지금까지 알아본 내용으로, 해당 노드들을 python으로 구현해 보겠습니다. 동작로직은 크게 어려운 부분이 없긴해서... 간단하게 끝나지 않을까 싶습니다.

 

각 연산자에 대해서 class로 구성하고, 각 class에는 순전파와 역전파에 해당하는 함수를 구현하면 끝입니다. 

(예시에 맞게 모든 입력은 2개로 제한하겠습니다.)

 

[Add Node]

import numpy as np

class AddNode:
    def __init__(self):
        self.x = None
        self.y = None

    def forward(self, x, y):
        self.x = x
        self.y = y
        return x + y

    def backward(self, dz):
        dx = dy = dz
        return dx, dy


add_node = AddNode()
t = add_node.forward(100, 200)
dx, dy = add_node.backward(1)

print(t)
print(dx, dy)
================================

forward : 300
backward : dx - 1, dy - 1

공통적으로 class생성 시, 2개 입력을 위한 변수를 생성자를 통해 생성하고, forward함수의 경우는 input을 받아 해당 변수에 할당 및 합의 결과를 리턴 해줍니다. 

 

Backward함수의 경우는 미분값이 input으로 들어오면 양쪽으로 input값을 동일하게 전달 및 출력해 줍니다.

 

[Mul Node]

class MulNode:
    def __init__(self):
        self.x = None
        self.y = None

    def forward(self, x, y):
        self.x = x
        self.y = y
        return x * y

    def backward(self, dz):
        dx = dz * self.y
        dy = dz * self.x
        return dx, dy


mul_node = MulNode()
out = mul_node.forward(300, 4)
dt, dz = mul_node.backward(1)

print(out)
print(dt, dz)
===============================

forward : 1200
backward : dt - 4, dz - 300

forward함수의 경우는 input을 받아 해당 변수에 할당 및 곱의 결과를 리턴 해줍니다. 

 

Backward함수의 경우는 미분값이 input으로 들어오면 양쪽으로 input값 * 상대편 변수의 값을 곱해서 전달 및 출력해 줍니다.

 

[Graph Model 구성]

 

# input
x = 100
y = 200
z = 4

# node 구성
add_node = AddNode()
mul_node = MulNode()

# 순전파(forward propagation)
t = add_node.forward(x, y)
out = mul_node.forward(t, z)

#역전파(backward propagation)
dout = 1
dt, dz = mul_node.backward(dout)
dx, dy = add_node.backward(dt)

 

[Test 및 결과 확인]

#결과확인
print(t, out)
print(dx, dy, dz)
====================

forward : t - 300, out - 1200
backward : dx - 4, dy - 4, dz - 300

 

0. 마치며


역전파를 한번에 이해하기는 쉽지가 않습니다. 여러가지 예시를 통해서 동작하는 로직을 생각하고, 판단하고를 반복해야 조금씩 이해가 가지 않을까요?? 적어도 저는 그랬습니다. 

 

그래서 다음 글에는 몇가지 Graph를 예시로 더 들어서 역전파의 로직을 좀더 예로 확인해 보는 시간을 갖도록 하겠습니다.

 

- Ayotera Lab -

댓글