Scikit-learn 框架入门
本节定位
Scikit-learn 是 Python 机器学习的事实标准库。几乎所有经典 ML 任务都可以用它来完成。掌握 sklearn 的 API 模式,后续学任何算法都会非常顺畅。
学习目标
- 理解 Scikit-learn 的设计哲学与统一 API
- 掌握 Estimator、Transformer、Pipeline 三大核心概念
- 学会加载和生成数据集
- 完成从训练到预测的完整流程
- 学会保存和加载模型
一、为什么是 Scikit-learn?
1.1 sklearn 在 ML 生态中的位置
| 特点 | 说明 |
|---|---|
| 统一的 API | 所有算法用同样的 fit / predict / transform |
| 丰富的算法 | 分类、回归、聚类、降维、预处理一应俱全 |
| 优秀的文档 | 每个算法都有详细文档和示例 |
| 活跃的社区 | 全球最流行的 ML 库之一 |
| 生产就绪 | 可直接用于真实项目 |
1.2 安装
pip install scikit-learn
import sklearn
print(sklearn.__version__)
二、Scikit-learn 的设计哲学
2.1 统一 API——一招鲜吃遍天
Scikit-learn 最厉害的地方是:所有算法都遵循同一套 API 模式。不管是线性回归、决策树还是 SVM,使用方法都一样。
# 无论什么算法,代码结构都一样:
from sklearn.xxx import SomeModel
model = SomeModel(超参数) # 创建模型
model.fit(X_train, y_train) # 训练
y_pred = model.predict(X_test) # 预测
score = model.score(X_test, y_test) # 评估
看几个具体例子——注意代码结构的一致性:
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
# 方法完全一样!只是换个模型名
models = {
"决策树": DecisionTreeClassifier(),
"逻辑回归": LogisticRegression(),
"SVM": SVC(),
"KNN": KNeighborsClassifier(),
}
for name, model in models.items():
model.fit(X_train, y_train)
score = model.score(X_test, y_test)
print(f"{name}: {score:.1%}")
统一 API 的好处
你只需要学一次 fit / predict / score,就能使用 sklearn 中的所有算法。换模型就像换零件一样简单。
2.2 三大核心角色
| 角色 | 核心方法 | 做什么 | 例子 |
|---|---|---|---|
| Estimator | fit(), predict() | 从数据中学习,然后做预测 | 决策树、线性回归、SVM |
| Transformer | fit(), transform() | 从数据中学参数,然后变换数据 | 标准化、PCA、独热编码 |
| Pipeline | 串联以上两者 | 把多个步骤串成流水线 | 标准化 → PCA → 分类器 |