import Foundation
import Accelerate

// MARK: - Pitch Detector (YIN Algorithm)

/// Detects the fundamental frequency of a monophonic audio signal using the YIN algorithm.
/// Optimized for vocal pitch detection in the range C2 (65 Hz) to C6 (1047 Hz).
struct PitchDetector {

    // MARK: - Configuration

    struct Configuration {
        var windowSize: Int = 2048          // ~42.7ms at 48kHz
        var hopSize: Int = 512              // ~10.7ms at 48kHz
        var threshold: Double = 0.15        // YIN aperiodicity threshold
        var minFrequency: Double = 65.0     // C2
        var maxFrequency: Double = 1100.0   // ~C6
        var silenceThresholdDB: Double = -50.0
    }

    // MARK: - Result

    struct Result {
        let frequency: Double       // Detected frequency in Hz (0 if unvoiced)
        let confidence: Double      // 0.0 to 1.0 (higher = more periodic)
        let isVoiced: Bool

        static let unvoiced = Result(frequency: 0, confidence: 0, isVoiced: false)
    }

    let configuration: Configuration

    init(configuration: Configuration = Configuration()) {
        self.configuration = configuration
    }

    // MARK: - Single Window Detection

    /// Detect pitch in a single window of audio samples.
    func detectPitch(in buffer: UnsafePointer<Float>, frameCount: Int, sampleRate: Double) -> Result {
        let W = min(frameCount, configuration.windowSize)

        // Silence gate: skip detection if RMS is below threshold
        var rms: Float = 0
        vDSP_rmsqv(buffer, 1, &rms, vDSP_Length(W))
        let rmsDB = 20.0 * log10(max(Double(rms), 1e-10))
        if rmsDB < configuration.silenceThresholdDB {
            return .unvoiced
        }

        // Lag bounds from frequency range
        let minLag = max(2, Int(sampleRate / configuration.maxFrequency))
        let maxLag = min(Int(sampleRate / configuration.minFrequency), W / 2)
        guard minLag < maxLag else { return .unvoiced }

        // Step 1: Compute the difference function d(tau) for ALL lags from 1 to maxLag
        // d(tau) = sum_{j=0}^{halfW-1} (x[j] - x[j+tau])^2
        let halfW = W / 2
        var difference = [Float](repeating: 0, count: maxLag + 1)

        for tau in 1...maxLag {
            var sum: Float = 0
            let count = min(halfW, W - tau)
            guard count > 0 else { continue }

            var temp = [Float](repeating: 0, count: count)
            vDSP_vsub(buffer + tau, 1, buffer, 1, &temp, 1, vDSP_Length(count))
            vDSP_dotpr(temp, 1, temp, 1, &sum, vDSP_Length(count))
            difference[tau] = sum
        }

        // Step 2: Cumulative mean normalized difference function d'(tau)
        // d'(tau) = d(tau) / ((1/tau) * sum(d(j), j=1..tau))
        // Must accumulate from tau=1 for correct normalization
        var cmndf = [Float](repeating: 1.0, count: maxLag + 1)
        var runningSum: Float = 0

        for tau in 1...maxLag {
            runningSum += difference[tau]
            if runningSum > 0 {
                cmndf[tau] = difference[tau] * Float(tau) / runningSum
            } else {
                cmndf[tau] = 1.0
            }
        }

        // Step 3: Find first dip below threshold (starting from minLag)
        let thresh = Float(configuration.threshold)
        var bestTau = -1

        for tau in minLag..<maxLag {
            if cmndf[tau] < thresh {
                // Find the local minimum in this valley
                var localMin = tau
                while localMin + 1 <= maxLag && cmndf[localMin + 1] < cmndf[localMin] {
                    localMin += 1
                }
                bestTau = localMin
                break
            }
        }

        // If no dip below threshold, find the global minimum as fallback
        if bestTau < 0 {
            var globalMinVal: Float = Float.greatestFiniteMagnitude
            for tau in minLag...maxLag {
                if cmndf[tau] < globalMinVal {
                    globalMinVal = cmndf[tau]
                    bestTau = tau
                }
            }
            if globalMinVal > 0.5 {
                return .unvoiced
            }
        }

        guard bestTau > 0 else { return .unvoiced }

        // Step 4: Sub-octave check — prefer half-lag (octave up) if it's also a good minimum.
        // This prevents the common YIN octave error where the algorithm locks onto 2*period.
        let halfTau = bestTau / 2
        if halfTau >= minLag && halfTau <= maxLag {
            // If the CMNDF at half the lag is also low, the true pitch is an octave higher
            let halfVal = cmndf[halfTau]
            let bestVal = cmndf[bestTau]
            // Accept half-lag if its CMNDF is within a reasonable factor of the best
            if halfVal < 0.5 && halfVal < bestVal * 1.5 {
                bestTau = halfTau
            }
        }

        // Step 5: Parabolic interpolation for sub-sample accuracy
        let interpolatedTau = parabolicInterpolation(cmndf: cmndf, tau: bestTau, maxLag: maxLag)

        let frequency = sampleRate / interpolatedTau
        let confidence = Double(1.0 - cmndf[bestTau])

        // Sanity check
        guard frequency >= configuration.minFrequency && frequency <= configuration.maxFrequency else {
            return .unvoiced
        }

        return Result(
            frequency: frequency,
            confidence: min(max(confidence, 0), 1),
            isVoiced: true
        )
    }

    // MARK: - Batch Detection

    /// Detect pitch across an entire audio buffer, returning one result per hop.
    func detectPitches(in buffer: [Float], sampleRate: Double) -> [Result] {
        let hopSize = configuration.hopSize
        let windowSize = configuration.windowSize
        let totalFrames = buffer.count
        var results: [Result] = []

        var offset = 0
        while offset + windowSize <= totalFrames {
            let result = buffer.withUnsafeBufferPointer { ptr in
                detectPitch(in: ptr.baseAddress! + offset, frameCount: windowSize, sampleRate: sampleRate)
            }
            results.append(result)
            offset += hopSize
        }

        return results
    }

    // MARK: - Private Helpers

    /// Parabolic interpolation around the best tau for sub-sample accuracy.
    private func parabolicInterpolation(cmndf: [Float], tau: Int, maxLag: Int) -> Double {
        guard tau > 0 && tau < maxLag else {
            return Double(tau)
        }

        let alpha = cmndf[tau - 1]
        let beta = cmndf[tau]
        let gamma = cmndf[tau + 1]

        let denominator = 2.0 * (2.0 * Float(beta) - Float(alpha) - Float(gamma))
        guard abs(denominator) > 1e-10 else {
            return Double(tau)
        }

        let adjustment = (Float(alpha) - Float(gamma)) / denominator
        return Double(tau) + Double(adjustment)
    }
}
