#include "VertexAnimator.h"
#include <cmath>
#include <algorithm>

VertexAnimator::VertexAnimator() : looping(false), duration(1.0f) {}

void VertexAnimator::addKeyFrame(float time, const Mesh& mesh) {
    KeyFrame kf(time);
    for (size_t i = 0; i < mesh.getVertexCount(); i++) {
        kf.positions.push_back(mesh.getVertex(i).position);
    }
    keyFrames.push_back(kf);

    std::sort(keyFrames.begin(), keyFrames.end(),
              [](const KeyFrame& a, const KeyFrame& b) { return a.time < b.time; });

    if (!keyFrames.empty()) {
        duration = keyFrames.back().time;
    }
}

Vec3 VertexAnimator::interpolate(const Vec3& a, const Vec3& b, float t) {
    return Vec3::lerp(a, b, t);
}

int VertexAnimator::findKeyFrameIndex(float time) {
    if (keyFrames.empty()) return -1;

    if (looping) {
        time = std::fmod(time, duration);
    } else {
        time = std::clamp(time, 0.0f, duration);
    }

    for (size_t i = 0; i < keyFrames.size() - 1; i++) {
        if (time >= keyFrames[i].time && time < keyFrames[i + 1].time) {
            return i;
        }
    }

    return keyFrames.size() - 2;
}

void VertexAnimator::applyAnimation(Mesh& mesh, float time) {
    int idx = findKeyFrameIndex(time);
    if (idx < 0 || idx >= (int)keyFrames.size() - 1) return;

    const KeyFrame& kf1 = keyFrames[idx];
    const KeyFrame& kf2 = keyFrames[idx + 1];

    float localTime = (time - kf1.time) / (kf2.time - kf1.time);

    for (size_t i = 0; i < mesh.getVertexCount() && i < kf1.positions.size(); i++) {
        mesh.getVertex(i).position = interpolate(kf1.positions[i], kf2.positions[i], localTime);
    }

    mesh.computeNormals();
}

void VertexAnimator::animateWave(Mesh& mesh, float time, float amplitude, float frequency) {
    for (size_t i = 0; i < mesh.getVertexCount(); i++) {
        Vertex& v = mesh.getVertex(i);
        float offset = std::sin(v.position.x * frequency + time) * amplitude;
        v.position.y += offset;
    }
    mesh.computeNormals();
}

void VertexAnimator::animateTwist(Mesh& mesh, float time, float amount) {
    float angle = amount * time;
    for (size_t i = 0; i < mesh.getVertexCount(); i++) {
        Vertex& v = mesh.getVertex(i);
        float twist = angle * v.position.y;
        float c = std::cos(twist);
        float s = std::sin(twist);
        float x = v.position.x * c - v.position.z * s;
        float z = v.position.x * s + v.position.z * c;
        v.position.x = x;
        v.position.z = z;
    }
    mesh.computeNormals();
}

void VertexAnimator::animateBend(Mesh& mesh, float time, float amount) {
    float bend = amount * std::sin(time);
    for (size_t i = 0; i < mesh.getVertexCount(); i++) {
        Vertex& v = mesh.getVertex(i);
        float c = std::cos(bend * v.position.x);
        float s = std::sin(bend * v.position.x);
        float y = v.position.y * c - v.position.z * s;
        float z = v.position.y * s + v.position.z * c;
        v.position.y = y;
        v.position.z = z;
    }
    mesh.computeNormals();
}

void VertexAnimator::animatePulse(Mesh& mesh, float time, float scale) {
    float pulse = 1.0f + scale * std::sin(time * 2.0f * M_PI);
    for (size_t i = 0; i < mesh.getVertexCount(); i++) {
        mesh.getVertex(i).position = mesh.getVertex(i).position * pulse;
    }
    mesh.computeNormals();
}

void VertexAnimator::animateNoise(Mesh& mesh, float time, float intensity) {
    // Simplified noise animation
    for (size_t i = 0; i < mesh.getVertexCount(); i++) {
        Vertex& v = mesh.getVertex(i);
        float noise = std::sin(v.position.x * 10 + time) * std::cos(v.position.z * 10 + time);
        v.position.y += noise * intensity;
    }
    mesh.computeNormals();
}
