306 lines
9.6 KiB
Plaintext
306 lines
9.6 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# KNN Model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 1. Datensatz laden\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 37,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>Vertragsnummer</th>\n",
|
|
" <th>Jahr</th>\n",
|
|
" <th>Vertragsalter</th>\n",
|
|
" <th>Status</th>\n",
|
|
" <th>Marktzins</th>\n",
|
|
" <th>Tarifzins</th>\n",
|
|
" <th>Guthaben</th>\n",
|
|
" <th>Sparbeitrag</th>\n",
|
|
" <th>Beruf</th>\n",
|
|
" <th>Alter</th>\n",
|
|
" <th>p</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>0</th>\n",
|
|
" <td>10000</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>gekuendigt</td>\n",
|
|
" <td>8.7</td>\n",
|
|
" <td>8.7</td>\n",
|
|
" <td>3933.85</td>\n",
|
|
" <td>3619.0</td>\n",
|
|
" <td>angestellt</td>\n",
|
|
" <td>48</td>\n",
|
|
" <td>0.154465</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1</th>\n",
|
|
" <td>9999</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>aktiv</td>\n",
|
|
" <td>8.7</td>\n",
|
|
" <td>8.7</td>\n",
|
|
" <td>523.93</td>\n",
|
|
" <td>482.0</td>\n",
|
|
" <td>angestellt</td>\n",
|
|
" <td>28</td>\n",
|
|
" <td>0.154465</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>2</th>\n",
|
|
" <td>9998</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>aktiv</td>\n",
|
|
" <td>8.7</td>\n",
|
|
" <td>8.7</td>\n",
|
|
" <td>7281.81</td>\n",
|
|
" <td>6699.0</td>\n",
|
|
" <td>selbstaendig</td>\n",
|
|
" <td>22</td>\n",
|
|
" <td>0.231475</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>3</th>\n",
|
|
" <td>9997</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>aktiv</td>\n",
|
|
" <td>8.7</td>\n",
|
|
" <td>8.7</td>\n",
|
|
" <td>8905.79</td>\n",
|
|
" <td>8193.0</td>\n",
|
|
" <td>angestellt</td>\n",
|
|
" <td>26</td>\n",
|
|
" <td>0.154465</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>4</th>\n",
|
|
" <td>9996</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>1</td>\n",
|
|
" <td>aktiv</td>\n",
|
|
" <td>8.7</td>\n",
|
|
" <td>8.7</td>\n",
|
|
" <td>8566.65</td>\n",
|
|
" <td>7881.0</td>\n",
|
|
" <td>angestellt</td>\n",
|
|
" <td>30</td>\n",
|
|
" <td>0.154465</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" Vertragsnummer Jahr Vertragsalter Status Marktzins Tarifzins \\\n",
|
|
"0 10000 1 1 gekuendigt 8.7 8.7 \n",
|
|
"1 9999 1 1 aktiv 8.7 8.7 \n",
|
|
"2 9998 1 1 aktiv 8.7 8.7 \n",
|
|
"3 9997 1 1 aktiv 8.7 8.7 \n",
|
|
"4 9996 1 1 aktiv 8.7 8.7 \n",
|
|
"\n",
|
|
" Guthaben Sparbeitrag Beruf Alter p \n",
|
|
"0 3933.85 3619.0 angestellt 48 0.154465 \n",
|
|
"1 523.93 482.0 angestellt 28 0.154465 \n",
|
|
"2 7281.81 6699.0 selbstaendig 22 0.231475 \n",
|
|
"3 8905.79 8193.0 angestellt 26 0.154465 \n",
|
|
"4 8566.65 7881.0 angestellt 30 0.154465 "
|
|
]
|
|
},
|
|
"execution_count": 37,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"import pandas as pd\n",
|
|
"\n",
|
|
"data = pd.read_csv('../3_lernfall/VertragsdatenABCAG.csv', sep=',', decimal='.')\n",
|
|
"data.head()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 2. Datenvorverarbeitung"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 38,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"array([[1.5611106971092041, -0.9658806326457345, 0.7729934357656066,\n",
|
|
" 0.19511754814402793, -0.47513581897659896, 1.389983329807085,\n",
|
|
" 0.5164540311483795, 0, 0],\n",
|
|
" [0.7340234916233305, 1.7968467730396942, 0.42492374363524343,\n",
|
|
" 1.422200684147216, 3.3738011771456753, 1.3529959365231317,\n",
|
|
" -0.16122757685231498, 0, 0],\n",
|
|
" [-1.74723812483429, -0.5712052889763876, 0.8165021472819023,\n",
|
|
" 1.422200684147216, -0.3993293859693223, -0.0075120485158652085,\n",
|
|
" 0.008192825147858662, 0, 0],\n",
|
|
" [0.3204798888803937, 0.6128207420316534, -0.09718079456030153,\n",
|
|
" -0.5988774222109761, 0.998529481910825, 1.1900420434702434,\n",
|
|
" -0.7541989838529227, 0, 1],\n",
|
|
" [-0.5066073166054799, 0.2181453983623064, -1.315424717016573,\n",
|
|
" 0.23120822861471013, -0.037511869651474364, -0.19977670605792403,\n",
|
|
" 1.0247152371489006, 0, 0]], dtype=object)"
|
|
]
|
|
},
|
|
"execution_count": 38,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"from sklearn.preprocessing import StandardScaler\n",
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"\n",
|
|
"# Stichprobe der Daten\n",
|
|
"data = data.sample(frac=0.3, random_state=11)\n",
|
|
"\n",
|
|
"# One hot encoding für die Spalten \"Beruf\" und \"Status\"\n",
|
|
"data = pd.get_dummies(data, columns=['Beruf', 'Status'], drop_first=True)\n",
|
|
"\n",
|
|
"# Zielvariable und Merkmale definieren\n",
|
|
"y = data['Status_gekuendigt']\n",
|
|
"X = data.drop(['Status_gekuendigt', 'Status_aktiv', 'Status_verstorben', 'Vertragsnummer', 'p'], axis=1)\n",
|
|
"\n",
|
|
"# Nur numerische Spalten skalieren, Dummy-Variablen nicht skalieren\n",
|
|
"num_cols = X.select_dtypes(include=['int64', 'float64']).columns\n",
|
|
"dummy_cols = X.select_dtypes(include=['bool']).columns\n",
|
|
"\n",
|
|
"scaler = StandardScaler()\n",
|
|
"X_num_scaled = scaler.fit_transform(X[num_cols])\n",
|
|
"X_scaled = X.copy()\n",
|
|
"X_scaled[num_cols] = X_num_scaled\n",
|
|
"X_scaled = X_scaled.values\n",
|
|
"\n",
|
|
"# Dummy-Variablen in 0 und 1 umwandeln\n",
|
|
"for col in dummy_cols:\n",
|
|
" X_scaled[:, X.columns.get_loc(col)] = X[col].astype(int)\n",
|
|
"\n",
|
|
"# Train-Test-Split\n",
|
|
"X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=22)\n",
|
|
"\n",
|
|
"# Ersten 5 Zeilen der Trainingsdaten anzeigen\n",
|
|
"X_train[:5]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 3. KNN Modell implementieren und MSE berechnen"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from sklearn.neighbors import KNeighborsClassifier\n",
|
|
"from sklearn.metrics import mean_squared_error\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
"# Werte für k definieren\n",
|
|
"k_values = list(range(1, 101, 1))\n",
|
|
"train_mse = []\n",
|
|
"test_mse = []\n",
|
|
"\n",
|
|
"for k in k_values:\n",
|
|
" model = KNeighborsClassifier(n_neighbors=k)\n",
|
|
" model.fit(X_train, y_train)\n",
|
|
" y_train_pred = model.predict_proba(X_train)\n",
|
|
" y_test_pred = model.predict_proba(X_test)\n",
|
|
" train_mse.append(mean_squared_error(y_train, y_train_pred[:, 1]))\n",
|
|
" test_mse.append(mean_squared_error(y_test, y_test_pred[:, 1]))\n",
|
|
"\n",
|
|
"# mean var (p*(1-p)).mean()\n",
|
|
"# oder: np.var(y - p)\n",
|
|
"\n",
|
|
"# Ergebnisse visualisieren\n",
|
|
"plt.figure(figsize=(10, 6))\n",
|
|
"plt.plot(k_values, train_mse, label='Train MSE', marker='o')\n",
|
|
"plt.plot(k_values, test_mse, label='Test MSE', marker='o')\n",
|
|
"\n",
|
|
"# vertikale Linie für den besten k-Wert einzeichnen\n",
|
|
"min_test_mse = min(test_mse)\n",
|
|
"best_k = k_values[test_mse.index(min_test_mse)]\n",
|
|
"plt.axvline(x=best_k, color='red', linestyle='--', label=f'Best k = {best_k}')\n",
|
|
"\n",
|
|
"plt.title('MSE für verschiedene Werte von k')\n",
|
|
"plt.xlabel('Wert von k')\n",
|
|
"plt.ylabel('Mean Squared Error')\n",
|
|
"plt.xscale('log')\n",
|
|
"plt.legend()\n",
|
|
"plt.grid()\n",
|
|
"plt.show()"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "hft_ml",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.10"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|