PHPでNaive Bayesを使ってみる

今月号のWEB+DB PRESS

WEB+DB PRESS Vol.49

WEB+DB PRESS Vol.49


はてなブックマークのリニューアルに際しての特集記事があったり、レコメンドエンジンの解説記事があったりと非常に読み応えがあっていつもの3割増でおすすめ。
で、ブックマークのカテゴリ自動判定システムで使われているアルゴリズムはComplement Naive Bayesで、このアルゴリズムの元となっているアルゴリズムはNaive Bayes(単純ベイズ分類機)と呼ばれるもの。
Perlでは、記事でも紹介されている通り、Algorithm::NaiveBayesというライブラリがCPANにあるので利用するとアルゴリズムが比較的簡単に利用することができる。
このアルゴリズムを使ってみたいと思ったのだけど、あいにくPHPでは似た形で利用できるライブラリがすぐに見つからなかったので、突貫でこのPerlのライブラリを移植してみた。
Perl版だと、スコアを計算する方法を"frequency", "discrete", "gaussian"の3通りから選べたり、学習させた結果を保管できるのだけど、このたびのものは無し。
記事に記載のサンプルに倣って試してみる。
PHP実装もPerlのインターフェースに併せている。addInstanceメソッドの第一引数に学習対象となる文書の単語の出現数をarrayで与え、その文書が所属するカテゴリを第2引数に与える。
trainメソッドで学習を実行して、predictメソッドで分類を推定する文書中の単語とその出現数を与えると、カテゴリに所属する確率を推測してくれる。

<?php
$bayes = new NaiveBayes();
$bayes->addInstance(array("はてな" => 5, "京都" => 2), array("it"));
$bayes->addInstance(array("引っ越し" => 1, "" => 1), array("life"));
$bayes->train();
$resp = $bayes->predict(array("はてな" => 1, "引っ越し" => 1, "京都" => 1));

print_r($resp);
?>

上記のソースを実行すると、

Array
(
[it] => 0.825130233192
[life] => 0.564942561923
)

Perl版と同じ結果になるので大丈夫かな。。。と。
この場合だと、ITというカテゴリに所属する確率が高いと判定されたということになる。
ソースコードは以下参照。おそらく逐次直します。変なところもあるだろうし。きっと。

<?php
/**
 * Naivebayes.php 
 *
 * This package was ported from Perl's Algorithm::NaiveBayes (frequency model only)
 * http://search.cpan.org/~kwilliams/Algorithm-NaiveBayes-0.04/lib/Algorithm/NaiveBayes.pm
 * 
 * @category  Algorithm 
 * @package   Naivebayes
 * @author    hideack
 * @license   http://www.php.net/license/3_01.txt The PHP License, version 3.01
 * @version   0.1 
 */
class Naivebayes{

  	private $modeltype;
	private $instances;
	private $trainingdata;
	private $model;

	public function __construct(){
		$this->trainingdata = array(
			"attributes" => array(),
			"labels"     => array(),
		);
		$this->instances = 0;
		$this->modeltype = "";	// Perl版では切り替え可能
	}

	public function addInstance($attributes, $label){
		$this->instances++;

		foreach($attributes as $keyword => $count){
			if(isset($this->trainingdata['attributes'][$keyword])){
				$this->trainingdata['attributes'][$keyword] += $count;
			}
			else{
				$this->trainingdata['attributes'][$keyword] = $count;
			}
		}

		foreach($label as $labelword){
			if(isset($this->trainingdata['labels'][$labelword]['count'])){
				$this->trainingdata['labels'][$labelword]['count']++;
			}
			else{
				$this->trainingdata['labels'][$labelword]['count'] = 1;
			}
			
			foreach($attributes as $keyword => $count){
				if(isset($this->trainingdata[$keyword])){
					$this->trainingdata['labels'][$labelword]['attributes'][$keyword] += $count;
				}
				else{
					$this->trainingdata['labels'][$labelword]['attributes'][$keyword] = $count;
				}
			}
		}
	}

	public function train(){
		$m = array();
		$labels = $this->trainingdata['labels'];

		$m['attributes'] = $this->trainingdata['attributes'];
		$vocab_size = count($m['attributes']);

		foreach($labels as $label => $info){

			$m['prior_probs'][$label] = log($info['count'] / $this->instances);
			
			$label_tokens = 0;
			foreach($info['attributes'] as $word => $count){
				$label_tokens += $count;
			}
			
			$m['smoother'][$label] = -log($label_tokens + $vocab_size);
			$denominator = log($label_tokens + $vocab_size);

			foreach($info['attributes'] as $attribute => $count){
				$m['probs'][$label][$attribute] = log($count + 1) - $denominator;
			}
		}

		$this->model = $m;
	}

	public function predict($newattrs){
		$scores = $this->model['prior_probs'];

		foreach($newattrs as $feature => $value){
			foreach($this->model['probs'] as $label => $attribute){
				$tmpscore = 0.0;
				
				if($attribute[$feature] == 0.0){
					$tmpscore = $this->model['smoother'][$label];
				}
				else{
					$tmpscore = $attribute[$feature];
				}

				$scores[$label] += $tmpscore * $value;
			}
		}

		$scores = $this->rescale($scores);

		return $scores;
	}

	public function labels(){
		$labels = array();
		
		foreach($this->trainingdata['labels'] as $label => $value){
			$labels[] = $label;
		}

		return $labels; 
	}

	public function doPurge(){
           // 未実装...
	}

	private function rescale($scores){
		$total = 0;
		$max  = max($scores);
		$rescalescore = $scores;

		foreach($rescalescore as $key => $val){
			$val = exp($val - $max);
			$total += pow($val, 2);
			
			$rescalescore[$key] = $val;
		}

		$total = sqrt($total);

		foreach($rescalescore as $key => $val){
			$rescalescore[$key] /= $total; 
		}

		return $rescalescore;
	}
}
?>