import Foundation
import AVFoundation

// MARK: - Pitch Corrector

/// Top-level pitch correction pipeline. Coordinates pitch detection, score mapping,
/// and pitch shifting to correct a recorded vocal part toward the score's target pitches.
class PitchCorrector {

    // MARK: - Configuration

    struct Configuration {
        var correctionStrength: Double = 0.5    // 0.0 (off) to 1.0 (full snap)
        var maxShiftCents: Double = 200.0       // Safety clamp: ±200 cents (2 semitones)
        var minConfidence: Double = 0.5         // Below this, skip correction
        var crossfadeHops: Int = 3              // ~30ms transition smoothing

        var detection: PitchDetector.Configuration = PitchDetector.Configuration()

        var shifterFFTSize: Int = 4096
        var shifterHopSize: Int = 1024
    }

    // MARK: - Correct from File

    /// Process a recorded WAV file and write the pitch-corrected result.
    ///
    /// - Parameters:
    ///   - inputURL: Path to the original recording WAV
    ///   - part: The musical part this recording corresponds to
    ///   - tempo: Song tempo in BPM
    ///   - startMeasure: Which measure the recording started at
    ///   - configuration: Correction parameters
    /// - Returns: URL to the corrected WAV file (in temp directory)
    static func correct(
        inputURL: URL,
        part: Part,
        tempo: Int,
        startMeasure: Int,
        configuration: Configuration = Configuration()
    ) throws -> URL {

        // Read input WAV
        let audioFile = try AVAudioFile(forReading: inputURL)
        let format = audioFile.processingFormat
        let sampleRate = format.sampleRate
        let frameCount = Int(audioFile.length)

        guard frameCount > 0 else { throw CorrectionError.emptyFile }

        let buffer = AVAudioPCMBuffer(pcmFormat: format, frameCapacity: AVAudioFrameCount(frameCount))!
        try audioFile.read(into: buffer)

        guard let channelData = buffer.floatChannelData?[0] else {
            throw CorrectionError.invalidFormat
        }

        let inputSamples = Array(UnsafeBufferPointer(start: channelData, count: frameCount))

        // Run correction
        let correctedSamples = correct(
            inputBuffer: inputSamples,
            part: part,
            tempo: tempo,
            startMeasure: startMeasure,
            sampleRate: sampleRate,
            configuration: configuration
        )

        // Write output WAV
        let outputURL = inputURL.deletingLastPathComponent()
            .appendingPathComponent(
                inputURL.deletingPathExtension().lastPathComponent + "_corrected.wav"
            )

        let outputFormat = AVAudioFormat(standardFormatWithSampleRate: sampleRate, channels: 1)!
        let outputFile = try AVAudioFile(forWriting: outputURL, settings: [
            AVFormatIDKey: kAudioFormatLinearPCM,
            AVSampleRateKey: sampleRate,
            AVNumberOfChannelsKey: 1,
            AVLinearPCMBitDepthKey: 16,
            AVLinearPCMIsFloatKey: false,
        ])

        let outputBuffer = AVAudioPCMBuffer(pcmFormat: outputFormat, frameCapacity: AVAudioFrameCount(correctedSamples.count))!
        outputBuffer.frameLength = AVAudioFrameCount(correctedSamples.count)
        let outPtr = outputBuffer.floatChannelData![0]
        for i in 0..<correctedSamples.count {
            outPtr[i] = correctedSamples[i]
        }
        try outputFile.write(from: outputBuffer)

        print("[PitchCorrector] Wrote corrected audio to \(outputURL.lastPathComponent), \(correctedSamples.count) frames")
        return outputURL
    }

    // MARK: - Correct In-Memory

    /// Process audio samples in memory and return corrected samples.
    /// Used by the export/mixdown pipeline.
    static func correct(
        inputBuffer: [Float],
        part: Part,
        tempo: Int,
        startMeasure: Int,
        sampleRate: Double,
        configuration: Configuration = Configuration()
    ) -> [Float] {

        guard configuration.correctionStrength > 0.001 else {
            return inputBuffer // No correction needed
        }

        let totalFrames = inputBuffer.count
        guard totalFrames > configuration.detection.windowSize else {
            return inputBuffer
        }

        // Step 1: Build score pitch map
        let targetMap = ScorePitchMap.build(
            part: part,
            tempo: tempo,
            sampleRate: sampleRate,
            startMeasure: startMeasure,
            hopSize: configuration.detection.hopSize,
            totalFrames: totalFrames
        )

        // Smooth note transitions
        let smoothedMap = ScorePitchMap.smoothTransitions(
            targets: targetMap,
            crossfadeHops: configuration.crossfadeHops
        )

        // Step 2: Detect pitch in recording
        let detector = PitchDetector(configuration: configuration.detection)
        let detectedPitches = detector.detectPitches(in: inputBuffer, sampleRate: sampleRate)

        // Step 3: Compute per-frame shift amounts
        let detectionHops = detectedPitches.count
        var shiftCentsDetectionRate = [Double](repeating: 0, count: detectionHops)

        // Diagnostic: log first few detected vs target pitches
        var logCount = 0

        for i in 0..<detectionHops {
            let detected = detectedPitches[i]
            let targetIdx = min(i, smoothedMap.count - 1)
            let target = smoothedMap[targetIdx]

            // Skip correction for rests, unvoiced, or low-confidence frames
            guard !target.isRest,
                  target.frequency > 0,
                  detected.isVoiced,
                  detected.frequency > 0,
                  detected.confidence >= configuration.minConfidence else {
                shiftCentsDetectionRate[i] = 0
                continue
            }

            // Calculate shift in cents, but octave-aware:
            // Singers often sing in a different octave than the score.
            // Find the nearest octave-equivalent of the target and correct within that.
            let rawCents = 1200.0 * log2(target.frequency / detected.frequency)
            let idealShiftCents = rawCents - 1200.0 * round(rawCents / 1200.0)

            // Log first 10 voiced frames for diagnostics
            if logCount < 10 {
                let timeMs = Double(i * configuration.detection.hopSize) / sampleRate * 1000
                print("[PitchCorrector] t=\(String(format: "%.0f", timeMs))ms: detected=\(String(format: "%.1f", detected.frequency))Hz, target=\(String(format: "%.1f", target.frequency))Hz, octave-shift=\(String(format: "%.0f", idealShiftCents))cents, conf=\(String(format: "%.2f", detected.confidence))")
                logCount += 1
            }

            // Apply correction strength
            let adjustedShift = idealShiftCents * configuration.correctionStrength

            // Clamp to safety range
            let clampedShift = max(-configuration.maxShiftCents, min(configuration.maxShiftCents, adjustedShift))

            shiftCentsDetectionRate[i] = clampedShift
        }

        // Step 4: Resample shift array from detection hop rate to shifter hop rate
        let shifterHopSize = configuration.shifterHopSize
        let detectionHopSize = configuration.detection.hopSize
        let shifterHops = (totalFrames - 1) / shifterHopSize + 1
        var shiftCentsShifterRate = [Double](repeating: 0, count: shifterHops)

        for i in 0..<shifterHops {
            // Map shifter hop frame to detection hop frame
            let shifterFrame = i * shifterHopSize
            let detectionHopFloat = Double(shifterFrame) / Double(detectionHopSize)
            let detIdx0 = Int(detectionHopFloat)
            let frac = detectionHopFloat - Double(detIdx0)

            if detIdx0 + 1 < detectionHops {
                // Linear interpolation between detection frames
                shiftCentsShifterRate[i] = shiftCentsDetectionRate[detIdx0] * (1.0 - frac)
                    + shiftCentsDetectionRate[detIdx0 + 1] * frac
            } else if detIdx0 < detectionHops {
                shiftCentsShifterRate[i] = shiftCentsDetectionRate[detIdx0]
            }
        }

        // Step 5: Apply pitch shifting
        let shifter = PhaseVocoderShifter(
            fftSize: configuration.shifterFFTSize,
            hopSize: shifterHopSize
        )
        let corrected = shifter.process(
            input: inputBuffer,
            shiftCents: shiftCentsShifterRate,
            sampleRate: sampleRate
        )

        // Log summary
        let activeFrames = shiftCentsDetectionRate.filter { abs($0) > 0.5 }.count
        let avgShift = activeFrames > 0
            ? shiftCentsDetectionRate.filter { abs($0) > 0.5 }.reduce(0, +) / Double(activeFrames)
            : 0
        print("[PitchCorrector] Processed \(totalFrames) frames, \(activeFrames)/\(detectionHops) frames corrected, avg shift: \(String(format: "%.1f", avgShift)) cents")

        return corrected
    }

    // MARK: - Errors

    enum CorrectionError: Error {
        case emptyFile
        case invalidFormat
    }
}
