from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
# 数据
X, y = load_iris(return_X_y=True)
# 切训练 / 测试
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# Pipeline
pipe = Pipeline([
('scaler', StandardScaler()),
('model', LogisticRegression(max_iter=200))
])
# 训练
pipe.fit(X_train, y_train) # X_train → scaler.fit_transform → model.fit
# 预测
print(pipe.predict(X_test)) # X_test → scaler.transform → model.predict
print(pipe.score(X_test, y_test))
from sklearn.feature_selection import SelectKBest, f_classif
pipe = Pipeline([
('scaler', StandardScaler()),
('select', SelectKBest(f_classif, k=2)),
('model', LogisticRegression())
])
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder
preprocess = ColumnTransformer([
('num', StandardScaler(), ['age', 'income']),
('cat', OneHotEncoder(), ['gender'])
])
pipe = Pipeline([
('preprocess', preprocess),
('model', LogisticRegression())
])
from sklearn.pipeline import make_pipeline
# 标准语法
pipe_long = Pipeline([("scaler", MinMaxScaler()), ("svm", SVC(C=100))])
# 缩写语法
pipe_short = make_pipeline(MinMaxScaler(), SVC(C=100))
pipe = make_pipeline(StandardScaler(), LogisticRegression())
param_grid = {'logisticregression__C': [0.01, 0.1, 1, 10, 100]}
X_train, X_test, y_train, y_test = train_test_split(
cancer.data, cancer.target, random_state=4)
grid = GridSearchCV(pipe, param_grid, cv=5)
grid.fit(X_train, y_train)
print("Best estimator:\n{}".format(grid.best_estimator_))
Best estimator:
Pipeline(steps=[
('standardscaler', StandardScaler(copy=True, with_mean=True, with_std=True)),
('logisticregression', LogisticRegression(C=0.1, class_weight=None,
dual=False, fit_intercept=True, intercept_scaling=1, max_iter=100,
multi_class='ovr', n_jobs=1, penalty='l2', random_state=None,
solver='liblinear', tol=0.0001, verbose=0, warm_start=False))])