Source code for model_eval

# -*- coding: utf-8 -*-

# LIBTwinSVM: A Library for Twin Support Vector Machines
# Developers: Mir, A. and Mahdi Rahbar
# License: GNU General Public License v3.0

"""
This module contains code for saving, loading, and evaluating pre-trained
models.
"""

from PyQt5.QtCore import QObject, pyqtSlot, pyqtSignal
from sklearn.metrics import accuracy_score
from joblib import dump, load
from libtsvm.estimators import BaseTSVM
from libtsvm.mc_scheme import (OneVsAllClassifier, OneVsOneClassifier,
                               mc_clf_no_params)
from libtsvm.misc import time_fmt
from datetime import datetime
from os.path import join
from numpy import savetxt


[docs]def save_model(validator, params, output_file): """ It saves an estimator with specified hyper-parameters and a evaluation method. Parameters ---------- validator : object An evaluation method. params : dict Hyper-parameters of the estimator. output_file : str The full path and filename of the saved model. """ # The evaluation method eval_func = validator.choose_validator() eval_func(params) dump(validator.estimator, output_file)
[docs]def load_model(model_path): """ It loads a pre-trained TSVM-based estimator. Parameters ---------- model_path : str The path at which the model is stored. Returns ------- object A pre-trained estimator. dict Model information. """ pre_trained_clf = load(model_path) if isinstance(pre_trained_clf, BaseTSVM): kernel_name = pre_trained_clf.kernel rect_kernel = pre_trained_clf.rect_kernel model_no_params = str(pre_trained_clf.w1.shape[0] + \ pre_trained_clf.w2.shape[0] + 2) model_h_param = pre_trained_clf.get_params() model_clf_type = 'Binary' elif isinstance(pre_trained_clf, OneVsAllClassifier) or \ isinstance(pre_trained_clf, OneVsOneClassifier): kernel_name = pre_trained_clf.estimator.kernel rect_kernel = pre_trained_clf.estimator.rect_kernel model_no_params = str(mc_clf_no_params(pre_trained_clf.bin_clf_)) model_h_param = pre_trained_clf.estimator.get_params() model_clf_type = 'Multi-class' else: raise ValueError("An unsupported estimator is loaded!") return pre_trained_clf, {'model_name': pre_trained_clf.clf_name, 'kernel': kernel_name, 'rect_kernel': rect_kernel, 'no_params': model_no_params, 'h_params': model_h_param, 'clf_type': model_clf_type}
[docs]class ModelThread(QObject): """ Evaluates a pre-trained model in a thread. Parameters ---------- usr_input : object An instance of :class:`UserInput` class which holds the user input. """ sig_update_model_eval = pyqtSignal(str, str) def __init__(self, usr_in): super(ModelThread, self).__init__() self.usr_in = usr_in
[docs] @pyqtSlot() def eval_model(self): """ It evaluates a pre-trained model on test samples. """ start_t = datetime.now() pred = self.usr_in.pre_trained_model.predict(self.usr_in.X_train) test_acc = accuracy_score(self.usr_in.y_train, pred) * 100 elapsed_t = datetime.now() - start_t self.sig_update_model_eval.emit("%.2f%%" % test_acc, time_fmt(elapsed_t.seconds)) if self.usr_in.save_pred: f_name = 'test_labels_model_%s_%s_%s_%s.txt' % ( self.usr_in.pre_trained_model.clf_name, self.usr_in.kernel_type, self.usr_in.data_filename, datetime.now().strftime('%Y-%m-%d %H-%M')) savetxt(join(self.usr_in.save_pred_path, f_name), pred, fmt='%d')