colab.research.google.com/drive/1vcBfl192NDGJ2_c2u01tAa65NGSJmrw8?usp=sharing
์๋ฌธ์ ๋๋ค. ์ฝ๋ฉ์ผ๋ก ๋ณด๋๊ฑฐ ์ถ์ฒ.
(์ด์ ๊ธ)
2021/01/28 - [Hi/AI] - [TensorFlow] ์ซ์ ์ธ์ AI ์ ์ (1)
<์ธ๊ณต์ง๋ฅ ๋ชจ๋ธ ์ค๊ณ>
# ๋ชจ๋ธ์ค๊ณ
4๊ฐ์ ์ธต์ผ๋ก ๋ง๋ค๊ฑฐ์. ์ ๋ ฅ์ธต, ์๋์ธต, ์๋์ธต, ์ถ๋ ฅ์ธต
์ ๋ ฅ์ธต ๋ด๋ฐ์ ์๋ 28 * 28์ด๋ 784์. 28 * 28๊ฐ์ ํฝ์ ๋ก ์ด๋ฃจ์ด์ง ์ซ์๋ฅผ ํ์ค๋ก ๋ฐ๊พผ๊ฑฐ์.
์ฒซ๋ฒ์งธ ์๋์ธต์ ๋ ธ๋ ์๋ 512๊ฐ, ๋๋ฒ์งธ ์๋์ธต์ ๋ ธ๋ ์๋ 256๊ฐ, ์ธ๋ฒ์งธ ๊ฒฐ๊ณผ์ธต์ ๋ ธ๋์๋ 10๊ฐ๋ก ์ค์ ํจ.
(๋ง์ง๋ง์ 0~9์ด๋ 10๊ฐ)
ํ์ฑํ ํจ์๋ ๋ ๋ฃจ(ReLU)ํจ์ ์ฌ์ฉํ๊ณ , ๋ง์ง๋ง์ ์ํํธ๋งฅ์ค(softmax)ํจ์๋ฅผ ์ฌ์ฉํ ๊ฑฐ์.
(ReLU - 0๋ณด๋ค ์์๊ฐ์ด ์ ๋ ฅ๋๋ฉด 0์ ๋ฐํ, 0๋ณด๋ค ํฐ ๊ฐ์ด ์ ๋ ฅ๋๋ฉด ๊ทธ ๊ฐ์ ๊ทธ๋๋ก ๋ฐํ.
softmax - 0~1 ์ฌ์ด์ ๊ฐ์ผ๋ก ๋ชจ๋ ์ ๊ทํ ์ํด. ๋ชจ๋ ์ถ๋ ฅ๊ฐ์ ํฉ์ด 1์ด ๋๊ฒ ๋ง๋๋ ํ์ฑํ ํจ์์.
๋ถ๋ฅ๋ฌธ์ ์์ ์ด๋ค ๋ฒ์ฃผ๋ฅผ ๊ฐ์ฅ ๋์ ํ๋ฅ ๋ก ์์ธกํ๋์ง์ ๋ํด ์ฃผ๋ก ์ฌ์ฉ๋.)
์ผ๋ผ์ค๋ ์ํ์ ๋ชจ๋ธ์ ํตํด ์ฝ๊ฒ ๊ฐ๋ฐํ ์๋ก ๋์์ค.
model = Sequential() #๋ชจ๋ธ์ ์ํ์
๋ฐฉ์
model.add(Dense(512, input_shape = (784, ))) #์ฒซ๋ฒ์งธ ์ธ์๋ ํด๋น ์๋์ธต์ ๋
ธ๋ ์, ๋๋ฒ์งธ ์ธ์๋ ์
๋ ฅํ๋ ๋ฐ์ดํฐ์ ํํ.
model.add(Activation('relu'))
model.add(Dense(256)) #๋๋ฒ์งธ๋ถํฐ๋ ์
๋ ฅ๋ฐ๋ ๋
ธ๋ ์ค์ ์ํด๋๋. ์ผ๋ผ์ค๋ฅผ ์ฌ์ฉํ๋ ์ด์ ์.
model.add(Activation('relu'))
model.add(Dense(10))
model.add(Activation('softmax'))
model.summary() # ๋ชจ๋ธ ํ์ธ
Layer ์ ๋ ์ด์ด๋ฅผ ํ์ํด์ฃผ๋ ๊ฑฐ๊ณ ,
Output Shape๋ ๋ ์ด์ด์ ๋ชจ์ต์ ๋ํ๋ด์ฃผ๋ ๊ฑฐ๊ณ ,
Param์ ๊ฐ ๋ ธ๋์ ํธํฅ์ ์ฐ๊ฒฐํ๋ ๊ฐ์ค์น์ ์๋ฅผ ๋ํ๋ด๋ ๊ฒ์.
์ฒซ๋ฒ์งธ ๋ ์ด์ด๋ 512๊ฐ์ ๋ ธ๋๋ก ์ด๋ฃจ์ด์ง.
784๊ฐ์ ์ ๋ ฅ์ธต ์์ 512๊ฐ์ ์๋์ธต์ผ๋ก ๊ฐ๊ฐ ์ฐ๊ฒฐ๋์ด ์๊ธฐ์ ๊ฐ์ค์น๋ 784 * 512 ์ด๊ณ , ์๋์ธต์ ๊ฐ ๋ ธ๋ ์๋งํผ ํธํฅ์ด ์๊ธฐ์ ํธํฅ์ 512์.
๋ฐ๋ผ์ 784 * 512 + 512 ์ด๋ฏ๋ก 401920๊ฐ์ ํ๋ผ๋ฏธํฐ๋ก ์ด๋ฃจ์ด์ ธ ์์.
๋ค์ ๋ ์ด์ด๋ ๋ง์ฐฌ๊ฐ์ง๋ก 512 * 256 + 256 = 131328
๋ง์ง๋ง ๋ ์ด์ด๋ 256 * 10 + 10 ์ด๋ฏ๋ก 2570 ์ด ๋์ด.
# ๋ชจ๋ธ ํ์ต
์ด์ ๋ชจ๋ธ์ ์ค๊ณํ์ผ๋ ๋ชจ๋ธ์ ์คํํด ๋ณด์์ผํจ.
์ฌ์ธต ์ ๊ฒฝ๋ง์ ๋ฐ์ดํฐ๋ฅผ ํ๋ ค๋ณด๋ด ์ ๋ต์ ์์ธกํ ์ ์๊ฒ ์ ๊ฒฝ๋ง์ ํ์ตํ๋ ๊ณผ์ ์ด ํ์ํจ. (๋ฅ๋ฌ๋)
์ ๊ฒฝ๋ง์ด ์์ธกํ ๊ฒฐ๊ณผ์ ์ค์ ์ ๋ต์ ๋น๊ต, ์ค์ฐจ์์ผ๋ฉด ์ ๊ฒฝ๋ง์ ๋ค์ ํ์ต.
์ฌ๋งํ๋ฉด ์ค์ฐจ๊ฐ 0์ด ๋์ค๋ ์ผ์ด ๊ฑฐ์ ์์. ๊ทธ๋์ ๋ณดํต ํ์ต์ํค๋ ํ์๋ฅผ ์ ํด์ค์ ๊ทธ๋งํผ๋ง ํ์ต์ํด.
์ ๊ฒฝ๋ง์ ์ ํ์ต์ํค๋ ค๋ฉด ์ ๊ฒฝ๋ง์ด ๋ถ๋ฅํ ๊ฐ๊ณผ ์ค์ ๊ฐ์ ์ค์ฐจ๋ถํฐ ๊ณ์ฐํด์ผํจ.
์ค์ฐจ์ค์ด๋ ๋ฐฉ๋ฒ - ๊ฒฝ์ฌ ํ๊ฐ๋ฒ.
model.compile(loss = 'categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(X_Train, Y_train, batch_size=128, epochs=10, verbose=1)
loss = ์ค์ฐจ๊ฐ
accuracy = ์ ํ๋
loss๋ Epoch๊ฐ ์ฆ๊ฐํ ์๋ก ์ค์ด๋ค๊ณ accuracy๋ ์ฆ๊ฐํ๋ ๊ฒ์ ํ์ธํ ์ ์๋ค.
compile ํจ์๋ ์ผ๋ผ์ค์์ ์ ๊ณตํ๋ ํจ์๋ก, ์ฌ์ธต ์ ๊ฒฝ๋ง์ ํ์ตํ๋ ๋ฐฉ๋ฒ์ ์ ํ๋ ๋ช ๋ น์ด์.
1. ์ค์ฐจ๊ฐ์ ๊ณ์ฐํ๋ ๋ฐฉ๋ฒ์ ์๋ ค์ค์ผํจ.
10๊ฐ์ค ํ๋๋ก ๋ถ๋ฅํ๋ ๊ฒ์ด๋ ๋ค์ค ๋ถ๋ฅ ๋ฌธ์ ์ด๋ค. ๋ฐ๋ผ์ categorical_crossentropy ๋ฐฉ๋ฒ์ผ๋ก ์ค์ ํด์ค๋ค.
2. ์ค์ฐจ๋ฅผ ์ค์ด๋ ๋ฐฉ๋ฒ์ ์๋ ค์ค์ผํจ.
์ค์ฐจ๋ฅผ ์ค์ด๊ธฐ ์ํด optimizer์ adam ๋ฐฉ๋ฒ์ ์ฌ์ฉํจ.
๋ฅ๋ฌ๋์ ํตํด ์ธ๊ณต์ง๋ฅ ๋ชจ๋ธ ํ์ต์ํฌ๋, ์ค์ฐจ๋ฅผ ์ค์ด๊ธฐ ์ํด ๊ฒฝ์ฌ ํ๊ฐ๋ฒ ์๊ณ ๋ฆฌ์ฆ์ ์ฌ์ฉํจ. ์ด๋ ๊ฒฝ์ฌ ํ๊ฐ๋ฒ์ ์ด๋ค ๋ฐฉ์์ผ๋ก ์ฌ์ฉํ ์ง ์ฌ๋ฌ ์๊ณ ๋ฆฌ์ฆ ์๋๋ฐ ์ด ์๊ณ ๋ฆฌ์ฆ๋ค์ ์ผ๋ผ์ค์์ ๋ชจ์ ๋์ ๊ฒ์ด ์ตํฐ๋ง์ด์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์.
์ตํฐ๋ง์ด์ ์ข ๋ฅ์๋ adam๋ฟ๋ง ์๋๋ผ ํ๋ฅ ์ ๊ฒฝ์ฌ ํ๊ฐ๋ฒ(SGD) ๋ฑ์ด ์์.
3. ํ์ต ๊ฒฐ๊ณผ๋ฅผ ์ด๋ป๊ฒ ํ์ธํ ์ง ์๋ ค์ค์ผํจ.
์ ํ๋(Accuracy)๋ ์ค์ 6๋ง๊ฐ ๋ฐ์ดํฐ์ ์์ธก ๊ฒฐ๊ณผ์ ์ค์ ๊ฐ์ ๋น๊ตํ๊ณ ์ ๋ต ๋น์จ์ ์๋ ค์ค.
์ด์ ์ฌ์ธต ์ ๊ฒฝ๋ง์ ํ์ตํ๋ ๋ฐฉ๋ฒ์ ์ ํ์ผ๋ fit์ด๋ผ๋ ์ผ๋ผ์ค์ ํ์ต ๋ช ๋ น์ด๋ฅผ ์ฌ์ฉํด ํ์ต์ ์ํด.
1. ์ ๋ ฅํ ๋ฐ์ดํฐ๋ฅผ ์ ํด์ผํจ.
X_Train , Y_train ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํด์ ์ธ๊ณต์ง๋ฅ ๋ชจ๋ธ์ ํ์ตํ๋ ์ด ๋๊ฐ๋ฅผ ๋ฃ์ด์ค.
2. ๋ฐฐ์น์ฌ์ด์ฆ๋ฅผ ์ ํด์ผํจ.
Batch_Size ๋ ์ธ๊ณต์ง๋ฅ ๋ชจ๋ธ์ด ํ๋ฒ์ ํ์ตํ๋ ๋ฐ์ดํฐ ์ ์.
ํ๋ฒ์ 128๊ฐ ๋ฐ์ดํฐ๋ฅผ ํ์ต ์ํจ๊ฑฐ์.
3. ์ํฌํฌ๋ฅผ ์ ํด์ผํจ.
Epochs๋ ๋ชจ๋ ๋ฐ์ดํฐ๋ฅผ ํ๋ฒ ํ์ตํ๋ ๊ฒ์ ์๋ฏธํจ.
๋ชจ๋ ๋ฐ์ดํฐ๋ฅผ 10๋ฒ ๋ฐ๋ณต ์ํฌ๋ ค๊ณ epochs =10, verbose =1 ๋ก ์ค์ ํจ
verbose๋ ์ผ๋ผ์ค fit ํจ์์ ๊ฒฐ๊ด๊ฐ์ ์ถ๋ ฅํ๋ ๋ฐฉ๋ฒ์.
0, 1, 2 ์ค ํ๋๋ก ๊ฒฐ์ ํ ์ ์์.
- verbose = 0 : ์๋ฌด๋ฐ ํ์ X
- verbose = 1 : ์ํฌํฌ๋ณ ์งํ ์ฌํญ ๋ณด์ฌ์ค
- verbose = 3 : ์ํฌํฌ๋ณ ํ์ต ๊ฒฐ๊ณผ ๋ณด์ฌ์ค.
# ๋ชจ๋ธ ์ ํ๋ ํ์ธ
๋ชจ๋ธ์ ์ค๊ณํ๊ณ , ํ์ต์ ํ์ผ๋ฉด ๋ค์์ ์ฑ๋ฅ์ ํ์ธํด ๋ณด์์ผํ๋ค.
์ธ๊ณต์ง๋ฅ ๋ชจ๋ธ์ด ์ผ๋ง๋ ์ ํ์ตํ๋์ง, ๊ฒ์ฆ ๋ฐ์ดํฐ๋ฅผ ์ผ๋ง๋ ์ ๋ง์ถ๋์ง ํ์ธํด๋ดค๋ค.
model.evaluate(X_Test, Y_test)
์ผ๋ผ์ค์ evakuate ํจ์๋ ๋ชจ๋ธ์ ์ ํ๋๋ฅผ ํ๊ฐํ๋ ํจ์์.
์ฒซ๋ฒ์งธ ์ธ์๋ ํ ์คํธํ ๋ฐ์ดํฐ,
๋๋ฒ์งธ ์ธ์๋ ํ ์คํธํ ๋ฐ์ดํฐ์ ์ ๋ต์.
loss(์ค์ฐจ๊ฐ)์ 0 ~ 1 ์ฌ์ด ๊ฐ์ผ๋ก 0์ ๊ฐ๊น์ธ ์๋ก ์ค์ฐจ๊ฐ ์ ์๊ฒ์.
accuracy(์ ํ๋) ๋ํ 0~1 ์ฌ์ด ๊ฐ์ผ๋ก 1์ ๊ฐ๊น์ธ ์๋ก ์ ํํ ๊ฑฐ์.
#๋ชจ๋ธ ํ์ต ๊ฒฐ๊ณผ ํ์ธ
predicted_classes = np.argmax(model.predict(X_Test), axis=1)
correct_indices = np.nonzero(predicted_classes == y_test)[0]
incorrect_indices = np.nonzero(predicted_classes != y_test)[0]
predict ๋ ๊ฒฐ๊ณผ๋ฅผ ์์ธกํ๋ ํจ์์.
X_Test ์ ๋ฐ์ดํฐ ๊ฐ์๊ฐ ๋ง๊ฐ์์ผ๋ ์์ธก๊ฐ๋ ๋ง๊ฐ๋์ด.
numpy์ argmax ํจ์๋ฅผ ์ด์ฉํด์ ์ฌ๋ฌ ๋ฐ์ดํฐ ์ค ๊ฐ์ฅ ํฐ ๊ฐ์ ์์น๋ฅผ ๋ฐํํจ.
argmax๋ฅผ ์ธ ๋๋, ์ด์์ ๊ฐ์ฅ ํฐ ๊ฐ์ ๊ตฌํ ์ง, ํ์์ ๊ฐ์ฅ ํฐ ๊ฐ์ ๊ตฌํ ์ง ์ค์ ํด์ผํจ.
์ด ๋ ๊ธฐ์ค์ ์ ํด์ฃผ๋ ๊ฒ์ด axis ์. axis=0์ด๋ฉด ์ด์์ ๊ฐ์ฅ ํฐ ์๋ฅผ ๊ณ ๋ฅด๊ณ , axis=1์ด๋ฉด ํ์์ ๊ฐ์ฅ ํฐ ์๋ฅผ ๊ณ ๋ฆ.
nonzero๋ ๋ํ์ด ๋ฐฐ์ด์์ 0์ด ์๋ ๊ฐ์ ์ฐพ๋ ๊ฒ์.
plt.figure()
for i in range(9):
plt.subplot(3,3,i+1)
correct = correct_indices[i]
plt.imshow(X_Test[correct].reshape(28,28), cmap='gray')
plt.title("Predicted {}, Class {}".format(predicted_classes[correct], y_test[correct]))
plt.tight_layout()
matplotlib ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ํตํด ๊ทธ๋ํ๋ฅผ ๊ทธ๋ฆผ.
figure()์ ๊ทธ๋ํ๋ฅผ ๊ทธ๋ฆฌ๊ฒ ๋ค๋ ๋ช ๋ น์. ๊ทธ๋ํ๋ฅผ ๊ทธ๋ฆฌ๋ ค๋ฉด ๊ทธ๋ฆด ์ค๋น๋ฅผ ํด์ผํจ. ๊ทธ ์ค๋นํ๋ ๋ช ๋ น์ด figureํจ์์.
subplot์ ๊ทธ๋ฆผ์ ์์๋ฅผ ์ ํด์ฃผ๋ ํจ์์.
์ฒซ๋ฒ์งธ ์ธ์๋ ๊ทธ๋ฆผ์ ๊ฐ๋ก ๊ฐ์,
๋๋ฒ์งธ ์ธ์๋ ๊ทธ๋ฆผ์ ์๋ก ๊ฐ์,
์ธ๋ฒ์งธ ์ธ์๋ ๊ทธ๋ฆผ์ ์์์.
imshow ํจ์๋ ์ด๋ค ์ด๋ฏธ์ง๋ฅผ ๋ณด์ฌ์ค์ง์ ๋ํ ๋ด์ฉ์ ๋ด๊ณ ์์.
ํ์ฌ ๋ฐ์ดํฐ๋ 1์ฐจ์ ๋ฐฐ์ด๋ก ๋์ด์์ผ๋ reshape๋ก 28 * 28 ํํ๋ก ๋ฐ๊พธ์ด์ค.
๊ทธ๋ฆผ์ ํ์์กฐ๋ก ๋ํ๋ด๋ ค๊ณ cmap = 'gray' ๋ก ์ค์ ํจ.
title ํจ์๋ ๊ทธ๋ฆผ์ ๋ํ ์ค๋ช ์ ๋ฃ๋ ํจ์์.
formatํจ์๋ฅผ ์ฌ์ฉํด {} ์์ ๊ฐ์ ๋ฃ์.
ํ๋ฉด์ ๊ทธ๋ฆผ์ ๋ณด์ฌ์ฃผ๋ tight_layout์ ์ฌ์ฉํจ.
์๋๋ ํ๋ฆฐ ๋ฐ์ดํฐ ๋ค์ ์ถ๋ ฅํ๋ ์ฝ๋์.
plt.figure()
for i in range(9):
plt.subplot(3,3,i+1)
incorrect = incorrect_indices[i]
plt.imshow(X_Test[incorrect].reshape(28,28), cmap='gray')
plt.title("Predicted {}, Class {}".format(predicted_classes[incorrect], y_test[incorrect]))
plt.tight_layout()
์ธ๊ณต์ง๋ฅ์ ์ฑ๋ฅ์ด ์ข์ ๋ณด์ด์ง๋ ์์.
์๋ํ๋ฉด ๋ชจ๋ธ ํ์ต์ด ์ ๋์ง ์์์์.
์ธ๊ณต์ง๋ฅ ๋ชจ๋ธ ํ์ต์ด ์ ์ด๋ฃจ์ด์ง๊ธฐ ์ํด์ ๋ชจ๋ธ์ ํ์ต ํ์๋ฅผ ๋๋ ค์ฃผ์ด์ผํจ.
๊ทผ๋ฐ ๋ชจ๋ธ์ ํ์ต ํ์๋ง ๋ฌด์์ ๋๋ฆฐ๋ค๊ณ ์ธ๊ณต์ง๋ฅ ์ฑ๋ฅ์ด ๊พธ์คํ ๋์์ง๋๊ฑด ์๋.
๊ณผ์ ํฉ(overfitting) ๋ฌธ์ ๊ฐ ์ผ์ด๋ ์ ์๊ธฐ ๋๋ฌธ.
์ฌ๊ธฐ์ ์ค๋ฒํผํ ์ ์ธ๊ณต์ง๋ฅ์ด ํ๋ จ ๋ฐ์ดํฐ์๋ง ์ต์ ํ๋์ ํ๋ จ ๋ฐ์ดํฐ๋ง ์ ๊ตฌ๋ถํ๊ณ , ์๋ก์ด ๋ฐ์ดํฐ์ธ ๊ฒ์ฆ ๋ฐ์ดํฐ๋ฅผ ์ธ๊ณต์ง๋ฅ ๋ชจ๋ธ์ ๋ฃ์ผ๋ฉด ๊ตฌ๋ถํ์ง ๋ชปํ๋ ํ์์ด ์ผ์ด๋ ์ ์์. ์ฆ, ์ฑ๋ฅ์ด ๋๋น ์ง๋ ๊ฑฐ์.
๊ทธ๋์ ์ธ๊ณต์ง๋ฅ ๋ชจ๋ธ์ ํ์ต์ํฌ ๋์๋ ๋ฌด์กฐ๊ฑด ๋ง์ด๊ฐ ์๋, ์ผ๋ง๋งํผ ํ์ต ์์ผ์ผ ์ ์ผ ์ข์์ง ๊ฒฐ์ ํ ์ ์์ด์ผํ๊ณ , ์ด๊ฒ ์ธ๊ณต์ง๋ฅ ๋ชจ๋ธ ์ค๊ณ์์ ์ค์ํ ๋ถ๋ถ์.
์ฒซ๋ฒ์งธ ์ธ๊ณต์ง๋ฅ ์ ์ ์๋ฃ