{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "initial_id", "metadata": { "ExecuteTime": { "end_time": "2025-02-08T12:03:10.007903Z", "start_time": "2025-02-08T12:03:08.375866Z" } }, "outputs": [], "source": [ "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from catboost import Pool, CatBoostRegressor" ] }, { "cell_type": "markdown", "id": "698ef0c6-d381-40b3-93f8-74b6904ba9a4", "metadata": {}, "source": [ "### Можно ли решить задачу с наскока?\n", "\n", "Сперва стоит подгрузить данные и посмотреть на то, с чем нам придётся работать.\n", "\n", "В этои соревновании участникам предоставляются 4 набора из пар файлов (train и test):\n", "1. train_main.parquet (279 признаков + ключ)\n", "2. train_card_spending.parquet (630 признаков + ключ)\n", "3. train_mcc_operations.parquet (1640 признаков + ключ)\n", "4. train_mcc_preferences.parquet (2112 признаков + ключ)\n", "\n", "Для простоты, попробуем использовать первый набор (*_main.parquet*).\n", "\n", "#### Данные с признаками (1 из 4):\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "44cafd85-2c29-4d73-b7ac-09b6fa32907f", "metadata": {}, "outputs": [], "source": [ "train = pd.read_parquet('data/task3/train_main.parquet')\n", "test = pd.read_parquet('data/task3/test_main.parquet')" ] }, { "cell_type": "code", "execution_count": 3, "id": "e12cd977-b91b-42e1-a448-cf04d1d209e6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Данные для обучения: (213345, 280)\n", "Тестовые данные: (318451, 280)\n" ] } ], "source": [ "print('Данные для обучения:', train.shape)\n", "print('Тестовые данные:', test.shape)" ] }, { "cell_type": "code", "execution_count": 4, "id": "7315a29f-f3ef-424b-8543-ee571f124d64", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
user_idapp_children_cntapp_dependent_cntapp_family_cntapp_income_appapp_real_estate_indapp_vehicle_indavg_dep_avg_balance_12month_amtavg_dep_avg_balance_12month_amt_termavg_dep_avg_balance_12month_amt_term_savings...savings_sum_oms_debet_3msavings_sum_oms_debet_6msavings_sum_oms_debet_9msavings_sum_oms_debet_12msavings_service_model_cdsavings_pension_flgsavings_deposit_flgsavings_safe_acc_flgsavings_broker_flgsavings_oms_flg
09NaNNaNNaNNaNNaNNaNNaNNaN4027.373535...34.6132160.0000004.31041475.214180Массовый00100
111NaNNaNNaNNaNNaNNaNNaNNaNNaN...6.2376720.0000000.0000000.000000Массовый00100
2120.00.00.0105372.9609380.00.0315208.781250NaN274816.375000...0.00000053.13412990.0252380.000000Массовый00100
3131.00.03.00.0000000.00.043187.953125NaN5277.233887...0.00000022.27611482.070015117.386795Массовый00100
415NaNNaNNaNNaN0.00.0NaNNaN0.000000...0.0000000.0000000.00000033.072178Массовый00100
\n", "

5 rows × 280 columns

\n", "
" ], "text/plain": [ " user_id app_children_cnt app_dependent_cnt app_family_cnt \\\n", "0 9 NaN NaN NaN \n", "1 11 NaN NaN NaN \n", "2 12 0.0 0.0 0.0 \n", "3 13 1.0 0.0 3.0 \n", "4 15 NaN NaN NaN \n", "\n", " app_income_app app_real_estate_ind app_vehicle_ind \\\n", "0 NaN NaN NaN \n", "1 NaN NaN NaN \n", "2 105372.960938 0.0 0.0 \n", "3 0.000000 0.0 0.0 \n", "4 NaN 0.0 0.0 \n", "\n", " avg_dep_avg_balance_12month_amt avg_dep_avg_balance_12month_amt_term \\\n", "0 NaN NaN \n", "1 NaN NaN \n", "2 315208.781250 NaN \n", "3 43187.953125 NaN \n", "4 NaN NaN \n", "\n", " avg_dep_avg_balance_12month_amt_term_savings ... \\\n", "0 4027.373535 ... \n", "1 NaN ... \n", "2 274816.375000 ... \n", "3 5277.233887 ... \n", "4 0.000000 ... \n", "\n", " savings_sum_oms_debet_3m savings_sum_oms_debet_6m \\\n", "0 34.613216 0.000000 \n", "1 6.237672 0.000000 \n", "2 0.000000 53.134129 \n", "3 0.000000 22.276114 \n", "4 0.000000 0.000000 \n", "\n", " savings_sum_oms_debet_9m savings_sum_oms_debet_12m \\\n", "0 4.310414 75.214180 \n", "1 0.000000 0.000000 \n", "2 90.025238 0.000000 \n", "3 82.070015 117.386795 \n", "4 0.000000 33.072178 \n", "\n", " savings_service_model_cd savings_pension_flg savings_deposit_flg \\\n", "0 Массовый 0 0 \n", "1 Массовый 0 0 \n", "2 Массовый 0 0 \n", "3 Массовый 0 0 \n", "4 Массовый 0 0 \n", "\n", " savings_safe_acc_flg savings_broker_flg savings_oms_flg \n", "0 1 0 0 \n", "1 1 0 0 \n", "2 1 0 0 \n", "3 1 0 0 \n", "4 1 0 0 \n", "\n", "[5 rows x 280 columns]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train.head(n = 5)" ] }, { "cell_type": "markdown", "id": "607e32ae-f33c-45f4-ae18-93338a822986", "metadata": {}, "source": [ "#### Особенности данных\n", "\n", "Сразу видно очень много пропусков. \n", "\n", "Наверное с пропусками даже придётся что-то делать, ведь где-то их 50% или больше." ] }, { "cell_type": "code", "execution_count": 5, "id": "f0e01e42-4e99-4ab1-86e4-dabe2cd3730b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Columns with missing values:\n", " vehicle_counrty_type_nm 206739\n", "max_amt_foreign_cur_5y 201136\n", "max_amt_dep_6m 182526\n", "min_amt_term_g1y 180330\n", "max_amt_dep_act 175074\n", " ... \n", "cnt_foreign_cur_5y 26\n", "cnt_save_5y 26\n", "cnt_grow_5y 26\n", "cnt_term_g1y 26\n", "cnt_manage_5y 26\n", "Length: 167, dtype: int64\n" ] } ], "source": [ "missing_values = train.isnull().sum().sort_values(ascending = False)\n", "missing_values = missing_values[missing_values > 0]\n", "print('\\nColumns with missing values:\\n', missing_values)" ] }, { "cell_type": "markdown", "id": "19e0c8d0-60a1-47fe-93a6-a9d810471ead", "metadata": {}, "source": [ "Интересно, а в тестовых данных же всё точно так же? \n", "\n", "Наверное, нужно будет это проверить 🤔️️️️️️\n", "\n", "Для некоторых алгоритмов будет неловко, если появятся новые признаки с пропусками, которых не было в обучающих данных 👀️️️️️️" ] }, { "cell_type": "markdown", "id": "31e49b25-225b-4d5f-997d-923e81a7f3a3", "metadata": {}, "source": [ "#### Типы данных\n", "\n", "Стоит верхнеуровнево посмотреть что находится внутри. Особенно интересуют категориальные признаки." ] }, { "cell_type": "code", "execution_count": 6, "id": "2975998a-e869-445d-ab7d-140eb7dd6c4d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "float64 266\n", "object 8\n", "int32 6\n", "Name: count, dtype: int64" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train.dtypes.value_counts()" ] }, { "cell_type": "markdown", "id": "1605bdb7-a901-406a-b936-39a901be144b", "metadata": {}, "source": [ "Категориальные признаки стоит отдельно выделить и сохранить для Catboost." ] }, { "cell_type": "code", "execution_count": 7, "id": "74f07ad6-0ef9-4368-9ad2-266ad9a0ac42", "metadata": {}, "outputs": [], "source": [ "features = train.columns\n", "\n", "categorical_features = train[features].select_dtypes(include=['object']).columns\n", "\n", "for feature in categorical_features:\n", " train[feature] = train[feature].astype(str)\n", "\n", "categorical_features_indices = np.where(train.dtypes == 'object')[0]" ] }, { "cell_type": "markdown", "id": "d9a7e087-20a0-45e3-a94a-a1e2bc7be6d1", "metadata": {}, "source": [ "Для первого подхода к снаряду, этого хватит. Не хватает только целевой переменной.\n", "\n", "### Целевая переменная\n", "\n" ] }, { "cell_type": "code", "execution_count": 8, "id": "25261de2-03f1-41d6-8830-a50bdc0dd061", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(213345, 2)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "target = pd.read_csv('data/task3/train_target.csv')\n", "target.shape" ] }, { "cell_type": "code", "execution_count": 9, "id": "beea57de-e043-44ce-b01e-ebd201017b84", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
user_idtarget
090.00000
1110.00000
212219932.90625
313631.77002
4150.00000
\n", "
" ], "text/plain": [ " user_id target\n", "0 9 0.00000\n", "1 11 0.00000\n", "2 12 219932.90625\n", "3 13 631.77002\n", "4 15 0.00000" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "target.head(5)" ] }, { "cell_type": "markdown", "id": "9e823601-0db6-485a-9f39-bb31f9ab6bb7", "metadata": {}, "source": [ "Мы решаем задачу регрессии. Из описания соревнования, нам требуется предсказать:\n", "\n", "> 50 перцентиль распределения суммарных остатков на всех накоп.ительных счетах клиента на горизонте +2 мес. от отчетной даты\n", "\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "bf91b2ba-0307-4b19-8420-f2beae248e53", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "count 2.133450e+05\n", "mean 2.210490e+05\n", "std 9.894988e+05\n", "min -7.100000e-01\n", "25% 0.000000e+00\n", "50% 3.174000e+01\n", "75% 1.000027e+05\n", "max 1.015605e+08\n", "Name: target, dtype: float64" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "target['target'].describe()" ] }, { "cell_type": "markdown", "id": "99e823e8-0f67-40df-b932-c5c4b603f6e0", "metadata": {}, "source": [ "Целевая переменная точно требует ее преобразовать. Посмотрим как она выглядит после log1p:" ] }, { "cell_type": "code", "execution_count": 11, "id": "f3ae0ce4-1295-4903-af59-b86ba61ff3ec", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(10,4))\n", "\n", "plt.hist(np.log1p(target['target']), bins = 200);" ] }, { "cell_type": "markdown", "id": "a0b88862-bd0a-4493-9342-38f20407bf78", "metadata": {}, "source": [ "В распределении очень много нулей, так что стоит смотреть чуть уже:" ] }, { "cell_type": "code", "execution_count": 12, "id": "a7da6bae-7c27-4cc3-ba62-82d3a606a15f", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(10,4))\n", "\n", "plt.hist(np.log1p(target.loc[target['target'] > 0, 'target']), bins = 200);" ] }, { "cell_type": "markdown", "id": "38caffeb-d66a-4393-bfdc-9ece2ec720ac", "metadata": {}, "source": [ "Распределение выглядит очень похожим на смесь:\n", "- Есть клиенты с значениями около нуля\n", "- Есть компонента смеси с центром в районе 6, то есть ~400 (np.exp(6) - 1)\n", "- Есть компонента справа, с центров в районе 13, то есть ~440,000\n", "- И есть еще клиенты с ровно 0, которых мы убрали с графика\n", "\n", "Выглядит заманчиво и для ML, и визуализации. Но нас пока интересует только сабмит.\n", "\n", "### Catboost\n", "\n", "Начнём собирать всё что нам потребуется дя обучения Catboost-а.\n", "\n", "- Будем ли мы проверять, что порядок `user_id` полностью совпадает в train и target?\n", "- Будем ли мы сразу настраивать свою валидацию и делить данные?\n", "- Или может быть будем что-либо преобразовывать?\n", "\n", "Нет, нас интересует atboost сабмит ASAP 🤗️️️️️️ " ] }, { "cell_type": "code", "execution_count": 13, "id": "378306c2-1373-4868-96d2-0e9b54bcde9f", "metadata": {}, "outputs": [], "source": [ "train_pool = Pool(data = train, \n", " label = np.log1p(target['target']), \n", " cat_features = categorical_features_indices)" ] }, { "cell_type": "markdown", "id": "e35224b7-1329-47ce-8ec6-e69d0ae6a0a9", "metadata": {}, "source": [ "#### Обучение\n", "\n", "Главные настройки, которые нам стоит учесть:\n", "\n", "- Так как метрика соревнования это RMSLE, а мы уже логарифмировали (log1p) целевую переменную, оптимизировать мы будем RMSE\n", "- У нас много пропусков в данных, поэтому нам очень повезло что у Catboost есть настройка nan_mode" ] }, { "cell_type": "code", "execution_count": 14, "id": "ff688458-d150-42e8-99d1-b8aced7ffd61", "metadata": {}, "outputs": [], "source": [ "model = CatBoostRegressor(iterations = 100, \n", " depth = 6, \n", " learning_rate = 0.1, \n", " loss_function = 'RMSE', \n", " nan_mode = 'Min', \n", " random_seed = 314,\n", " verbose = 10)\n" ] }, { "cell_type": "code", "execution_count": 15, "id": "3cd7c516-80b2-438c-ac59-f14e181c09ee", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0:\tlearn: 5.1242446\ttotal: 128ms\tremaining: 12.7s\n", "10:\tlearn: 2.7903763\ttotal: 891ms\tremaining: 7.21s\n", "20:\tlearn: 2.2642536\ttotal: 1.65s\tremaining: 6.21s\n", "30:\tlearn: 2.1466213\ttotal: 2.33s\tremaining: 5.18s\n", "40:\tlearn: 2.1033289\ttotal: 3s\tremaining: 4.31s\n", "50:\tlearn: 2.0807207\ttotal: 3.63s\tremaining: 3.48s\n", "60:\tlearn: 2.0606817\ttotal: 4.26s\tremaining: 2.72s\n", "70:\tlearn: 2.0467002\ttotal: 4.91s\tremaining: 2s\n", "80:\tlearn: 2.0315319\ttotal: 5.6s\tremaining: 1.31s\n", "90:\tlearn: 2.0204322\ttotal: 6.22s\tremaining: 615ms\n", "99:\tlearn: 2.0125081\ttotal: 6.84s\tremaining: 0us\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.fit(train_pool)" ] }, { "cell_type": "markdown", "id": "fe974317-8a22-4a11-abf7-298245c3e3a9", "metadata": {}, "source": [ "Мы успешно обучили модель 🌟️️️️️️\n", "\n", "И вправду — зачем нам валидация, если можно ее сразу отправить в соревнование и узнать наш результат на лидерборде? Он же не будет прямо сильно хуже чем в логе обучения? (ведь правда, да?)\n", "\n", "### Подготовка сабмита\n", "\n", "Посмотрим на пример рабочего бейзлайн решения. \n", "\n", "Именно в таком формате платформа ждет от нас решения:" ] }, { "cell_type": "code", "execution_count": 16, "id": "456c9f2e-e51c-4419-a402-871cfd1322d6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(318451, 2)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sample = pd.read_csv('data/task3/sample_submit_naive.csv')\n", "sample.shape" ] }, { "cell_type": "code", "execution_count": 17, "id": "da43b692-9ec7-415d-9c61-42ffaaf6cdf5", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
user_idpredict
010000081.004656e+06
110000090.000000e+00
210000135.047758e+02
310000161.680799e+05
410000172.222542e+02
\n", "
" ], "text/plain": [ " user_id predict\n", "0 1000008 1.004656e+06\n", "1 1000009 0.000000e+00\n", "2 1000013 5.047758e+02\n", "3 1000016 1.680799e+05\n", "4 1000017 2.222542e+02" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sample.head(5)" ] }, { "cell_type": "markdown", "id": "526becca-affc-439c-a378-15aa630705f3", "metadata": {}, "source": [ "С форматом тоже всё понятно. \n", "\n", "Важно заметить, что предсказания от нас ждут без преобразований целевой переменной, так что нужно будет сделать обратные преобразования предсказаний нашей модели.\n", "\n", "#### Использование модели \n", "\n", "Тестовые данные у нас уже есть, но их нужно подготовить для формата Catboost-а." ] }, { "cell_type": "code", "execution_count": 18, "id": "4ba55359-c12f-49af-9ddd-7cc50594f2d8", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
user_idapp_children_cntapp_dependent_cntapp_family_cntapp_income_appapp_real_estate_indapp_vehicle_indavg_dep_avg_balance_12month_amtavg_dep_avg_balance_12month_amt_termavg_dep_avg_balance_12month_amt_term_savings...savings_sum_oms_debet_3msavings_sum_oms_debet_6msavings_sum_oms_debet_9msavings_sum_oms_debet_12msavings_service_model_cdsavings_pension_flgsavings_deposit_flgsavings_safe_acc_flgsavings_broker_flgsavings_oms_flg
01000008NaNNaNNaNNaNNaNNaN998138.5625002678.6992191.009246e+06...0.0000000.00000067.8935090.000000Массовый00100
110000090.0NaNNaN29125.3945310.00.00.000030NaNNaN...8.40705054.11141670.21389082.739632Массовый00100
210000130.00.00.059536.8164060.00.054086.031250NaN3.513455e+04...0.0000000.0000000.00000056.554066Массовый00100
310000160.0NaNNaN66908.4687500.00.060340.105469NaN6.347482e+04...33.32173259.4613990.0000000.000000Массовый00110
41000017NaNNaNNaNNaNNaNNaN0.000030NaN0.000000e+00...26.5271950.00000056.96283059.217648Массовый00100
\n", "

5 rows × 280 columns

\n", "
" ], "text/plain": [ " user_id app_children_cnt app_dependent_cnt app_family_cnt \\\n", "0 1000008 NaN NaN NaN \n", "1 1000009 0.0 NaN NaN \n", "2 1000013 0.0 0.0 0.0 \n", "3 1000016 0.0 NaN NaN \n", "4 1000017 NaN NaN NaN \n", "\n", " app_income_app app_real_estate_ind app_vehicle_ind \\\n", "0 NaN NaN NaN \n", "1 29125.394531 0.0 0.0 \n", "2 59536.816406 0.0 0.0 \n", "3 66908.468750 0.0 0.0 \n", "4 NaN NaN NaN \n", "\n", " avg_dep_avg_balance_12month_amt avg_dep_avg_balance_12month_amt_term \\\n", "0 998138.562500 2678.699219 \n", "1 0.000030 NaN \n", "2 54086.031250 NaN \n", "3 60340.105469 NaN \n", "4 0.000030 NaN \n", "\n", " avg_dep_avg_balance_12month_amt_term_savings ... \\\n", "0 1.009246e+06 ... \n", "1 NaN ... \n", "2 3.513455e+04 ... \n", "3 6.347482e+04 ... \n", "4 0.000000e+00 ... \n", "\n", " savings_sum_oms_debet_3m savings_sum_oms_debet_6m \\\n", "0 0.000000 0.000000 \n", "1 8.407050 54.111416 \n", "2 0.000000 0.000000 \n", "3 33.321732 59.461399 \n", "4 26.527195 0.000000 \n", "\n", " savings_sum_oms_debet_9m savings_sum_oms_debet_12m \\\n", "0 67.893509 0.000000 \n", "1 70.213890 82.739632 \n", "2 0.000000 56.554066 \n", "3 0.000000 0.000000 \n", "4 56.962830 59.217648 \n", "\n", " savings_service_model_cd savings_pension_flg savings_deposit_flg \\\n", "0 Массовый 0 0 \n", "1 Массовый 0 0 \n", "2 Массовый 0 0 \n", "3 Массовый 0 0 \n", "4 Массовый 0 0 \n", "\n", " savings_safe_acc_flg savings_broker_flg savings_oms_flg \n", "0 1 0 0 \n", "1 1 0 0 \n", "2 1 0 0 \n", "3 1 1 0 \n", "4 1 0 0 \n", "\n", "[5 rows x 280 columns]" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test.head(5)" ] }, { "cell_type": "code", "execution_count": 19, "id": "3a7a771e-ed7d-4aa9-bf43-cfaaec8a707a", "metadata": {}, "outputs": [], "source": [ "test_pool = Pool(data = test, \n", " cat_features = categorical_features_indices)" ] }, { "cell_type": "code", "execution_count": 20, "id": "684501d8-d52a-4cbf-8e68-fdd1f7f8e246", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(318451,)" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_predict = model.predict(test_pool)\n", "test_predict.shape" ] }, { "cell_type": "markdown", "id": "bd893712-bddb-485a-adad-9ec65b0ed3e5", "metadata": {}, "source": [ "Обратные преобразования (не забываем -1):" ] }, { "cell_type": "code", "execution_count": 21, "id": "e19e0142-4df9-445b-9ecd-c6514ba27415", "metadata": {}, "outputs": [], "source": [ "test_full_predict = np.exp(test_predict) - 1" ] }, { "cell_type": "markdown", "id": "f3a99017-1335-4a4c-b883-d7d68c462e28", "metadata": {}, "source": [ "#### Упаковка сабмита\n", "\n", "Так как мы торопимся отправить решение, мы снова доверимся воле случая, что все 'user_id' отсортированы за нас 🌚️️️️️️\n", "\n", "И мы просто перепишем предсказания в исходном сабмит файле." ] }, { "cell_type": "code", "execution_count": 22, "id": "d70eb17b-caa6-4c05-9748-d918c3715171", "metadata": {}, "outputs": [], "source": [ "sample['predict'] = test_full_predict" ] }, { "cell_type": "code", "execution_count": 23, "id": "9f54656f-1491-4af6-aaa5-697155bf97a5", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
user_idpredict
01000008290314.317272
11000009-0.030440
2100001311.244460
31000016271.858813
4100001712.335367
\n", "
" ], "text/plain": [ " user_id predict\n", "0 1000008 290314.317272\n", "1 1000009 -0.030440\n", "2 1000013 11.244460\n", "3 1000016 271.858813\n", "4 1000017 12.335367" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sample.head(5)" ] }, { "cell_type": "markdown", "id": "f819060c-0fc8-46a0-8304-b58e85f657e8", "metadata": {}, "source": [ "Финишная прямая — пишем файл:\n" ] }, { "cell_type": "code", "execution_count": 24, "id": "0496c64c-87b0-4545-9e34-fb0f99f3354b", "metadata": {}, "outputs": [], "source": [ "sample.to_csv('submit_baseline_catboost.csv', index=False)" ] }, { "cell_type": "markdown", "id": "82e301c7-379c-4ec1-a861-477d2f2b86a6", "metadata": {}, "source": [ "И... результат на паблике это 4.169190539294088\t\n", "\n", "- Это лучше чем наивный сабмит с 5.848489205052006\n", "- Но это точно не метрика, показываемая в логах при обучении\n", "- И этот результат вряд ли пошёл бы в продакшен\n", "\n", "Можно ли это улучшить? " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 5 }