#include "VoxelGenerator.h"
#include "../mesh/MeshGenerator.h"
#include <algorithm>
#include <vector>

VoxelGenerator::VoxelGenerator(int sizeX, int sizeY, int sizeZ)
    : sizeX(sizeX), sizeY(sizeY), sizeZ(sizeZ) {
    voxels.resize(sizeX * sizeY * sizeZ, false);
}

int VoxelGenerator::getIndex(int x, int y, int z) const {
    return x + y * sizeX + z * sizeX * sizeY;
}

bool VoxelGenerator::isValidCoord(int x, int y, int z) const {
    return x >= 0 && x < sizeX && y >= 0 && y < sizeY && z >= 0 && z < sizeZ;
}

void VoxelGenerator::setVoxel(int x, int y, int z, bool value) {
    if (isValidCoord(x, y, z)) {
        voxels[getIndex(x, y, z)] = value;
    }
}

bool VoxelGenerator::getVoxel(int x, int y, int z) const {
    if (isValidCoord(x, y, z)) {
        return voxels[getIndex(x, y, z)];
    }
    return false;
}

void VoxelGenerator::fillSphere(const Vec3& center, float radius) {
    for (int x = 0; x < sizeX; x++) {
        for (int y = 0; y < sizeY; y++) {
            for (int z = 0; z < sizeZ; z++) {
                Vec3 pos(x, y, z);
                if ((pos - center).length() <= radius) {
                    setVoxel(x, y, z, true);
                }
            }
        }
    }
}

void VoxelGenerator::fillBox(const Vec3& min, const Vec3& max) {
    for (int x = (int)min.x; x <= (int)max.x; x++) {
        for (int y = (int)min.y; y <= (int)max.y; y++) {
            for (int z = (int)min.z; z <= (int)max.z; z++) {
                setVoxel(x, y, z, true);
            }
        }
    }
}

void VoxelGenerator::fillFromMesh(const Mesh& mesh, float voxelSize) {
    if (mesh.getVertexCount() == 0) {
        return;
    }

    // Find mesh bounding box
    Vec3 minBounds = mesh.getVertex(0).position;
    Vec3 maxBounds = mesh.getVertex(0).position;

    for (size_t i = 1; i < mesh.getVertexCount(); i++) {
        const Vec3& pos = mesh.getVertex(i).position;
        minBounds.x = std::min(minBounds.x, pos.x);
        minBounds.y = std::min(minBounds.y, pos.y);
        minBounds.z = std::min(minBounds.z, pos.z);
        maxBounds.x = std::max(maxBounds.x, pos.x);
        maxBounds.y = std::max(maxBounds.y, pos.y);
        maxBounds.z = std::max(maxBounds.z, pos.z);
    }

    // Calculate offset to center mesh in voxel grid
    Vec3 meshSize = maxBounds - minBounds;
    Vec3 gridCenter(sizeX / 2.0f, sizeY / 2.0f, sizeZ / 2.0f);
    Vec3 meshCenter = (minBounds + maxBounds) * 0.5f;
    Vec3 offset = gridCenter - meshCenter / voxelSize;

    // Fill voxels for each vertex
    for (size_t i = 0; i < mesh.getVertexCount(); i++) {
        Vec3 pos = mesh.getVertex(i).position / voxelSize + offset;
        int x = (int)pos.x;
        int y = (int)pos.y;
        int z = (int)pos.z;
        setVoxel(x, y, z, true);
    }

    // Fill voxels along edges between vertices to ensure connectivity
    for (size_t t = 0; t < mesh.getTriangleCount(); t++) {
        Triangle tri = mesh.getTriangle(t);

        Vec3 v0 = mesh.getVertex(tri.v0).position / voxelSize + offset;
        Vec3 v1 = mesh.getVertex(tri.v1).position / voxelSize + offset;
        Vec3 v2 = mesh.getVertex(tri.v2).position / voxelSize + offset;

        // Rasterize edges
        auto rasterizeLine = [this](Vec3 from, Vec3 to) {
            float dist = (to - from).length();
            int steps = (int)dist + 1;
            for (int i = 0; i <= steps; i++) {
                float t = i / (float)steps;
                Vec3 pos = Vec3::lerp(from, to, t);
                setVoxel((int)pos.x, (int)pos.y, (int)pos.z, true);
            }
        };

        rasterizeLine(v0, v1);
        rasterizeLine(v1, v2);
        rasterizeLine(v2, v0);
    }
}

Mesh VoxelGenerator::generateMesh(float voxelSize) {
    Mesh mesh;

    for (int x = 0; x < sizeX; x++) {
        for (int y = 0; y < sizeY; y++) {
            for (int z = 0; z < sizeZ; z++) {
                if (getVoxel(x, y, z)) {
                    Mesh voxelMesh = MeshGenerator::createCube(voxelSize);
                    Vec3 pos(x * voxelSize, y * voxelSize, z * voxelSize);
                    voxelMesh.transform(pos, Vec3(1, 1, 1), Vec3(0, 0, 0));
                    mesh.merge(voxelMesh);
                }
            }
        }
    }

    return mesh;
}

void VoxelGenerator::relaxMesh(Mesh& mesh, int iterations, float factor) {
    size_t vertexCount = mesh.getVertexCount();
    size_t triangleCount = mesh.getTriangleCount();

    if (vertexCount == 0 || triangleCount == 0) {
        return;
    }

    for (int iter = 0; iter < iterations; iter++) {
        // Build adjacency list from triangles
        std::vector<std::vector<size_t>> neighbors(vertexCount);

        for (size_t t = 0; t < triangleCount; t++) {
            Triangle tri = mesh.getTriangle(t);

            // Add bidirectional edges
            neighbors[tri.v0].push_back(tri.v1);
            neighbors[tri.v0].push_back(tri.v2);
            neighbors[tri.v1].push_back(tri.v0);
            neighbors[tri.v1].push_back(tri.v2);
            neighbors[tri.v2].push_back(tri.v0);
            neighbors[tri.v2].push_back(tri.v1);
        }

        // Remove duplicates from neighbor lists
        for (size_t i = 0; i < vertexCount; i++) {
            std::sort(neighbors[i].begin(), neighbors[i].end());
            neighbors[i].erase(std::unique(neighbors[i].begin(), neighbors[i].end()), neighbors[i].end());
        }

        // Compute new positions
        std::vector<Vec3> newPositions(vertexCount);

        for (size_t i = 0; i < vertexCount; i++) {
            if (neighbors[i].empty()) {
                newPositions[i] = mesh.getVertex(i).position;
                continue;
            }

            Vec3 avg(0, 0, 0);
            for (size_t neighbor : neighbors[i]) {
                avg = avg + mesh.getVertex(neighbor).position;
            }
            avg = avg * (1.0f / neighbors[i].size());

            newPositions[i] = Vec3::lerp(mesh.getVertex(i).position, avg, factor);
        }

        // Apply new positions
        for (size_t i = 0; i < vertexCount; i++) {
            mesh.getVertex(i).position = newPositions[i];
        }

        mesh.computeNormals();
    }
}
