﻿//#include "example.h"
//#include "shared/logger/logger.h"
//#include "database/sql_func.h"
//#include "database/connect.h"
//#include <iostream>
#include "utils.h"
#include "xgboost/learner.h"
#include "xgboost/c_api.h"

//#include <xgboost/data.h>
//#include <xgboost/logging.h>

#include <iostream>
#include <QtCore>

//#define log_error_m   alog::logger().error  (__FILE__, __func__, __LINE__, "XGBoost")
//#define log_warn_m    alog::logger().warn   (__FILE__, __func__, __LINE__, "XGBoost")
//#define log_info_m    alog::logger().info   (__FILE__, __func__, __LINE__, "XGBoost")
//#define log_verbose_m alog::logger().verbose(__FILE__, __func__, __LINE__, "XGBoost")
//#define log_debug_m   alog::logger().debug  (__FILE__, __func__, __LINE__, "XGBoost")
//#define log_debug2_m  alog::logger().debug2 (__FILE__, __func__, __LINE__, "XGBoost")

#define CHECK_RESULT \
    if (res)         \
        return 1;    \

using namespace std;
using namespace task;

int main(int argc, char *argv[])
{
    QMap<QString, QString> xgbOptions;
    // Параметр n_estimators
    QVector<qint16> estim = /*{500};*/{ 50, 100, 150, 200, 250, 300, 350, 400, 450, 500, 550, 600 };
    // Параметр max_depth
    QVector<qint16> max_depth = /*{10};*/{ 3, 5, 7, 9, 11 };
    // Параметр learning_rate
    QVector<double> learning_rate = /*{0.1};*/{ 0.01, 0.025, 0.05, 0.075, 0.1, 0.25, 0.5 };

    float _thresholdStep = 0.01;
    int _iterCount = 10;
    double _diffLimit = 0.025;

    QFile reportFile("report.csv");
    //reportFile.open(QIODevice::WriteOnly | QIODevice::Text);
    reportFile.open(QIODevice::WriteOnly);
    QTextStream filestream(&reportFile);
    filestream.setCodec("UTF8");

    // ЭКМП
    xgbOptions.insert("silent",           "false"           );
    xgbOptions.insert("objective",        "binary:logistic" );
    xgbOptions.insert("nthread",          "6"               );
    xgbOptions.insert("gamma",            "0.01"            );
    xgbOptions.insert("scale_pos_weight", "0.5"             );
    xgbOptions.insert("min_child_weight", "2"               );
    xgbOptions.insert("max_delta_step",   "0"               );
    xgbOptions.insert("subsample",        "0.8"             );
    xgbOptions.insert("colsample_bytree", "0.7"             );
    xgbOptions.insert("base_score",       "0.5"             );
    xgbOptions.insert("seed",             "0"               );
    xgbOptions.insert("missing",          "None"            );

    // МЭЭ
//    xgbOptions.insert("silent",            "false"           );
//    xgbOptions.insert("objective",         "binary:logistic" );
//    xgbOptions.insert("nthread",           "6"               );
//    xgbOptions.insert("gamma",             "0.01"            );
//    xgbOptions.insert("min_child_weight",  "1"               );
//    xgbOptions.insert("max_delta_step",    "0"               );
//    xgbOptions.insert("subsample",         "0.85"            );
//    xgbOptions.insert("colsample_bytree",  "0.75"            );
//    xgbOptions.insert("base_score",        "0.5"             );
//    xgbOptions.insert("seed",              "0"               );
//    xgbOptions.insert("missing",           "None"            );

    QFile fileT("/home/egorov_vn/input_data/2ecba520-b5a5-48f9-9b32-dbcfc4145be1_ekmp_train.dump");
    fileT.open(QIODevice::ReadOnly | QIODevice::Text);
    QVector<QVector<float>> dataListT;

    bool first = true;
    while (!fileT.atEnd())
    {
        QString line = QString(fileT.readLine());
        if (first)
        {
            first = false;
            continue;
        }
        QVector<QString> vec = line.split(";").toVector();
        QVector<float> fvec;
        for(QString &item : vec)
        {
            fvec.push_back(item.toFloat());
        }
        dataListT.push_back(fvec);
    }
    fileT.close();

    QFile fileA("/home/egorov_vn/input_data/2ecba520-b5a5-48f9-9b32-dbcfc4145be1_ekmp_apply.dump");
    fileA.open(QIODevice::ReadOnly | QIODevice::Text);
    QVector<QVector<float>> dataListA;

    first = true;
    while (!fileA.atEnd())
    {
        QString line = QString(fileA.readLine());
        if (first)
        {
            first = false;
            continue;
        }
        QVector<QString> vec = line.split(";").toVector();
        QVector<float> fvec;
        for(QString &item : vec)
        {
            fvec.push_back(item.toFloat());
        }
        dataListA.push_back(fvec);
    }
    fileA.close();

    // Матрица входных данных
    Ret2DArray trainDataT;
    qint32 rowsT = dataListT.count();
    qint32 colsT = dataListT[0].count() - 1;
    trainDataT.alloc(rowsT, colsT);

    for (int i = 0; i < rowsT; ++i)
    {
        float* r1 = (float*)dataListT[i].constData();
        float* r2 = trainDataT.ptr() + i * trainDataT.columns();
        for (int j = 0; j < colsT; ++j)
            *r2++ = *r1++;
    }

    // Матрица меток
    Ret2DArray labelsDataT;
    labelsDataT.alloc(rowsT, 1);

    for (int i = 0; i < rowsT; ++i)
    {
        float lbl = dataListT.at(i)[colsT];
        labelsDataT.ptr()[i] = lbl;
    }

    // Матрица входных данных (проверка)
    Ret2DArray trainDataA;
    qint32 rowsA = dataListA.count();
    qint32 colsA = dataListA[0].count() - 1;
    trainDataA.alloc(rowsA, colsA);

    for (int i = 0; i < rowsA; ++i)
    {
        float* r1 = (float*)dataListA[i].constData();
        float* r2 = trainDataA.ptr() + i * trainDataA.columns();
        for (int j = 0; j < colsA; ++j)
            *r2++ = *r1++;
    }

    // Матрица меток  (проверка)
    Ret2DArray labelsDataA;
    labelsDataA.alloc(rowsA, 1);

    for (int i = 0; i < rowsA; ++i)
    {
        float lbl = dataListA.at(i)[colsA];
        labelsDataA.ptr()[i] = lbl;
    }

    DMatrixHandle hTrainT[1];

    int res = XGDMatrixCreateFromMat(trainDataT.ptr(), trainDataT.rows(), trainDataT.columns(), NAN, &hTrainT[0]);
    CHECK_RESULT

    res = XGDMatrixSetFloatInfo(hTrainT[0], "label", labelsDataT.ptr(), labelsDataT.rows());
    CHECK_RESULT

    DMatrixHandle hTrainA[1];
    res = XGDMatrixCreateFromMat(trainDataA.ptr(), trainDataA.rows(), trainDataA.columns(), NAN, &hTrainA[0]);
    CHECK_RESULT

    // create the booster and load some parameters
    BoosterHandle hBooster;

    double maxDiff = 0;
    int p1 = 0, p2 = 0, p3 = 0;

    // n-estimators
    for (int est = 0; est < estim.size(); est++)
    {
        // max_depth
        for (int depth = 0; depth < max_depth.size(); depth++)
        {
            // learning_rate
            for (int rate = 0; rate < learning_rate.size(); rate++)
            {
                filestream << QString::fromUtf8(u8"estimators: ") << estim[est] << ", "
                           << QString::fromUtf8(u8"max_depth:  ") << max_depth[depth] << ", "
                           << QString::fromUtf8(u8"learning_rate: ") << learning_rate[rate] << "; "
                           << QString::fromUtf8(u8"F-мера") << ";"
                           << QString::fromUtf8(u8"Итерация") << ";"
                           << QString::fromUtf8(u8"Порог") << ";"
                           << QString::fromUtf8(u8"Прирост") << ";"
                           << QString::fromUtf8(u8"МинВероят") << ";"
                           << QString::fromUtf8(u8"МаксВероят") << ";"
                           << QString::fromUtf8(u8"РазницаДиап") << ";"
                           << endl;

                res = XGBoosterCreate(hTrainT, 1, &hBooster);
                CHECK_RESULT

                res = XGBoosterSetParam(hBooster, "n_estimators", std::to_string(estim[est]).c_str());
                CHECK_RESULT
                res = XGBoosterSetParam(hBooster, "max_depth", std::to_string(max_depth[depth]).c_str());
                CHECK_RESULT
                res = XGBoosterSetParam(hBooster, "learning_rate", std::to_string(learning_rate[rate]).c_str());
                CHECK_RESULT

                float minValue, maxValue;

                for (QString &key : xgbOptions.keys())
                {
                    QString name = key;
                    QString value  = xgbOptions.value(key);
                    res = XGBoosterSetParam(hBooster, name.toStdString().c_str(), value.toStdString().c_str());
                    CHECK_RESULT
                }


                // Площадь предудущей фигуры для каждой итерации
                double squareOnePrev = 0;
                // Площадь предудущей фигуры для нескольких итераций
                double squareManyPrev = 0;

                for (int i = 0; i <= 200; i++)
                {
                    XGBoosterUpdateOneIter(hBooster, i, hTrainT[0]);

                    // Разница между соседними итерациями
                    double diffOne = 0;
                    // Разница между несколькими итерациями
                    double diffMany = 0;

                    double f1MeasureMax = 0;
                    double f1ThresholdMax = 0;

                    bst_ulong out_len;
                    const float* f;
                    //res = XGBoosterPredict(hBooster, hTrain[0], 0, 0, &out_len, &f);
                    res = XGBoosterPredict(hBooster, hTrainA[0], 0, 0, &out_len, &f);
                    CHECK_RESULT

                    double TP = 0;
                    double TN = 0;
                    double FP = 0;
                    double FN = 0;
                    double threshold = 0;

                    // Обнуление значения текущей площади
                    double squareCurrent = 0;

                    minValue = 1;
                    maxValue = 0;

                    for (threshold = _thresholdStep; threshold <= 1.0; threshold += _thresholdStep)
                    {
                        TP = 0;
                        TN = 0;
                        FP = 0;
                        FN = 0;

                        bool predict = false;
                        bool fact = false;

                        for (int j = 0; j < int(out_len); ++j)
                        {

                            // Прогноз XGBoost
                            float predictScore = f[j];

                            if (predictScore < minValue )
                                minValue = predictScore;

                            if (predictScore > maxValue )
                                maxValue = predictScore;

                            if (maxValue - minValue > maxDiff)
                            {
                                maxDiff = maxValue - minValue;
                                p1 = est; p2 = depth; p3 = rate;
                            }

                            if (predictScore <= threshold)
                                predict = false;
                            else
                                predict = true;

                            // Фактическое значение
                            float expertScore = labelsDataA.ptr()[j];
                            if (expertScore > 0)
                                fact = true;
                            else
                                fact = false;

                            if (predict && fact)
                                ++TP;
                            else if (predict && !fact)
                                ++FP;
                            else if (!predict && fact)
                                ++FN;
                            else if (!predict && !fact)
                                ++TN;

                        }

                        double precision = 0;
                        if ((TP + FP) > 0)
                            precision = TP / ( TP + FP );

                        double recall = 0;
                        if ((TP + FN) > 0)
                            recall = TP / ( TP + FN );

                        double f1Measure = 0;

                        if ((precision + recall) < 0 || (precision + recall) > 0)
                            f1Measure = 2 * ( (precision * recall) / (precision + recall) );

                        // Если найдено максимальное значение Ф-Меры, то необходимо
                        // сохранить значение Ф-Меры и значение порога, на котором
                        // данное максимальное значение найдено.
                        if (f1Measure > f1MeasureMax)
                        {
                            f1MeasureMax = f1Measure;
                            f1ThresholdMax = threshold;
                        }

                        // Общая площадь фигуры
                        squareCurrent += _thresholdStep * f1Measure;
                    }

                    filestream << ";";

                    filestream << f1MeasureMax << ";"
                               << i << ";"
                               << f1ThresholdMax  << ";";

                    // Для первой итерации значение предыдущего заполняется первым вычисленным значением
                    // и расчёт разницы не прозводится
                    if (i == 0)
                    {
                        squareOnePrev = squareCurrent;
                        squareManyPrev = squareCurrent;
                        filestream << "0;";
                        filestream << minValue << ";" << maxValue << ";";
                        filestream << "0";
                        filestream << endl;
                        continue;
                    }
                    // Разница текущей и предыдущей площадей
                    diffOne = squareCurrent - squareOnePrev;

                    // Разница между текущей площадью и предыдущей с разницей в _iterCount итераций
                    if (i % _iterCount == 0)
                    {
                        diffMany = squareCurrent - squareManyPrev;
                    }

                    filestream << QString::number(diffOne, 'f', 5) << ";";

                    squareOnePrev = squareCurrent;

                    if (diffOne < 0.001)
                    {
                        filestream << minValue << ";" << maxValue << ";";
                        filestream << maxValue - minValue;
                        filestream << endl;

                        break;
                    }

//                    if (i % _iterCount == 0)
//                    {
//                        squareManyPrev = squareCurrent;

//                        if (diffMany < _diffLimit)
//                        {
//                            filestream << minValue << ";" << maxValue;
//                            filestream << endl;

//                            break;
//                        }
//                    }

                    filestream << minValue << ";" << maxValue << ";";
                    filestream << maxValue - minValue;
                    filestream << endl;
                }

                filestream.flush();

                XGBoosterFree(hBooster);
            }
        }
    }

    filestream << QString::fromUtf8(u8"n-estimators: ") << p1
               << QString::fromUtf8(". max_depth: ") << p2
               << QString::fromUtf8(". learning_rate: ") << p3
               << endl;

    XGDMatrixFree(hTrainT[0]);
    XGDMatrixFree(hTrainA[0]);
}
