PyEGRO ModelTesting API Reference¶
This document provides detailed API reference for the Model Testing module in the PyEGRO package, which allows evaluation of trained GPR and Co-Kriging models on unseen data.
Table of Contents¶
ModelTester Class¶
The ModelTester class provides functionality for loading trained models and evaluating their performance on test data.
Constructor¶
Parameters¶
model_dir(str, optional): Directory containing trained model files. Default:'RESULT_MODEL_GPR'model_name(str, optional): Base name of the model file without extension. Default:None(will be inferred from directory or file path)model_path(str, optional): Direct path to the model file. Default:Nonelogger(logging.Logger, optional): Logger object for logging messages. Default:None
Methods¶
load_model¶
Load the trained model and scalers from disk.
Returns¶
self: Returns the ModelTester instance for method chaining
load_test_data¶
Load test data from CSV file or generate synthetic test data.
Parameters¶
data_path(str, optional): Path to the CSV file containing test data. Default:Nonefeature_cols(list of str, optional): List of feature column names. Default:None(all columns except target)target_col(str, optional): Target column name. Default:'y'n_samples(int, optional): Number of samples for synthetic data if no data_path provided. Default:100n_features(int, optional): Number of features for synthetic data if no data_path provided. Default:2
Returns¶
- Tuple of
(X_test, y_test): Test features and targets as numpy arrays
evaluate¶
Evaluate the model performance on test data.
Parameters¶
X_test(numpy.ndarray): Test featuresy_test(numpy.ndarray): Test targets
Returns¶
self: Returns the ModelTester instance for method chaining with updated test_results property
save_results¶
Save test results and generate plots.
Parameters¶
output_dir(str, optional): Directory to save results. Default:'test_results'
Returns¶
self: Returns the ModelTester instance for method chaining
plot_results¶
Generate and optionally save plots of test results.
plot_results(output_dir='test_results', show_plots=True, save_plots=True, smooth=True, smooth_window=11)
Parameters¶
output_dir(str, optional): Directory to save plots. Default:'test_results'show_plots(bool, optional): Whether to display plots. Default:Truesave_plots(bool, optional): Whether to save plots to disk. Default:Truesmooth(bool, optional): Whether to apply smoothing to uncertainty plots. Default:Truesmooth_window(int, optional): Window size for smoothing filter. Default:11
Returns¶
figures(dict): Dictionary of matplotlib figure objects
Properties¶
test_results(dict): Contains test metrics and predictions after runningevaluate. Keys include:r2: R² scoremse: Mean squared errorrmse: Root mean squared errormae: Mean absolute errory_test: Test targetsy_pred: Model predictionsstd_dev: Standard deviations of predictionsmodel_type: Type of model ('gpr' or 'cokriging')
Utility Functions¶
load_and_test_model¶
A convenience function to load a model and test it in one call.
load_and_test_model(
data_path=None,
model_dir=None,
model_name=None,
model_path=None,
output_dir='test_results',
feature_cols=None,
target_col='y',
logger=None,
show_plots=True,
smooth=True,
smooth_window=11
)
Parameters¶
data_path(str, optional): Path to the CSV file containing test data. Default:Nonemodel_dir(str, optional): Directory containing trained model files. Default:Nonemodel_name(str, optional): Base name of the model file without extension. Default:Nonemodel_path(str, optional): Direct path to the model file. Default:Noneoutput_dir(str, optional): Directory to save results. Default:'test_results'feature_cols(list of str, optional): List of feature column names. Default:Nonetarget_col(str, optional): Target column name. Default:'y'logger(logging.Logger, optional): Logger object for logging messages. Default:Noneshow_plots(bool, optional): Whether to display plots. Default:Truesmooth(bool, optional): Whether to apply smoothing to uncertainty plots. Default:Truesmooth_window(int, optional): Window size for smoothing filter. Default:11
Returns¶
test_results(dict): Dictionary containing test metrics and predictions