train_test_split简介

train_test_split 是 scikit-learn 中一个非常重要的数据划分函数,用于将数据集随机划分为训练集测试集

函数签名

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
    X, y,                     # 特征和标签
    test_size=0.25,           # 测试集比例(默认0.25)
    train_size=None,          # 训练集比例(与test_size二选一)
    random_state=None,        # 随机种子
    shuffle=True,             # 是否打乱数据(默认True)
    stratify=None             # 分层采样
)

参数详解

参数说明示例
X特征数据[[1,2], [3,4], ...]
y标签数据[0,1,0,1,...]
test_size测试集比例/数量0.2(20%)或 10(10个样本)
train_size训练集比例/数量0.8(80%)或 50(50个样本)
random_state随机种子,保证可复现4254 等整数
shuffle是否打乱数据True(默认)或 False
stratify分层采样,保持类别比例y(按标签分层)

基本用法示例

1. 最简单的划分

from sklearn.model_selection import train_test_split
import numpy as np

# 示例数据
X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
y = np.array([0, 1, 0, 1, 0])

# 划分(默认测试集25%)
X_train, X_test, y_train, y_test = train_test_split(X, y)

print(f"X_train: {X_train.shape}")  # (3, 2) 3个样本
print(f"X_test: {X_test.shape}")    # (2, 2) 2个样本

2. 指定测试集比例

# 测试集占30%
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3
)

# 测试集占20%,训练集占80%
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, train_size=0.8
)

3. 固定随机种子(保证结果可复现)

# 每次运行结果都一样
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42
)

4. 分层采样(处理不平衡数据集)

# 当类别不平衡时,使用 stratify 保持训练集和测试集中各类别比例一致
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, stratify=y, random_state=42
)

# 检查分层效果
print("原始数据各类别比例:")
print(np.bincount(y) / len(y))

print("训练集各类别比例:")
print(np.bincount(y_train) / len(y_train))

print("测试集各类别比例:")
print(np.bincount(y_test) / len(y_test))

完整实战示例(鸢尾花数据集)

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

# 1. 加载数据
iris = load_iris()
X, y = iris.data, iris.target

print(f"原始数据大小: {X.shape}")  # (150, 4)
print(f"各类别数量: {np.bincount(y)}")  # [50 50 50]

# 2. 划分数据集(带分层)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, 
    test_size=0.3,      # 30% 用于测试
    random_state=42,    # 固定随机种子
    stratify=y          # 分层采样,保持类别比例
)

print(f"\n训练集大小: {X_train.shape}")  # (105, 4)
print(f"测试集大小: {X_test.shape}")    # (45, 4)

print(f"\n训练集各类别数量: {np.bincount(y_train)}")  # [35 35 35]
print(f"测试集各类别数量: {np.bincount(y_test)}")    # [15 15 15]

# 3. 数据预处理(只用训练集拟合)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# 4. 训练模型
knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(X_train_scaled, y_train)

# 5. 预测和评估
y_pred = knn.predict(X_test_scaled)
accuracy = accuracy_score(y_test, y_pred)
print(f"\n模型准确率: {accuracy:.4f}")  # 通常在0.95以上

常见问题

1. test_size 和 train_size 如何选择?

# 常用比例(数据量较大时)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2  # 80%训练,20%测试
)

# 数据量较小时,测试集比例要更小
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.1  # 90%训练,10%测试
)

# 或者指定具体数量
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=50  # 测试集50个样本
)

2. 什么时候需要 stratify(分层)?

# 不平衡数据集(比如二分类,90%是0,10%是1)
y = [0] * 900 + [1] * 100  # 900个0,100个1

# ❌ 不分层:可能测试集中没有1
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# ✅ 分层:保证训练集和测试集中1的比例都是10%
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, stratify=y, random_state=42
)

3. shuffle=False 什么时候用?

# 时间序列数据通常不打乱,保持时间顺序
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, shuffle=False
)

常见错误

❌ 错误1:先标准化再划分

# 错误:导致数据泄露
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)  # 用全部数据计算参数
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y)

✅ 正确:先划分再标准化

# 正确
X_train, X_test, y_train, y_test = train_test_split(X, y)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)  # 只用训练集拟合
X_test_scaled = scaler.transform(X_test)        # 用训练集参数转换

❌ 错误2:忘记固定 random_state

# 每次运行结果不一样,无法复现
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

✅ 正确:固定 random_state

# 结果可复现
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

返回值说明

# 返回值顺序固定
X_train, X_test, y_train, y_test = train_test_split(X, y, ...)

# 如果是多个特征数组
X1_train, X1_test, X2_train, X2_test, y_train, y_test = train_test_split(
    X1, X2, y, test_size=0.2
)

总结

要点说明
作用将数据划分为训练集和测试集
常用比例训练70-80%,测试20-30%
关键参数test_sizerandom_statestratify
注意事项先划分再预处理,固定随机种子保证可复现
分层采样类别不平衡时必须使用 stratify=y
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值