# 1) Separación de los datos: Conjunto de entrenamiento y prueba.
# Importamos train_test_split para dividir los datos en conjuntos de entrenamiento y prueba
from sklearn.model_selection import train_test_split
X_train_desafio, X_test_desafio, y_train_desafio, y_test_desafio = train_test_split(X_desafio,y_desafio, stratify = y_desafio, random_state=5)
# 2) Modelo base con el DummyClassifier y tasa de acierto con el método score.
# Cargamos DummyClassifier, un modelo de referencia muy simple
from sklearn.dummy import DummyClassifier
# Instanciamos el modelo Dummy
dummy_desafio = DummyClassifier()
# Entrenamos el modelo Dummy con los datos de entrenamiento
dummy_desafio.fit(X_train_desafio, y_train_desafio)
# Evaluamos el modelo Dummy con los datos de prueba
dummy_desafio.score(X_test_desafio, y_test_desafio)
# Resultado: 0.7964
# 3) Árbol de decisión.
from sklearn.tree import DecisionTreeClassifier
# Instanciamos el modelo con una semilla fija para asegurar reproducibilidad
modelo_arbol_desafio = DecisionTreeClassifier(max_depth=4, random_state=5)
# Entrenamos el modelo con los datos de entrenamiento
modelo_arbol_desafio.fit(X_train_desafio, y_train_desafio)
# Evaluamos el modelo con el conjunto de prueba
modelo_arbol_desafio.score(X_test_desafio,y_test_desafio)
# Resultado: 0.8464
# Visualización del árbol de decisión entrenado
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
# Lista con los nombres de las columnas que se usarán en el gráfico para mejor interpretación
valores_columnas_desafio = ['pais_Alemania',
'pais_España',
'pais_Francia',
'sexo_biologico_Mujer',
'tiene_tarjeta_credito_1',
'miembro_activo_1',
'score_credito',
'edad',
'años_de_cliente',
'saldo',
'servicios_adquiridos',
'salario_estimado']
# Graficamos el árbol
plt.figure(figsize=(24,7))
plot_tree(modelo_arbol_desafio, filled=True,class_names=['no','si'],fontsize=7,feature_names=valores_columnas_desafio);
# Evaluamos nuevamente el modelo pero ahora en entrenamiento
round(modelo_arbol_desafio.score(X_train_desafio,y_train_desafio),4)
# Resultado: 0.8509