Eksploracja Danych: Laboratorium 7

Piszemy własny klasyfikator Naive Bayes w języku Python.

import numpy as np
 
class NaiveBayesClassifier:
    def __init__(self):
        self.priors = None
        self.means = None
        self.variances = None
 
    def fit(self, X, y):
        """
        Fit the Naive Bayes classifier to the training data.
 
        Parameters:
            -- X: numpy array, shape (n_samples, n_features)
                Training data features
            -- y: numpy array, shape (n_samples,)
                Training data labels
        """
        n_samples, n_features = X.shape
        unique_classes = np.unique(y)
        n_classes = len(unique_classes)
 
        self.priors = np.zeros(n_classes)
        self.means = np.zeros((n_classes, n_features))
        self.variances = np.zeros((n_classes, n_features))
 
        for i, c in enumerate(unique_classes):
            X_c = X[y == c]
            self.priors[i] = len(X_c) / n_samples
            self.means[i] = np.mean(X_c, axis=0)
            self.variances[i] = np.var(X_c, axis=0)
 
    def predict(self, X):
        """
        Predict the class labels for new data.
 
        Parameters:
            -- X: numpy array, shape (n_samples, n_features)
                Test data features
 
        Returns:
            -- y_pred: numpy array, shape (n_samples,)
                Predicted class labels
        """
        n_samples = X.shape[0]
        n_classes, n_features = self.means.shape
 
        log_likelihoods = np.zeros((n_samples, n_classes))
 
        for i in range(n_classes):
            log_likelihoods[:, i] = -0.5 * np.sum(np.log(2 * np.pi * self.variances[i]))
            log_likelihoods[:, i] -= 0.5 * np.sum(((X - self.means[i]) ** 2) / self.variances[i], axis=1)
            log_likelihoods[:, i] += np.log(self.priors[i])
 
        y_pred = np.argmax(log_likelihoods, axis=1)
        return y_pred
med/lab_7ed.txt · Last modified: 2023/04/20 12:22 by pszwed
CC Attribution-Share Alike 4.0 International
Driven by DokuWiki Recent changes RSS feed Valid CSS Valid XHTML 1.0