문제

I'd like to pass a sparse precomputed Gram matrix to sklearn.svm.SVC.fit. Here's some working code:

import numpy as np
from sklearn import svm
X = np.array([[0, 0], [1, 1]])
y = [0, 1]
clf = svm.SVC(kernel='precomputed')
gram = np.dot(X, X.T)
clf.fit(gram, y) 

But if I have:

from scipy.sparse import csr_matrix
sparse_gram = csr_matrix(gram)
clf.fit(sparse_gram, y)

I get:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python2.7/dist-packages/sklearn/svm/base.py", line 191, in fit
    fit(X, y, sample_weight, solver_type, kernel)
  File "/usr/local/lib/python2.7/dist-packages/sklearn/svm/base.py", line 235, in _dense_fit
    max_iter=self.max_iter)
TypeError: Argument 'X' has incorrect type (expected numpy.ndarray, got csr_matrix)

The fact that I ended up in the _dense_fit function (see where it says line 235 above) makes me think I need to do something special to tell fit to use sparse matrix. But I'm not sure how to do that.

Update: I just checked the code for the fit function (https://sourcegraph.com/github.com/scikit-learn/scikit-learn/symbols/python/sklearn/svm/base/BaseLibSVM/fit) and now I'm even more confused:

    self._sparse = sp.isspmatrix(X) and not self._pairwise

    if self._sparse and self._pairwise:
        raise ValueError("Sparse precomputed kernels are not supported. "
                         "Using sparse data and dense kernels is possible "
                         "by not using the ``sparse`` parameter")

So I guess as it says, "Sparse precomputed kernels are not supported" and that's indeed what I wanted to do, so I'm probably out of luck. (Is it a bug that I didn't actually see that error though?)

도움이 되었습니까?

해결책

so I'm probably out of luck.

Yep. Sorry about that.

Is it a bug that I didn't actually see that error though?

Yes it is: the released code sets

self._sparse = sp.isspmatrix(X) and not self._pairwise

and then checks

self._sparse and self._pairwise

to raise the exception. That condition is impossible to fulfill. I just pushed a patch to fix this, thanks for the report.

라이센스 : CC-BY-SA ~와 함께 속성
제휴하지 않습니다 StackOverflow
scroll top