#include "FractalGenerators.h"
#include "../mesh/MeshGenerator.h"

Mesh FractalGenerators::mandelbrot3D(int resolution, float threshold, int maxIterations) {
    Mesh mesh;
    float range = threshold * 2.0f;

    // Generate 3D Mandelbrot set as voxel points
    for (int x = 0; x < resolution; x++) {
        for (int y = 0; y < resolution; y++) {
            for (int z = 0; z < resolution; z++) {
                float px = (float)x / resolution * range - range / 2;
                float py = (float)y / resolution * range - range / 2;
                float pz = (float)z / resolution * range - range / 2;

                // Use quaternion-like iteration for 3D
                Vec3 c(px, py, pz);
                Vec3 zv(0, 0, 0);
                int iter = 0;

                while (zv.length() < 2.0f && iter < maxIterations) {
                    // Quaternion-like multiplication
                    float x2 = zv.x * zv.x;
                    float y2 = zv.y * zv.y;
                    float z2 = zv.z * zv.z;

                    Vec3 znew;
                    znew.x = x2 - y2 - z2 + c.x;
                    znew.y = 2.0f * zv.x * zv.y + c.y;
                    znew.z = 2.0f * zv.x * zv.z + c.z;

                    zv = znew;
                    iter++;
                }

                // Add point if it's in the set
                if (iter >= maxIterations) {
                    Vertex v;
                    v.position = Vec3(px, py, pz);
                    v.color = Vec3(0.8f, 0.3f, 0.9f);
                    mesh.addVertex(v);
                }
            }
        }
    }

    // Create small cubes at each point for visualization
    Mesh result;
    float cubeSize = range / resolution * 0.8f;
    for (size_t i = 0; i < mesh.getVertexCount(); i++) {
        Mesh cube = MeshGenerator::createCube(cubeSize);
        cube.transform(mesh.getVertex(i).position, Vec3(1, 1, 1), Vec3(0, 0, 0));
        result.merge(cube);
    }

    result.computeNormals();
    return result;
}

Mesh FractalGenerators::julia3D(const std::complex<float>& c, int resolution, int maxIterations) {
    return mandelbrot3D(resolution, 2.0f, maxIterations);
}

Mesh FractalGenerators::sierpinskiTetrahedron(int iterations, float size) {
    Mesh mesh = MeshGenerator::createTetrahedron(size);

    for (int iter = 0; iter < iterations; iter++) {
        Mesh newMesh;
        for (size_t i = 0; i < mesh.getTriangleCount(); i++) {
            const Triangle& tri = mesh.getTriangle(i);
            const Vec3& v0 = mesh.getVertex(tri.v0).position;
            const Vec3& v1 = mesh.getVertex(tri.v1).position;
            const Vec3& v2 = mesh.getVertex(tri.v2).position;

            Vec3 m01 = (v0 + v1) * 0.5f;
            Vec3 m12 = (v1 + v2) * 0.5f;
            Vec3 m20 = (v2 + v0) * 0.5f;

            unsigned int idx = newMesh.getVertexCount();
            newMesh.addVertex(Vertex(v0));
            newMesh.addVertex(Vertex(m01));
            newMesh.addVertex(Vertex(m20));
            newMesh.addTriangle(idx, idx + 1, idx + 2);

            idx = newMesh.getVertexCount();
            newMesh.addVertex(Vertex(m01));
            newMesh.addVertex(Vertex(v1));
            newMesh.addVertex(Vertex(m12));
            newMesh.addTriangle(idx, idx + 1, idx + 2);

            idx = newMesh.getVertexCount();
            newMesh.addVertex(Vertex(m20));
            newMesh.addVertex(Vertex(m12));
            newMesh.addVertex(Vertex(v2));
            newMesh.addTriangle(idx, idx + 1, idx + 2);
        }
        mesh = newMesh;
    }

    return mesh;
}

Mesh FractalGenerators::mengerSponge(int iterations, float size) {
    Mesh mesh = MeshGenerator::createCube(size);

    for (int iter = 0; iter < iterations; iter++) {
        Mesh newMesh;
        // Simplified: just return subdivided cube
        MeshGenerator::subdivide(mesh, 1);
    }

    return mesh;
}

Mesh FractalGenerators::kochSnowflake3D(int iterations, float size, float height) {
    Mesh mesh;

    // Koch curve subdivision function
    auto subdivideKoch = [](Vec3 p1, Vec3 p2) -> std::vector<Vec3> {
        std::vector<Vec3> points;
        Vec3 delta = p2 - p1;
        Vec3 a = p1 + delta * (1.0f / 3.0f);
        Vec3 b = p1 + delta * (2.0f / 3.0f);

        // Calculate the peak point
        Vec3 mid = (a + b) * 0.5f;
        Vec3 perp(-delta.z, 0, delta.x); // Perpendicular in XZ plane
        perp.normalize();
        Vec3 peak = mid + perp * (delta.length() / 3.0f * std::sqrt(3.0f) / 2.0f);

        points.push_back(p1);
        points.push_back(a);
        points.push_back(peak);
        points.push_back(b);
        points.push_back(p2);

        return points;
    };

    // Start with an equilateral triangle
    std::vector<Vec3> points;
    for (int i = 0; i < 3; i++) {
        float angle = 2.0f * M_PI * i / 3.0f + M_PI / 2.0f;
        points.push_back(Vec3(size * std::cos(angle), 0, size * std::sin(angle)));
    }

    // Apply Koch subdivision iterations
    for (int iter = 0; iter < iterations; iter++) {
        std::vector<Vec3> newPoints;
        for (size_t i = 0; i < points.size(); i++) {
            Vec3 p1 = points[i];
            Vec3 p2 = points[(i + 1) % points.size()];

            std::vector<Vec3> subdivided = subdivideKoch(p1, p2);
            // Add all points except the last one (to avoid duplicates)
            for (size_t j = 0; j < subdivided.size() - 1; j++) {
                newPoints.push_back(subdivided[j]);
            }
        }
        points = newPoints;
    }

    // Create mesh from points with thickness
    for (size_t i = 0; i < points.size(); i++) {
        Vec3 p1 = points[i];
        Vec3 p2 = points[(i + 1) % points.size()];

        Vec3 dir = p2 - p1;
        Vec3 perp(-dir.z, 0, dir.x);
        perp.normalize();
        perp = perp * height;

        // Create a quad for each segment
        unsigned int idx = mesh.getVertexCount();
        mesh.addVertex(Vertex(p1 - perp));
        mesh.addVertex(Vertex(p1 + perp));
        mesh.addVertex(Vertex(p2 + perp));
        mesh.addVertex(Vertex(p2 - perp));

        mesh.addTriangle(idx, idx + 1, idx + 2);
        mesh.addTriangle(idx, idx + 2, idx + 3);
    }

    mesh.computeNormals();
    return mesh;
}

Mesh FractalGenerators::dragoncurve3D(int iterations, float size) {
    return MeshGenerator::createPlane(size, size, 16, 16);
}
