python - Problem with SHAP plot for multiclass problem - Stack Overflow

I have following this code:from xgboost import XGBClassifierimport shapfrom sklearn.preprocessing im

I have following this code:

from xgboost import XGBClassifier
import shap
from sklearn.preprocessing import LabelEncoder

# Encode the labels for multiclass classification
label_encoder = LabelEncoder()
y_enc = label_encoder.fit_transform(y)

# Train an XGBoost model
model = XGBClassifier(objective="multi:softprob", num_class=3)
model.fit(X, y_enc)

explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)
classes = label_encoder.inverse_transform(range(
                            len(label_encoder.classes_)))
shap.summary_plot(shap_values, X, class_names=classes, plot_type='barh')

but the plot shown like this:

2

Any idea what went wrong?

I have following this code:

from xgboost import XGBClassifier
import shap
from sklearn.preprocessing import LabelEncoder

# Encode the labels for multiclass classification
label_encoder = LabelEncoder()
y_enc = label_encoder.fit_transform(y)

# Train an XGBoost model
model = XGBClassifier(objective="multi:softprob", num_class=3)
model.fit(X, y_enc)

explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)
classes = label_encoder.inverse_transform(range(
                            len(label_encoder.classes_)))
shap.summary_plot(shap_values, X, class_names=classes, plot_type='barh')

but the plot shown like this:

2

Any idea what went wrong?

Share Improve this question edited Nov 19, 2024 at 15:35 0stone0 44.3k5 gold badges51 silver badges80 bronze badges asked Nov 19, 2024 at 15:31 ZidanZidan 11 bronze badge 1
  • You might like to describe how you would like the table to look and what efforts you have made to fix the problem. – Tony Williams Commented Nov 20, 2024 at 3:05
Add a comment  | 

1 Answer 1

Reset to default 0

I made a dummy data to try out this code

import numpy as np
import pandas as pd
from xgboost import XGBClassifier
import shap
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

# Generate a sample dataset
np.random.seed(42)

# Create random features (100 samples, 5 features)
X = pd.DataFrame(np.random.randn(100, 5), columns=['feat1', 'feat2', 'feat3', 'feat4', 'feat5'])

# Create random labels for 3 classes
y = np.random.choice(['class_0', 'class_1', 'class_2'], size=100)

# Encode the labels for multiclass classification
label_encoder = LabelEncoder()
y_enc = label_encoder.fit_transform(y)

# Train-test split (80% train, 20% test)
X_train, X_test, y_train, y_test = train_test_split(X, y_enc, test_size=0.2, random_state=42)

Your code is fine, I think there is probably something wrong with your x or y data which I cannot really check.

发布者:admin,转转请注明出处:http://www.yc00.com/questions/1742418119a4440137.html

相关推荐

  • python - Problem with SHAP plot for multiclass problem - Stack Overflow

    I have following this code:from xgboost import XGBClassifierimport shapfrom sklearn.preprocessing im

    12小时前
    20

发表回复

评论列表(0条)

  • 暂无评论

联系我们

400-800-8888

在线咨询: QQ交谈

邮件:admin@example.com

工作时间:周一至周五,9:30-18:30,节假日休息

关注微信