import Foundation
import Accelerate

// MARK: - Pitch Shifting Engine Protocol

/// Abstraction for pitch shifting implementations.
/// Default is ResamplingShifter; can be swapped for RubberBandShifter later.
protocol PitchShiftingEngine {
    /// Shift audio by a variable amount of cents per analysis frame.
    /// - Parameters:
    ///   - input: Mono audio samples
    ///   - shiftCents: One value per hop frame (positive = shift up, negative = shift down)
    ///   - sampleRate: Audio sample rate
    /// - Returns: Pitch-shifted audio (same length as input)
    func process(input: [Float], shiftCents: [Double], sampleRate: Double) -> [Float]
}

// MARK: - Resampling Pitch Shifter

/// Shifts pitch using windowed overlap-add resampling.
/// Simple, correct, and well-suited for small corrections (±200 cents / 2 semitones).
/// For each overlapping window, resamples the input at a shifted rate using linear
/// interpolation, then overlap-adds with a Hann window to produce smooth output.
class PhaseVocoderShifter: PitchShiftingEngine {

    let fftSize: Int      // window size (kept named fftSize for API compatibility)
    let hopSize: Int
    private let window: [Float]

    init(fftSize: Int = 4096, hopSize: Int = 1024) {
        self.fftSize = fftSize
        self.hopSize = hopSize

        // Hann window
        var w = [Float](repeating: 0, count: fftSize)
        vDSP_hann_window(&w, vDSP_Length(fftSize), Int32(vDSP_HANN_NORM))
        self.window = w
    }

    func process(input: [Float], shiftCents: [Double], sampleRate: Double) -> [Float] {
        let totalFrames = input.count
        guard totalFrames > fftSize else { return input }

        let numHops = (totalFrames - fftSize) / hopSize + 1
        let windowSize = fftSize

        // Output buffer with overlap-add
        var output = [Float](repeating: 0, count: totalFrames)
        var windowSum = [Float](repeating: 0, count: totalFrames)

        for hop in 0..<numHops {
            let offset = hop * hopSize
            let shiftIndex = min(hop, shiftCents.count - 1)
            let cents = shiftIndex >= 0 && shiftIndex < shiftCents.count ? shiftCents[shiftIndex] : 0.0

            if abs(cents) < 0.5 {
                // No shift needed — pass through with windowing
                for i in 0..<windowSize {
                    let inIdx = offset + i
                    guard inIdx < totalFrames else { break }
                    output[inIdx] += input[inIdx] * window[i]
                    windowSum[inIdx] += window[i] * window[i]
                }
                continue
            }

            // Resampling ratio: to shift pitch UP, read input FASTER (ratio > 1)
            let ratio = pow(2.0, cents / 1200.0)

            // Center the resampling around the window center to minimize edge artifacts.
            // The window center stays anchored; we stretch/compress around it.
            let windowCenter = Double(offset) + Double(windowSize) / 2.0

            for i in 0..<windowSize {
                // Map output position to input position, centered on window midpoint
                let outputRelative = Double(i) - Double(windowSize) / 2.0
                let inputRelative = outputRelative * ratio
                let srcPos = windowCenter + inputRelative

                let srcIdx = Int(floor(srcPos))
                let frac = Float(srcPos - Double(srcIdx))

                let sample: Float
                if srcIdx >= 0 && srcIdx + 1 < totalFrames {
                    sample = input[srcIdx] * (1.0 - frac) + input[srcIdx + 1] * frac
                } else if srcIdx >= 0 && srcIdx < totalFrames {
                    sample = input[srcIdx]
                } else {
                    sample = 0
                }

                let outIdx = offset + i
                if outIdx < totalFrames {
                    output[outIdx] += sample * window[i]
                    windowSum[outIdx] += window[i] * window[i]
                }
            }
        }

        // Normalize by window sum to maintain amplitude
        for i in 0..<totalFrames {
            if windowSum[i] > 1e-6 {
                output[i] /= windowSum[i]
            }
        }

        return output
    }
}
