fashion mnist的一个baseline (MobileNet 95%)
By 苏剑林 | 2017-08-27 | 81098位读者 |浅尝 #
昨天简单试了一下在fashion mnist的gan模型,发现还能work,当然那个尝试也没什么技术水平,就是把原来的脚本改一下路径跑了就完事。今天回到fashion mnist本身的主要任务——10分类,用Keras测了一下一些模型在上面的分类效果,最后得到了94.5%左右的准确率,加上随机翻转的数据扩增能做到95%。
首先随便手写了一些模型的组合,测试发现准确率都不大好,看来对于这个数据集来说,自己构思模型是比较困难的了,于是想着用现成的模型结构。一说到现成的cnn模型,基本上我们都会想到VGG、ResNet、inception、Xception等,但这些模型为解决imagenet的1000分类问题而设计,用到这个入门级别的数据集上似乎过于庞大了,而且也容易过拟合。后来突然想起,Keras好像自带了个叫MobileNet的模型,查看了一下模型权重,发现参数量不大,但是容量应该还是可以的,故选用MobileNet做实验。
深究 #
对于MobileNet就不多做介绍了,网上有不少文章讲解。简单来讲,它跟Xception的思想是一样的,在把大部分的卷积换成了depthwise卷积,而这个depthwise卷积有点类似矩阵的SVD分解,它把本来很大的卷积核矩阵分解为两个小矩阵,最后参数变少了,效果还更好了。更新的类似的工作还有ShuffleNet,不过目前还没有Keras版,作罢~
实验很简单,加载MobileNet模型,默认加载imagenet的预训练权重(不见得imagenet的权重会对这个数据集有啥帮助,但这确实有助于加快收敛并提高精度,看来不少视觉特征都是通用型的),然后接一个10分类器进行分类,放开所有权重进行训练。要注意的是:
1、MobileNet的原始设计是224*224的输入,而fashion mnist的图像只有28*28,差别比较大。虽然说直接输入也不会报错,这里还是把图像放大了两倍,变成了56*56,以免丢失细节信息,当然可以放得更大,但效果没有明显提升,浪费计算量;
2、MobileNet必须要三通道图像输入,为了迎合这一点,把图像复制三次即可。
整个代码如下:
import numpy as np
import mnist_reader
from tqdm import tqdm
from scipy import misc
import tensorflow as tf
np.random.seed(2017)
tf.set_random_seed(2017)
X_train, y_train = mnist_reader.load_mnist('../data/fashion', kind='train')
X_test, y_test = mnist_reader.load_mnist('../data/fashion', kind='t10k')
height,width = 56,56
from keras.applications.mobilenet import MobileNet
from keras.layers import Input,Dense,Dropout,Lambda
from keras.models import Model
from keras import backend as K
input_image = Input(shape=(height,width))
input_image_ = Lambda(lambda x: K.repeat_elements(K.expand_dims(x,3),3,3))(input_image)
base_model = MobileNet(input_tensor=input_image_, include_top=False, pooling='avg')
output = Dropout(0.5)(base_model.output)
predict = Dense(10, activation='softmax')(output)
model = Model(inputs=input_image, outputs=predict)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.summary()
X_train = X_train.reshape((-1,28,28))
X_train = np.array([misc.imresize(x, (height,width)).astype(float) for x in tqdm(iter(X_train))])/255.
X_test = X_test.reshape((-1,28,28))
X_test = np.array([misc.imresize(x, (height,width)).astype(float) for x in tqdm(iter(X_test))])/255.
model.fit(X_train, y_train, batch_size=64, epochs=50, validation_data=(X_test, y_test))
代码很简单很清晰,就不注释了~
经过多次测试,基本上在20个epoch内能达到94.5%以上的准确率(尽管我们设置了random seed,但由于cudnn的存在,依然不能保证重复运行的结果都一样)。后面的epoch不稳定,有过拟合嫌疑。
细抠 #
感觉不做数据扩增的前提下,94.5%以上的准确率应该是可以满意了。再测了一下做数据扩增的,但想了一下,对于这个数据集,似乎没有什么适合的数据扩增手段,唯一想到的就是随机左右翻转了,加上去,看效果:
import numpy as np
import mnist_reader
from tqdm import tqdm
from scipy import misc
import tensorflow as tf
np.random.seed(2017)
tf.set_random_seed(2017)
X_train, y_train = mnist_reader.load_mnist('../data/fashion', kind='train')
X_test, y_test = mnist_reader.load_mnist('../data/fashion', kind='t10k')
height,width = 56,56
from keras.applications.mobilenet import MobileNet
from keras.layers import Input,Dense,Dropout,Lambda
from keras.models import Model
from keras import backend as K
input_image = Input(shape=(height,width))
input_image_ = Lambda(lambda x: K.repeat_elements(K.expand_dims(x,3),3,3))(input_image)
base_model = MobileNet(input_tensor=input_image_, include_top=False, pooling='avg')
output = Dropout(0.5)(base_model.output)
predict = Dense(10, activation='softmax')(output)
model = Model(inputs=input_image, outputs=predict)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.summary()
X_train = X_train.reshape((-1,28,28))
X_train = np.array([misc.imresize(x, (height,width)).astype(float) for x in tqdm(iter(X_train))])/255.
X_test = X_test.reshape((-1,28,28))
X_test = np.array([misc.imresize(x, (height,width)).astype(float) for x in tqdm(iter(X_test))])/255.
def random_reverse(x):
if np.random.random() > 0.5:
return x[:,::-1]
else:
return x
def data_generator(X,Y,batch_size=100):
while True:
idxs = np.random.permutation(len(X))
X = X[idxs]
Y = Y[idxs]
p,q = [],[]
for i in range(len(X)):
p.append(random_reverse(X[i]))
q.append(Y[i])
if len(p) == batch_size:
yield np.array(p),np.array(q)
p,q = [],[]
if p:
yield np.array(p),np.array(q)
p,q = [],[]
model.fit_generator(data_generator(X_train,y_train), steps_per_epoch=600, epochs=50, validation_data=data_generator(X_test,y_test), validation_steps=100)
果然,数据扩增还是有一定帮助的,我跑了两次,一次有95.04%,一次是94.91%,也就是说50个epoch内可以达到了95%左右的准确率。要注意不是所有数据扩增都能提升,我尝试多加了random mask,发现效果还下降了。所以说数据扩增必须适应数据集,尤其是适应测试集,说白了,我觉得虽然数据扩增是对训练集做的,但其本质就是引入测试集的先验知识。
道远 #
看来fashion mnist还真的挺有难度,不像mnist随便写单个Dense层就可以得到90%多的准确率,因此用它来作为CNN算法基准确实有相当的代表性了。mnist的测试准确率普遍可以达到99%以上了,而就目前我所能浏览到的数据来看,在fashion mnist上最多也只有96%左右(还是没公开源代码的),离99%还有很大距离,看来哪怕是对于这个数据集,还是任重而道远呀。
不知道哪个模型能率先到达99%的准确率呢~
转载到请包括本文地址:https://kexue.fm/archives/4556
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Aug. 27, 2017). 《fashion mnist的一个baseline (MobileNet 95%) 》[Blog post]. Retrieved from https://kexue.fm/archives/4556
@online{kexuefm-4556,
title={fashion mnist的一个baseline (MobileNet 95%)},
author={苏剑林},
year={2017},
month={Aug},
url={\url{https://kexue.fm/archives/4556}},
}
December 7th, 2017
博主有没有试过直接原图送入?
December 13th, 2017
您好,我的环境是keras'2.0.8',tensorflow'1.3.0',会报这样的错误If imagenet weights are being loaded, input must have a static square shape (one of (128,128), (160,160), (192,192), or (224, 224)),需要固定图片的shape。
解决方法是,把图片的shape改为错误提示里面的shape。然后在MobileNet()里面指定input_shape为以上大小。希望可以对大家有帮助
December 15th, 2017
请问一次能po具体的解决code 被error弄一头雾水 谢谢你
December 15th, 2017
我这边还遇到了一个问题,就是我在kaggle load的dataset 我只是reshape X——train(60000,28,28,1). 这样MobileNet就不让我input因为我要input一个three dimension的 请问有什么很好的解决办法吗? 谢谢!
March 15th, 2018
请问您看到的96%准确度的数据在哪里?谢谢
https://github.com/zalandoresearch/fashion-mnist/
May 13th, 2018
[...]
December 12th, 2018
98.7%的话有人跑过吗?
没关注了,自己留意https://github.com/zalandoresearch/fashion-mnist/ 就好~
September 6th, 2019
开始的时候把28*28的图像resize成56*56的,为什么不resize为32*32呢?因为输入大图的结果会更好么?
原则上更大会更好些,因为MobileNet的原始输入是224x224的。
October 23rd, 2019
[...]| Classifier | Preprocessing | Fashion test accuracy | MNIST test accuracy | Submitter| Code || — | — | — | — | — |— ||2 Conv+pooling | None | 0.876 | – | Kas[...]