TensorFlow Tutorial

PrettyTensor

之前的轮子用prettytensor再造一遍

mark

前戏

引入各种包和MNIST

1
2
3
4
5
6
7
8
9
%matplotlib inline
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from sklearn.metrics import confusion_matrix
import time
from datetime import timedelta
import math
import prettytensor as pt
1
2
3
from tensorflow.examples.tutorials.mnist import input_data
data = input_data.read_data_sets('data/MNIST/', one_hot=True)
data.test.cls = np.argmax(data.test.labels, axis=1)

定义图片相关大小

1
2
3
4
5
img_size = 28
img_size_flat = img_size * img_size
img_shape = (img_size, img_size)
num_channels = 1
num_classes = 10
1
2
3
4
5
6
7
8
9
10
11
12
13
14
def plot_images(images, cls_true, cls_pred=None):
assert len(images) == len(cls_true) == 9
fig, axes = plt.subplots(3, 3)
fig.subplots_adjust(hspace=0.3, wspace=0.3)
for i, ax in enumerate(axes.flat):
ax.imshow(images[i].reshape(img_shape), cmap='binary')
if cls_pred is None:
xlabel = "True: {0}".format(cls_true[i])
else:
xlabel = "True: {0}, pred: {1}".format(cls_true[i], cls_pred[i])
ax.set_xlabel(xlabel)
ax.set_xticks([])
ax.set_yticks([])
plt.show()
1
2
3
images = data.test.images[0:9]
cls_true = data.test.cls[0:9]
plot_images(images=images, cls_true=cls_true)

mark

PrettyTensor 运算

定义变量

placeholder是原始数据,训练的时候不变。

Variable 是训练的时候的哪些变量。

1
2
3
4
x = tf.placeholder(tf.float32, shape=[None, img_size_flat], name='x')
x_image = tf.reshape(x, [-1, img_size, img_size, num_channels])
y_true = tf.placeholder(tf.float32, shape=[None, num_classes], name='y_true')
y_true_cls = tf.argmax(y_true, axis=1)

开始骚操作

pt.defaults_scope(activation_fn=tf.nn.relu) 的意思:activation_fn=tf.nn.reluwith这个block里每一层的参数。

1
2
3
4
5
6
7
8
9
10
x_pretty = pt.wrap(x_image)
with pt.defaults_scope(activation_fn=tf.nn.relu):
y_pred, loss = x_pretty.\
conv2d(kernel=5, depth=16, name='layer_conv1').\
max_pool(kernel=2, stride=2).\
conv2d(kernel=5, depth=36, name='layer_conv2').\
max_pool(kernel=2, stride=2).\
flatten().\
fully_connected(size=128, name='layer_fc1').\
softmax_classifier(num_classes=num_classes, labels=y_true)
1
2
3
4
def get_weights_variable(layer_name):
with tf.variable_scope(layer_name, reuse=True):
variable = tf.get_variable('weights')
return variable
1
2
3
4
weights_conv1 = get_weights_variable(layer_name='layer_conv1')
weights_conv2 = get_weights_variable(layer_name='layer_conv2')
print(weights_conv1)
print(weights_conv2)
<tf.Variable 'layer_conv1/weights:0' shape=(5, 5, 1, 16) dtype=float32_ref>
<tf.Variable 'layer_conv2/weights:0' shape=(5, 5, 16, 36) dtype=float32_ref>
1
2
3
4
optimizer = tf.train.AdamOptimizer(learning_rate=1e-5).minimize(loss)
y_pred_cls = tf.argmax(y_pred, axis=1)
correct_prediction = tf.equal(y_pred_cls, y_true_cls)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

跑两步

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
session = tf.Session()
session.run(tf.global_variables_initializer())
train_batch_size = 64
total_iterations = 0
def optimize(num_iterations):
global total_iterations
start_time = time.time()
for i in range(total_iterations, total_iterations + num_iterations):
x_batch, y_true_batch = data.train.next_batch(train_batch_size)
feed_dict_train = {x: x_batch,
y_true: y_true_batch}
session.run(optimizer, feed_dict=feed_dict_train)
if i % 100 == 0:
acc = session.run(accuracy, feed_dict = feed_dict_train)
msg = "Optimization Iteration: {0:>6},\
Training Accuracy: {1:>6.1%}"
print(msg.format(i + 1, acc))
total_iterations += num_iterations
end_time = time.time()
time_dif = end_time - start_time
print("Time usage: " + str(timedelta(seconds=int(round(time_dif)))))
1
2
3
4
5
6
7
8
def plot_example_errors(cls_pred, correct):
incorrect = (correct == False)
images = data.test.images[incorrect]
cls_pred = cls_pred[incorrect]
cls_true = data.test.cls[incorrect]
plot_images(images=images[0:9],
cls_true=cls_true[0:9],
cls_pred=cls_pred[0:9])
1
2
3
4
5
6
7
8
9
10
11
12
13
def plot_confusion_matrix(cls_pred):
cls_true = data.test.cls
cm = confusion_matrix(y_true=cls_true,
y_pred=cls_pred)
print(cm)
plt.matshow(cm)
plt.colorbar()
tick_marks = np.arange(num_classes)
plt.xticks(tick_marks, range(num_classes))
plt.yticks(tick_marks, range(num_classes))
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# Split the test-set into smaller batches of this size.
test_batch_size = 256
def print_test_accuracy(show_example_errors=False,
show_confusion_matrix=False):
num_test = len(data.test.images)
cls_pred = np.zeros(shape=num_test, dtype=np.int)
i = 0
while i < num_test:
j = min(i + test_batch_size, num_test)
images = data.test.images[i:j, :]
labels = data.test.labels[i:j, :]
feed_dict = {x: images,
y_true: labels}
cls_pred[i:j] = session.run(y_pred_cls, feed_dict=feed_dict)
i = j
cls_true = data.test.cls
correct = (cls_true == cls_pred)
correct_sum = correct.sum()
acc = float(correct_sum) / num_test
msg = "Accuracy on Test-Set: {0:.1%} ({1} / {2})"
print(msg.format(acc, correct_sum, num_test))
if show_example_errors:
print("Example errors:")
plot_example_errors(cls_pred=cls_pred, correct=correct)
if show_confusion_matrix:
print("Confusion Matrix:")
plot_confusion_matrix(cls_pred=cls_pred)
print_test_accuracy()
Accuracy on Test-Set: 15.1% (1511 / 10000)
  • 跑一步
1
2
optimize(num_iterations=1)
print_test_accuracy()
Optimization Iteration:      1,                    Training Accuracy:  28.1%
Time usage: 0:00:00
Accuracy on Test-Set: 15.5% (1549 / 10000)
  • 跑100步
1
2
optimize(num_iterations=99)
print_test_accuracy()
Time usage: 0:00:05
Accuracy on Test-Set: 47.6% (4756 / 10000)
  • 跑 1000步
1
2
optimize(num_iterations=900)
print_test_accuracy(show_example_errors=True)
Optimization Iteration:    101,                    Training Accuracy:  40.6%
Optimization Iteration:    201,                    Training Accuracy:  60.9%
Optimization Iteration:    301,                    Training Accuracy:  71.9%
Optimization Iteration:    401,                    Training Accuracy:  71.9%
Optimization Iteration:    501,                    Training Accuracy:  79.7%
Optimization Iteration:    601,                    Training Accuracy:  84.4%
Optimization Iteration:    701,                    Training Accuracy:  84.4%
Optimization Iteration:    801,                    Training Accuracy:  81.2%
Optimization Iteration:    901,                    Training Accuracy:  82.8%
Time usage: 0:00:37
Accuracy on Test-Set: 88.7% (8871 / 10000)
Example errors:

mark

  • 跑一万步
1
2
3
optimize(num_iterations=9000) # We performed 1000 iterations above.
print_test_accuracy(show_example_errors=True,
show_confusion_matrix=True)
Optimization Iteration:   1001,                    Training Accuracy:  92.2%
Optimization Iteration:   6201,                    Training Accuracy: 100.0%
Optimization Iteration:   6301,                    Training Accuracy:  96.9%
Optimization Iteration:   6401,                    Training Accuracy:  96.9%
Optimization Iteration:   6501,                    Training Accuracy:  96.9%
Optimization Iteration:   9801,                    Training Accuracy: 100.0%
Optimization Iteration:   9901,                    Training Accuracy:  96.9%
Time usage: 0:06:18
Accuracy on Test-Set: 96.9% (9691 / 10000)
Example errors:

mark

Confusion Matrix:
[[ 970    0    1    1    0    0    2    2    4    0]
 [   0 1117    2    1    1    0    3    0   11    0]
 [   5    2  993    6    2    0    2    7   15    0]
 [   2    1    5  977    0    8    0    8    8    1]
 [   1    1    3    0  957    0    3    1    2   14]
 [   3    1    0   11    0  866    4    2    4    1]
 [   7    3    3    1    5    5  931    0    3    0]
 [   0    7   17    3    1    1    0  992    2    5]
 [   5    0    2    8    5    1    3    4  943    3]
 [   9    8    1   10   11    3    0   14    8  945]]

mark

卷积层分析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def plot_conv_weights(weights, input_channel=0):
w = session.run(weights)
w_min = np.min(w)
w_max = np.max(w)
num_filters = w.shape[3]
num_grids = math.ceil(math.sqrt(num_filters))
fig, axes = plt.subplots(num_grids, num_grids)
for i, ax in enumerate(axes.flat):
if i<num_filters:
img = w[:, :, input_channel, i]
ax.imshow(img, vmin=w_min, vmax=w_max,
interpolation='nearest', cmap='seismic')
ax.set_xticks([])
ax.set_yticks([])
plt.show()
1
plot_conv_weights(weights=weights_conv1)

mark

1
plot_conv_weights(weights=weights_conv2, input_channel=0)

mark

1
plot_conv_weights(weights=weights_conv2, input_channel=1)

mark

1
session.close()
作者

mmmwhy

发布于

2018-06-19

更新于

2021-05-31

许可协议

评论