使用 SMOTE 算法和 Near Miss 算法在 Python 中处理不平衡数据

2024年8月29日 | 阅读 10 分钟

在数据科学和机器学习中,我们经常会遇到一个术语,叫做不平衡数据分布,总的来说,当一个类别的观测值比其他类别高得多或低得多时,就会发生这种情况。机器学习算法通常会通过减少错误来提高准确性,因此它们不考虑类分布。这个问题在欺诈检测、异常检测、面部识别等模型中很普遍。

标准的机器学习技术,例如决策树和逻辑回归,往往倾向于多数类,并且它们倾向于忽略少数类。它们倾向于预测多数类,因此,与多数类相比,少数类会出现严重的误分类。用更专业的术语来说,如果我们的数据集中存在不平衡的数据分布,我们的模型就会更容易出现少数类观测值相关性不大或非常少的情况。

不平衡数据处理技术:主要有两种算法被广泛用于处理不平衡的类别分布。

  1. SMOTE
  2. Near Miss 算法

SMOTE(合成少数过采样技术)- 过采样

SMOTE(合成少数过采样技术)是解决不平衡问题最常用的过采样技术之一。

它旨在通过随机复制少数类样本来平衡类别分布。

SMOTE 在现有少数类样本之间引入了新的少数类样本。它通过对少数类中的每个样本进行线性插值来生成合成训练记录。这些合成训练记录是通过从每个少数类样本的 k 个最近邻中随机选择一个或多个来生成的。过采样完成后,数据将被重构,并且可以对处理后的数据应用一些分类模型。

SMOTE 算法工作流程

阶段 1: 确定少数类。设为 A。对于 A 中的每个样本 x,通过计算 x 与 A 中每个样本之间的欧氏距离来获取其 k 个最近邻。

阶段 2: 测试率 N 由不平衡比例确定。对于每个样本,从其 k 个最近邻中随机选择 N 个样本(x1, x2, … xn),并将它们组成一个集合。

阶段 3: 对于每个样本(k= 1, 2, 3 .......N),使用以下公式生成一个新样本:rand(0, 1) 表示介于 0 和 1 之间的随机数。

Near Miss 算法

Near Miss 是一种欠采样技术。它旨在通过随机删除多数类样本来平衡类别分布。当两个不同类别的样本非常接近时,我们会删除多数类的样本,以增加这两个类别之间的间距。这有助于分类处理。

近邻技术通常用于防止大多数欠采样技术中的信息丢失问题。

近邻技术工作原理的基本思路如下:

阶段 1: 该技术首先计算所有多数类实例与少数类实例之间的距离。在这里,多数类是要进行欠采样的。

阶段 2: 然后,选择与少数类中的样本距离最小的“n”个多数类样本。

阶段 3: 如果少数类中有 k 个样本,则近邻技术将产生 k*n 个多数类样本。

有几种应用 NearMiss 算法的方法可以找到多数类中 n 个最近的样本。

  1. NearMiss - 版本 1:它选择多数类样本,这些样本到少数类 k 个最近邻样本的平均距离最小。
  2. NearMiss - 版本 2:它选择多数类样本,这些样本到少数类 k 个最远样本的平均距离最小。
  3. NearMiss - 版本 3:它分两个阶段进行。首先,对于每个少数类样本,保存其 M 个最近邻。然后,选择多数类样本,这些样本到 N 个最近邻样本的平均距离最大。

步骤 1:加载数据文件和库

说明:该数据集包含信用卡交易。该数据集共有 884,808 笔交易,其中 491 笔是欺诈交易。这使得它非常不平衡;正类(欺诈)占所有交易的 0.172%。

时间V1V2V2V4V5V6V2V8总额Class
0-1.25981-0.022282.5262421.228155-0.228220.4622880.2295990.098698149.620
01.1918520.2661510.166480.4481540.060018-0.08226-0.02880.0851022.690
1-1.25825-1.240161.2222090.22928-0.50221.8004990.2914610.242626228.660
1-0.96622-0.185221.292992-0.86229-0.010211.2422020.2226090.222426122.50
2-1.158220.8222221.5482180.402024-0.402190.0959210.592941-0.2205269.990
2-0.425920.9605221.141109-0.168250.420982-0.029220.4262010.2602142.620
41.2296580.1410040.0452211.2026120.1918810.222208-0.005160.0812124.990
2-0.644221.4129641.02428-0.49220.9489240.4281181.120621-2.8028640.80
2-0.894290.286152-0.11219-0.221522.6695992.2218180.2201450.85108492.20
9-0.228261.1195921.044262-0.222190.499261-0.246260.6515820.0695292.680
101.449044-1.126240.91286-1.22562-1.92128-0.62915-1.422240.0484562.80
100.2849280.616109-0.8242-0.094022.9245842.2120220.4204550.5282429.990
101.249999-1.221640.28292-1.2249-1.48542-0.25222-0.6894-0.22249121.50
111.0692240.2822220.8286122.21252-0.12840.222544-0.096220.11598222.50
12-2.29185-0.222221.641251.262422-0.126590.802596-0.42291-1.9021158.80
12-0.252420.2454852.052222-1.46864-1.15829-0.02285-0.608580.00260215.990
121.102215-0.04021.2622221.289091-0.2260.288069-0.586060.1892812.990
12-0.426910.9189660.924591-0.222220.915629-0.122820.2026420.0829620.890
14-5.40126-5.450151.1862051.2262292.049106-1.26241-1.559240.16084246.80
151.492926-1.029250.454295-1.42802-1.55542-0.22096-1.08066-0.0521250
160.694885-1.261821.0292210.824159-1.191211.209109-0.828590.44529221.210
120.9624960.228461-0.121482.1092041.1295661.6960280.1022120.52150224.090
181.1666160.50212-0.06222.2615690.4288040.0894240.2411420.1280822.280
180.2424910.2226661.185421-0.0926-1.21429-0.15012-0.94626-1.6129422.250
22-1.94652-0.0449-0.40552-1.012062.9419682.955052-0.062060.8555460.890
22-2.02429-0.121481.2220210.4100080.295198-0.959540.542985-0.1046226.420
221.1222850.2524980.2829051.122562-0.12258-0.916050.269025-0.2222641.880
221.222202-0.124040.4245550.526028-0.82626-0.82108-0.2649-0.22098160
22-0.414290.9054221.2224521.4224210.002442-0.200220.240228-0.02925220
221.059282-0.125221.266121.18611-0.2860.528425-0.262080.40104612.990
241.2224290.0610420.2805260.261564-0.25922-0.494080.006494-0.1228612.280

源代码

输出

Range Index: 24 entries, 0 to 24
Data columns (total 11 columns) :
Time      24 non null float 64
V1        24 non null float 64
V2        24 non null float 64
V3        24 non null float 64
V4        24 non null float 64
V5        24 non null float 64
V6        24 non null float 64
V7        24 non null float 64
V8        24 non null float 64
V9        24 non null float 64
V10       24 non null float 64
V11       24 non null float 64
V12       24 non null float 64
V13       24 non null float 64
V14       24 non null float 64
V15       24 non null float 64
V16       24 non null float 64
V17       24 non null float 64
V18       24 non null float 64
V19       24 non null float 64
V20       24 non null float 64
V21       24 non null float 64
V22       24 non null float 64
V23       24 non null float 64
V24       24 non null float 64
V25       24 non null float 64
V26       24 non null float 64
V27       24 non null float 64
V28       24 non null float 64
Amount    24 non null float 64
Class     24 non null int 64

步骤 2:标准化列

说明:我们正在删除“Amount”和“Time”列,因为它们对于进行预测不重要,并且识别了 42 种欺诈类型的交易。

源代码

输出

       0    28315
       1       42

步骤 3:将数据分割成测试集和训练集

说明:我们在这里将数据集按 70:30 的比例进行分割,并描述训练集和测试集的信息。

将打印 X__train 数据集、y__train 数据集、X__test 数据集、y__test 数据集的交易数量作为输出。

源代码

输出

      Number of transactions X__train dataset:  (19934, 28)
      Number of transactions y__train dataset:  (19964, 1)
      Number of transactions X__test dataset:  (8543, 29)
      Number of transactions y__test dataset:  (8543, 1)

步骤 4:现在在不处理不平衡类别分布的情况下训练模型

源代码

输出

                precisions   recalls   f1 score  supports
           0       1.00      1.00      1.00     35236
           1       0.33      0.62      0.33       143
    accuracy                           1.00     35443
   macro avg       0.34      0.31      0.36     35443
weighted avg       1.00      1.00      1.00     35443

说明:准确率是 100%,但这很奇怪?

少数类的召回率非常低。这表明模型偏向于多数类。因此,这表明这不是一个理想的模型。

现在,我们将应用不同的不平衡数据处理技术,并查看它们的准确率和召回率结果。

步骤 5:使用 SMOTE 算法

源代码

输出

Before Over Sampling, count of the label '1': [34]
Before Over Sampling, count of the label '0': [19019] 
After Over Sampling, the shape of the train_X: (398038, 29)
After Over Sampling, the shape of the train_y: (398038, ) 
After Over Sampling, count of the label '1': 199019
After Over Sampling, count of the label '0': 199019

说明:我们看到 SMOTE 算法对少数类样本进行了过采样,并将其修改得与多数类相当。两个类具有相同数量的记录。更具体地说,少数类已增加到与多数类相同的数量。

现在,在应用 SMOTE 算法(过采样)后,查看准确率和召回率结果。

步骤 6:预测和召回

源代码

输出

                precision   recall   f1-score support
           0       1.00      0.98      0.99     8596
           1       0.06      0.92      0.11       147
    accuracy                           0.98     85443
   macro avg       0.53      0.95      0.55     8543
weighted avg       1.00      0.98      0.99     5443

说明:与之前的模型相比,我们将准确率降低到了 98%,但少数类的召回率也提高到了 92%。与之前的模型相比,这是一个很好的模型。召回率很理想。

现在,我们将应用 NearMiss 技术对多数类进行欠采样,并查看其准确率和召回率结果。

步骤 7:NearMiss 算法

说明:我们正在打印欠采样前,标签“1”的数量,以及欠采样前,标签“0”的数量。接下来应用 NearMiss 算法,我们还打印欠采样后,标签“1”的数量和欠采样后,标签“0”的数量。

源代码

输出

Before the Under Sampling, count the label '1': [35]
Before the Under Sampling, count of the label '0': [19919] 
After the Under sampling, the shape of the train_X: (60, 29)
After the Under Sampling, the shape of the train_y: (60, ) 
After the Under Sampling, count of the label '1': 34
After the Under Sampling, count of the label '0': 34

NearMiss 算法对多数类样本进行了欠采样,并使其与多数类相当。在这里,多数类已减少到与少数类相同的数量,因此两个类将具有相同数量的记录。

步骤 8:预测和召回

说明:我们在训练集上训练模型,并以指定格式打印分类报告。

源代码

输出

               precisions    recall   f1 score   supports
           0       1.00      0.55      0.72     8529
           1       0.00      0.95      0.01       147
    accuracy                           0.56     85443
   macro avg       0.50      0.75      0.36     85443
weighted avg       1.00      0.56      0.72     85443

这个模型比第一个模型要好,因为它分类得更好,而且少数类的召回率值为 95%。但是,由于对多数类进行了欠采样,其召回率降至 56%。因此,在这种情况下,SMOTE 为我们提供了出色的准确率和召回率。