#include "gtest/gtest.h"
#include "db/NewDB.h"  // NewDB 的头文件
#include "leveldb/env.h"
#include "leveldb/db.h"
#include "db/write_batch_internal.h"
#include <iostream>
#include <random>
#include <ctime>
#include <thread>
#include <mutex>
#include <chrono>
#include <atomic>
using namespace leveldb;

// 全局计数器
std::atomic<int> put_count(0);
std::atomic<int> delete_count(0);
std::atomic<int> create_index_count(0);
std::atomic<int> delete_index_count(0);
std::atomic<int> total_operations(0);

std::mutex latency_mutex;
std::vector<long long> latencies;




void RecordLatency(long long duration) {
    std::unique_lock<std::mutex> lock(latency_mutex);
    latencies.push_back(duration);
}


Status OpenNewDB(std::string dbName, NewDB** db) {
    Options options = Options();
    options.create_if_missing = true;
    return NewDB::Open(options, dbName, db);
}

// 全局的随机数引擎
std::default_random_engine rng;

// 设置随机种子
void SetGlobalSeed(unsigned seed) {
    rng.seed(seed);
}

// 生成随机字符串
std::string GenerateRandomString(size_t length) {
    static const char alphanum[] =
        "0123456789"
        "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
        "abcdefghijklmnopqrstuvwxyz";

    std::uniform_int_distribution<int> dist(0, sizeof(alphanum) - 2);

    std::string str(length, 0);
    for (size_t i = 0; i < length; ++i) {
        str[i] = alphanum[dist(rng)];
    }
    return str;
}

int getRandomInRange(int min, int max) {
    return min + std::rand() % ((max + 1) - min);
}

std::mutex results_mutex;
std::vector<bool> results;

void InsertResult(bool TRUEorFALSE){
    std::unique_lock<std::mutex> lock(results_mutex, std::defer_lock);
    lock.lock();
    results.emplace_back(TRUEorFALSE);
    lock.unlock();
}

std::string testdbname = "dbtest36";



// 多线程插入
void thread_task_put(NewDB* db, int thread_id, int num_operations) {
    for (int i = 0; i < num_operations; ++i) {
        auto start = std::chrono::high_resolution_clock::now();
        std::string key = "k_" + std::to_string(thread_id) + "_" + std::to_string(i);
        FieldArray fields = {
            {"name", "User" + std::to_string(i)},
            {"email", "user" + std::to_string(i) + "@test.com"}
        };
        db->Put_fields(WriteOptions(), key, fields);
        total_operations++;

        auto end = std::chrono::high_resolution_clock::now();
        long long duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
        RecordLatency(duration);
    }
}


// // 测试吞吐量和延迟
// TEST(TestNewDB, PerformanceTest) {
//     NewDB* db;
//     ASSERT_TRUE(OpenNewDB(testdbname, &db).ok());

//     const int thread_num = 100;
//     std::vector<std::thread> threads;

//     auto test_start = std::chrono::high_resolution_clock::now();

//     for (int i = 0; i < thread_num; ++i) {
//         threads.emplace_back([=]() { thread_task_put(db, i); });
//         threads.emplace_back([=]() { thread_task_delete(db, i); });
//         threads.emplace_back([=]() { thread_task_createindex(db, i); });
//         threads.emplace_back([=]() { thread_task_deleteindex(db, i); });
//     }

//     for (auto& th : threads) {
//         if (th.joinable()) {
//             th.join();
//         }
//     }

//     auto test_end = std::chrono::high_resolution_clock::now();
//     long long total_duration = std::chrono::duration_cast<std::chrono::milliseconds>(test_end - test_start).count();

//     // 计算吞吐量
//     double throughput = static_cast<double>(total_operations) / (total_duration / 1000.0);

//     // 计算平均延迟
//     long long sum_latency = 0;
//     for (auto& lat : latencies) {
//         sum_latency += lat;
//     }
//     double avg_latency = static_cast<double>(sum_latency) / latencies.size();

//     std::cout << "Total Operations: " << total_operations.load() << std::endl;
//     std::cout << "Throughput (ops/sec): " << throughput << std::endl;
//     std::cout << "Average Latency (us): " << avg_latency << std::endl;
//     std::cout << "PUT operations: " << put_count.load() << std::endl;
//     std::cout << "DELETE operations: " << delete_count.load() << std::endl;
//     std::cout << "CREATE INDEX operations: " << create_index_count.load() << std::endl;
//     std::cout << "DELETE INDEX operations: " << delete_index_count.load() << std::endl;

//     delete db;
// }
//并发测试
TEST(NewDBTest, ConcurrencyTest) {
    NewDB* db;
    ASSERT_TRUE(OpenNewDB(testdbname, &db).ok());

    const int thread_num = 100;
    const int operations_per_thread = 10;
    std::vector<std::thread> threads;

    auto start = std::chrono::high_resolution_clock::now();

    // 启动线程
    for (int i = 0; i < thread_num; ++i) {
        threads.emplace_back(thread_task_put, db, i, operations_per_thread);
    }

    // 等待线程完成
    for (auto& th : threads) {
        th.join();
    }

    auto end = std::chrono::high_resolution_clock::now();
    long long total_duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();

    // 计算吞吐量
    double throughput = static_cast<double>(total_operations.load()) / (total_duration / 1000.0);

    // 计算平均延迟
    long long sum_latency = 0;
    for (auto& lat : latencies) {
        sum_latency += lat;
    }
    double avg_latency = static_cast<double>(sum_latency) / latencies.size();
    //double avg_latency = static_cast<double>(total_duration) * 1000.0 / total_operations.load();
    // 转换为毫秒
    double total_latency_ms = static_cast<double>(sum_latency) / 1000.0;

    std::cout << "Total Operations: " << total_operations.load() << std::endl;
    std::cout << "Total Duration: " << total_duration << std::endl;
    std::cout << "Throughput (ops/sec): " << throughput << std::endl;
    std::cout << "Average Latency (us): " << avg_latency << std::endl;
    std::cout << "Total Latency (ms): " << total_latency_ms << std::endl;
    delete db;
}


int main(int argc, char** argv) {
    // 设置全局随机种子
    SetGlobalSeed(static_cast<unsigned>(time(nullptr)));
    testing::InitGoogleTest(&argc, argv);
    return RUN_ALL_TESTS();
}