如何持久化scikit-learn中训练好的模型

2014-11-25

scikit-learn是一个很棒的基于python的机器学习库,易学易用。本文说一说如何持久化其训练好的模型,用高斯贝叶斯分类iris数据集作为示例。关于高斯贝叶斯,可以参考Gaussian Naive Bayes

关于iris数据集

>>> from sklearn import datasets
>>> iris = datasets.load_iris()
>>> iris.data.shape
(150, 4)
>>> iris.data[:5,:]
array([[ 5.1,  3.5,  1.4,  0.2],
       [ 4.9,  3. ,  1.4,  0.2],
       [ 4.7,  3.2,  1.3,  0.2],
       [ 4.6,  3.1,  1.5,  0.2],
       [ 5. ,  3.6,  1.4,  0.2]])
>>> iris.target
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

iris数据集中共有150个样本,每个样本有四个特征。这150个样本被分0、1、2这三类。iris.target存储了每个样本对应的类别。

训练高斯贝叶斯

用iris的第1到第148个样本(样本标号从0开始),共148个样本来训练高斯贝叶斯GaussianNB。

>>> from sklearn.naive_bayes import GaussianNB
>>> clf = GaussianNB()
>>> clf.fit(iris.data[1:149, :], iris.target[1:149])
>>> iris.data[1:149, :].shape
(148, 4) 
>>> clf.fit(iris.data[1:149, :], iris.target[1:149])
GaussianNB()

对第0个样本和第149个样本分类:

>>> clf.predict(iris.data[0, :])
array([0])
>>> clf.predict_proba(iris.data[0, :])
array([[  1.00000000e+00,   1.47235201e-18,   7.67420665e-26]])
>>> clf.predict(iris.data[149, :])
array([2])
>>> clf.predict_proba(iris.data[149, :])
array([[  3.70074861e-142,   6.07049597e-002,   9.39295040e-001]])

clf.predict_probaclf.predict的参数可以是一个样本或者多个样本。clf.predict_proba输出样本属于各个类别的可能性,这些可能性之和为1。clf.predict输出样本最有可能的类别。

第0个样本被分到类0,第149个样本被分到类2,正确。

使用pickle持久化训练好的模型

>>> pickle.dump(clf, open('bayes.pk', 'wb'))
>>> clf2 = pickle.load(open('bayes.pk', 'rb'))
>>> clf2.predict(iris.data[0, :])
array([0])
>>> clf2.predict_proba(iris.data[0, :])
array([[  1.00000000e+00,   1.47235201e-18,   7.67420665e-26]])
>>> clf2.predict(iris.data[149, :])
array([2])
>>> clf2.predict_proba(iris.data[149, :])
array([[  3.70074861e-142,   6.07049597e-002,   9.39295040e-001]])

使用joblib持久化训练好的模型

joblib是scikit-learn自带的工具,为numpy array做了优化。

>>> from sklearn.externals import joblib
>>> joblib.dump(clf, 'bayes.pkl')
['bayes.pkl', 'bayes.pkl_01.npy', 'bayes.pkl_02.npy', 'bayes.pkl_03.npy', 'bayes.pkl_04.npy', 'bayes.pkl_05.npy']
>>> clf3 = joblib.load('bayes.pkl')
>>> clf3.predict(iris.data[0, :])
array([0])
>>> clf3.predict(iris.data[149, :])
array([2])

joblib.dump方法在当前目录产生了多个文件,*.npy文件用于保存numpy array 类型的变量。

参考

Model persistence

(完)

( 完 )