train_test_split 是 scikit-learn 中一个非常重要的数据划分函数,用于将数据集随机划分为训练集和测试集。
函数签名
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
X, y,
test_size=0.25,
train_size=None,
random_state=None,
shuffle=True,
stratify=None
)
参数详解
| 参数 | 说明 | 示例 |
|---|
X | 特征数据 | [[1,2], [3,4], ...] |
y | 标签数据 | [0,1,0,1,...] |
test_size | 测试集比例/数量 | 0.2(20%)或 10(10个样本) |
train_size | 训练集比例/数量 | 0.8(80%)或 50(50个样本) |
random_state | 随机种子,保证可复现 | 42、54 等整数 |
shuffle | 是否打乱数据 | True(默认)或 False |
stratify | 分层采样,保持类别比例 | y(按标签分层) |
基本用法示例
1. 最简单的划分
from sklearn.model_selection import train_test_split
import numpy as np
X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
y = np.array([0, 1, 0, 1, 0])
X_train, X_test, y_train, y_test = train_test_split(X, y)
print(f"X_train: {X_train.shape}")
print(f"X_test: {X_test.shape}")
2. 指定测试集比例
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3
)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, train_size=0.8
)
3. 固定随机种子(保证结果可复现)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=42
)
4. 分层采样(处理不平衡数据集)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, stratify=y, random_state=42
)
print("原始数据各类别比例:")
print(np.bincount(y) / len(y))
print("训练集各类别比例:")
print(np.bincount(y_train) / len(y_train))
print("测试集各类别比例:")
print(np.bincount(y_test) / len(y_test))
完整实战示例(鸢尾花数据集)
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
iris = load_iris()
X, y = iris.data, iris.target
print(f"原始数据大小: {X.shape}")
print(f"各类别数量: {np.bincount(y)}")
X_train, X_test, y_train, y_test = train_test_split(
X, y,
test_size=0.3,
random_state=42,
stratify=y
)
print(f"\n训练集大小: {X_train.shape}")
print(f"测试集大小: {X_test.shape}")
print(f"\n训练集各类别数量: {np.bincount(y_train)}")
print(f"测试集各类别数量: {np.bincount(y_test)}")
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(X_train_scaled, y_train)
y_pred = knn.predict(X_test_scaled)
accuracy = accuracy_score(y_test, y_pred)
print(f"\n模型准确率: {accuracy:.4f}")
常见问题
1. test_size 和 train_size 如何选择?
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2
)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.1
)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=50
)
2. 什么时候需要 stratify(分层)?
y = [0] * 900 + [1] * 100
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, stratify=y, random_state=42
)
3. shuffle=False 什么时候用?
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, shuffle=False
)
常见错误
❌ 错误1:先标准化再划分
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y)
✅ 正确:先划分再标准化
X_train, X_test, y_train, y_test = train_test_split(X, y)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
❌ 错误2:忘记固定 random_state
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
✅ 正确:固定 random_state
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
返回值说明
X_train, X_test, y_train, y_test = train_test_split(X, y, ...)
X1_train, X1_test, X2_train, X2_test, y_train, y_test = train_test_split(
X1, X2, y, test_size=0.2
)
总结
| 要点 | 说明 |
|---|
| 作用 | 将数据划分为训练集和测试集 |
| 常用比例 | 训练70-80%,测试20-30% |
| 关键参数 | test_size、random_state、stratify |
| 注意事项 | 先划分再预处理,固定随机种子保证可复现 |
| 分层采样 | 类别不平衡时必须使用 stratify=y |