#include "gtest/gtest.h"
#include "db/NewDB.h"  // NewDB 的头文件
#include "leveldb/env.h"
#include "leveldb/db.h"
#include "db/write_batch_internal.h"
#include <iostream>
#include <random>
#include <ctime>
#include <thread>
#include <mutex>
#include <chrono>
#include <atomic>
using namespace leveldb;

// 全局计数器
std::atomic<int> put_count(0);
std::atomic<int> delete_count(0);
std::atomic<int> create_index_count(0);
std::atomic<int> delete_index_count(0);
std::atomic<int> total_operations(0);

std::mutex latency_mutex;
std::vector<long long> latencies;

thread_local std::vector<long long> put_latencies_local;
thread_local std::vector<long long> delete_latencies_local;
thread_local std::vector<long long> update_latencies_local;

std::mutex put_latency_mutex;
std::vector<long long> put_latencies;

std::mutex delete_latency_mutex;
std::vector<long long> delete_latencies;

std::mutex update_latency_mutex;
std::vector<long long> update_latencies;


// void RecordLatency(long long duration) {
//     std::unique_lock<std::mutex> lock(latency_mutex);
//     latencies.push_back(duration);
// }

void RecordLatency(long long duration, const std::string& op_type) {
    if (op_type == "PUT") {
        put_latencies_local.push_back(duration);
    } else if (op_type == "DELETE") {
        delete_latencies_local.push_back(duration);
    } else if (op_type == "UPDATE") {
        update_latencies_local.push_back(duration);
    }
}
void MergeLatencies() {
    {
        std::unique_lock<std::mutex> lock(put_latency_mutex);
        put_latencies.insert(put_latencies.end(),
                             put_latencies_local.begin(),
                             put_latencies_local.end());
    }

    {
        std::unique_lock<std::mutex> lock(delete_latency_mutex);
        delete_latencies.insert(delete_latencies.end(),
                                delete_latencies_local.begin(),
                                delete_latencies_local.end());
    }

    {
        std::unique_lock<std::mutex> lock(update_latency_mutex);
        update_latencies.insert(update_latencies.end(),
                                update_latencies_local.begin(),
                                update_latencies_local.end());
    }
}


Status OpenNewDB(std::string dbName, NewDB** db) {
    Options options = Options();
    options.create_if_missing = true;
    return NewDB::Open(options, dbName, db);
}

// 全局的随机数引擎
std::default_random_engine rng;

// 设置随机种子
void SetGlobalSeed(unsigned seed) {
    rng.seed(seed);
}

// 生成随机字符串
std::string GenerateRandomString(size_t length) {
    static const char alphanum[] =
        "0123456789"
        "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
        "abcdefghijklmnopqrstuvwxyz";

    std::uniform_int_distribution<int> dist(0, sizeof(alphanum) - 2);

    std::string str(length, 0);
    for (size_t i = 0; i < length; ++i) {
        str[i] = alphanum[dist(rng)];
    }
    return str;
}

int getRandomInRange(int min, int max) {
    return min + std::rand() % ((max + 1) - min);
}

std::mutex results_mutex;
std::vector<bool> results;

void InsertResult(bool TRUEorFALSE){
    std::unique_lock<std::mutex> lock(results_mutex, std::defer_lock);
    lock.lock();
    results.emplace_back(TRUEorFALSE);
    lock.unlock();
}

std::string testdbname = "dbtest36";




void thread_task_put(NewDB* db, int thread_id, int num_operations) {
    for (int i = 0; i < num_operations; ++i) {
        auto start = std::chrono::high_resolution_clock::now();
        std::string key = "k_" + std::to_string(thread_id) + "_" + std::to_string(i);
        FieldArray fields = {
            {"name", "User" + std::to_string(i)},
            {"email", "user" + std::to_string(i) + "@test.com"}
        };
        db->Put_fields(WriteOptions(), key, fields);
        total_operations++;

        auto end = std::chrono::high_resolution_clock::now();
        long long duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
        RecordLatency(duration, "PUT");
    }
    // 合并本线程的延迟数据
    MergeLatencies();
}
void thread_task_delete(NewDB* db, int thread_id, int num_operations) {
    for (int i = 0; i < num_operations; ++i) {
        int key_index = getRandomInRange(0, 9999);
        std::string key = "k_" + std::to_string(key_index);

        auto start = std::chrono::high_resolution_clock::now();
        db->Delete(WriteOptions(), key);
        delete_count++;

        auto end = std::chrono::high_resolution_clock::now();
        long long duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
        RecordLatency(duration, "DELETE");
        total_operations++;
    }
    // 合并本线程的延迟数据
    MergeLatencies();
}
void thread_task_update(NewDB* db, int thread_id, int num_operations) {
    for (int i = 0; i < num_operations; ++i) {
        std::string key = "k_" + std::to_string(thread_id) + "_" + std::to_string(i);
        FieldArray fields = {
            {"email", "updated_user" + std::to_string(i) + "@test.com"}
        };
        
        auto start = std::chrono::high_resolution_clock::now();
        db->Put_fields(WriteOptions(), key, fields);  // 简单地更新一个字段
        total_operations++;

        auto end = std::chrono::high_resolution_clock::now();
        long long duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();
        RecordLatency(duration, "UPDATE");
    }
    // 合并本线程的延迟数据
    MergeLatencies();
}


// 并发测试:包括 put, delete, update
TEST(NewDBTest, ConcurrencyTest_mul) {
    NewDB* db;
    ASSERT_TRUE(OpenNewDB(testdbname, &db).ok());

    const int thread_num_put = 10;
    const int thread_num_delete = 1;
    const int thread_num_update = 0;
    const int operations_per_thread = 10;
    const int operations_per_thread_delete = 50;
    const int operations_per_thread_update = 50;

    std::vector<std::thread> threads;

    auto start = std::chrono::high_resolution_clock::now();
    // 启动 PUT 线程
    for (int i = 0; i < thread_num_put; ++i) {
        threads.emplace_back(thread_task_put, db, i, operations_per_thread);
    }
    // 启动 DELETE 线程
    for (int i = 0; i < thread_num_delete; ++i) {
        threads.emplace_back(thread_task_delete, db, i, operations_per_thread_delete);
    }

    // 启动 UPDATE 线程
    for (int i = 0; i < thread_num_update; ++i) {
        threads.emplace_back(thread_task_update, db, i, operations_per_thread_update);
    }

    for (auto& th : threads) {
        th.join();
    }

    auto end = std::chrono::high_resolution_clock::now();
    long long total_duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();

    auto calculate_latency = [](const std::vector<long long>& latencies) {
        long long sum = 0;
        for (auto& lat : latencies) {
            sum += lat;
        }
        return std::make_pair(sum, latencies.size() > 0 ? static_cast<double>(sum) / latencies.size() : 0);
    };

    auto [put_sum, put_avg] = calculate_latency(put_latencies);
    auto [delete_sum, delete_avg] = calculate_latency(delete_latencies);
    auto [update_sum, update_avg] = calculate_latency(update_latencies);

    std::cout << "Total Operations: " << total_operations.load() << std::endl;
    std::cout << "Total Duration: " << total_duration << " ms" << std::endl;
    std::cout << "Throughput (ops/sec): " << static_cast<double>(total_operations.load()) / (total_duration / 1000.0) << std::endl;
    
    std::cout << "\nPUT Operations:" << std::endl;
    std::cout << "  Total Latency (ms): " << put_sum / 1000.0 << std::endl;
    std::cout << "  Average Latency (us): " << put_avg << std::endl;

    std::cout << "\nDELETE Operations:" << std::endl;
    std::cout << "  Total Latency (ms): " << delete_sum / 1000.0 << std::endl;
    std::cout << "  Average Latency (us): " << delete_avg << std::endl;

    std::cout << "\nUPDATE Operations:" << std::endl;
    std::cout << "  Total Latency (ms): " << update_sum / 1000.0 << std::endl;
    std::cout << "  Average Latency (us): " << update_avg << std::endl;

    delete db;
}


int main(int argc, char** argv) {
    // 设置全局随机种子
    SetGlobalSeed(static_cast<unsigned>(time(nullptr)));
    testing::InitGoogleTest(&argc, argv);
    return RUN_ALL_TESTS();
}