verticapy.machine_learning.vertica.naive_bayes.NaiveBayes#
- class verticapy.machine_learning.vertica.naive_bayes.NaiveBayes(name: str = None, overwrite_model: bool = False, alpha: int | float | Decimal = 1.0, nbtype: Literal['auto', 'bernoulli', 'categorical', 'multinomial', 'gaussian'] = 'auto')#
Creates a
NaiveBayes
object using the Vertica Naive Bayes algorithm. It is a “probabilistic classifier” based on applying Bayes’ theorem with strong (naïve) independence assumptions between the features.Parameters#
- name: str, optional
Name of the model. The model is stored in the database.
- overwrite_model: bool, optional
If set to
True
, training a model with the same name as an existing model overwrites the existing model.- alpha: float, optional
A
float
that specifies use of Laplace smoothing if the event model is categorical, multinomial, or Bernoulli.- nbtype: str, optional
Naive Bayes type.
- auto:
Vertica NaiveBayes objects treat columns according to data type:
- FLOAT:
values are assumed to follow some Gaussian distribution.
- INTEGER:
values are assumed to belong to one multinomial distribution.
- CHAR/VARCHAR:
values are assumed to follow some categorical distribution. The string values stored in these columns must be no greater than 128 characters.
- BOOLEAN:
values are treated as categorical with two values.
- bernoulli:
Casts the variables to boolean.
- categorical:
Casts the variables to categorical.
- multinomial:
Casts the variables to integer.
- gaussian:
Casts the variables to float.
Attributes#
Many attributes are created during the fitting phase.
- prior_: numpy.array
The model’s classes probabilities.
- attributes: list of dict
list
of the model’s attributes. Each feature is represented by adictionary
, which differs based on the distribution.- classes_: numpy.array
The classes labels.
Note
All attributes can be accessed using the
get_attributes()
method.Note
Several other attributes can be accessed by using the
get_vertica_attributes()
method.Examples#
The following examples provide a basic understanding of usage. For more detailed examples, please refer to the Machine Learning or the Examples section on the website.
Load data for machine learning#
We import
verticapy
:import verticapy as vp
Hint
By assigning an alias to
verticapy
, we mitigate the risk of code collisions with other libraries. This precaution is necessary because verticapy uses commonly known function names like “average” and “median”, which can potentially lead to naming conflicts. The use of an alias ensures that the functions fromverticapy
are used as intended without interfering with functions from other libraries.For this example, we will use the iris dataset.
import verticapy.datasets as vpd data = vpd.load_iris()
123SepalLengthCmNumeric(7)123SepalWidthCmNumeric(7)123PetalLengthCmNumeric(7)123PetalWidthCmNumeric(7)AbcSpeciesVarchar(30)1 3.3 4.5 5.6 7.8 Iris-setosa 2 3.3 4.5 5.6 7.8 Iris-setosa 3 3.3 4.5 5.6 7.8 Iris-setosa 4 3.3 4.5 5.6 7.8 Iris-setosa 5 3.3 4.5 5.6 7.8 Iris-setosa 6 3.3 4.5 5.6 7.8 Iris-setosa 7 3.3 4.5 5.6 7.8 Iris-setosa 8 3.3 4.5 5.6 7.8 Iris-setosa 9 3.3 4.5 5.6 7.8 Iris-setosa 10 3.3 4.5 5.6 7.8 Iris-setosa 11 3.3 4.5 5.6 7.8 Iris-setosa 12 3.3 4.5 5.6 7.8 Iris-setosa 13 3.3 4.5 5.6 7.8 Iris-setosa 14 3.3 4.5 5.6 7.8 Iris-setosa 15 3.3 4.5 5.6 7.8 Iris-setosa 16 3.3 4.5 5.6 7.8 Iris-setosa 17 3.3 4.5 5.6 7.8 Iris-setosa 18 3.3 4.5 5.6 7.8 Iris-setosa 19 3.3 4.5 5.6 7.8 Iris-setosa 20 3.3 4.5 5.6 7.8 Iris-setosa 21 3.3 4.5 5.6 7.8 Iris-setosa 22 3.3 4.5 5.6 7.8 Iris-setosa 23 3.3 4.5 5.6 7.8 Iris-setosa 24 3.3 4.5 5.6 7.8 Iris-setosa 25 3.3 4.5 5.6 7.8 Iris-setosa 26 3.3 4.5 5.6 7.8 Iris-setosa 27 3.3 4.5 5.6 7.8 Iris-setosa 28 3.3 4.5 5.6 7.8 Iris-setosa 29 3.3 4.5 5.6 7.8 Iris-setosa 30 3.3 4.5 5.6 7.8 Iris-setosa 31 3.3 4.5 5.6 7.8 Iris-setosa 32 3.3 4.5 5.6 7.8 Iris-setosa 33 3.3 4.5 5.6 7.8 Iris-setosa 34 3.3 4.5 5.6 7.8 Iris-setosa 35 3.3 4.5 5.6 7.8 Iris-setosa 36 3.3 4.5 5.6 7.8 Iris-setosa 37 3.3 4.5 5.6 7.8 Iris-setosa 38 3.3 4.5 5.6 7.8 Iris-setosa 39 3.3 4.5 5.6 7.8 Iris-setosa 40 3.3 4.5 5.6 7.8 Iris-setosa 41 3.3 4.5 5.6 7.8 Iris-setosa 42 3.3 4.5 5.6 7.8 Iris-setosa 43 4.3 3.0 1.1 0.1 Iris-setosa 44 4.3 4.7 9.6 1.8 Iris-virginica 45 4.3 4.7 9.6 1.8 Iris-virginica 46 4.3 4.7 9.6 1.8 Iris-virginica 47 4.3 4.7 9.6 1.8 Iris-virginica 48 4.3 4.7 9.6 1.8 Iris-virginica 49 4.3 4.7 9.6 1.8 Iris-virginica 50 4.3 4.7 9.6 1.8 Iris-virginica 51 4.3 4.7 9.6 1.8 Iris-virginica 52 4.3 4.7 9.6 1.8 Iris-virginica 53 4.3 4.7 9.6 1.8 Iris-virginica 54 4.3 4.7 9.6 1.8 Iris-virginica 55 4.3 4.7 9.6 1.8 Iris-virginica 56 4.3 4.7 9.6 1.8 Iris-virginica 57 4.3 4.7 9.6 1.8 Iris-virginica 58 4.3 4.7 9.6 1.8 Iris-virginica 59 4.3 4.7 9.6 1.8 Iris-virginica 60 4.3 4.7 9.6 1.8 Iris-virginica 61 4.3 4.7 9.6 1.8 Iris-virginica 62 4.3 4.7 9.6 1.8 Iris-virginica 63 4.3 4.7 9.6 1.8 Iris-virginica 64 4.3 4.7 9.6 1.8 Iris-virginica 65 4.3 4.7 9.6 1.8 Iris-virginica 66 4.3 4.7 9.6 1.8 Iris-virginica 67 4.3 4.7 9.6 1.8 Iris-virginica 68 4.3 4.7 9.6 1.8 Iris-virginica 69 4.3 4.7 9.6 1.8 Iris-virginica 70 4.3 4.7 9.6 1.8 Iris-virginica 71 4.3 4.7 9.6 1.8 Iris-virginica 72 4.3 4.7 9.6 1.8 Iris-virginica 73 4.3 4.7 9.6 1.8 Iris-virginica 74 4.3 4.7 9.6 1.8 Iris-virginica 75 4.3 4.7 9.6 1.8 Iris-virginica 76 4.3 4.7 9.6 1.8 Iris-virginica 77 4.3 4.7 9.6 1.8 Iris-virginica 78 4.3 4.7 9.6 1.8 Iris-virginica 79 4.3 4.7 9.6 1.8 Iris-virginica 80 4.3 4.7 9.6 1.8 Iris-virginica 81 4.3 4.7 9.6 1.8 Iris-virginica 82 4.3 4.7 9.6 1.8 Iris-virginica 83 4.3 4.7 9.6 1.8 Iris-virginica 84 4.3 4.7 9.6 1.8 Iris-virginica 85 4.3 4.7 9.6 1.8 Iris-virginica 86 4.4 2.9 1.4 0.2 Iris-setosa 87 4.4 3.0 1.3 0.2 Iris-setosa 88 4.4 3.2 1.3 0.2 Iris-setosa 89 4.5 2.3 1.3 0.3 Iris-setosa 90 4.6 3.1 1.5 0.2 Iris-setosa 91 4.6 3.2 1.4 0.2 Iris-setosa 92 4.6 3.4 1.4 0.3 Iris-setosa 93 4.6 3.6 1.0 0.2 Iris-setosa 94 4.7 3.2 1.3 0.2 Iris-setosa 95 4.7 3.2 1.6 0.2 Iris-setosa 96 4.8 3.0 1.4 0.1 Iris-setosa 97 4.8 3.0 1.4 0.3 Iris-setosa 98 4.8 3.1 1.6 0.2 Iris-setosa 99 4.8 3.4 1.6 0.2 Iris-setosa 100 4.8 3.4 1.9 0.2 Iris-setosa Rows: 1-100 | Columns: 5Note
VerticaPy offers a wide range of sample datasets that are ideal for training and testing purposes. You can explore the full list of available datasets in the Datasets, which provides detailed information on each dataset and how to use them effectively. These datasets are invaluable resources for honing your data analysis and machine learning skills within the VerticaPy environment.
You can easily divide your dataset into training and testing subsets using the
vDataFrame.
train_test_split()
method. This is a crucial step when preparing your data for machine learning, as it allows you to evaluate the performance of your models accurately.data = vpd.load_iris() train, test = data.train_test_split(test_size = 0.2)
Warning
In this case, VerticaPy utilizes seeded randomization to guarantee the reproducibility of your data split. However, please be aware that this approach may lead to reduced performance. For a more efficient data split, you can use the
vDataFrame.
to_db()
method to save your results intotables
ortemporary tables
. This will help enhance the overall performance of the process.Balancing the Dataset#
In VerticaPy, balancing a dataset to address class imbalances is made straightforward through the
balance()
function within thepreprocessing
module. This function enables users to rectify skewed class distributions efficiently. By specifying the target variable and setting parameters like the method for balancing, users can effortlessly achieve a more equitable representation of classes in their dataset. Whether opting for over-sampling, under-sampling, or a combination of both, VerticaPy’sbalance()
function streamlines the process, empowering users to enhance the performance and fairness of their machine learning models trained on imbalanced data.To balance the dataset, use the following syntax.
from verticapy.machine_learning.vertica.preprocessing import balance balanced_train = balance( name = "my_schema.train_balanced", input_relation = train, y = "good", method = "hybrid", )
Note
With this code, a table named train_balanced is created in the my_schema schema. It can then be used to train the model. In the rest of the example, we will work with the full dataset.
Hint
Balancing the dataset is a crucial step in improving the accuracy of machine learning models, particularly when faced with imbalanced class distributions. By addressing disparities in the number of instances across different classes, the model becomes more adept at learning patterns from all classes rather than being biased towards the majority class. This, in turn, enhances the model’s ability to make accurate predictions for under-represented classes. The balanced dataset ensures that the model is not dominated by the majority class and, as a result, leads to more robust and unbiased model performance. Therefore, by employing techniques such as over-sampling, under-sampling, or a combination of both during dataset preparation, practitioners can significantly contribute to achieving higher accuracy and better generalization of their machine learning models.
Model Initialization#
First we import the
NaiveBayes
model:from verticapy.machine_learning.vertica import NaiveBayes
Then we can create the model:
model = NaiveBayes()
Hint
In
verticapy
1.0.x and higher, you do not need to specify the model name, as the name is automatically assigned. If you need to re-use the model, you can fetch the model name from the model’s attributes.Important
The model name is crucial for the model management system and versioning. It’s highly recommended to provide a name if you plan to reuse the model later.
Model Training#
We can now fit the model:
model.fit( train, [ "SepalLengthCm", "SepalWidthCm", "PetalLengthCm", "PetalWidthCm", ], "Species", test, )
Important
To train a model, you can directly use the
vDataFrame
or the name of the relation stored in the database. The test set is optional and is only used to compute the test metrics. Inverticapy
, we don’t work usingX
matrices andy
vectors. Instead, we work directly with lists of predictors and the response name.Metrics#
We can get the entire report using:
model.report()
Iris-setosa Iris-versicolor Iris-virginica avg_macro avg_weighted avg_micro auc 1.0 0.9806201550387597 0.9916666666666667 0.9907622739018088 0.9935453251067868 [null] prc_auc 1.0 0.899074074074074 0.9911651165863123 0.9634130635534621 0.9833144335207336 [null] accuracy 1.0 0.8979591836734694 0.8979591836734694 0.9319727891156463 0.9375260308204915 0.9319727891156463 log_loss 0.00258057060444779 0.0779784350604565 0.0808958638264282 0.053818289830444156 0.05017147313635763 [null] precision 1.0 0.5555555555555556 0.9523809523809523 0.8359788359788359 0.9222546161321672 0.8979591836734694 recall 1.0 0.8333333333333334 0.8333333333333334 0.888888888888889 0.8979591836734694 0.8979591836734694 f1_score 1.0 0.6666666666666667 0.888888888888889 0.851851851851852 0.9047619047619048 0.8979591836734694 mcc 1.0 0.6267181379708672 0.8013876853447538 0.8093686077718737 0.8570125158387611 0.8469387755102041 informedness 1.0 0.7403100775193798 0.7933333333333334 0.8445478036175711 0.8669767441860466 0.8469387755102042 markedness 1.0 0.5305555555555554 0.8095238095238093 0.780026455026455 0.8492225461613215 0.8469387755102042 csi 1.0 0.5 0.8 0.7666666666666666 0.8408163265306123 0.8148148148148148 Rows: 1-11 | Columns: 7Important
Most metrics are computed using a single SQL query, but some of them might require multiple SQL queries. Selecting only the necessary metrics in the report can help optimize performance. E.g.
model.report(metrics = ["auc", "accuracy"])
.For classification models, we can easily modify the
cutoff
to observe the effect on different metrics:model.report(cutoff = 0.2)
Iris-setosa Iris-versicolor Iris-virginica avg_macro avg_weighted avg_micro auc 1.0 0.9806201550387597 0.9916666666666667 0.9907622739018088 0.9935453251067868 [null] prc_auc 1.0 0.899074074074074 0.9911651165863123 0.9634130635534621 0.9833144335207336 [null] accuracy 1.0 0.9183673469387755 0.9387755102040817 0.9523809523809524 0.9600166597251146 0.9523809523809523 log_loss 0.00258057060444779 0.0779784350604565 0.0808958638264282 0.053818289830444156 0.05017147313635763 [null] precision 1.0 0.6 0.9565217391304348 0.8521739130434783 0.9297249334516416 0.9038461538461539 recall 1.0 1.0 0.9166666666666666 0.9722222222222222 0.9591836734693877 0.9591836734693877 f1_score 1.0 0.7499999999999999 0.9361702127659574 0.8953900709219859 0.9381241858445505 0.9306930693069307 mcc 1.0 0.737689668161096 0.8781314407318439 0.87194036963098 0.9081896446230782 0.8953885457014064 informedness 1.0 0.9069767441860463 0.8766666666666665 0.9278811369509042 0.9282012339819647 0.9081632653061225 markedness 1.0 0.6000000000000001 0.8795986622073579 0.8265328874024527 0.8920483243464611 0.8827935222672065 csi 1.0 0.6 0.88 0.8266666666666667 0.8922448979591836 0.8703703703703703 Rows: 1-11 | Columns: 7You can also use the
NaiveBayes.score
function to compute any classification metric. The default metric is the accuracy:model.score(metric = "f1", average = "macro") Out[4]: 0.851851851851852
Note
For multi-class scoring,
verticapy
allows the flexibility to use three averaging techniques:micro
,macro
andweighted
. Please refer to this link for more details on how they are calculated.Prediction#
Prediction is straight-forward:
model.predict( test, [ "SepalLengthCm", "SepalWidthCm", "PetalLengthCm", "PetalWidthCm", ], "prediction", )
123SepalLengthCmNumeric(7)123SepalWidthCmNumeric(7)123PetalLengthCmNumeric(7)123PetalWidthCmNumeric(7)AbcSpeciesVarchar(30)AbcpredictionVarchar(100)1 3.3 4.5 5.6 7.8 Iris-setosa Iris-setosa 2 3.3 4.5 5.6 7.8 Iris-setosa Iris-setosa 3 3.3 4.5 5.6 7.8 Iris-setosa Iris-setosa 4 3.3 4.5 5.6 7.8 Iris-setosa Iris-setosa 5 3.3 4.5 5.6 7.8 Iris-setosa Iris-setosa 6 3.3 4.5 5.6 7.8 Iris-setosa Iris-setosa 7 3.3 4.5 5.6 7.8 Iris-setosa Iris-setosa 8 3.3 4.5 5.6 7.8 Iris-setosa Iris-setosa 9 3.3 4.5 5.6 7.8 Iris-setosa Iris-setosa 10 4.3 4.7 9.6 1.8 Iris-virginica Iris-virginica 11 4.3 4.7 9.6 1.8 Iris-virginica Iris-virginica 12 4.3 4.7 9.6 1.8 Iris-virginica Iris-virginica 13 4.3 4.7 9.6 1.8 Iris-virginica Iris-virginica 14 4.3 4.7 9.6 1.8 Iris-virginica Iris-virginica 15 4.3 4.7 9.6 1.8 Iris-virginica Iris-virginica 16 4.3 4.7 9.6 1.8 Iris-virginica Iris-virginica 17 4.3 4.7 9.6 1.8 Iris-virginica Iris-virginica 18 4.3 4.7 9.6 1.8 Iris-virginica Iris-virginica 19 4.6 3.4 1.4 0.3 Iris-setosa Iris-setosa 20 4.6 3.6 1.0 0.2 Iris-setosa Iris-setosa 21 4.8 3.1 1.6 0.2 Iris-setosa Iris-setosa 22 4.9 2.5 4.5 1.7 Iris-virginica Iris-versicolor 23 5.0 3.3 1.4 0.2 Iris-setosa Iris-setosa 24 5.1 3.5 1.4 0.2 Iris-setosa Iris-setosa 25 5.4 3.7 1.5 0.2 Iris-setosa Iris-setosa 26 5.4 3.9 1.3 0.4 Iris-setosa Iris-setosa 27 5.4 3.9 1.7 0.4 Iris-setosa Iris-setosa 28 5.5 3.5 1.3 0.2 Iris-setosa Iris-setosa 29 5.7 3.0 4.2 1.2 Iris-versicolor Iris-versicolor 30 5.9 3.2 4.8 1.8 Iris-versicolor Iris-virginica 31 6.0 2.7 5.1 1.6 Iris-versicolor Iris-versicolor 32 6.0 3.0 4.8 1.8 Iris-virginica Iris-versicolor 33 6.1 2.8 4.0 1.3 Iris-versicolor Iris-versicolor 34 6.1 2.9 4.7 1.4 Iris-versicolor Iris-versicolor 35 6.2 2.9 4.3 1.3 Iris-versicolor Iris-versicolor 36 6.3 2.5 5.0 1.9 Iris-virginica Iris-virginica 37 6.3 2.7 4.9 1.8 Iris-virginica Iris-versicolor 38 6.3 2.8 5.1 1.5 Iris-virginica Iris-versicolor 39 6.3 2.9 5.6 1.8 Iris-virginica Iris-virginica 40 6.4 2.8 5.6 2.1 Iris-virginica Iris-virginica 41 6.7 3.0 5.2 2.3 Iris-virginica Iris-virginica 42 6.7 3.1 5.6 2.4 Iris-virginica Iris-virginica 43 6.7 3.3 5.7 2.5 Iris-virginica Iris-virginica 44 6.9 3.1 5.4 2.1 Iris-virginica Iris-virginica 45 7.2 3.6 6.1 2.5 Iris-virginica Iris-virginica 46 7.7 2.6 6.9 2.3 Iris-virginica Iris-virginica 47 7.7 3.0 6.1 2.3 Iris-virginica Iris-virginica 48 3.3 4.5 5.6 7.8 Iris-setosa Iris-setosa 49 4.3 4.7 9.6 1.8 Iris-virginica Iris-virginica Rows: 1-49 | Columns: 6Note
Predictions can be made automatically using the test set, in which case you don’t need to specify the predictors. Alternatively, you can pass only the
vDataFrame
to thepredict()
function, but in this case, it’s essential that the column names of thevDataFrame
match the predictors and response name in the model.Probabilities#
It is also easy to get the model’s probabilities:
model.predict_proba( test, [ "SepalLengthCm", "SepalWidthCm", "PetalLengthCm", "PetalWidthCm", ], "prediction", )
123SepalLengthCmNumeric(7)123SepalWidthCmNumeric(7)123PetalLengthCmNumeric(7)123PetalWidthCmNumeric(7)AbcSpeciesVarchar(30)AbcpredictionVarchar(100)Abcprediction_irissetosaVarchar(100)Abcprediction_irisversicolorVarchar(100)Abcprediction_irisvirginicaVarchar(100)1 3.3 4.5 5.6 7.8 Iris-setosa Iris-setosa 1 5.32765e-259 3.04755e-177 2 3.3 4.5 5.6 7.8 Iris-setosa Iris-setosa 1 5.32765e-259 3.04755e-177 3 3.3 4.5 5.6 7.8 Iris-setosa Iris-setosa 1 5.32765e-259 3.04755e-177 4 3.3 4.5 5.6 7.8 Iris-setosa Iris-setosa 1 5.32765e-259 3.04755e-177 5 3.3 4.5 5.6 7.8 Iris-setosa Iris-setosa 1 5.32765e-259 3.04755e-177 6 3.3 4.5 5.6 7.8 Iris-setosa Iris-setosa 1 5.32765e-259 3.04755e-177 7 3.3 4.5 5.6 7.8 Iris-setosa Iris-setosa 1 5.32765e-259 3.04755e-177 8 3.3 4.5 5.6 7.8 Iris-setosa Iris-setosa 1 5.32765e-259 3.04755e-177 9 3.3 4.5 5.6 7.8 Iris-setosa Iris-setosa 1 5.32765e-259 3.04755e-177 10 4.3 4.7 9.6 1.8 Iris-virginica Iris-virginica 0.00232998 1.9755e-38 0.99767 11 4.3 4.7 9.6 1.8 Iris-virginica Iris-virginica 0.00232998 1.9755e-38 0.99767 12 4.3 4.7 9.6 1.8 Iris-virginica Iris-virginica 0.00232998 1.9755e-38 0.99767 13 4.3 4.7 9.6 1.8 Iris-virginica Iris-virginica 0.00232998 1.9755e-38 0.99767 14 4.3 4.7 9.6 1.8 Iris-virginica Iris-virginica 0.00232998 1.9755e-38 0.99767 15 4.3 4.7 9.6 1.8 Iris-virginica Iris-virginica 0.00232998 1.9755e-38 0.99767 16 4.3 4.7 9.6 1.8 Iris-virginica Iris-virginica 0.00232998 1.9755e-38 0.99767 17 4.3 4.7 9.6 1.8 Iris-virginica Iris-virginica 0.00232998 1.9755e-38 0.99767 18 4.3 4.7 9.6 1.8 Iris-virginica Iris-virginica 0.00232998 1.9755e-38 0.99767 19 4.6 3.4 1.4 0.3 Iris-setosa Iris-setosa 1 5.90302e-14 2.07352e-14 20 4.6 3.6 1.0 0.2 Iris-setosa Iris-setosa 1 3.40448e-18 2.60895e-16 21 4.8 3.1 1.6 0.2 Iris-setosa Iris-setosa 1 6.58419e-13 8.9424e-16 22 4.9 2.5 4.5 1.7 Iris-virginica Iris-versicolor 0.015503 0.803665 0.180832 23 5.0 3.3 1.4 0.2 Iris-setosa Iris-setosa 1 4.30877e-14 7.72366e-16 24 5.1 3.5 1.4 0.2 Iris-setosa Iris-setosa 1 1.44006e-14 7.63312e-16 25 5.4 3.7 1.5 0.2 Iris-setosa Iris-setosa 1 2.42326e-14 1.19441e-15 26 5.4 3.9 1.3 0.4 Iris-setosa Iris-setosa 1 6.09786e-14 1.63566e-12 27 5.4 3.9 1.7 0.4 Iris-setosa Iris-setosa 1 7.05111e-12 2.44434e-12 28 5.5 3.5 1.3 0.2 Iris-setosa Iris-setosa 1 1.68971e-14 1.2626e-15 29 5.7 3.0 4.2 1.2 Iris-versicolor Iris-versicolor 0.000640977 0.999313 4.62317e-05 30 5.9 3.2 4.8 1.8 Iris-versicolor Iris-virginica 0.0167434 0.334584 0.648672 31 6.0 2.7 5.1 1.6 Iris-versicolor Iris-versicolor 0.00120372 0.940276 0.0585206 32 6.0 3.0 4.8 1.8 Iris-virginica Iris-versicolor 0.0075102 0.551957 0.440533 33 6.1 2.8 4.0 1.3 Iris-versicolor Iris-versicolor 0.000114008 0.999801 8.52623e-05 34 6.1 2.9 4.7 1.4 Iris-versicolor Iris-versicolor 0.000239158 0.998651 0.0011096 35 6.2 2.9 4.3 1.3 Iris-versicolor Iris-versicolor 0.000118777 0.999754 0.000126845 36 6.3 2.5 5.0 1.9 Iris-virginica Iris-virginica 0.00253309 0.187117 0.81035 37 6.3 2.7 4.9 1.8 Iris-virginica Iris-versicolor 0.00209139 0.61447 0.383439 38 6.3 2.8 5.1 1.5 Iris-virginica Iris-versicolor 0.000512624 0.98203 0.0174578 39 6.3 2.9 5.6 1.8 Iris-virginica Iris-virginica 0.00358159 0.0281004 0.968318 40 6.4 2.8 5.6 2.1 Iris-virginica Iris-virginica 0.0037708 0.000259601 0.99597 41 6.7 3.0 5.2 2.3 Iris-virginica Iris-virginica 0.0152803 3.43483e-05 0.984685 42 6.7 3.1 5.6 2.4 Iris-virginica Iris-virginica 0.0330208 3.76687e-07 0.966979 43 6.7 3.3 5.7 2.5 Iris-virginica Iris-virginica 0.12142 1.01411e-08 0.87858 44 6.9 3.1 5.4 2.1 Iris-virginica Iris-virginica 0.00211698 0.000163728 0.997719 45 7.2 3.6 6.1 2.5 Iris-virginica Iris-virginica 0.0313574 1.59568e-11 0.968643 46 7.7 2.6 6.9 2.3 Iris-virginica Iris-virginica 0.000110741 1.41021e-12 0.999889 47 7.7 3.0 6.1 2.3 Iris-virginica Iris-virginica 0.000441386 3.17972e-09 0.999559 48 3.3 4.5 5.6 7.8 Iris-setosa Iris-setosa 1 5.32765e-259 3.04755e-177 49 4.3 4.7 9.6 1.8 Iris-virginica Iris-virginica 0.00232998 1.9755e-38 0.99767 Rows: 1-49 | Columns: 9Note
Probabilities are added to the
vDataFrame
, and VerticaPy uses the corresponding probability function in SQL behind the scenes. You can use thepos_label
parameter to add only the probability of the selected category.Confusion Matrix#
You can obtain the confusion matrix.
model.confusion_matrix() Out[5]: array([[19, 0, 0], [ 0, 5, 1], [ 0, 4, 20]])
Hint
In the context of multi-class classification, you typically work with an overall confusion matrix that summarizes the classification efficiency across all classes. However, you have the flexibility to specify a
pos_label
and adjust the cutoff threshold. In this case, a binary confusion matrix is computed, where the chosen class is treated as the positive class, allowing you to evaluate its efficiency as if it were a binary classification problem.model.confusion_matrix(pos_label = "Iris-setosa", cutoff = 0.6) Out[6]: array([[30, 0], [ 0, 19]])
Note
In classification, the
cutoff
is a threshold value used to determine class assignment based on predicted probabilities or scores from a classification model. In binary classification, if the predicted probability for a specific class is greater than or equal to the cutoff, the instance is assigned to the positive class; otherwise, it is assigned to the negative class. Adjusting the cutoff allows for trade-offs between true positives and false positives, enabling the model to be optimized for specific objectives or to consider the relative costs of different classification errors. The choice of cutoff is critical for tailoring the model’s performance to meet specific needs.Main Plots (Classification Curves)#
Classification models allow for the creation of various plots that are very helpful in understanding the model, such as the ROC Curve, PRC Curve, Cutoff Curve, Gain Curve, and more.
Most of the classification curves can be found in the Machine Learning - Classification Curve.
For example, let’s draw the model’s ROC curve.
model.roc_curve(pos_label = "Iris-setosa")
Important
Most of the curves have a parameter called
nbins
, which is essential for estimating metrics. The larger thenbins
, the more precise the estimation, but it can significantly impact performance. Exercise caution when increasing this parameter excessively.Hint
In binary classification, various curves can be easily plotted. However, in multi-class classification, it’s important to select the
pos_label
, representing the class to be treated as positive when drawing the curve.Other Plots#
Contour plot is another useful plot that can be produced for models with two predictors.
model.contour(pos_label = "Iris-setosa")
Important
Machine learning models with two predictors can usually benefit from their own contour plot. This visual representation aids in exploring predictions and gaining a deeper understanding of how these models perform in different scenarios. Please refer to Contour Plot for more examples.
Parameter Modification#
In order to see the parameters:
model.get_params() Out[7]: {'alpha': 1.0, 'nbtype': 'auto'}
And to manually change some of the parameters:
model.set_params({'alpha': 0.9})
Model Register#
In order to register the model for tracking and versioning:
model.register("model_v1")
Please refer to Model Tracking and Versioning for more details on model tracking and versioning.
Model Exporting#
To Memmodel
model.to_memmodel()
Note
MemModel
objects serve as in-memory representations of machine learning models. They can be used for both in-database and in-memory prediction tasks. These objects can be pickled in the same way that you would pickle ascikit-learn
model.The following methods for exporting the model use
MemModel
, and it is recommended to useMemModel
directly.To SQL
You can get the SQL code by:
model.to_sql() Out[9]: 'CASE WHEN "SepalLengthCm" IS NULL OR "SepalWidthCm" IS NULL OR "PetalLengthCm" IS NULL OR "PetalWidthCm" IS NULL THEN NULL WHEN 0.3244726886220993 * EXP(- POWER("SepalLengthCm" - 5.41486486486486, 2) / 3.02338763420952) * 0.4491388607461844 * EXP(- POWER("SepalWidthCm" - 3.87027027027027, 2) / 1.577934098482068) * 0.19437108725076185 * EXP(- POWER("PetalLengthCm" - 7.6472972972973, 2) / 8.42532765642342) * 1.9328615063345225 * EXP(- POWER("PetalWidthCm" - 1.90135135135135, 2) / 0.085201777119588) * 0.375634517766497 >= 0.4505140594417424 * EXP(- POWER("SepalLengthCm" - 4.17594936708861, 2) / 1.568315481986386) * 0.6362690709840972 * EXP(- POWER("SepalWidthCm" - 3.92405063291139, 2) / 0.786264199935092) * 0.1921613470484011 * EXP(- POWER("PetalLengthCm" - 3.46075949367089, 2) / 8.6202142161636) * 0.10494293818212533 * EXP(- POWER("PetalWidthCm" - 3.87721518987342, 2) / 28.9030509574812) * 0.401015228426396 AND 0.3244726886220993 * EXP(- POWER("SepalLengthCm" - 5.41486486486486, 2) / 3.02338763420952) * 0.4491388607461844 * EXP(- POWER("SepalWidthCm" - 3.87027027027027, 2) / 1.577934098482068) * 0.19437108725076185 * EXP(- POWER("PetalLengthCm" - 7.6472972972973, 2) / 8.42532765642342) * 1.9328615063345225 * EXP(- POWER("PetalWidthCm" - 1.90135135135135, 2) / 0.085201777119588) * 0.375634517766497 >= 0.7292913131474902 * EXP(- POWER("SepalLengthCm" - 5.92727272727273, 2) / 0.598477801268546) * 1.2291225672892683 * EXP(- POWER("SepalWidthCm" - 2.75, 2) / 0.210697674418604) * 0.8486436275756416 * EXP(- POWER("PetalLengthCm" - 4.225, 2) / 0.441976744186038) * 2.0783710790945586 * EXP(- POWER("PetalWidthCm" - 1.31136363636364, 2) / 0.073689217758985) * 0.223350253807107 THEN \'Iris-virginica\' WHEN 0.7292913131474902 * EXP(- POWER("SepalLengthCm" - 5.92727272727273, 2) / 0.598477801268546) * 1.2291225672892683 * EXP(- POWER("SepalWidthCm" - 2.75, 2) / 0.210697674418604) * 0.8486436275756416 * EXP(- POWER("PetalLengthCm" - 4.225, 2) / 0.441976744186038) * 2.0783710790945586 * EXP(- POWER("PetalWidthCm" - 1.31136363636364, 2) / 0.073689217758985) * 0.223350253807107 >= 0.4505140594417424 * EXP(- POWER("SepalLengthCm" - 4.17594936708861, 2) / 1.568315481986386) * 0.6362690709840972 * EXP(- POWER("SepalWidthCm" - 3.92405063291139, 2) / 0.786264199935092) * 0.1921613470484011 * EXP(- POWER("PetalLengthCm" - 3.46075949367089, 2) / 8.6202142161636) * 0.10494293818212533 * EXP(- POWER("PetalWidthCm" - 3.87721518987342, 2) / 28.9030509574812) * 0.401015228426396 THEN \'Iris-versicolor\' ELSE \'Iris-setosa\' END'
To Python
To obtain the prediction function in Python syntax, use the following code:
X = [[5, 2, 3, 1]] model.to_python()(X) Out[11]: array(['Iris-versicolor'], dtype='<U15')
Hint
The
to_python()
method is used to retrieve predictions, probabilities, or cluster distances. For specific details on how to use this method for different model types, refer to the relevant documentation for each model.- __init__(name: str = None, overwrite_model: bool = False, alpha: int | float | Decimal = 1.0, nbtype: Literal['auto', 'bernoulli', 'categorical', 'multinomial', 'gaussian'] = 'auto') None #
Must be overridden in the child class
Methods
__init__
([name, overwrite_model, alpha, nbtype])Must be overridden in the child class
classification_report
([metrics, cutoff, ...])Computes a classification report using multiple model evaluation metrics (
auc
,accuracy
,f1
...).confusion_matrix
([pos_label, cutoff])Computes the model confusion matrix.
contour
([pos_label, nbins, chart])Draws the model's contour plot.
cutoff_curve
([pos_label, nbins, show, chart])Draws the model Cutoff curve.
deploySQL
([X, pos_label, cutoff, allSQL])Returns the SQL code needed to deploy the model.
does_model_exists
(name[, raise_error, ...])Checks whether the model is stored in the Vertica database.
drop
()Drops the model from the Vertica database.
export_models
(name, path[, kind])Exports machine learning models.
fit
(input_relation, X, y[, test_relation, ...])Trains the model.
get_attributes
([attr_name])Returns the model attributes.
get_match_index
(x, col_list[, str_check])Returns the matching index.
Returns the parameters of the model.
get_plotting_lib
([class_name, chart, ...])Returns the first available library (Plotly, Matplotlib, or Highcharts) to draw a specific graphic.
get_vertica_attributes
([attr_name])Returns the model Vertica attributes.
import_models
(path[, schema, kind])Imports machine learning models.
lift_chart
([pos_label, nbins, show, chart])Draws the model Lift Chart.
prc_curve
([pos_label, nbins, show, chart])Draws the model PRC curve.
predict
(vdf[, X, name, cutoff, inplace])Predicts using the input relation.
predict_proba
(vdf[, X, name, pos_label, inplace])Returns the model's probabilities using the input relation.
register
(registered_name[, raise_error])Registers the model and adds it to in-DB Model versioning environment with a status of 'under_review'.
report
([metrics, cutoff, labels, nbins])Computes a classification report using multiple model evaluation metrics (
auc
,accuracy
,f1
...).roc_curve
([pos_label, nbins, show, chart])Draws the model ROC curve.
score
([metric, average, pos_label, cutoff, ...])Computes the model score.
set_params
([parameters])Sets the parameters of the model.
Summarizes the model.
to_binary
(path)Exports the model to the Vertica Binary format.
Converts the model to an InMemory object that can be used for different types of predictions.
to_pmml
(path)Exports the model to PMML.
to_python
([return_proba, ...])Returns the Python function needed for in-memory scoring without using built-in Vertica functions.
to_sql
([X, return_proba, ...])Returns the SQL code needed to deploy the model without using built-in Vertica functions.
to_tf
(path)Exports the model to the Frozen Graph format (TensorFlow).
Attributes