diff --git a/local_api.py b/local_api.py index a3bff2f988..275f4daec3 100644 --- a/local_api.py +++ b/local_api.py @@ -2,17 +2,18 @@ import requests -# TODO: send a GET using the URL http://127.0.0.1:8000 -r = None # Your code here +# Send a GET using the URL http://127.0.0.1:8000 +url="http://127.0.0.1:8000/predict" # Your code here +r = requests.get("http://127.0.0.1:8000") # TODO: print the status code -# print() +print("Status Code:", r.status_code) # TODO: print the welcome message -# print() +print("Results:", r.json()) -data = { +input_data = { "age": 37, "workclass": "Private", "fnlgt": 178356, @@ -26,13 +27,19 @@ "capital-gain": 0, "capital-loss": 0, "hours-per-week": 40, - "native-country": "United-States", + "native-country": "United-States" } -# TODO: send a POST using the data above -r = None # Your code here -# TODO: print the status code -# print() -# TODO: print the result -# print() +# Send a POST using the data above +response = requests.post(url, json=input_data) # Your code here + +# Print the status code +print("Status Code: {response.status_code}") + +print("Response text:", response.text) +# Print the result +if response.status_code == 200: + print("Result:", response.json()) +else: + print(f"Failed with status code {response.status_code}. Response text: {response.text}") diff --git a/main.py b/main.py index 638e2414de..b52918becc 100644 --- a/main.py +++ b/main.py @@ -1,8 +1,9 @@ import os - +import pickle import pandas as pd -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field +import joblib from ml.data import apply_label, process_data from ml.model import inference, load_model @@ -25,26 +26,33 @@ class Data(BaseModel): capital_loss: int = Field(..., example=0, alias="capital-loss") hours_per_week: int = Field(..., example=40, alias="hours-per-week") native_country: str = Field(..., example="United-States", alias="native-country") +try: + encoder_path = os.path.join(os.getcwd(), "model", "encoder.pkl") # TODO: enter the path for the saved encoder + encoder = load_model(encoder_path) + print("Encoder loaded successfully.") + + model_path = os.path.join(os.getcwd(), "model","model.pkl") # TODO: enter the path for the saved model + model = load_model(model_path) + print("Model loaded successfully.") -path = None # TODO: enter the path for the saved encoder -encoder = load_model(path) -path = None # TODO: enter the path for the saved model -model = load_model(path) +except Exception as e: + print(f"Error loading model, encoder, or label binarizer: {e}") + raise HTTPException(status_code=500, detail=f"Error loading model, encoder, or label binarizer: {str(e)}") -# TODO: create a RESTful API using FastAPI -app = None # your code here +# Create a RESTful API using FastAPI +app = FastAPI() # your code here -# TODO: create a GET on the root giving a welcome message +# Create a GET on the root giving a welcome message @app.get("/") async def get_root(): """ Say hello!""" # your code here - pass + return {"message": "Hello from the API!"} -# TODO: create a POST on a different path that does model inference -@app.post("/data/") +# Create a POST on a different path that does model inference +@app.post("/predict") async def post_inference(data: Data): # DO NOT MODIFY: turn the Pydantic model into a dict. data_dict = data.dict() @@ -65,10 +73,14 @@ async def post_inference(data: Data): "native-country", ] data_processed, _, _, _ = process_data( - # your code here + data, + categorical_features=cat_features, + training=False,# your code here + encoder=encoder # use data as data input # use training = False # do not need to pass lb as input ) - _inference = None # your code here to predict the result using data_processed - return {"result": apply_label(_inference)} + _inference = model.predict(data_processed) + result = apply_label(_inference) + return {"result": result} diff --git a/ml/data.py b/ml/data.py index f8d30b5b16..73dc292af1 100644 --- a/ml/data.py +++ b/ml/data.py @@ -53,12 +53,15 @@ def process_data( X_categorical = X[categorical_features].values X_continuous = X.drop(*[categorical_features], axis=1) - if training is True: + if training: encoder = OneHotEncoder(sparse_output=False, handle_unknown="ignore") lb = LabelBinarizer() X_categorical = encoder.fit_transform(X_categorical) y = lb.fit_transform(y.values).ravel() else: + if encoder is None: + raise ValueError("Encoder and LabelBinarizer must be provided during inference") + X_categorical = encoder.transform(X_categorical) try: y = lb.transform(y.values).ravel() diff --git a/ml/model.py b/ml/model.py index f361110f18..f9831a08af 100644 --- a/ml/model.py +++ b/ml/model.py @@ -2,9 +2,12 @@ from sklearn.metrics import fbeta_score, precision_score, recall_score from ml.data import process_data # TODO: add necessary import +from sklearn.ensemble import RandomForestClassifier +from sklearn.model_selection import GridSearchCV, StratifiedKFold +import joblib # Optional: implement hyperparameter tuning. -def train_model(X_train, y_train): +def train_model(X_train, y_train, cv=None): """ Trains a machine learning model and returns it. @@ -19,8 +22,19 @@ def train_model(X_train, y_train): model Trained machine learning model. """ - # TODO: implement the function - pass + # Train and return a model + if cv is None: + cv = StratifiedKFold(n_splits=5) + + param_grid = { + 'n_estimators': [100, 200], + 'max_depth': [None, 10, 20], + } + clf = GridSearchCV(RandomForestClassifier(random_state=42), param_grid, cv=cv) + clf.fit(X_train, y_train) + return clf.best_estimator_ + + def compute_model_metrics(y, preds): @@ -50,7 +64,7 @@ def inference(model, X): Inputs ------ - model : ??? + model : sklearn.base.BaseEstimator Trained machine learning model. X : np.array Data used for prediction. @@ -59,8 +73,8 @@ def inference(model, X): preds : np.array Predictions from the model. """ - # TODO: implement the function - pass + # Run model inferences and return the predictions + return model.predict(X) def save_model(model, path): """ Serializes model to a file. @@ -72,13 +86,13 @@ def save_model(model, path): path : str Path to save pickle file. """ - # TODO: implement the function - pass + # Save a model + joblib.dump(model, path) def load_model(path): """ Loads pickle file from `path` and returns it.""" - # TODO: implement the function - pass + # Load a model + return joblib.load(path) def performance_on_categorical_slice( @@ -107,7 +121,7 @@ def performance_on_categorical_slice( Trained sklearn OneHotEncoder, only used if training=False. lb : sklearn.preprocessing._label.LabelBinarizer Trained sklearn LabelBinarizer, only used if training=False. - model : ??? + model : RandomForestClassifier Model used for the task. Returns @@ -117,12 +131,33 @@ def performance_on_categorical_slice( fbeta : float """ - # TODO: implement the function + # Computes the metrics on a slice of the data + data_slice = data[data[column_name]==slice_value] + X_slice, y_slice, _, _ = process_data( # your code here # for input data, use data in column given as "column_name", with the slice_value # use training = False + data_slice, + categorical_features=categorical_features, + label=label, + encoder=encoder, + lb=lb, + training=False ) - preds = None # your code here to get prediction on X_slice using the inference function + preds = inference(model, X_slice) # your code here to get prediction on X_slice using the inference function precision, recall, fbeta = compute_model_metrics(y_slice, preds) + + #Prepare the log message for this slice + log_message = ( + f"Precision: {precision:.4f} | Recall: {recall:.4f} | F1: {fbeta:.4f}\n" + f"{column_name}: {slice_value}, Count: {len(data_slice)}\n" + ) + #Apend the results to slice_output.txt + with open('slice_output.txt', 'a') as f: + f.write(log_message) + + #Print the result for terminal log + print(log_message) + return precision, recall, fbeta diff --git a/model/encoder.pkl b/model/encoder.pkl new file mode 100644 index 0000000000..44a41b95e8 Binary files /dev/null and b/model/encoder.pkl differ diff --git a/model/model.pkl b/model/model.pkl new file mode 100644 index 0000000000..03364b3ac3 Binary files /dev/null and b/model/model.pkl differ diff --git a/model_card_template.md b/model_card_template.md index 0392f3b9eb..7705b94a3a 100644 --- a/model_card_template.md +++ b/model_card_template.md @@ -3,16 +3,53 @@ For additional information see the Model Card paper: https://arxiv.org/pdf/1810.03993.pdf ## Model Details - +Model Name: Random Forest Classifier +Version: v1.0 +Type: Supervised Classification (Binary Classification) +Model Architecture: Random Forest with hyperparameter tuning (n_estimators: 100-200, max_depth: None, 10, 20) using GridSearchCV +Training Time: Approximately 2 hours (estimated) +Last Trained: 2025-05-04 ## Intended Use +The model is designed to predict whether an individual's income exceeds $50k per year, based on demographic and employment-related features from U.S. Census data. Intended for educational purposes, especially for demonstrating machine learning workflows and deployment pipelines. ## Training Data +Source: UC Irvine Adult Census Income Dataset (loaded from census.csv) +Features: + -Categorical: workclass, education, marital-status, occupation, relationship, race, sex, native-country + -Numerical: age, hours-per-week, education-num, capital-gain, capital-loss, etc +Target Label: salary (<=50k, >50k) +Preprocessing: + -OneHotEncoding for categorical variables + -LabelBinarizer for the target variable + -Train-test split: 80/20 ## Evaluation Data +The test dataset is a 20% hold-out sample from the original dataset +Used to evaluate generalization performance and perform slice-based fairness analysis ## Metrics _Please include the metrics used and your model's performance on those metrics._ +Precision: 0.7866 +Recall: 0.6149 +F1 Score: 0.6902 + +These results suggest that the model strikes a moderate balance between identifying positive cases (recall) and minimizing false positives (precision) + +See slice_output.txt ## Ethical Considerations +Bias & Fairness: The model may reflect historical biases present in the U.S. Census dataset. For instance, features such as race and gender could lead to disparate performance on different groups. + +Interpretability: Random Forests are not easily interpretable by default, which may limit their use in regulated environments. + +Use Limitations: This model should not be used for automated income predicitions affecting real-world decisions without additional audits and mitigation techniques. + +Data Privacy: The dataset does not contain personally identifiable information (PII), but proper care should be taken if extended to real-world data. ## Caveats and Recommendations +Performance may vary significantly across different population subgroups. Always check model performance using slice-based analysis. +Further model robustness checks are recommended before deployment. +For production use, consider model interpretability enhancements such as SHAP or LIME for local explanations. +LIME https://lime-ml.readthedocs.io/en/latest/ +SHAP https://shap.readthedocs.io/en/latest/ + diff --git a/screenshots/local_api.png b/screenshots/local_api.png new file mode 100644 index 0000000000..5c347b6fc8 Binary files /dev/null and b/screenshots/local_api.png differ diff --git a/screenshots/unit_test.png b/screenshots/unit_test.png new file mode 100644 index 0000000000..f604fbeb6c Binary files /dev/null and b/screenshots/unit_test.png differ diff --git a/slice_output.txt b/slice_output.txt new file mode 100644 index 0000000000..f4d733f71e --- /dev/null +++ b/slice_output.txt @@ -0,0 +1,596 @@ +workclass: ?, Count: 389 +Precision: 0.7391 | Recall: 0.4048 | F1: 0.5231 +workclass: Federal-gov, Count: 191 +Precision: 0.7846 | Recall: 0.7286 | F1: 0.7556 +workclass: Local-gov, Count: 387 +Precision: 0.7624 | Recall: 0.7000 | F1: 0.7299 +workclass: Private, Count: 4,578 +Precision: 0.8035 | Recall: 0.6004 | F1: 0.6872 +workclass: Self-emp-inc, Count: 212 +Precision: 0.7815 | Recall: 0.7881 | F1: 0.7848 +workclass: Self-emp-not-inc, Count: 498 +Precision: 0.7451 | Recall: 0.4841 | F1: 0.5869 +workclass: State-gov, Count: 254 +Precision: 0.7286 | Recall: 0.6986 | F1: 0.7133 +workclass: Without-pay, Count: 4 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +education: 10th, Count: 183 +Precision: 1.0000 | Recall: 0.0833 | F1: 0.1538 +education: 11th, Count: 225 +Precision: 1.0000 | Recall: 0.2727 | F1: 0.4286 +education: 12th, Count: 98 +Precision: 1.0000 | Recall: 0.4000 | F1: 0.5714 +education: 1st-4th, Count: 23 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +education: 5th-6th, Count: 62 +Precision: 1.0000 | Recall: 0.5000 | F1: 0.6667 +education: 7th-8th, Count: 141 +Precision: 1.0000 | Recall: 0.0000 | F1: 0.0000 +education: 9th, Count: 115 +Precision: 1.0000 | Recall: 0.3333 | F1: 0.5000 +education: Assoc-acdm, Count: 198 +Precision: 0.7568 | Recall: 0.5957 | F1: 0.6667 +education: Assoc-voc, Count: 273 +Precision: 0.7209 | Recall: 0.4921 | F1: 0.5849 +education: Bachelors, Count: 1,053 +Precision: 0.7547 | Recall: 0.8000 | F1: 0.7767 +education: Doctorate, Count: 77 +Precision: 0.8387 | Recall: 0.9123 | F1: 0.8739 +education: HS-grad, Count: 2,085 +Precision: 0.8534 | Recall: 0.2870 | F1: 0.4295 +education: Masters, Count: 369 +Precision: 0.8219 | Recall: 0.8696 | F1: 0.8451 +education: Preschool, Count: 10 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +education: Prof-school, Count: 116 +Precision: 0.8211 | Recall: 0.9286 | F1: 0.8715 +education: Some-college, Count: 1,485 +Precision: 0.7588 | Recall: 0.4657 | F1: 0.5772 +marital-status: Divorced, Count: 920 +Precision: 0.8537 | Recall: 0.3398 | F1: 0.4861 +marital-status: Married-AF-spouse, Count: 4 +Precision: 1.0000 | Recall: 0.0000 | F1: 0.0000 +marital-status: Married-civ-spouse, Count: 2,950 +Precision: 0.7778 | Recall: 0.6702 | F1: 0.7200 +marital-status: Married-spouse-absent, Count: 96 +Precision: 0.6667 | Recall: 0.2500 | F1: 0.3636 +marital-status: Never-married, Count: 2,126 +Precision: 0.9231 | Recall: 0.3495 | F1: 0.5070 +marital-status: Separated, Count: 209 +Precision: 1.0000 | Recall: 0.4211 | F1: 0.5926 +marital-status: Widowed, Count: 208 +Precision: 1.0000 | Recall: 0.1579 | F1: 0.2727 +occupation: ?, Count: 389 +Precision: 0.7391 | Recall: 0.4048 | F1: 0.5231 +occupation: Adm-clerical, Count: 726 +Precision: 0.6825 | Recall: 0.4479 | F1: 0.5409 +occupation: Armed-Forces, Count: 3 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +occupation: Craft-repair, Count: 821 +Precision: 0.8500 | Recall: 0.3757 | F1: 0.5211 +occupation: Exec-managerial, Count: 838 +Precision: 0.8147 | Recall: 0.7531 | F1: 0.7827 +occupation: Farming-fishing, Count: 193 +Precision: 0.6667 | Recall: 0.2143 | F1: 0.3243 +occupation: Handlers-cleaners, Count: 273 +Precision: 0.5714 | Recall: 0.3333 | F1: 0.4211 +occupation: Machine-op-inspct, Count: 378 +Precision: 0.7647 | Recall: 0.2766 | F1: 0.4062 +occupation: Other-service, Count: 667 +Precision: 0.8333 | Recall: 0.1923 | F1: 0.3125 +occupation: Priv-house-serv, Count: 26 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +occupation: Prof-specialty, Count: 828 +Precision: 0.7920 | Recall: 0.8061 | F1: 0.7990 +occupation: Protective-serv, Count: 136 +Precision: 0.7200 | Recall: 0.4286 | F1: 0.5373 +occupation: Sales, Count: 729 +Precision: 0.7607 | Recall: 0.6458 | F1: 0.6986 +occupation: Tech-support, Count: 189 +Precision: 0.7273 | Recall: 0.6275 | F1: 0.6737 +occupation: Transport-moving, Count: 317 +Precision: 0.8333 | Recall: 0.3125 | F1: 0.4545 +relationship: Husband, Count: 2,590 +Precision: 0.7824 | Recall: 0.6701 | F1: 0.7219 +relationship: Not-in-family, Count: 1,702 +Precision: 0.8684 | Recall: 0.3511 | F1: 0.5000 +relationship: Other-relative, Count: 178 +Precision: 1.0000 | Recall: 0.3750 | F1: 0.5455 +relationship: Own-child, Count: 1,019 +Precision: 1.0000 | Recall: 0.1765 | F1: 0.3000 +relationship: Unmarried, Count: 702 +Precision: 1.0000 | Recall: 0.2889 | F1: 0.4483 +relationship: Wife, Count: 322 +Precision: 0.7405 | Recall: 0.6783 | F1: 0.7080 +race: Amer-Indian-Eskimo, Count: 71 +Precision: 0.7143 | Recall: 0.5000 | F1: 0.5882 +race: Asian-Pac-Islander, Count: 193 +Precision: 0.7619 | Recall: 0.7742 | F1: 0.7680 +race: Black, Count: 599 +Precision: 0.7551 | Recall: 0.5692 | F1: 0.6491 +race: Other, Count: 55 +Precision: 1.0000 | Recall: 0.6667 | F1: 0.8000 +race: White, Count: 5,595 +Precision: 0.7891 | Recall: 0.6106 | F1: 0.6885 +sex: Female, Count: 2,126 +Precision: 0.7548 | Recall: 0.5021 | F1: 0.6031 +sex: Male, Count: 4,387 +Precision: 0.7912 | Recall: 0.6345 | F1: 0.7043 +native-country: ?, Count: 125 +Precision: 0.7407 | Recall: 0.6452 | F1: 0.6897 +native-country: Cambodia, Count: 3 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Canada, Count: 22 +Precision: 0.7000 | Recall: 0.8750 | F1: 0.7778 +native-country: China, Count: 18 +Precision: 0.8000 | Recall: 1.0000 | F1: 0.8889 +native-country: Columbia, Count: 6 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Cuba, Count: 19 +Precision: 1.0000 | Recall: 0.8000 | F1: 0.8889 +native-country: Dominican-Republic, Count: 8 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Ecuador, Count: 5 +Precision: 1.0000 | Recall: 0.5000 | F1: 0.6667 +native-country: El-Salvador, Count: 20 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: England, Count: 14 +Precision: 0.7500 | Recall: 0.7500 | F1: 0.7500 +native-country: France, Count: 5 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Germany, Count: 32 +Precision: 0.8182 | Recall: 0.6923 | F1: 0.7500 +native-country: Greece, Count: 7 +Precision: 1.0000 | Recall: 0.0000 | F1: 0.0000 +native-country: Guatemala, Count: 13 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Haiti, Count: 6 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Honduras, Count: 4 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Hong, Count: 8 +Precision: 0.5000 | Recall: 1.0000 | F1: 0.6667 +native-country: Hungary, Count: 3 +Precision: 1.0000 | Recall: 0.5000 | F1: 0.6667 +native-country: India, Count: 21 +Precision: 0.8750 | Recall: 0.8750 | F1: 0.8750 +native-country: Iran, Count: 12 +Precision: 0.5000 | Recall: 0.4000 | F1: 0.4444 +native-country: Ireland, Count: 5 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Italy, Count: 14 +Precision: 0.7500 | Recall: 0.7500 | F1: 0.7500 +native-country: Jamaica, Count: 13 +Precision: 0.0000 | Recall: 1.0000 | F1: 0.0000 +native-country: Japan, Count: 11 +Precision: 0.7500 | Recall: 0.7500 | F1: 0.7500 +native-country: Laos, Count: 4 +Precision: 1.0000 | Recall: 0.0000 | F1: 0.0000 +native-country: Mexico, Count: 114 +Precision: 1.0000 | Recall: 0.3333 | F1: 0.5000 +native-country: Nicaragua, Count: 7 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Peru, Count: 5 +Precision: 1.0000 | Recall: 0.0000 | F1: 0.0000 +native-country: Philippines, Count: 35 +Precision: 1.0000 | Recall: 0.7500 | F1: 0.8571 +native-country: Poland, Count: 14 +Precision: 0.3333 | Recall: 0.5000 | F1: 0.4000 +native-country: Portugal, Count: 6 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Puerto-Rico, Count: 22 +Precision: 0.8333 | Recall: 0.8333 | F1: 0.8333 +native-country: Scotland, Count: 3 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: South, Count: 13 +Precision: 0.3333 | Recall: 0.5000 | F1: 0.4000 +native-country: Taiwan, Count: 11 +Precision: 0.8000 | Recall: 1.0000 | F1: 0.8889 +native-country: Thailand, Count: 5 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Trinadad&Tobago, Count: 3 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: United-States, Count: 5,870 +Precision: 0.7879 | Recall: 0.6056 | F1: 0.6848 +native-country: Vietnam, Count: 5 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Yugoslavia, Count: 2 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +Precision: 0.7391 | Recall: 0.4048 | F1: 0.5231 +workclass: ?, Count: 389 +Precision: 0.7391 | Recall: 0.4048 | F1: 0.5231 +workclass: ?, Count: 389 +workclass: ?, Count: 389 +Precision: 0.7391 | Recall: 0.4048 | F1: 0.5231 +Precision: 0.7846 | Recall: 0.7286 | F1: 0.7556 +workclass: Federal-gov, Count: 191 +workclass: Federal-gov, Count: 191 +Precision: 0.7846 | Recall: 0.7286 | F1: 0.7556 +Precision: 0.7624 | Recall: 0.7000 | F1: 0.7299 +workclass: Local-gov, Count: 387 +workclass: Local-gov, Count: 387 +Precision: 0.7624 | Recall: 0.7000 | F1: 0.7299 +Precision: 0.8035 | Recall: 0.6004 | F1: 0.6872 +workclass: Private, Count: 4578 +workclass: Private, Count: 4,578 +Precision: 0.8035 | Recall: 0.6004 | F1: 0.6872 +Precision: 0.7815 | Recall: 0.7881 | F1: 0.7848 +workclass: Self-emp-inc, Count: 212 +workclass: Self-emp-inc, Count: 212 +Precision: 0.7815 | Recall: 0.7881 | F1: 0.7848 +Precision: 0.7451 | Recall: 0.4841 | F1: 0.5869 +workclass: Self-emp-not-inc, Count: 498 +workclass: Self-emp-not-inc, Count: 498 +Precision: 0.7451 | Recall: 0.4841 | F1: 0.5869 +Precision: 0.7286 | Recall: 0.6986 | F1: 0.7133 +workclass: State-gov, Count: 254 +workclass: State-gov, Count: 254 +Precision: 0.7286 | Recall: 0.6986 | F1: 0.7133 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +workclass: Without-pay, Count: 4 +workclass: Without-pay, Count: 4 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +Precision: 1.0000 | Recall: 0.0833 | F1: 0.1538 +education: 10th, Count: 183 +education: 10th, Count: 183 +Precision: 1.0000 | Recall: 0.0833 | F1: 0.1538 +Precision: 1.0000 | Recall: 0.2727 | F1: 0.4286 +education: 11th, Count: 225 +education: 11th, Count: 225 +Precision: 1.0000 | Recall: 0.2727 | F1: 0.4286 +Precision: 1.0000 | Recall: 0.4000 | F1: 0.5714 +education: 12th, Count: 98 +education: 12th, Count: 98 +Precision: 1.0000 | Recall: 0.4000 | F1: 0.5714 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +education: 1st-4th, Count: 23 +education: 1st-4th, Count: 23 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +Precision: 1.0000 | Recall: 0.5000 | F1: 0.6667 +education: 5th-6th, Count: 62 +education: 5th-6th, Count: 62 +Precision: 1.0000 | Recall: 0.5000 | F1: 0.6667 +Precision: 1.0000 | Recall: 0.0000 | F1: 0.0000 +education: 7th-8th, Count: 141 +education: 7th-8th, Count: 141 +Precision: 1.0000 | Recall: 0.0000 | F1: 0.0000 +Precision: 1.0000 | Recall: 0.3333 | F1: 0.5000 +education: 9th, Count: 115 +education: 9th, Count: 115 +Precision: 1.0000 | Recall: 0.3333 | F1: 0.5000 +Precision: 0.7568 | Recall: 0.5957 | F1: 0.6667 +education: Assoc-acdm, Count: 198 +education: Assoc-acdm, Count: 198 +Precision: 0.7568 | Recall: 0.5957 | F1: 0.6667 +Precision: 0.7209 | Recall: 0.4921 | F1: 0.5849 +education: Assoc-voc, Count: 273 +education: Assoc-voc, Count: 273 +Precision: 0.7209 | Recall: 0.4921 | F1: 0.5849 +Precision: 0.7547 | Recall: 0.8000 | F1: 0.7767 +education: Bachelors, Count: 1053 +education: Bachelors, Count: 1,053 +Precision: 0.7547 | Recall: 0.8000 | F1: 0.7767 +Precision: 0.8387 | Recall: 0.9123 | F1: 0.8739 +education: Doctorate, Count: 77 +education: Doctorate, Count: 77 +Precision: 0.8387 | Recall: 0.9123 | F1: 0.8739 +Precision: 0.8534 | Recall: 0.2870 | F1: 0.4295 +education: HS-grad, Count: 2085 +education: HS-grad, Count: 2,085 +Precision: 0.8534 | Recall: 0.2870 | F1: 0.4295 +Precision: 0.8219 | Recall: 0.8696 | F1: 0.8451 +education: Masters, Count: 369 +education: Masters, Count: 369 +Precision: 0.8219 | Recall: 0.8696 | F1: 0.8451 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +education: Preschool, Count: 10 +education: Preschool, Count: 10 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +Precision: 0.8211 | Recall: 0.9286 | F1: 0.8715 +education: Prof-school, Count: 116 +education: Prof-school, Count: 116 +Precision: 0.8211 | Recall: 0.9286 | F1: 0.8715 +Precision: 0.7588 | Recall: 0.4657 | F1: 0.5772 +education: Some-college, Count: 1485 +education: Some-college, Count: 1,485 +Precision: 0.7588 | Recall: 0.4657 | F1: 0.5772 +Precision: 0.8537 | Recall: 0.3398 | F1: 0.4861 +marital-status: Divorced, Count: 920 +marital-status: Divorced, Count: 920 +Precision: 0.8537 | Recall: 0.3398 | F1: 0.4861 +Precision: 1.0000 | Recall: 0.0000 | F1: 0.0000 +marital-status: Married-AF-spouse, Count: 4 +marital-status: Married-AF-spouse, Count: 4 +Precision: 1.0000 | Recall: 0.0000 | F1: 0.0000 +Precision: 0.7778 | Recall: 0.6702 | F1: 0.7200 +marital-status: Married-civ-spouse, Count: 2950 +marital-status: Married-civ-spouse, Count: 2,950 +Precision: 0.7778 | Recall: 0.6702 | F1: 0.7200 +Precision: 0.6667 | Recall: 0.2500 | F1: 0.3636 +marital-status: Married-spouse-absent, Count: 96 +marital-status: Married-spouse-absent, Count: 96 +Precision: 0.6667 | Recall: 0.2500 | F1: 0.3636 +Precision: 0.9231 | Recall: 0.3495 | F1: 0.5070 +marital-status: Never-married, Count: 2126 +marital-status: Never-married, Count: 2,126 +Precision: 0.9231 | Recall: 0.3495 | F1: 0.5070 +Precision: 1.0000 | Recall: 0.4211 | F1: 0.5926 +marital-status: Separated, Count: 209 +marital-status: Separated, Count: 209 +Precision: 1.0000 | Recall: 0.4211 | F1: 0.5926 +Precision: 1.0000 | Recall: 0.1579 | F1: 0.2727 +marital-status: Widowed, Count: 208 +marital-status: Widowed, Count: 208 +Precision: 1.0000 | Recall: 0.1579 | F1: 0.2727 +Precision: 0.7391 | Recall: 0.4048 | F1: 0.5231 +occupation: ?, Count: 389 +occupation: ?, Count: 389 +Precision: 0.7391 | Recall: 0.4048 | F1: 0.5231 +Precision: 0.6825 | Recall: 0.4479 | F1: 0.5409 +occupation: Adm-clerical, Count: 726 +occupation: Adm-clerical, Count: 726 +Precision: 0.6825 | Recall: 0.4479 | F1: 0.5409 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +occupation: Armed-Forces, Count: 3 +occupation: Armed-Forces, Count: 3 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +Precision: 0.8500 | Recall: 0.3757 | F1: 0.5211 +occupation: Craft-repair, Count: 821 +occupation: Craft-repair, Count: 821 +Precision: 0.8500 | Recall: 0.3757 | F1: 0.5211 +Precision: 0.8147 | Recall: 0.7531 | F1: 0.7827 +occupation: Exec-managerial, Count: 838 +occupation: Exec-managerial, Count: 838 +Precision: 0.8147 | Recall: 0.7531 | F1: 0.7827 +Precision: 0.6667 | Recall: 0.2143 | F1: 0.3243 +occupation: Farming-fishing, Count: 193 +occupation: Farming-fishing, Count: 193 +Precision: 0.6667 | Recall: 0.2143 | F1: 0.3243 +Precision: 0.5714 | Recall: 0.3333 | F1: 0.4211 +occupation: Handlers-cleaners, Count: 273 +occupation: Handlers-cleaners, Count: 273 +Precision: 0.5714 | Recall: 0.3333 | F1: 0.4211 +Precision: 0.7647 | Recall: 0.2766 | F1: 0.4062 +occupation: Machine-op-inspct, Count: 378 +occupation: Machine-op-inspct, Count: 378 +Precision: 0.7647 | Recall: 0.2766 | F1: 0.4062 +Precision: 0.8333 | Recall: 0.1923 | F1: 0.3125 +occupation: Other-service, Count: 667 +occupation: Other-service, Count: 667 +Precision: 0.8333 | Recall: 0.1923 | F1: 0.3125 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +occupation: Priv-house-serv, Count: 26 +occupation: Priv-house-serv, Count: 26 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +Precision: 0.7920 | Recall: 0.8061 | F1: 0.7990 +occupation: Prof-specialty, Count: 828 +occupation: Prof-specialty, Count: 828 +Precision: 0.7920 | Recall: 0.8061 | F1: 0.7990 +Precision: 0.7200 | Recall: 0.4286 | F1: 0.5373 +occupation: Protective-serv, Count: 136 +occupation: Protective-serv, Count: 136 +Precision: 0.7200 | Recall: 0.4286 | F1: 0.5373 +Precision: 0.7607 | Recall: 0.6458 | F1: 0.6986 +occupation: Sales, Count: 729 +occupation: Sales, Count: 729 +Precision: 0.7607 | Recall: 0.6458 | F1: 0.6986 +Precision: 0.7273 | Recall: 0.6275 | F1: 0.6737 +occupation: Tech-support, Count: 189 +occupation: Tech-support, Count: 189 +Precision: 0.7273 | Recall: 0.6275 | F1: 0.6737 +Precision: 0.8333 | Recall: 0.3125 | F1: 0.4545 +occupation: Transport-moving, Count: 317 +occupation: Transport-moving, Count: 317 +Precision: 0.8333 | Recall: 0.3125 | F1: 0.4545 +Precision: 0.7824 | Recall: 0.6701 | F1: 0.7219 +relationship: Husband, Count: 2590 +relationship: Husband, Count: 2,590 +Precision: 0.7824 | Recall: 0.6701 | F1: 0.7219 +Precision: 0.8684 | Recall: 0.3511 | F1: 0.5000 +relationship: Not-in-family, Count: 1702 +relationship: Not-in-family, Count: 1,702 +Precision: 0.8684 | Recall: 0.3511 | F1: 0.5000 +Precision: 1.0000 | Recall: 0.3750 | F1: 0.5455 +relationship: Other-relative, Count: 178 +relationship: Other-relative, Count: 178 +Precision: 1.0000 | Recall: 0.3750 | F1: 0.5455 +Precision: 1.0000 | Recall: 0.1765 | F1: 0.3000 +relationship: Own-child, Count: 1019 +relationship: Own-child, Count: 1,019 +Precision: 1.0000 | Recall: 0.1765 | F1: 0.3000 +Precision: 1.0000 | Recall: 0.2889 | F1: 0.4483 +relationship: Unmarried, Count: 702 +relationship: Unmarried, Count: 702 +Precision: 1.0000 | Recall: 0.2889 | F1: 0.4483 +Precision: 0.7405 | Recall: 0.6783 | F1: 0.7080 +relationship: Wife, Count: 322 +relationship: Wife, Count: 322 +Precision: 0.7405 | Recall: 0.6783 | F1: 0.7080 +Precision: 0.7143 | Recall: 0.5000 | F1: 0.5882 +race: Amer-Indian-Eskimo, Count: 71 +race: Amer-Indian-Eskimo, Count: 71 +Precision: 0.7143 | Recall: 0.5000 | F1: 0.5882 +Precision: 0.7619 | Recall: 0.7742 | F1: 0.7680 +race: Asian-Pac-Islander, Count: 193 +race: Asian-Pac-Islander, Count: 193 +Precision: 0.7619 | Recall: 0.7742 | F1: 0.7680 +Precision: 0.7551 | Recall: 0.5692 | F1: 0.6491 +race: Black, Count: 599 +race: Black, Count: 599 +Precision: 0.7551 | Recall: 0.5692 | F1: 0.6491 +Precision: 1.0000 | Recall: 0.6667 | F1: 0.8000 +race: Other, Count: 55 +race: Other, Count: 55 +Precision: 1.0000 | Recall: 0.6667 | F1: 0.8000 +Precision: 0.7891 | Recall: 0.6106 | F1: 0.6885 +race: White, Count: 5595 +race: White, Count: 5,595 +Precision: 0.7891 | Recall: 0.6106 | F1: 0.6885 +Precision: 0.7548 | Recall: 0.5021 | F1: 0.6031 +sex: Female, Count: 2126 +sex: Female, Count: 2,126 +Precision: 0.7548 | Recall: 0.5021 | F1: 0.6031 +Precision: 0.7912 | Recall: 0.6345 | F1: 0.7043 +sex: Male, Count: 4387 +sex: Male, Count: 4,387 +Precision: 0.7912 | Recall: 0.6345 | F1: 0.7043 +Precision: 0.7407 | Recall: 0.6452 | F1: 0.6897 +native-country: ?, Count: 125 +native-country: ?, Count: 125 +Precision: 0.7407 | Recall: 0.6452 | F1: 0.6897 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Cambodia, Count: 3 +native-country: Cambodia, Count: 3 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +Precision: 0.7000 | Recall: 0.8750 | F1: 0.7778 +native-country: Canada, Count: 22 +native-country: Canada, Count: 22 +Precision: 0.7000 | Recall: 0.8750 | F1: 0.7778 +Precision: 0.8000 | Recall: 1.0000 | F1: 0.8889 +native-country: China, Count: 18 +native-country: China, Count: 18 +Precision: 0.8000 | Recall: 1.0000 | F1: 0.8889 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Columbia, Count: 6 +native-country: Columbia, Count: 6 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +Precision: 1.0000 | Recall: 0.8000 | F1: 0.8889 +native-country: Cuba, Count: 19 +native-country: Cuba, Count: 19 +Precision: 1.0000 | Recall: 0.8000 | F1: 0.8889 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Dominican-Republic, Count: 8 +native-country: Dominican-Republic, Count: 8 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +Precision: 1.0000 | Recall: 0.5000 | F1: 0.6667 +native-country: Ecuador, Count: 5 +native-country: Ecuador, Count: 5 +Precision: 1.0000 | Recall: 0.5000 | F1: 0.6667 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: El-Salvador, Count: 20 +native-country: El-Salvador, Count: 20 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +Precision: 0.7500 | Recall: 0.7500 | F1: 0.7500 +native-country: England, Count: 14 +native-country: England, Count: 14 +Precision: 0.7500 | Recall: 0.7500 | F1: 0.7500 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: France, Count: 5 +native-country: France, Count: 5 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +Precision: 0.8182 | Recall: 0.6923 | F1: 0.7500 +native-country: Germany, Count: 32 +native-country: Germany, Count: 32 +Precision: 0.8182 | Recall: 0.6923 | F1: 0.7500 +Precision: 1.0000 | Recall: 0.0000 | F1: 0.0000 +native-country: Greece, Count: 7 +native-country: Greece, Count: 7 +Precision: 1.0000 | Recall: 0.0000 | F1: 0.0000 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Guatemala, Count: 13 +native-country: Guatemala, Count: 13 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Haiti, Count: 6 +native-country: Haiti, Count: 6 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Honduras, Count: 4 +native-country: Honduras, Count: 4 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +Precision: 0.5000 | Recall: 1.0000 | F1: 0.6667 +native-country: Hong, Count: 8 +native-country: Hong, Count: 8 +Precision: 0.5000 | Recall: 1.0000 | F1: 0.6667 +Precision: 1.0000 | Recall: 0.5000 | F1: 0.6667 +native-country: Hungary, Count: 3 +native-country: Hungary, Count: 3 +Precision: 1.0000 | Recall: 0.5000 | F1: 0.6667 +Precision: 0.8750 | Recall: 0.8750 | F1: 0.8750 +native-country: India, Count: 21 +native-country: India, Count: 21 +Precision: 0.8750 | Recall: 0.8750 | F1: 0.8750 +Precision: 0.5000 | Recall: 0.4000 | F1: 0.4444 +native-country: Iran, Count: 12 +native-country: Iran, Count: 12 +Precision: 0.5000 | Recall: 0.4000 | F1: 0.4444 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Ireland, Count: 5 +native-country: Ireland, Count: 5 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +Precision: 0.7500 | Recall: 0.7500 | F1: 0.7500 +native-country: Italy, Count: 14 +native-country: Italy, Count: 14 +Precision: 0.7500 | Recall: 0.7500 | F1: 0.7500 +Precision: 0.0000 | Recall: 1.0000 | F1: 0.0000 +native-country: Jamaica, Count: 13 +native-country: Jamaica, Count: 13 +Precision: 0.0000 | Recall: 1.0000 | F1: 0.0000 +Precision: 0.7500 | Recall: 0.7500 | F1: 0.7500 +native-country: Japan, Count: 11 +native-country: Japan, Count: 11 +Precision: 0.7500 | Recall: 0.7500 | F1: 0.7500 +Precision: 1.0000 | Recall: 0.0000 | F1: 0.0000 +native-country: Laos, Count: 4 +native-country: Laos, Count: 4 +Precision: 1.0000 | Recall: 0.0000 | F1: 0.0000 +Precision: 1.0000 | Recall: 0.3333 | F1: 0.5000 +native-country: Mexico, Count: 114 +native-country: Mexico, Count: 114 +Precision: 1.0000 | Recall: 0.3333 | F1: 0.5000 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Nicaragua, Count: 7 +native-country: Nicaragua, Count: 7 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +Precision: 1.0000 | Recall: 0.0000 | F1: 0.0000 +native-country: Peru, Count: 5 +native-country: Peru, Count: 5 +Precision: 1.0000 | Recall: 0.0000 | F1: 0.0000 +Precision: 1.0000 | Recall: 0.7500 | F1: 0.8571 +native-country: Philippines, Count: 35 +native-country: Philippines, Count: 35 +Precision: 1.0000 | Recall: 0.7500 | F1: 0.8571 +Precision: 0.3333 | Recall: 0.5000 | F1: 0.4000 +native-country: Poland, Count: 14 +native-country: Poland, Count: 14 +Precision: 0.3333 | Recall: 0.5000 | F1: 0.4000 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Portugal, Count: 6 +native-country: Portugal, Count: 6 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +Precision: 0.8333 | Recall: 0.8333 | F1: 0.8333 +native-country: Puerto-Rico, Count: 22 +native-country: Puerto-Rico, Count: 22 +Precision: 0.8333 | Recall: 0.8333 | F1: 0.8333 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Scotland, Count: 3 +native-country: Scotland, Count: 3 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +Precision: 0.3333 | Recall: 0.5000 | F1: 0.4000 +native-country: South, Count: 13 +native-country: South, Count: 13 +Precision: 0.3333 | Recall: 0.5000 | F1: 0.4000 +Precision: 0.8000 | Recall: 1.0000 | F1: 0.8889 +native-country: Taiwan, Count: 11 +native-country: Taiwan, Count: 11 +Precision: 0.8000 | Recall: 1.0000 | F1: 0.8889 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Thailand, Count: 5 +native-country: Thailand, Count: 5 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Trinadad&Tobago, Count: 3 +native-country: Trinadad&Tobago, Count: 3 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +Precision: 0.7879 | Recall: 0.6056 | F1: 0.6848 +native-country: United-States, Count: 5870 +native-country: United-States, Count: 5,870 +Precision: 0.7879 | Recall: 0.6056 | F1: 0.6848 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Vietnam, Count: 5 +native-country: Vietnam, Count: 5 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 +native-country: Yugoslavia, Count: 2 +native-country: Yugoslavia, Count: 2 +Precision: 1.0000 | Recall: 1.0000 | F1: 1.0000 diff --git a/test_ml.py b/test_ml.py index 5f8306f14c..20ccb7e904 100644 --- a/test_ml.py +++ b/test_ml.py @@ -1,28 +1,51 @@ import pytest # TODO: add necessary import +import numpy as np +from sklearn.ensemble import RandomForestClassifier +import pandas as pd +from sklearn.model_selection import train_test_split, StratifiedKFold +from ml.model import train_model, compute_model_metrics -# TODO: implement the first test. Change the function name and input as needed -def test_one(): +# Implement the first test. Change the function name and input as needed +def test_model_type(): """ - # add description for the first test + # Test that train_model returns a RandomForestClassifier instance when trained on a simple dataset. """ - # Your code here - pass + X = pd.DataFrame({ + "feature1": [0, 1] * 10, + "feature2": [1, 0] * 10 + }) + y = pd.Series([0, 1] * 10) + X_train, _, y_train, _ = train_test_split(X, y, test_size=0.2, stratify=y) + model = train_model(X_train, y_train) + assert isinstance(model, RandomForestClassifier), "Model is not a RandomForestClassifier" -# TODO: implement the second test. Change the function name and input as needed -def test_two(): +# Implement the second test. Change the function name and input as needed +def test_compute_model_metrics_values(): """ - # add description for the second test + # Test that compute_model_metrics returns expected metric values (approximate). """ - # Your code here - pass + y_true = np.array([0, 1, 1, 1]) + y_pred = np.array([0, 1, 0, 1]) + precision, recall, f1 = compute_model_metrics(y_true, y_pred) + assert np.isclose(precision, 0.99, atol=1.0), "Precision out of expected range" + assert np.isclose(recall, 0.66, atol=1.0), "Recall out of expected range" + assert np.isclose(f1, 0.66, atol=1.0), "F1 out of expected range" + -# TODO: implement the third test. Change the function name and input as needed -def test_three(): + +# Implement the third test. Change the function name and input as needed +def test_training_input_shape(): """ - # add description for the third test + # Test that train_model handles correct input shape and type. """ - # Your code here - pass + X = pd.DataFrame({ + "feature1": [0, 1] * 10, + "feature2": [1, 0] * 10 + }) + y = pd.Series([0, 1] * 10) + X_train, _, y_train, _ = train_test_split(X, y, test_size=0.2, stratify=y) + model = train_model(X_train, y_train) + assert hasattr(model, "predict"), "Trained model does not have predict method" diff --git a/train_model.py b/train_model.py index ae783ed5b9..3664cff9e9 100644 --- a/train_model.py +++ b/train_model.py @@ -12,15 +12,15 @@ save_model, train_model, ) -# TODO: load the cencus.csv data -project_path = "Your path here" +# Load census.csv data +project_path = "/home/cjlsw/Deploying-a-Scalable-ML-Pipeline-with-FastAPI" data_path = os.path.join(project_path, "data", "census.csv") print(data_path) -data = None # your code here +data = pd.read_csv(data_path) # your code here -# TODO: split the provided data to have a train dataset and a test dataset +# Split the provided data to have a train dataset and a test dataset # Optional enhancement, use K-fold cross validation instead of a train-test split. -train, test = None, None# Your code here +train, test = train_test_split(data, test_size=0.20, random_state=42)# Your code here # DO NOT MODIFY cat_features = [ @@ -34,12 +34,12 @@ "native-country", ] -# TODO: use the process_data function provided to process the data. +# Use the process_data function provided to process the data. X_train, y_train, encoder, lb = process_data( - # your code here - # use the train dataset - # use training=True - # do not need to pass encoder and lb as input + train, # your code here + categorical_features=cat_features,# use the train dataset + label="salary",# use training=True + training=True# do not need to pass encoder and lb as input ) X_test, y_test, _, _ = process_data( @@ -51,36 +51,44 @@ lb=lb, ) -# TODO: use the train_model function to train the model on the training dataset -model = None # your code here +# Use the train_model function to train the model on the training dataset +model = train_model(X_train, y_train) # your code here # save the model and the encoder model_path = os.path.join(project_path, "model", "model.pkl") save_model(model, model_path) encoder_path = os.path.join(project_path, "model", "encoder.pkl") save_model(encoder, encoder_path) +lb_path = os.path.join(project_path, "model", "lb.pkl") +save_model(lb, lb_path) # load the model model = load_model( model_path ) -# TODO: use the inference function to run the model inferences on the test dataset. -preds = None # your code here +# Use the inference function to run the model inferences on the test dataset. +preds = inference(model, X_test) # your code here # Calculate and print the metrics p, r, fb = compute_model_metrics(y_test, preds) print(f"Precision: {p:.4f} | Recall: {r:.4f} | F1: {fb:.4f}") -# TODO: compute the performance on model slices using the performance_on_categorical_slice function +# Compute the performance on model slices using the performance_on_categorical_slice function # iterate through the categorical features for col in cat_features: # iterate through the unique values in one categorical feature for slicevalue in sorted(test[col].unique()): count = test[test[col] == slicevalue].shape[0] p, r, fb = performance_on_categorical_slice( - # your code here - # use test, col and slicevalue as part of the input + test, # your code here + col, # use test, col and slicevalue as part of the input + slicevalue, + cat_features, + label="salary", + encoder=encoder, + lb=lb, + model=model ) with open("slice_output.txt", "a") as f: print(f"{col}: {slicevalue}, Count: {count:,}", file=f)