#pragma once
#include "../math/Vec3.h"
#include <cmath>
#include <algorithm>

class SDFPrimitives {
public:
    // Basic primitives
    static float sphere(const Vec3& p, float radius) {
        return p.length() - radius;
    }

    static float box(const Vec3& p, const Vec3& b) {
        Vec3 q = p.abs() - b;
        return Vec3::max(q, Vec3(0, 0, 0)).length() + std::min(std::max(q.x, std::max(q.y, q.z)), 0.0f);
    }

    static float torus(const Vec3& p, float majorRadius, float minorRadius) {
        float q = std::sqrt(p.x * p.x + p.z * p.z) - majorRadius;
        return std::sqrt(q * q + p.y * p.y) - minorRadius;
    }

    static float cylinder(const Vec3& p, float radius, float height) {
        float dx = std::sqrt(p.x * p.x + p.z * p.z) - radius;
        float dy = std::abs(p.y) - height;
        return std::min(std::max(dx, dy), 0.0f) + std::sqrt(std::max(dx, 0.0f) * std::max(dx, 0.0f) + std::max(dy, 0.0f) * std::max(dy, 0.0f));
    }

    static float cone(const Vec3& p, float angle, float height) {
        float q = std::sqrt(p.x * p.x + p.z * p.z);
        return std::max(std::cos(angle) * q + std::sin(angle) * p.y, std::abs(p.y) - height);
    }

    static float capsule(const Vec3& p, const Vec3& a, const Vec3& b, float radius) {
        Vec3 pa = p - a, ba = b - a;
        float h = std::clamp(pa.dot(ba) / ba.dot(ba), 0.0f, 1.0f);
        return (pa - ba * h).length() - radius;
    }

    static float plane(const Vec3& p, const Vec3& n, float d) {
        return p.dot(n) + d;
    }

    // Polyhedra
    static float octahedron(const Vec3& p, float s) {
        Vec3 pa = p.abs();
        float m = pa.x + pa.y + pa.z - s;
        Vec3 q;
        if (3.0f * pa.x < m) q = Vec3(pa.x, pa.y, pa.z);
        else if (3.0f * pa.y < m) q = Vec3(pa.y, pa.z, pa.x);
        else if (3.0f * pa.z < m) q = Vec3(pa.z, pa.x, pa.y);
        else return m * 0.57735027f;
        float k = std::clamp(0.5f * (q.z - q.y + s), 0.0f, s);
        return Vec3(q.x, q.y - s + k, q.z - k).length();
    }

    static float dodecahedron(const Vec3& p, float r) {
        const float phi = (1.0f + std::sqrt(5.0f)) / 2.0f;
        const float invPhi = 1.0f / phi;

        Vec3 n1(1, 1, 1);
        Vec3 n2(0, invPhi, phi);
        Vec3 n3(invPhi, phi, 0);
        Vec3 n4(phi, 0, invPhi);

        n1 = n1.normalized();
        n2 = n2.normalized();
        n3 = n3.normalized();
        n4 = n4.normalized();

        Vec3 pa = p.abs();
        float d = std::max({
            pa.dot(n1) - r,
            pa.dot(n2) - r,
            pa.dot(n3) - r,
            pa.dot(n4) - r
        });
        return d;
    }

    static float icosahedron(const Vec3& p, float r) {
        const float phi = (1.0f + std::sqrt(5.0f)) / 2.0f;
        Vec3 n1(1, phi, 0);
        Vec3 n2(0, 1, phi);
        Vec3 n3(phi, 0, 1);

        n1 = n1.normalized();
        n2 = n2.normalized();
        n3 = n3.normalized();

        Vec3 pa = p.abs();
        return std::max({
            pa.dot(n1),
            pa.dot(n2),
            pa.dot(n3)
        }) - r;
    }

    // Boolean operations
    static float opUnion(float d1, float d2) {
        return std::min(d1, d2);
    }

    static float opSubtraction(float d1, float d2) {
        return std::max(-d1, d2);
    }

    static float opIntersection(float d1, float d2) {
        return std::max(d1, d2);
    }

    static float opSmoothUnion(float d1, float d2, float k) {
        float h = std::clamp(0.5f + 0.5f * (d2 - d1) / k, 0.0f, 1.0f);
        return d2 * (1 - h) + d1 * h - k * h * (1.0f - h);
    }

    static float opSmoothSubtraction(float d1, float d2, float k) {
        float h = std::clamp(0.5f - 0.5f * (d2 + d1) / k, 0.0f, 1.0f);
        return d2 * (1 - h) + (-d1) * h + k * h * (1.0f - h);
    }

    static float opSmoothIntersection(float d1, float d2, float k) {
        float h = std::clamp(0.5f - 0.5f * (d2 - d1) / k, 0.0f, 1.0f);
        return d2 * (1 - h) + d1 * h + k * h * (1.0f - h);
    }

    // Domain operations
    static Vec3 opRepeat(const Vec3& p, const Vec3& c) {
        return Vec3(
            std::fmod(p.x + 0.5f * c.x, c.x) - 0.5f * c.x,
            std::fmod(p.y + 0.5f * c.y, c.y) - 0.5f * c.y,
            std::fmod(p.z + 0.5f * c.z, c.z) - 0.5f * c.z
        );
    }

    static Vec3 opTwist(const Vec3& p, float k) {
        float c = std::cos(k * p.y);
        float s = std::sin(k * p.y);
        return Vec3(c * p.x - s * p.z, p.y, s * p.x + c * p.z);
    }

    static Vec3 opBend(const Vec3& p, float k) {
        float c = std::cos(k * p.x);
        float s = std::sin(k * p.x);
        return Vec3(c * p.x - s * p.y, s * p.x + c * p.y, p.z);
    }
};
