1616from sklearn .preprocessing import label_binarize
1717from sklearn .utils .fixes import np_version
1818from sklearn .utils .validation import check_random_state
19- from sklearn .utils .testing import (assert_allclose , assert_array_equal ,
20- assert_no_warnings , assert_equal ,
21- assert_raises , assert_warns_message ,
22- ignore_warnings , assert_not_equal ,
23- assert_raise_message )
24- from sklearn .metrics import (accuracy_score , average_precision_score ,
25- brier_score_loss , cohen_kappa_score ,
26- jaccard_similarity_score , precision_score ,
27- recall_score , roc_auc_score )
19+ from sklearn .utils .testing import assert_allclose , assert_array_equal
20+ from sklearn .utils .testing import assert_no_warnings , assert_raises
21+ from sklearn .utils .testing import assert_warns_message , ignore_warnings
22+ from sklearn .utils .testing import assert_raise_message
23+ from sklearn .metrics import accuracy_score , average_precision_score
24+ from sklearn .metrics import brier_score_loss , cohen_kappa_score
25+ from sklearn .metrics import jaccard_similarity_score , precision_score
26+ from sklearn .metrics import recall_score , roc_auc_score
2827
2928from imblearn .metrics import sensitivity_specificity_support
3029from imblearn .metrics import sensitivity_score
3332from imblearn .metrics import make_index_balanced_accuracy
3433from imblearn .metrics import classification_report_imbalanced
3534
35+ from pytest import approx
36+
3637RND_SEED = 42
3738R_TOL = 1e-2
3839
@@ -113,11 +114,11 @@ def test_sensitivity_specificity_score_binary():
113114
114115def test_sensitivity_specificity_f_binary_single_class ():
115116 # Such a case may occur with non-stratified cross-validation
116- assert_equal ( 1. , sensitivity_score ([1 , 1 ], [1 , 1 ]))
117- assert_equal ( 0. , specificity_score ([1 , 1 ], [1 , 1 ]))
117+ assert sensitivity_score ([1 , 1 ], [1 , 1 ]) == 1.
118+ assert specificity_score ([1 , 1 ], [1 , 1 ]) == 0.
118119
119- assert_equal ( 0. , sensitivity_score ([- 1 , - 1 ], [- 1 , - 1 ]))
120- assert_equal ( 0. , specificity_score ([- 1 , - 1 ], [- 1 , - 1 ]))
120+ assert sensitivity_score ([- 1 , - 1 ], [- 1 , - 1 ]) == 0.
121+ assert specificity_score ([- 1 , - 1 ], [- 1 , - 1 ]) == 0.
121122
122123
123124@ignore_warnings
@@ -166,9 +167,8 @@ def test_sensitivity_specificity_ignored_labels():
166167 rtol = R_TOL )
167168
168169 # ensure the above were meaningful tests:
169- for average in ['macro' , 'weighted' , 'micro' ]:
170- assert_not_equal (
171- specificity_13 (average = average ), specificity_all (average = average ))
170+ for each in ['macro' , 'weighted' , 'micro' ]:
171+ assert specificity_13 (average = each ) != specificity_all (average = each )
172172
173173
174174def test_sensitivity_specificity_error_multilabels ():
@@ -333,15 +333,15 @@ def test_classification_report_imbalanced_multiclass():
333333 y_pred ,
334334 labels = np .arange (len (iris .target_names )),
335335 target_names = iris .target_names )
336- assert_equal ( _format_report (report ), expected_report )
336+ assert _format_report (report ) == expected_report
337337 # print classification report with label detection
338338 expected_report = ('pre rec spe f1 geo iba sup 0 0.83 0.79 0.92 0.81 '
339339 '0.86 0.74 24 1 0.33 0.10 0.86 0.15 0.44 0.19 31 2 '
340340 '0.42 0.90 0.55 0.57 0.63 0.37 20 avg / total 0.51 '
341341 '0.53 0.80 0.47 0.62 0.41 75' )
342342
343343 report = classification_report_imbalanced (y_true , y_pred )
344- assert_equal ( _format_report (report ), expected_report )
344+ assert _format_report (report ) == expected_report
345345
346346
347347def test_classification_report_imbalanced_multiclass_with_digits ():
@@ -361,14 +361,14 @@ def test_classification_report_imbalanced_multiclass_with_digits():
361361 labels = np .arange (len (iris .target_names )),
362362 target_names = iris .target_names ,
363363 digits = 5 )
364- assert_equal ( _format_report (report ), expected_report )
364+ assert _format_report (report ) == expected_report
365365 # print classification report with label detection
366366 expected_report = ('pre rec spe f1 geo iba sup 0 0.83 0.79 0.92 0.81 '
367367 '0.86 0.74 24 1 0.33 0.10 0.86 0.15 0.44 0.19 31 2 '
368368 '0.42 0.90 0.55 0.57 0.63 0.37 20 avg / total 0.51 '
369369 '0.53 0.80 0.47 0.62 0.41 75' )
370370 report = classification_report_imbalanced (y_true , y_pred )
371- assert_equal ( _format_report (report ), expected_report )
371+ assert _format_report (report ) == expected_report
372372
373373
374374def test_classification_report_imbalanced_multiclass_with_string_label ():
@@ -382,15 +382,15 @@ def test_classification_report_imbalanced_multiclass_with_string_label():
382382 '0.19 31 red 0.42 0.90 0.55 0.57 0.63 0.37 20 '
383383 'avg / total 0.51 0.53 0.80 0.47 0.62 0.41 75' )
384384 report = classification_report_imbalanced (y_true , y_pred )
385- assert_equal ( _format_report (report ), expected_report )
385+ assert _format_report (report ) == expected_report
386386
387387 expected_report = ('pre rec spe f1 geo iba sup a 0.83 0.79 0.92 0.81 '
388388 '0.86 0.74 24 b 0.33 0.10 0.86 0.15 0.44 0.19 31 '
389389 'c 0.42 0.90 0.55 0.57 0.63 0.37 20 avg / total '
390390 '0.51 0.53 0.80 0.47 0.62 0.41 75' )
391391 report = classification_report_imbalanced (
392392 y_true , y_pred , target_names = ["a" , "b" , "c" ])
393- assert_equal ( _format_report (report ), expected_report )
393+ assert _format_report (report ) == expected_report
394394
395395
396396def test_classification_report_imbalanced_multiclass_with_unicode_label ():
@@ -411,7 +411,7 @@ def test_classification_report_imbalanced_multiclass_with_unicode_label():
411411 classification_report_imbalanced , y_true , y_pred )
412412 else :
413413 report = classification_report_imbalanced (y_true , y_pred )
414- assert_equal ( _format_report (report ), expected_report )
414+ assert _format_report (report ) == expected_report
415415
416416
417417def test_classification_report_imbalanced_multiclass_with_long_string_label ():
@@ -427,7 +427,7 @@ def test_classification_report_imbalanced_multiclass_with_long_string_label():
427427 '0.37 20 avg / total 0.51 0.53 0.80 0.47 0.62 0.41 75' )
428428
429429 report = classification_report_imbalanced (y_true , y_pred )
430- assert_equal ( _format_report (report ), expected_report )
430+ assert _format_report (report ) == expected_report
431431
432432
433433def test_iba_sklearn_metrics ():
@@ -436,22 +436,22 @@ def test_iba_sklearn_metrics():
436436 acc = make_index_balanced_accuracy (alpha = 0.5 , squared = True )(
437437 accuracy_score )
438438 score = acc (y_true , y_pred )
439- assert_equal ( score , 0.54756 )
439+ assert score == approx ( 0.54756 )
440440
441441 jss = make_index_balanced_accuracy (alpha = 0.5 , squared = True )(
442442 jaccard_similarity_score )
443443 score = jss (y_true , y_pred )
444- assert_equal ( score , 0.54756 )
444+ assert score == approx ( 0.54756 )
445445
446446 pre = make_index_balanced_accuracy (alpha = 0.5 , squared = True )(
447447 precision_score )
448448 score = pre (y_true , y_pred )
449- assert_equal ( score , 0.65025 )
449+ assert score == approx ( 0.65025 )
450450
451451 rec = make_index_balanced_accuracy (alpha = 0.5 , squared = True )(
452452 recall_score )
453453 score = rec (y_true , y_pred )
454- assert_equal ( score , 0.41616000000000009 )
454+ assert score == approx ( 0.41616000000000009 )
455455
456456
457457def test_iba_error_y_score_prob ():
0 commit comments