SKLearn Logistic Regression

Logistic Regression
屬於分類演算法,透過regression方式達到判斷類別的效果
http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html#sklearn.linear_model.LogisticRegression


載入模組
from sklearn import linear_model

建立初始model
< model>=linear_model.LogisticRegression(penalty=’l2′, dual=False, tol=0.0001, C=1.0, fit_intercept=True,intercept_scaling=1, class_weight=None, random_state=None, solver=’liblinear’, max_iter=100, multi_class=’ovr’,verbose=0, warm_start=False, n_jobs=1)
常用參數如下
 C : float, optional (default=1.0)
  該參數用於控制overfitting程度,數字越大,Regularization強度越低


讓model學習
< model>.fit(input, output)


根據input預測output class
< model>.predict(input)
ps:
predict_log_proba(X) # Log of probability estimates.
predict_proba(X) # Probability estimates.

 

………………………….

Example in iris

進入python互動介面
# python


載入資料
>>> iris_X = iris.data
>>> iris_Y = iris.target
把資料打亂
>>> np.random.seed(0)
>>> indices = np.random.permutation(len(iris_X))
除了最後10筆外,都給training使用
>>> iris_X_train = iris_X[indices[:-10]]
>>> iris_Y_train = iris_Y[indices[:-10]]
將最後10筆給validation使用
>>> iris_X_test = iris_X[indices[-10:]]
>>> iris_Y_test = iris_Y[indices[-10:]]

跑training建立model
>>> from sklearn import linear_model
>>> logistic = linear_model.LogisticRegression(C=1e5)
>>> logistic.fit(iris_X_train,iris_Y_train)
LogisticRegression(C=100000.0, class_weight=None, dual=False,
fit_intercept=True, intercept_scaling=1, max_iter=100,
multi_class=’ovr’, penalty=’l2′, random_state=None,
solver=’liblinear’, tol=0.0001, verbose=0)


根據model預測資料
predict會選機率最大的類別
>>> logistic.predict(iris_X_test)
array([1, 2, 1, 0, 0, 0, 2, 1, 2, 0])
>>> print(iris_Y_test)
[1 1 1 0 0 0 2 1 2 0]
不過還是有一筆預測錯誤

從predict_proba看出各類別對應的機率
>>> logistic.predict_proba(iris_X_test)
array([[ 7.84463971e-06, 9.99954368e-01, 3.77870449e-05],
[ 7.27268437e-09, 1.47166349e-01, 8.52833644e-01],
[ 2.71292961e-08, 9.99764511e-01, 2.35461414e-04],
[ 9.30834892e-01, 6.91651082e-02, 6.24975648e-25],
[ 9.67856008e-01, 3.21439919e-02, 2.01259183e-22],
[ 7.62068769e-01, 2.37931231e-01, 2.12976894e-23],
[ 3.61723169e-12, 3.29818401e-01, 6.70181598e-01],
[ 3.46987440e-06, 9.99996344e-01, 1.85665512e-07],
[ 1.40350469e-13, 7.03733483e-02, 9.29626652e-01],
[ 8.40655049e-01, 1.59344951e-01, 9.99274296e-23]])

refer
Logistic Regression 3-class Classifier
http://scikit-learn.org/stable/auto_examples/linear_model/plot_iris_logistic.html