渚  直樹

渚 直樹

1641468732

Rustを使用して単純ベイズ分類器を実装する

私はRustスキルを向上させ、あなたもあなたのスキルを磨く手助けをしたいと思っています。そこで、Rustプログラミング言語に関する一連の記事を書くことにしました。

Rustを使って実際にものを構築することで、その過程でのさまざまな技術的概念について学びます。この記事では、Rustを使用して単純ベイズ分類器を実装する方法を学習します。

この記事では、なじみのない用語や概念に遭遇する可能性があります。落胆しないでください。時間があればこれらを調べてください。ただし、それでも、この記事の主なアイデアが失われることはありません。

単純ベイズ分類器とは何ですか?

単純ベイズ分類器は、ベイズの定理に基づく機械学習アルゴリズムです。ベイズの定理は、いくつかのデータDが与えられた場合に、仮説Hの確率を更新する方法を提供します。

数学的に表現すると、次のようになります。

ここで、P(H | D)= Dが与えられたHの確率。

我々は、AC場合cは、より多くのデータをumulate、我々は更新することができる| P(D H)に応じています。

単純ベイズモデルは、データポイントがデータセットに存在するか存在しないかは、そのセット(ソース)にすでに存在するデータから独立しているという大きな仮定に基づいています。つまり、各データは他のデータポイントに関する情報を伝達しません。

この仮定が正しいとは期待していません–それは弱いです。しかし、それでも有用であり、非常にうまく機能する効率的な分類子を作成できます(ソース)。

ナイーブベイズの説明はそこに残しておきます。もっと多くのことが言えますが、この記事の要点はRustを練習することです。

アルゴリズムについて詳しく知りたい場合は、次のリソースをご覧ください。

単純ベイズ分類器の標準的なアプリケーションは、スパム分類器です。それが私たちが構築するものです。ここですべてのコードを見つけることができます:https//github.com/josht-jpg/shaking-off-the-rust

まず、Cargoを使用して新しいライブラリを作成します。

cargo new naive_bayes --lib
cd naive_bayes

それでは、詳しく見ていきましょう。

Rustでのトークン化

私たちの分類器は、入力としてメッセージを受け取り、スパムまたは非スパムの分類を返します。

与えられたメッセージを処理するために、それをトークン化する必要があります。トークン化された表現は、順序と繰り返しのエントリが無視される小文字の単語のセットになります。Rustのstd::collections::HashSet構造は、これを実現するための優れた方法です。

プリフォームトークン化に書き込む関数では、正規表現クレートを使用する必要があります。次の依存関係をCargo.tomlファイルに含めるようにしてください。

[dependencies]
regex = "^1.5.4"

そして、これがtokenize機能です:

// lib.rs

// We'll need HashMap later
use std::collections::{HashMap, HashSet};

extern crate regex;
use regex::Regex;

pub fn tokenize(lower_case_text: &str) -> HashSet<&str> {
    Regex::new(r"[a-z0-9']+")
        .unwrap()
        .find_iter(lower_case_text)
        .map(|mat| mat.as_str())
        .collect()
}

この関数は、正規表現を使用して、すべての数値と小文字を照合します。異なるタイプの記号(多くの場合、空白または句読点)に遭遇するたびに、入力を分割し、最後の分割以降に検出されたすべての数字と文字をグループ化します(正規表現の詳細については、Rustを参照してください)。つまり、入力テキスト内の単語を識別して分離しています。

いくつかの便利な構造

structメッセージを表すためにを使用すると便利です。これにstructは、メッセージのテキストの文字列スライスと、メッセージがスパムであるかどうかを示すブール値が含まれます。

pub struct Message<'a> {
    pub text: &'a str,
    pub is_spam: bool,
}

'a寿命パラメータ注釈です。生涯に慣れておらず、生涯について知りたい場合は、Rustプログラミング言語の本のセクション10.3を読むことをお勧めします。

Astructは、分類子を表すのにも役立ちます。を作成する前にstruct、ラプラシアン平滑化について少し説明する必要があります。

ラプラススムージングとは何ですか?

トレーニングデータでは、fubarという単語が一部の非スパムメッセージに表示されているが、どのスパムメッセージにも表示されていないと仮定します。次に、単純ベイズ分類器は、fubarソース)という単語を含むすべてのメッセージにスパムの確率0を割り当てます。

オンラインデートでの私の成功について話しているのでない限り、イベントがまだ発生していないという理由だけで、イベントに確率0を割り当てるのは賢明ではありません。

LaplaceSmoothingと入力します。これは追加のテクニックです

各トークン(ソース)の観測数に。これを数学的に見てみましょう。ラプラススムージングがないと、スパムメッセージに単語wが表示される確率は次のようになります。

ラプラススムージングでは、次のようになります。

分類器に戻るstruct

pub struct NaiveBayesClassifier {
    pub alpha: f64,
    pub tokens: HashSet<String>,
    pub token_ham_counts: HashMap<String, i32>,
    pub token_spam_counts: HashMap<String, i32>,
    pub spam_messages_count: i32,
    pub ham_messages_count: i32,
}

の実装ブロックはNaiveBayesClassifiertrainメソッドとメソッドを中心にしていpredictます。

分類器をトレーニングする方法

このtrainメソッドは、Messagesのスライスを取り込んで、それぞれをループしMessage、次のようにします。

  • メッセージがスパムであるかどうかを確認し、spam_messages_countそれにham_messages_count応じて更新します。このためのヘルパー関数increment_message_classifications_countを作成します。
  • メッセージの内容をtokenize関数でトークン化します。
  • メッセージ内の各トークンをループし、次のことを行います。
  • トークンをに挿入してtokens HashSetから、token_spam_countsまたはを更新しtoken_ham_countsます。このためのヘルパー関数increment_token_countを作成します。

これが私たちのtrainメソッドの擬似コードです。以下の私の実装を見る前に、疑似コードをRustに変換してみてください。私にあなたの実装を送ることを躊躇しないでください、私はそれを見たいです!

implementation block for NaiveBayesClassifier {

	train(self, messages) {
		for each message in messages {
			self.increment_message_classifications_count(message)
			
			lowercase_text = to_lowercase(message.text)
			for each token in tokenize(lowercase_text) {
				self.tokens.insert(tokens)
				self.increment_token_count(token, message.is_spam)
			}			
		}
	}

	increment_message_classifications_count(self, message) {
		if message.is_spam {
			self.spam_messages_count = self.spam_messages_count + 1
		} else {
			self.ham_messages_count = self.ham_messages_count + 1
		}
	}

	increment_token_count(&mut self, token, is_spam) {
		if token is not a key of self.token_spam_counts {
			insert record with key=token and value=0 into self.token_spam_counts
		}

		if token is not a key of self.token_ham_counts {
			insert record with key=token and value=0 into self.token_ham_counts
		}

		if is_spam {
			self.token_spam_counts[token] = self.token_spam_counts[token] + 1
		} else {
			self.token_ham_counts[token] = self.token_ham_counts[token] + 1
		}
	}

}

そして、これがRustの実装です。

impl NaiveBayesClassifier {
    pub fn train(&mut self, messages: &[Message]) {
        for message in messages.iter() {
            self.increment_message_classifications_count(message);
            for token in tokenize(&message.text.to_lowercase()) {
                self.tokens.insert(token.to_string());
                self.increment_token_count(token, message.is_spam)
            }
        }
    }

    fn increment_message_classifications_count(&mut self, message: &Message) {
        if message.is_spam {
            self.spam_messages_count += 1;
        } else {
            self.ham_messages_count += 1;
        }
    }

    fn increment_token_count(&mut self, token: &str, is_spam: bool) {
        if !self.token_spam_counts.contains_key(token) {
            self.token_spam_counts.insert(token.to_string(), 0);
        }

        if !self.token_ham_counts.contains_key(token) {
            self.token_ham_counts.insert(token.to_string(), 0);
        }

        if is_spam {
            self.increment_spam_count(token);
        } else {
            self.increment_ham_count(token);
        }
    }

    fn increment_spam_count(&mut self, token: &str) {
        *self.token_spam_counts.get_mut(token).unwrap() += 1;
    }

    fn increment_ham_count(&mut self, token: &str) {
        *self.token_ham_counts.get_mut(token).unwrap() += 1;
    }
}

aの値をインクリメントするのHashMapはかなり面倒であることに注意してください。初心者のRustプログラマーは、何を理解するのが難しいでしょう。

*self.token_spam_counts.get_mut(token).unwrap() += 1

やっています。

コードをより明確にするために、関数increment_spam_countincrement_ham_count関数を作成しました。しかし、私はそれに満足していません–それでも面倒な感じがします。より良いアプローチの提案があれば、私に連絡してください。

分類器で予測する方法

このpredictメソッドは文字列スライスを取得し、モデルで計算されたスパムの確率を返します。

2つのヘルパー関数probabilities_of_messageを作成probabilites_of_tokenし、の手間のかかる作業を行いますpredict

probabilities_of_messageP(メッセージ|スパム)P(メッセージ|ハム)を返します。

probabilities_of_tokenP(Token | Spam)P(Token | ham)を返します。

入力メッセージがスパムである確率を計算するには、スパムメッセージで発生する各単語の確率を乗算する必要があります。

確率は0から1までの浮動小数点数であるため、多くの確率を掛け合わせるとアンダーフローソース)が発生する可能性があります。これは、操作の結果、コンピューターが正確に格納できる数よりも少ない数になる場合です(ここここを参照)。したがって、対数と指数を使用して、タスクを一連の加算に変換します。

実数abの場合これを行うことができます

もう一度、predictメソッドの擬似コードから始めます。

implementation block for NaiveBayesCalssifier {
	/*...*/

	predict(self, text) {
		lower_case_text = to_lowercase(text)
		message_tokens = tokenize(text)
		(prob_if_spam, prob_if_ham) = self.probabilities_of_message(message_tokens)
		return prob_if_spam / (prob_if_spam + prob_if_ham)
	}
	
	probabilities_of_message(self, message_tokens) {
		log_prob_if_spam = 0
		log_prob_if_ham = 0

		for each token in self.tokens {
			(prob_if_spam, prob_if_ham) = self.probabilites_of_token(token)

			if message_tokens contains token {
				log_prob_if_spam = log_prob_if_spam + ln(prob_if_spam)
				log_prob_if_ham = log_prob_if_ham + ln(prob_if_ham)
			} else {
				log_prob_if_spam = log_prob_if_spam + ln(1 - prob_if_spam)
				log_prob_if_ham = log_prob_if_ham + ln(1 - prob_if_ham)
			}
		}

		prob_if_spam = exp(log_prob_if_spam)
		prob_if_ham = exp(log_prob_if_ham)

		return (prob_if_spam, prob_if_ham)
	}

	probabilites_of_token(self, token) {
		prob_of_token_spam = (self.token_spam_counts[token] + self.alpha) 
						/ (self.spam_messages_count + 2 * self.alpha)
        
		prob_of_token_ham = (self.token_ham_counts[token] + self.alpha) 
						/ (self.ham_messages_count + 2 * self.alpha)

		return (prob_of_token_spam, prob_of_token_ham)
	}
	
	
}

そして、これがRustコードです。

impl NaiveBayesClassifier {

		/*...*/

	pub fn predict(&self, text: &str) -> f64 {
        let lower_case_text = text.to_lowercase();
        let message_tokens = tokenize(&lower_case_text);
        let (prob_if_spam, prob_if_ham) = self.probabilities_of_message(message_tokens);

        return prob_if_spam / (prob_if_spam + prob_if_ham);
    }

    fn probabilities_of_message(&self, message_tokens: HashSet<&str>) -> (f64, f64) {
        let mut log_prob_if_spam = 0.;
        let mut log_prob_if_ham = 0.;

        for token in self.tokens.iter() {
            let (prob_if_spam, prob_if_ham) = self.probabilites_of_token(&token);

            if message_tokens.contains(token.as_str()) {
                log_prob_if_spam += prob_if_spam.ln();
                log_prob_if_ham += prob_if_ham.ln();
            } else {
                log_prob_if_spam += (1. - prob_if_spam).ln();
                log_prob_if_ham += (1. - prob_if_ham).ln();
            }
        }

        let prob_if_spam = log_prob_if_spam.exp();
        let prob_if_ham = log_prob_if_ham.exp();

        return (prob_if_spam, prob_if_ham);
    }

    fn probabilites_of_token(&self, token: &str) -> (f64, f64) {
        let prob_of_token_spam = (self.token_spam_counts[token] as f64 + self.alpha)
            / (self.spam_messages_count as f64 + 2. * self.alpha);

        let prob_of_token_ham = (self.token_ham_counts[token] as f64 + self.alpha)
            / (self.ham_messages_count as f64 + 2. * self.alpha);

        return (prob_of_token_spam, prob_of_token_ham);
    }
}

分類器をテストする方法

モデルをテストしてみましょう。以下のテストでは、Naive Bayesを手動で実行し、モデルで同じ結果が得られることを確認します。

テストのロジックを確認する価値があると思うかもしれません。あるいは、コードをlib.rsファイルの最後に貼り付けて、コードが機能することを確認したい場合もあります。

// ...lib.rs

pub fn new_classifier(alpha: f64) -> NaiveBayesClassifier {
    return NaiveBayesClassifier {
        alpha,
        tokens: HashSet::new(),
        token_ham_counts: HashMap::new(),
        token_spam_counts: HashMap::new(),
        spam_messages_count: 0,
        ham_messages_count: 0,
    };
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn naive_bayes() {
        let train_messages = [
            Message {
                text: "Free Bitcoin viagra XXX christmas deals 😻😻😻",
                is_spam: true,
            },
            Message {
                text: "My dear Granddaughter, please explain Bitcoin over Christmas dinner",
                is_spam: false,
            },
            Message {
                text: "Here in my garage...",
                is_spam: true,
            },
        ];

        let alpha = 1.;
        let num_spam_messages = 2.;
        let num_ham_messages = 1.;

        let mut model = new_classifier(alpha);
        model.train(&train_messages);

        let mut expected_tokens: HashSet<String> = HashSet::new();
        for message in train_messages.iter() {
            for token in tokenize(&message.text.to_lowercase()) {
                expected_tokens.insert(token.to_string());
            }
        }

        let input_text = "Bitcoin crypto academy Christmas deals";

        let probs_if_spam = [
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "Free"  (not present)
            (1. + alpha) / (num_spam_messages + 2. * alpha),      // "Bitcoin"  (present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "viagra"  (not present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "XXX"  (not present)
            (1. + alpha) / (num_spam_messages + 2. * alpha),      // "christmas"  (present)
            (1. + alpha) / (num_spam_messages + 2. * alpha),      // "deals"  (present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "my"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "dear"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "granddaughter"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "please"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "explain"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "over"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "dinner"  (not present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "here"  (not present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "in"  (not present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "garage"  (not present)
        ];

        let probs_if_ham = [
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "Free"  (not present)
            (1. + alpha) / (num_ham_messages + 2. * alpha),      // "Bitcoin"  (present)
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "viagra"  (not present)
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "XXX"  (not present)
            (1. + alpha) / (num_ham_messages + 2. * alpha),      // "christmas"  (present)
            (0. + alpha) / (num_ham_messages + 2. * alpha),      // "deals"  (present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "my"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "dear"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "granddaughter"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "please"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "explain"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "over"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "dinner"  (not present)
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "here"  (not present)
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "in"  (not present)
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "garage"  (not present)
        ];

        let p_if_spam_log: f64 = probs_if_spam.iter().map(|p| p.ln()).sum();
        let p_if_spam = p_if_spam_log.exp();

        let p_if_ham_log: f64 = probs_if_ham.iter().map(|p| p.ln()).sum();
        let p_if_ham = p_if_ham_log.exp();

        // P(message | spam) / (P(messge | spam) + P(message | ham)) rounds to 0.97
        assert!((model.predict(input_text) - p_if_spam / (p_if_spam + p_if_ham)).abs() < 0.000001);
    }
}

次に実行しcargo testます。それがうまくいけば、Rustに単純ベイズ分類器を実装したことになります。

友達と一緒にコーディングしてくれてありがとう。ご不明な点やご提案がございましたら、お気軽にお問い合わせください。 

リンク:https//www.freecodecamp.org/news/implement-naive-bayes-with-rust/

#rust 

What is GEEK

Buddha Community

Rustを使用して単純ベイズ分類器を実装する
渚  直樹

渚 直樹

1641468732

Rustを使用して単純ベイズ分類器を実装する

私はRustスキルを向上させ、あなたもあなたのスキルを磨く手助けをしたいと思っています。そこで、Rustプログラミング言語に関する一連の記事を書くことにしました。

Rustを使って実際にものを構築することで、その過程でのさまざまな技術的概念について学びます。この記事では、Rustを使用して単純ベイズ分類器を実装する方法を学習します。

この記事では、なじみのない用語や概念に遭遇する可能性があります。落胆しないでください。時間があればこれらを調べてください。ただし、それでも、この記事の主なアイデアが失われることはありません。

単純ベイズ分類器とは何ですか?

単純ベイズ分類器は、ベイズの定理に基づく機械学習アルゴリズムです。ベイズの定理は、いくつかのデータDが与えられた場合に、仮説Hの確率を更新する方法を提供します。

数学的に表現すると、次のようになります。

ここで、P(H | D)= Dが与えられたHの確率。

我々は、AC場合cは、より多くのデータをumulate、我々は更新することができる| P(D H)に応じています。

単純ベイズモデルは、データポイントがデータセットに存在するか存在しないかは、そのセット(ソース)にすでに存在するデータから独立しているという大きな仮定に基づいています。つまり、各データは他のデータポイントに関する情報を伝達しません。

この仮定が正しいとは期待していません–それは弱いです。しかし、それでも有用であり、非常にうまく機能する効率的な分類子を作成できます(ソース)。

ナイーブベイズの説明はそこに残しておきます。もっと多くのことが言えますが、この記事の要点はRustを練習することです。

アルゴリズムについて詳しく知りたい場合は、次のリソースをご覧ください。

単純ベイズ分類器の標準的なアプリケーションは、スパム分類器です。それが私たちが構築するものです。ここですべてのコードを見つけることができます:https//github.com/josht-jpg/shaking-off-the-rust

まず、Cargoを使用して新しいライブラリを作成します。

cargo new naive_bayes --lib
cd naive_bayes

それでは、詳しく見ていきましょう。

Rustでのトークン化

私たちの分類器は、入力としてメッセージを受け取り、スパムまたは非スパムの分類を返します。

与えられたメッセージを処理するために、それをトークン化する必要があります。トークン化された表現は、順序と繰り返しのエントリが無視される小文字の単語のセットになります。Rustのstd::collections::HashSet構造は、これを実現するための優れた方法です。

プリフォームトークン化に書き込む関数では、正規表現クレートを使用する必要があります。次の依存関係をCargo.tomlファイルに含めるようにしてください。

[dependencies]
regex = "^1.5.4"

そして、これがtokenize機能です:

// lib.rs

// We'll need HashMap later
use std::collections::{HashMap, HashSet};

extern crate regex;
use regex::Regex;

pub fn tokenize(lower_case_text: &str) -> HashSet<&str> {
    Regex::new(r"[a-z0-9']+")
        .unwrap()
        .find_iter(lower_case_text)
        .map(|mat| mat.as_str())
        .collect()
}

この関数は、正規表現を使用して、すべての数値と小文字を照合します。異なるタイプの記号(多くの場合、空白または句読点)に遭遇するたびに、入力を分割し、最後の分割以降に検出されたすべての数字と文字をグループ化します(正規表現の詳細については、Rustを参照してください)。つまり、入力テキスト内の単語を識別して分離しています。

いくつかの便利な構造

structメッセージを表すためにを使用すると便利です。これにstructは、メッセージのテキストの文字列スライスと、メッセージがスパムであるかどうかを示すブール値が含まれます。

pub struct Message<'a> {
    pub text: &'a str,
    pub is_spam: bool,
}

'a寿命パラメータ注釈です。生涯に慣れておらず、生涯について知りたい場合は、Rustプログラミング言語の本のセクション10.3を読むことをお勧めします。

Astructは、分類子を表すのにも役立ちます。を作成する前にstruct、ラプラシアン平滑化について少し説明する必要があります。

ラプラススムージングとは何ですか?

トレーニングデータでは、fubarという単語が一部の非スパムメッセージに表示されているが、どのスパムメッセージにも表示されていないと仮定します。次に、単純ベイズ分類器は、fubarソース)という単語を含むすべてのメッセージにスパムの確率0を割り当てます。

オンラインデートでの私の成功について話しているのでない限り、イベントがまだ発生していないという理由だけで、イベントに確率0を割り当てるのは賢明ではありません。

LaplaceSmoothingと入力します。これは追加のテクニックです

各トークン(ソース)の観測数に。これを数学的に見てみましょう。ラプラススムージングがないと、スパムメッセージに単語wが表示される確率は次のようになります。

ラプラススムージングでは、次のようになります。

分類器に戻るstruct

pub struct NaiveBayesClassifier {
    pub alpha: f64,
    pub tokens: HashSet<String>,
    pub token_ham_counts: HashMap<String, i32>,
    pub token_spam_counts: HashMap<String, i32>,
    pub spam_messages_count: i32,
    pub ham_messages_count: i32,
}

の実装ブロックはNaiveBayesClassifiertrainメソッドとメソッドを中心にしていpredictます。

分類器をトレーニングする方法

このtrainメソッドは、Messagesのスライスを取り込んで、それぞれをループしMessage、次のようにします。

  • メッセージがスパムであるかどうかを確認し、spam_messages_countそれにham_messages_count応じて更新します。このためのヘルパー関数increment_message_classifications_countを作成します。
  • メッセージの内容をtokenize関数でトークン化します。
  • メッセージ内の各トークンをループし、次のことを行います。
  • トークンをに挿入してtokens HashSetから、token_spam_countsまたはを更新しtoken_ham_countsます。このためのヘルパー関数increment_token_countを作成します。

これが私たちのtrainメソッドの擬似コードです。以下の私の実装を見る前に、疑似コードをRustに変換してみてください。私にあなたの実装を送ることを躊躇しないでください、私はそれを見たいです!

implementation block for NaiveBayesClassifier {

	train(self, messages) {
		for each message in messages {
			self.increment_message_classifications_count(message)
			
			lowercase_text = to_lowercase(message.text)
			for each token in tokenize(lowercase_text) {
				self.tokens.insert(tokens)
				self.increment_token_count(token, message.is_spam)
			}			
		}
	}

	increment_message_classifications_count(self, message) {
		if message.is_spam {
			self.spam_messages_count = self.spam_messages_count + 1
		} else {
			self.ham_messages_count = self.ham_messages_count + 1
		}
	}

	increment_token_count(&mut self, token, is_spam) {
		if token is not a key of self.token_spam_counts {
			insert record with key=token and value=0 into self.token_spam_counts
		}

		if token is not a key of self.token_ham_counts {
			insert record with key=token and value=0 into self.token_ham_counts
		}

		if is_spam {
			self.token_spam_counts[token] = self.token_spam_counts[token] + 1
		} else {
			self.token_ham_counts[token] = self.token_ham_counts[token] + 1
		}
	}

}

そして、これがRustの実装です。

impl NaiveBayesClassifier {
    pub fn train(&mut self, messages: &[Message]) {
        for message in messages.iter() {
            self.increment_message_classifications_count(message);
            for token in tokenize(&message.text.to_lowercase()) {
                self.tokens.insert(token.to_string());
                self.increment_token_count(token, message.is_spam)
            }
        }
    }

    fn increment_message_classifications_count(&mut self, message: &Message) {
        if message.is_spam {
            self.spam_messages_count += 1;
        } else {
            self.ham_messages_count += 1;
        }
    }

    fn increment_token_count(&mut self, token: &str, is_spam: bool) {
        if !self.token_spam_counts.contains_key(token) {
            self.token_spam_counts.insert(token.to_string(), 0);
        }

        if !self.token_ham_counts.contains_key(token) {
            self.token_ham_counts.insert(token.to_string(), 0);
        }

        if is_spam {
            self.increment_spam_count(token);
        } else {
            self.increment_ham_count(token);
        }
    }

    fn increment_spam_count(&mut self, token: &str) {
        *self.token_spam_counts.get_mut(token).unwrap() += 1;
    }

    fn increment_ham_count(&mut self, token: &str) {
        *self.token_ham_counts.get_mut(token).unwrap() += 1;
    }
}

aの値をインクリメントするのHashMapはかなり面倒であることに注意してください。初心者のRustプログラマーは、何を理解するのが難しいでしょう。

*self.token_spam_counts.get_mut(token).unwrap() += 1

やっています。

コードをより明確にするために、関数increment_spam_countincrement_ham_count関数を作成しました。しかし、私はそれに満足していません–それでも面倒な感じがします。より良いアプローチの提案があれば、私に連絡してください。

分類器で予測する方法

このpredictメソッドは文字列スライスを取得し、モデルで計算されたスパムの確率を返します。

2つのヘルパー関数probabilities_of_messageを作成probabilites_of_tokenし、の手間のかかる作業を行いますpredict

probabilities_of_messageP(メッセージ|スパム)P(メッセージ|ハム)を返します。

probabilities_of_tokenP(Token | Spam)P(Token | ham)を返します。

入力メッセージがスパムである確率を計算するには、スパムメッセージで発生する各単語の確率を乗算する必要があります。

確率は0から1までの浮動小数点数であるため、多くの確率を掛け合わせるとアンダーフローソース)が発生する可能性があります。これは、操作の結果、コンピューターが正確に格納できる数よりも少ない数になる場合です(ここここを参照)。したがって、対数と指数を使用して、タスクを一連の加算に変換します。

実数abの場合これを行うことができます

もう一度、predictメソッドの擬似コードから始めます。

implementation block for NaiveBayesCalssifier {
	/*...*/

	predict(self, text) {
		lower_case_text = to_lowercase(text)
		message_tokens = tokenize(text)
		(prob_if_spam, prob_if_ham) = self.probabilities_of_message(message_tokens)
		return prob_if_spam / (prob_if_spam + prob_if_ham)
	}
	
	probabilities_of_message(self, message_tokens) {
		log_prob_if_spam = 0
		log_prob_if_ham = 0

		for each token in self.tokens {
			(prob_if_spam, prob_if_ham) = self.probabilites_of_token(token)

			if message_tokens contains token {
				log_prob_if_spam = log_prob_if_spam + ln(prob_if_spam)
				log_prob_if_ham = log_prob_if_ham + ln(prob_if_ham)
			} else {
				log_prob_if_spam = log_prob_if_spam + ln(1 - prob_if_spam)
				log_prob_if_ham = log_prob_if_ham + ln(1 - prob_if_ham)
			}
		}

		prob_if_spam = exp(log_prob_if_spam)
		prob_if_ham = exp(log_prob_if_ham)

		return (prob_if_spam, prob_if_ham)
	}

	probabilites_of_token(self, token) {
		prob_of_token_spam = (self.token_spam_counts[token] + self.alpha) 
						/ (self.spam_messages_count + 2 * self.alpha)
        
		prob_of_token_ham = (self.token_ham_counts[token] + self.alpha) 
						/ (self.ham_messages_count + 2 * self.alpha)

		return (prob_of_token_spam, prob_of_token_ham)
	}
	
	
}

そして、これがRustコードです。

impl NaiveBayesClassifier {

		/*...*/

	pub fn predict(&self, text: &str) -> f64 {
        let lower_case_text = text.to_lowercase();
        let message_tokens = tokenize(&lower_case_text);
        let (prob_if_spam, prob_if_ham) = self.probabilities_of_message(message_tokens);

        return prob_if_spam / (prob_if_spam + prob_if_ham);
    }

    fn probabilities_of_message(&self, message_tokens: HashSet<&str>) -> (f64, f64) {
        let mut log_prob_if_spam = 0.;
        let mut log_prob_if_ham = 0.;

        for token in self.tokens.iter() {
            let (prob_if_spam, prob_if_ham) = self.probabilites_of_token(&token);

            if message_tokens.contains(token.as_str()) {
                log_prob_if_spam += prob_if_spam.ln();
                log_prob_if_ham += prob_if_ham.ln();
            } else {
                log_prob_if_spam += (1. - prob_if_spam).ln();
                log_prob_if_ham += (1. - prob_if_ham).ln();
            }
        }

        let prob_if_spam = log_prob_if_spam.exp();
        let prob_if_ham = log_prob_if_ham.exp();

        return (prob_if_spam, prob_if_ham);
    }

    fn probabilites_of_token(&self, token: &str) -> (f64, f64) {
        let prob_of_token_spam = (self.token_spam_counts[token] as f64 + self.alpha)
            / (self.spam_messages_count as f64 + 2. * self.alpha);

        let prob_of_token_ham = (self.token_ham_counts[token] as f64 + self.alpha)
            / (self.ham_messages_count as f64 + 2. * self.alpha);

        return (prob_of_token_spam, prob_of_token_ham);
    }
}

分類器をテストする方法

モデルをテストしてみましょう。以下のテストでは、Naive Bayesを手動で実行し、モデルで同じ結果が得られることを確認します。

テストのロジックを確認する価値があると思うかもしれません。あるいは、コードをlib.rsファイルの最後に貼り付けて、コードが機能することを確認したい場合もあります。

// ...lib.rs

pub fn new_classifier(alpha: f64) -> NaiveBayesClassifier {
    return NaiveBayesClassifier {
        alpha,
        tokens: HashSet::new(),
        token_ham_counts: HashMap::new(),
        token_spam_counts: HashMap::new(),
        spam_messages_count: 0,
        ham_messages_count: 0,
    };
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn naive_bayes() {
        let train_messages = [
            Message {
                text: "Free Bitcoin viagra XXX christmas deals 😻😻😻",
                is_spam: true,
            },
            Message {
                text: "My dear Granddaughter, please explain Bitcoin over Christmas dinner",
                is_spam: false,
            },
            Message {
                text: "Here in my garage...",
                is_spam: true,
            },
        ];

        let alpha = 1.;
        let num_spam_messages = 2.;
        let num_ham_messages = 1.;

        let mut model = new_classifier(alpha);
        model.train(&train_messages);

        let mut expected_tokens: HashSet<String> = HashSet::new();
        for message in train_messages.iter() {
            for token in tokenize(&message.text.to_lowercase()) {
                expected_tokens.insert(token.to_string());
            }
        }

        let input_text = "Bitcoin crypto academy Christmas deals";

        let probs_if_spam = [
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "Free"  (not present)
            (1. + alpha) / (num_spam_messages + 2. * alpha),      // "Bitcoin"  (present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "viagra"  (not present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "XXX"  (not present)
            (1. + alpha) / (num_spam_messages + 2. * alpha),      // "christmas"  (present)
            (1. + alpha) / (num_spam_messages + 2. * alpha),      // "deals"  (present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "my"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "dear"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "granddaughter"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "please"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "explain"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "over"  (not present)
            1. - (0. + alpha) / (num_spam_messages + 2. * alpha), // "dinner"  (not present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "here"  (not present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "in"  (not present)
            1. - (1. + alpha) / (num_spam_messages + 2. * alpha), // "garage"  (not present)
        ];

        let probs_if_ham = [
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "Free"  (not present)
            (1. + alpha) / (num_ham_messages + 2. * alpha),      // "Bitcoin"  (present)
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "viagra"  (not present)
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "XXX"  (not present)
            (1. + alpha) / (num_ham_messages + 2. * alpha),      // "christmas"  (present)
            (0. + alpha) / (num_ham_messages + 2. * alpha),      // "deals"  (present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "my"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "dear"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "granddaughter"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "please"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "explain"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "over"  (not present)
            1. - (1. + alpha) / (num_ham_messages + 2. * alpha), // "dinner"  (not present)
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "here"  (not present)
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "in"  (not present)
            1. - (0. + alpha) / (num_ham_messages + 2. * alpha), // "garage"  (not present)
        ];

        let p_if_spam_log: f64 = probs_if_spam.iter().map(|p| p.ln()).sum();
        let p_if_spam = p_if_spam_log.exp();

        let p_if_ham_log: f64 = probs_if_ham.iter().map(|p| p.ln()).sum();
        let p_if_ham = p_if_ham_log.exp();

        // P(message | spam) / (P(messge | spam) + P(message | ham)) rounds to 0.97
        assert!((model.predict(input_text) - p_if_spam / (p_if_spam + p_if_ham)).abs() < 0.000001);
    }
}

次に実行しcargo testます。それがうまくいけば、Rustに単純ベイズ分類器を実装したことになります。

友達と一緒にコーディングしてくれてありがとう。ご不明な点やご提案がございましたら、お気軽にお問い合わせください。 

リンク:https//www.freecodecamp.org/news/implement-naive-bayes-with-rust/

#rust