Skip to content

Commit ba5c1d7

Browse files
author
Guillaume Lemaitre
committed
Remove nonzero occurence in ENN
Conflicts: imblearn/under_sampling/nearmiss.py
1 parent e7a7de3 commit ba5c1d7

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

imblearn/under_sampling/edited_nearest_neighbours.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def _sample(self, X, y):
124124

125125
# If we need to offer support for the indices
126126
if self.return_indices:
127-
idx_under = np.nonzero(y == self.min_c_)[0]
127+
idx_under = np.flatnonzero(y == self.min_c_)
128128

129129
# Create a k-NN to fit the whole data
130130
nn_obj = NearestNeighbors(n_neighbors=self.size_ngh + 1,
@@ -162,12 +162,12 @@ def _sample(self, X, y):
162162
raise NotImplementedError
163163

164164
# Get the samples which agree all together
165-
sel_x = np.squeeze(sub_samples_x[np.nonzero(nnhood_bool), :])
166-
sel_y = sub_samples_y[np.nonzero(nnhood_bool)]
165+
sel_x = sub_samples_x[np.flatnonzero(nnhood_bool), :]
166+
sel_y = sub_samples_y[np.flatnonzero(nnhood_bool)]
167167

168168
# If we need to offer support for the indices selected
169169
if self.return_indices:
170-
idx_tmp = np.nonzero(y == key)[0][np.nonzero(nnhood_bool)]
170+
idx_tmp = np.flatnonzero(y == key)[np.flatnonzero(nnhood_bool)]
171171
idx_under = np.concatenate((idx_under, idx_tmp), axis=0)
172172

173173
X_resampled = np.concatenate((X_resampled, sel_x), axis=0)

imblearn/under_sampling/nearmiss.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def _selection_dist_based(self, X, y, dist_vec, num_samples, key,
166166
sel_idx = sorted_idx[:num_samples]
167167

168168
return (X[y == key][sel_idx], y[y == key][sel_idx],
169-
np.nonzero(y == key)[0][sel_idx])
169+
np.flatnonzero(y == key)[sel_idx])
170170

171171
def _sample(self, X, y):
172172
"""Resample the dataset.
@@ -195,9 +195,9 @@ def _sample(self, X, y):
195195

196196
# Assign the parameter of the element of this class
197197
# Check that the version asked is implemented
198-
if not (self.version == 1 or self.version == 2 or self.version == 3):
199-
raise ValueError('UnbalancedData.NearMiss: there is only 3 '
200-
'versions available with parameter version=1/2/3')
198+
if self.version not in [1, 2, 3]:
199+
raise ValueError("Parameter 'version' must be 1, 2 or 3, "
200+
"got {0}".format(self.version))
201201

202202
# Start with the minority class
203203
X_min = X[y == self.min_c_]
@@ -215,7 +215,7 @@ def _sample(self, X, y):
215215

216216
# If we need to offer support for the indices
217217
if self.return_indices:
218-
idx_under = np.nonzero(y == self.min_c_)[0]
218+
idx_under = np.flatnonzero(y == self.min_c_)
219219

220220
# For each element of the current class, find the set of NN
221221
# of the minority class

0 commit comments

Comments
 (0)