diff --git a/pattern/vector/svm/libsvmutil.py b/pattern/vector/svm/libsvmutil.py index 29110b00..9ce8e37f 100644 --- a/pattern/vector/svm/libsvmutil.py +++ b/pattern/vector/svm/libsvmutil.py @@ -14,6 +14,8 @@ from .libsvm import * from .libsvm import __all__ as svm_all +import numpy as np + __all__ = ['evaluations', 'svm_load_model', 'svm_predict', 'svm_read_problem', 'svm_save_model', 'svm_train'] + svm_all @@ -77,16 +79,13 @@ def evaluations(ty, pv): if len(ty) != len(pv): raise ValueError("len(ty) must equal to len(pv)") total_correct = total_error = 0 - sumv = sumy = sumvv = sumyy = sumvy = 0 - for v, y in zip(pv, ty): - if y == v: - total_correct += 1 - total_error += (v - y) * (v - y) - sumv += v - sumy += y - sumvv += v * v - sumyy += y * y - sumvy += v * y + sumv = np.sum(pv) + sumy = np.sum(ty) + sumvv = np.dot(pv,pv) + sumyy = np.dot(ty,ty) + sumvy = np.dot(pv,ty) + total_correct = np.sum([1 for v, y in zip(pv, ty) if y == v]) + total_error = np.sum(np.subtract(pv, ty)**2) l = len(ty) ACC = 100.0 * total_correct / l MSE = total_error / l