OpenCV SVM - eiichiromomma/CVMLAB GitHub Wiki

(OpenCV) Support Vector Machine

CvSVMによるサポートベクタマシン(SVM: Support Vector Machine)による分類器について

Classificationの他にRegression(SVR)に対応。 ClassificationでMulticlassに対応しているのがありがたい。

注意点

  • CvSVMでload,saveメソッドを使うにはml.libのリビルドが必要。(ver.1.0)
  • NU_SVCは学習データの内容が各クラスまんべんなく入っていないとエラーになる

基本的な使い方

  1. データ作り
  2. 学習
  3. テスト

と、他の分類器と同様なのでおおまかな使い方はletter_recog.cppを追うと分かる。

パラメータについてはopencv.jp - OpenCV:機械学習 サポートベクターマシン(SVM)リファレンス マニュアル - を参照。

letter_recog.cppにSVMを追加

main

svmのオプションを付ける

    int main( int argc, char *argv[] )
    {
        char* filename_to_save = 0;
        char* filename_to_load = 0;
        char default_data_filename[] = "./letter-recognition.data";
        char* data_filename = default_data_filename;
        int method = 0;

        int i;
        for( i = 1; i < argc; i++ )
        {
            if( strcmp(argv[i],"-data") == 0 ) // flag "-data letter_recognition.xml"
            {
                i++;
                data_filename = argv[i];
            }
            else if( strcmp(argv[i],"-save") == 0 ) // flag "-save filename.xml"
            {
                i++;
                filename_to_save = argv[i];
            }
            else if( strcmp(argv[i],"-load") == 0) // flag "-load filename.xml"
            {
                i++;
                filename_to_load = argv[i];
            }
            else if( strcmp(argv[i],"-boost") == 0)
            {
                method = 1;
            }
            else if( strcmp(argv[i],"-mlp") == 0 )
            {
                method = 2;
            }
            else if( strcmp(argv[i],"-svm") == 0 )
            {
                method = 3;
            }
            else
                break;
        }

        if( i < argc ||
            (method == 0 ?
            build_rtrees_classifier( data_filename, filename_to_save, filename_to_load ) :
            method == 1 ?
            build_boost_classifier( data_filename, filename_to_save, filename_to_load ) :
            method == 2 ?
            build_mlp_classifier( data_filename, filename_to_save, filename_to_load ) :
            method == 3 ?
            build_svm_classifier( data_filename, filename_to_save, filename_to_load ) :
            -1) < 0)
        {
            printf("This is letter recognition sample.\n"
                    "The usage: letter_recog [-data <path to letter-recognition.data>] \\\n"
                    "  [-save <output XML file for the classifier>] \\\n"
                    "  [-load <XML file with the pre-trained classifier>] \\\n"
                    "  [-boost|-mlp|-svm] # to use boost/mlp/svm classifier instead of default Random Trees\n" );
        }
        return 0;
    }

build_svm_classifier

build_rtrees_classifierを少し修正しただけ。 パラメータの設定はいい加減。

全データの20%を学習に使ってみた。

    static
    int build_svm_classifier( char* data_filename,
        char* filename_to_save, char* filename_to_load )
    {
        CvMat* data = 0;
        CvMat* responses = 0;
        CvMat* var_type = 0;
        CvMat* sample_idx = 0;
        //C_SVCのパラメータ
        float svm_C = 1000;
        //RBFカーネルのパラメータ
        float svm_gamma = 0.1f;

        int ok = read_num_class_data( data_filename, 16, &data, &responses );
        int nsamples_all = 0, ntrain_samples = 0;
        int i = 0;
        double train_hr = 0, test_hr = 0;
        CvSVM svm;
        //終了条件の設定
        CvTermCriteria criteria = cvTermCriteria(CV_TERMCRIT_EPS,100,0.001);

        if( !ok )
        {
            printf( "Could not read the database %s\n", data_filename );
            return -1;
        }

        printf( "The database %s is loaded.\n", data_filename );
        nsamples_all = data->rows;
        ntrain_samples = (int)(nsamples_all*0.2);

        // Create or load Random Trees classifier
        if( filename_to_load )
        {
            // load classifier from the specified file
            svm.load( filename_to_load );
            ntrain_samples = 0;
            if( svm.get_support_vector_count() == 0 )
            {
                printf( "Could not read the classifier %s\n", filename_to_load );
                return -1;
            }
            printf( "The classifier %s is loaded.\n", data_filename );
        }
        else
        {
            // create classifier by using <data> and <responses>
            printf( "Training the classifier ...");

            // 1. create type mask
            var_type = cvCreateMat( data->cols + 1, 1, CV_8U );
            cvSet( var_type, cvScalarAll(CV_VAR_ORDERED) );
            cvSetReal1D( var_type, data->cols, CV_VAR_CATEGORICAL );

            // 2. create sample_idx
            sample_idx = cvCreateMat( 1, nsamples_all, CV_8UC1 );
            {
                CvMat mat;
                cvGetCols( sample_idx, &mat, 0, ntrain_samples );
                cvSet( &mat, cvRealScalar(1) );

                cvGetCols( sample_idx, &mat, ntrain_samples, nsamples_all );
                cvSetZero( &mat );
            }

            // 3. train classifier
            //方法、カーネルにより使わないパラメータは0で良く、
            //重みについてもNULLで良い
            svm.train( data, responses, 0, sample_idx, 
              CvSVMParams(CvSVM::C_SVC,CvSVM::RBF,0,svm_gamma,0,svm_C,0,0,NULL,
              criteria));
            printf( "\n");
        }

        // compute prediction error on train and test data
        for( i = 0; i < nsamples_all; i++ )
        {
            double r;
            CvMat sample;
            cvGetRow( data, &sample, i );

            r = svm.predict( &sample );
            //結果の比較
            printf("predict: %c, responses: %c, %s\n", (unsigned char)r, (unsigned char)responses->data.fl[i],
              fabs((double)r - responses->data.fl[i]) <= FLT_EPSILON?"Good!":"Bad!");
            r = fabs((double)r - responses->data.fl[i]) <= FLT_EPSILON ? 1 : 0;

            if( i < ntrain_samples )
                train_hr += r;
            else
                test_hr += r;
        }

        test_hr /= (double)(nsamples_all-ntrain_samples);
        train_hr /= (double)ntrain_samples;
        printf( "Gamma=%.5f, C=%.5f\n", svm_gamma, svm_C);
        if( filename_to_load ){
          printf( "Recognition rate: test = %.1f%%\n", test_hr*100. );
        }else{
          printf( "Recognition rate: train = %.1f%%, test = %.1f%%\n",
                  train_hr*100., test_hr*100. );
        }

        printf( "Number of Support Vector: %d\n", svm.get_support_vector_count() );
        // Save SVM classifier to file if needed
        if( filename_to_save )
            svm.save( filename_to_save );

        cvReleaseMat( &sample_idx );
        cvReleaseMat( &var_type );
        cvReleaseMat( &data );
        cvReleaseMat( &responses );

        return 0;
    }

実行結果

やけにお目出度い(?)結果になった。

predict: Y, responses: Y, Good!
predict: V, responses: V, Good!
predict: S, responses: S, Good!
predict: M, responses: M, Good!
predict: O, responses: O, Good!
predict: L, responses: L, Good!
predict: D, responses: D, Good!
predict: P, responses: P, Good!
predict: W, responses: W, Good!
predict: O, responses: O, Good!
predict: E, responses: E, Good!
predict: J, responses: J, Good!
predict: T, responses: T, Good!
predict: D, responses: D, Good!
predict: C, responses: C, Good!
predict: T, responses: T, Good!
predict: S, responses: S, Good!
predict: A, responses: A, Good!
Gamma=0.10000, C=1000.00000
Recognition rate: test = 93.6%
Number of Support Vector: 3383
続行するには何かキーを押してください . . .

xmlファイル

  <?xml version="1.0"?>
  <opencv_storage>
  <my_svm type_id="opencv-ml-svm">
    <svm_type>C_SVC</svm_type>
    <kernel><type>RBF</type>
      <gamma>0.1000000014901161</gamma></kernel>
    <C>1000.</C>
    <term_criteria><epsilon>1.0000000474974513e-003</epsilon>
      <iterations>2147483647</iterations></term_criteria>
    <var_all>16</var_all>
    <var_count>16</var_count>
    <class_count>26</class_count>
    <class_labels type_id="opencv-matrix">
      <rows>1</rows>
      <cols>26</cols>
      <dt>i</dt>
      <data>
        65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
        87 88 89 90</data></class_labels>
    <sv_total>3383</sv_total>
    <support_vectors>
      <_>
        2. 1. 4. 2. 1. 8. 1. 2. 2. 7. 2. 8. 2. 5. 2. 7.</_>
      <_>
        3. 7. 5. 5. 3. 10. 4. 1. 2. 8. 3. 9. 2. 4. 2. 7.</_>
省略

参考文献

共立出版の「サポートベクターマシン入門」しか日本語の書籍がなかったが、Ohmshaの「サポートベクターマシン」小野田崇著が具体例も豊富で非常に分かりやすい。 パラメータの決定については交差検定を用いている。

⚠️ **GitHub.com Fallback** ⚠️