tf.train API


머신러닝의 자세한 이론적 내용은 이 튜토리얼의 범위를 벗어납니다. 그러나 TensorFlow는 loss function을 최소화하기 위해 각 variable을 천천히 변경하는 optimizer를 제공합니다. 가장 간단한 optimizer는 gradient desent 방법입니다. 해당 variable에 대한 loss의 derivative의 magnitude에 따라 각 variable을 수정합니다. 일반적으로 수동으로 symbolic derivative를 계산하는 것은 지루하고 오류가 발생하기 쉽습니다. 결과적으로, TensorFlow는 tf.gradients 함수를 사용하여 model의 description만 제공된 derivative를 자동으로 생성할 수 있습니다. 단순화를 위해 일반적으로 optimizer를 수행합니다. 

optimizer = tf.train.GradientDescentOptimizer(0.01)
train
= optimizer.minimize(loss)
sess.run(init) # reset values to incorrect defaults.
for i in range(1000):
  sess
.run(train, {x:[1,2,3,4], y:[0,-1,-2,-3]})

print(sess.run([W, b]))

최종 model 매개변수가 생성됩니다.

[array([-0.9999969], dtype=float32), array([ 0.99999082], dtype=float32)]

이제 실제 머신러닝을 해봤습니다. 이런 간단한 linear regression을 수행한다면 TensorFlow 핵심 코드가 많이 필요하지는 않지만, 좀 더 복잡한 model과  data를 입력하는 method는 더 많은 코드가 필요합니다. 따라서 TensorFlow는 일반적인 패턴, 구조 및 기능에 대해 더 높은 level의 추상화를 제공합니다. 우리는 다음 section에서 이러한 추상화를 사용하는 방법을 배웁니다.


complete program


완성된 training 가능한 linear regression model은 다음과 같습니다.

import numpy as np
import tensorflow as tf

# Model parameters
W
= tf.Variable([.3], tf.float32)
b
= tf.Variable([-.3], tf.float32)
# Model input and output
x
= tf.placeholder(tf.float32)
linear_model
= W * x + b
y
= tf.placeholder(tf.float32)
# loss
loss
= tf.reduce_sum(tf.square(linear_model - y)) # sum of the squares
# optimizer
optimizer
= tf.train.GradientDescentOptimizer(0.01)
train
= optimizer.minimize(loss)
# training data
x_train
= [1,2,3,4]
y_train
= [0,-1,-2,-3]
# training loop
init
= tf.global_variables_initializer()
sess
= tf.Session()
sess
.run(init) # reset values to wrong
for i in range(1000):
  sess
.run(train, {x:x_train, y:y_train})

# evaluate training accuracy
curr_W
, curr_b, curr_loss  = sess.run([W, b, loss], {x:x_train, y:y_train})
print("W: %s b: %s loss: %s"%(curr_W, curr_b, curr_loss))

실행하면 다음과 같은 결과를 보여줍니다.

W: [-0.9999969] b: [ 0.99999082] loss: 5.69997e-11

좀 더 복잡한 이 프로그램은 여전히 TensorBorad에서 시각화 할 수 있습니다.

TensorBoard final model visualization

출처





'머신러닝 > TensorFlow' 카테고리의 다른 글

tf.contrib.learn  (0) 2017.03.08
TensorFlow Core tutorial  (0) 2017.02.28
설치하기  (2) 2017.02.21
TensorFlow란?  (0) 2017.02.21

+ Recent posts