PHPでNaive Bayesを使ってみる
今月号のWEB+DB PRESS。
- 作者: arton,桑田誠,角田直行,和田卓人,伊藤直也,西田圭介,岡野原大輔,縣俊貴,大塚知洋,nanto_vi,徳永拓之,山本陽平,田中洋一郎,下岡秀幸,ミック,武者晶紀,高林哲,小飼弾,はまちや2,WEB+DB PRESS編集部
- 出版社/メーカー: 技術評論社
- 発売日: 2009/02/23
- メディア: 大型本
- 購入: 8人 クリック: 326回
- この商品を含むブログ (47件) を見る
はてなブックマークのリニューアルに際しての特集記事があったり、レコメンドエンジンの解説記事があったりと非常に読み応えがあっていつもの3割増でおすすめ。
で、ブックマークのカテゴリ自動判定システムで使われているアルゴリズムはComplement Naive Bayesで、このアルゴリズムの元となっているアルゴリズムはNaive Bayes(単純ベイズ分類機)と呼ばれるもの。
Perlでは、記事でも紹介されている通り、Algorithm::NaiveBayesというライブラリがCPANにあるので利用するとアルゴリズムが比較的簡単に利用することができる。
このアルゴリズムを使ってみたいと思ったのだけど、あいにくPHPでは似た形で利用できるライブラリがすぐに見つからなかったので、突貫でこのPerlのライブラリを移植してみた。
Perl版だと、スコアを計算する方法を"frequency", "discrete", "gaussian"の3通りから選べたり、学習させた結果を保管できるのだけど、このたびのものは無し。
記事に記載のサンプルに倣って試してみる。
PHP実装もPerlのインターフェースに併せている。addInstanceメソッドの第一引数に学習対象となる文書の単語の出現数をarrayで与え、その文書が所属するカテゴリを第2引数に与える。
trainメソッドで学習を実行して、predictメソッドで分類を推定する文書中の単語とその出現数を与えると、カテゴリに所属する確率を推測してくれる。
<?php $bayes = new NaiveBayes(); $bayes->addInstance(array("はてな" => 5, "京都" => 2), array("it")); $bayes->addInstance(array("引っ越し" => 1, "春" => 1), array("life")); $bayes->train(); $resp = $bayes->predict(array("はてな" => 1, "引っ越し" => 1, "京都" => 1)); print_r($resp); ?>
上記のソースを実行すると、
Array
(
[it] => 0.825130233192
[life] => 0.564942561923
)
とPerl版と同じ結果になるので大丈夫かな。。。と。
この場合だと、ITというカテゴリに所属する確率が高いと判定されたということになる。
ソースコードは以下参照。おそらく逐次直します。変なところもあるだろうし。きっと。
<?php /** * Naivebayes.php * * This package was ported from Perl's Algorithm::NaiveBayes (frequency model only) * http://search.cpan.org/~kwilliams/Algorithm-NaiveBayes-0.04/lib/Algorithm/NaiveBayes.pm * * @category Algorithm * @package Naivebayes * @author hideack * @license http://www.php.net/license/3_01.txt The PHP License, version 3.01 * @version 0.1 */ class Naivebayes{ private $modeltype; private $instances; private $trainingdata; private $model; public function __construct(){ $this->trainingdata = array( "attributes" => array(), "labels" => array(), ); $this->instances = 0; $this->modeltype = ""; // Perl版では切り替え可能 } public function addInstance($attributes, $label){ $this->instances++; foreach($attributes as $keyword => $count){ if(isset($this->trainingdata['attributes'][$keyword])){ $this->trainingdata['attributes'][$keyword] += $count; } else{ $this->trainingdata['attributes'][$keyword] = $count; } } foreach($label as $labelword){ if(isset($this->trainingdata['labels'][$labelword]['count'])){ $this->trainingdata['labels'][$labelword]['count']++; } else{ $this->trainingdata['labels'][$labelword]['count'] = 1; } foreach($attributes as $keyword => $count){ if(isset($this->trainingdata[$keyword])){ $this->trainingdata['labels'][$labelword]['attributes'][$keyword] += $count; } else{ $this->trainingdata['labels'][$labelword]['attributes'][$keyword] = $count; } } } } public function train(){ $m = array(); $labels = $this->trainingdata['labels']; $m['attributes'] = $this->trainingdata['attributes']; $vocab_size = count($m['attributes']); foreach($labels as $label => $info){ $m['prior_probs'][$label] = log($info['count'] / $this->instances); $label_tokens = 0; foreach($info['attributes'] as $word => $count){ $label_tokens += $count; } $m['smoother'][$label] = -log($label_tokens + $vocab_size); $denominator = log($label_tokens + $vocab_size); foreach($info['attributes'] as $attribute => $count){ $m['probs'][$label][$attribute] = log($count + 1) - $denominator; } } $this->model = $m; } public function predict($newattrs){ $scores = $this->model['prior_probs']; foreach($newattrs as $feature => $value){ foreach($this->model['probs'] as $label => $attribute){ $tmpscore = 0.0; if($attribute[$feature] == 0.0){ $tmpscore = $this->model['smoother'][$label]; } else{ $tmpscore = $attribute[$feature]; } $scores[$label] += $tmpscore * $value; } } $scores = $this->rescale($scores); return $scores; } public function labels(){ $labels = array(); foreach($this->trainingdata['labels'] as $label => $value){ $labels[] = $label; } return $labels; } public function doPurge(){ // 未実装... } private function rescale($scores){ $total = 0; $max = max($scores); $rescalescore = $scores; foreach($rescalescore as $key => $val){ $val = exp($val - $max); $total += pow($val, 2); $rescalescore[$key] = $val; } $total = sqrt($total); foreach($rescalescore as $key => $val){ $rescalescore[$key] /= $total; } return $rescalescore; } } ?>