import pandas as pddefinitions = pd.read_json('data/variables_descriptions.json')definitions.head()
name
description
0
age
Age of the individual in years
1
gender
Biological sex (0 = Female, 1 = Male)
2
education_years
Total years of formal education completed
3
income_level
Socioeconomic status on an ordinal scale (1 = ...
4
smoker
Indicates whether the individual has a history...
In this notebook, we will focus on predicting lung cancer risk based on patient data. Our approach includes: 1. Exploratory Data Analysis (EDA): Examining data distributions and correlations among features. 2. Feature Engineering & Selection: Handling multicollinearity using Variance Inflation Factor (VIF). 3. Supervised Modeling: Training Logistic Regression and Lasso models. 4. Model Evaluation & Interpretation: Comparing models using precision, recall, and F1-score, and explaining model predictions using SHAP values.
1. Data Preparation
First, we split our dataset into training and testing sets. We use a stratified split on the target variable lung_cancer_risk to ensure both sets have a representative proportion of the positive class.
Code
from sklearn.model_selection import train_test_splittarget_label ='lung_cancer_risk'y = df_lung[target_label].reset_index(drop=True)X = df_lung.drop(target_label, axis=1).reset_index(drop=True)
2. Exploratory Data Analysis (EDA)
We visualize the correlations between all features. Highly correlated features can introduce multicollinearity, which might destabilize linear models like Logistic Regression.
Code
from statsmodels_utils import plot_correlation_heatmapplot_correlation_heatmap(X, 'Correlation Heatmap all features')
Let’s take a closer look at the relationships between smoking-related variables, as they are typically highly correlated.
As seen in the correlation heatmap and pair plot, there are strong correlations among certain features (like pack_years and smoking_years). This indicates a potential multicollinearity issue, which we can quantify using the Variance Inflation Factor (VIF). We will recursively calculate VIF and drop features that exceed a threshold to ensure model stability.
Starting VIF Multicollinearity Check...
----------------------------------------
[DROPPED] 'pack_years'
* VIF Score: 35.66 (Overlap: 97.2%)
* Primary Suspects (Highest pairwise correlations with this feature):
- fev1_x10 (0.95)
- crp_level (0.92)
- oxygen_saturation (0.91)
* Current Runner-Ups (Watch these scores drop in the next loop!):
- cigarettes_per_day: 14.60
- smoking_years: 13.52
- smoker: 12.77
------------------------------
[DROPPED] 'cigarettes_per_day'
* VIF Score: 10.40 (Overlap: 90.4%)
* Primary Suspects (Highest pairwise correlations with this feature):
- smoker (0.81)
- fev1_x10 (0.79)
- crp_level (0.79)
* Current Runner-Ups (Watch these scores drop in the next loop!):
- smoker: 10.22
- smoking_years: 9.84
- fev1_x10: 7.19
------------------------------
[DROPPED] 'fev1_x10'
* VIF Score: 6.33 (Overlap: 84.2%)
* Primary Suspects (Highest pairwise correlations with this feature):
- crp_level (0.88)
- oxygen_saturation (0.86)
- smoking_years (0.77)
* Current Runner-Ups (Watch these scores drop in the next loop!):
- crp_level: 5.36
- smoking_years: 5.03
- oxygen_saturation: 4.73
------------------------------
----------------------------------------
VIF Check Complete!
Features dropped: 3
4. Modeling: Logistic Regression
We will start by training a Logistic Regression model (Logit) from statsmodels. To select the most relevant features, we employ a stepwise selection method, iteratively adding features that significantly improve the model based on their p-values.
Code
# go for stats model logitfrom statsmodels_utils import stepwise_selection, backward_eliminationimport statsmodels.api as smlogit_model = sm.Logitfinal_model_logit, features = backward_elimination(logit_model, X_train_clean, y_train, disp=False, warn_convergence=False)print(f"Optimal features: {features}")final_model_logit.summary()
Possibly complete quasi-separation: A fraction 0.67 of observations can be perfectly predicted. This might indicate that there is complete quasi-separation. In this case some parameters will not be identified.
The impressive Recall score gives us confidence that the model is performing well, particularly in identifying true positive cases. In medical diagnoses such as lung cancer prediction, minimizing false negatives (i.e., missing a cancer diagnosis) is critical, making Recall a primary metric.
Overall, we have reached a good model that generalizes well to the unseen test data. Next, we will compare it against a Lasso (L1 Regularization) model to evaluate differences in metrics and feature selection.
Code
from sklearn.utils import compute_sample_weightfrom sklearn.linear_model import Lassoalpha =1.0weights = compute_sample_weight(class_weight='balanced', y=y_train)clf = Lasso(alpha=alpha, random_state=seed)clf.fit(X_train, y_train, sample_weight=weights)clf_selected_features = X_train.columns[clf.coef_ !=0]y_pred_lasso = clf.predict(X_test[X_train.columns])y_pred_lasso = (y_pred_lasso >0.5).astype(int)print("Lasso selected features: ", clf_selected_features.tolist())print("Stepwise selected features: ", features)print("Common features: ", [features[i] for i inrange(len(features)) if features[i] in clf_selected_features])print("Logit Excluded features from Lasso: ", clf_selected_features.difference(features).tolist())print("Lasso Excluded features from Logit: ", list(set(features).difference(clf_selected_features)))
Lasso is highly aggressive with its L1 penalty and selected only 3 features: smoking_years, cigarettes_per_day, and pack_years. In contrast, the stepwise selection retained 17 features and correctly avoided pack_years (which we had previously dropped due to high VIF).
The Lasso model performed significantly worse in detecting positive cases. It skipped several cases of lung cancer, resulting in a much lower Recall.
In a life-threatening context, achieving a high Recall score is crucial, and sacrificing a small amount of precision is totally acceptable. Lasso did not outperform the Logit model (with stepwise selection) on Recall or F1-score. Therefore, we will proceed with our Logistic Regression model.
6. Saving Model Artifacts
We will serialize and save the coefficients of our chosen Logistic Regression model, along with the list of selected features. This will allow us to reproduce the scoring and deploy the model for future predictions.
Code
# we save the coefficients for further predictionsimport jsonimport osout_dir ='../frontend/public'ols_filename ='ols_weights'coeffs = final_model_logit.params.tolist()os.makedirs(out_dir, exist_ok=True)withopen(f"{out_dir}/{ols_filename}.json", 'w') as f: json.dump(coeffs, f, indent=4)print(f"Model coefficients saved to '{out_dir}/{ols_filename}.json'") features_filename =f"{ols_filename}_features.json"# we save the features and their order for referencewithopen(f"{out_dir}/{features_filename}", 'w') as f:# the add the const as the first feature json.dump(['const',*features], f, indent=4)print(f"Features saved to '{out_dir}/{ols_filename}_features.json'")
Model coefficients saved to '../frontend/public/ols_weights.json'
Features saved to '../frontend/public/ols_weights_features.json'
7. Model Interpretability with SHAP
To understand how our Logistic Regression model makes its predictions, we will analyze the impact of individual features using SHAP (SHapley Additive exPlanations). This allows for both global and local interpretability.
Code
from statsmodels_utils import plot_coeffeatures_impact_image =f'{plot_folder}/features_impact.jpg'plot_coef(final_model_logit, plot_file=features_impact_image)
Plot saved successfully to: ../frontend/public/plots/features_impact.jpg
Global Explanations: SHAP Summary
The SHAP summary plot gives us a bird’s-eye view of feature importance and the direction of the relationship between feature values and the prediction.
Plot saved successfully to: ../frontend/public/plots/shap_summary.jpg
As seen in the plots above, the top three most influential features (crp_level, oxygen_saturation, and xray_abnormal) are all derived from clinical laboratory tests and imaging. This reinforces the clinical validity of the model, as these objective medical measurements strongly drive the predictions.
The SHAP bar plot aggregates the mean absolute SHAP values, providing a clear ranking of overall feature importance across the entire dataset.
Code
from statsmodels_utils import plot_shap_barshap_feature_importance_image =f'{plot_folder}/shap_feature_importance.jpg'plot_shap_bar(final_model_logit, with_constant, plot_file=shap_feature_importance_image)
Plot saved successfully to: ../frontend/public/plots/shap_feature_importance.jpg
Local Explanations: Single Prediction
We can also use SHAP to explain individual predictions. Let’s look at a specific patient (e.g., index 99) to see how their specific feature values pushed the model’s prediction probability above or below the base value.
Code
from statsmodels_utils import plot_shap_waterfall_singleindex_to_plot =99predicted_probability = final_model_logit.predict(with_constant.iloc[[index_to_plot]])label =f"This case {"has"if predicted_probability.iloc[0] >0.5else"dont have"} cancer."single_explanation_image =f'{plot_folder}/single_explanation.jpg'plot_shap_waterfall_single( final_model_logit, with_constant, row_index=index_to_plot, title=f"{label} Single explanation", plot_file=single_explanation_image)
Plot saved successfully to: ../frontend/public/plots/single_explanation.jpg
For the patient at index 99, we can see exactly how the model arrived at its positive prediction. Notably, having an abnormal X-ray (xray_abnormal = 1) is a primary driver, increasing the log-odds of the prediction. Other clinical features and medical history also significantly push the risk higher.
Feature Dependence
SHAP dependence plots show the marginal effect of a feature on the prediction. Let’s visualize how the model’s output depends on specific variables like age.
Code
from statsmodels_utils import plot_shap_dependenceshap_dependence_image =f'{plot_folder}/shap_dependence.jpg'plot_shap_dependence(final_model_logit, with_constant, 'age', plot_file=shap_feature_importance_image)
Plot saved successfully to: ../frontend/public/plots/shap_feature_importance.jpg
Code
from statsmodels_utils import plot_top_shap_dependenceplot_shap_dependence_image =f'{plot_folder}/shap_dependence_top.jpg'plot_top_shap_dependence(final_model_logit, with_constant, top_n=3, plot_file=plot_shap_dependence_image)