KeyDetector.swift 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. import Accelerate
  2. import AVFoundation
  3. import Foundation
  4. /// Detects the musical key of an audio file using chromagram analysis
  5. /// with Krumhansl-Kessler key profiles.
  6. struct KeyDetector {
  7. // MARK: - Key Profiles (Krumhansl-Kessler)
  8. /// Major key profile weights for each pitch class (C, C#, D, ..., B).
  9. private static let majorProfile: [Double] = [
  10. 6.35, 2.23, 3.48, 2.33, 4.38, 4.09,
  11. 2.52, 5.19, 2.39, 3.66, 2.29, 2.88
  12. ]
  13. /// Minor key profile weights for each pitch class.
  14. private static let minorProfile: [Double] = [
  15. 6.33, 2.68, 3.52, 5.38, 2.60, 3.53,
  16. 2.54, 4.75, 3.98, 2.69, 3.34, 3.17
  17. ]
  18. /// Note names for Camelot-compatible display.
  19. private static let noteNames = ["C", "C#", "D", "Eb", "E", "F", "F#", "G", "Ab", "A", "Bb", "B"]
  20. /// Camelot wheel codes for DJ-friendly key display.
  21. private static let camelotMajor = ["8B", "3B", "10B", "5B", "12B", "7B", "2B", "9B", "4B", "11B", "6B", "1B"]
  22. private static let camelotMinor = ["5A", "12A", "7A", "2A", "9A", "4A", "11A", "6A", "1A", "8A", "3A", "10A"]
  23. // MARK: - Configuration
  24. private static let fftSize = 4096
  25. private static let hopSize = 2048
  26. private static let referenceFrequency: Double = 440.0 // A4
  27. // MARK: - Result
  28. struct KeyResult {
  29. let key: String // e.g., "C Major" or "A Minor"
  30. let camelotCode: String // e.g., "8B" or "8A"
  31. let confidence: Double // 0.0 to 1.0
  32. let rootNote: Int // pitch class index 0-11
  33. let isMinor: Bool
  34. var shortKey: String {
  35. let note = KeyDetector.noteNames[rootNote]
  36. return "\(note)\(isMinor ? "m" : "")"
  37. }
  38. }
  39. // MARK: - Public API
  40. static func detectKey(for track: Track) async throws -> KeyResult {
  41. try await detectKey(fileURL: track.fileURL)
  42. }
  43. static func detectKey(fileURL: URL) async throws -> KeyResult {
  44. try await Task.detached(priority: .userInitiated) {
  45. let sampleRate: Double
  46. let samples: [Float]
  47. if OGGDecoder.isOGGFile(fileURL) {
  48. let result = try OGGDecoder.readMonoSamples(url: fileURL, maxSeconds: 30)
  49. sampleRate = result.sampleRate
  50. samples = result.samples
  51. } else {
  52. let audioFile = try AVAudioFile(forReading: fileURL)
  53. sampleRate = audioFile.processingFormat.sampleRate
  54. samples = try readMonoSamples(from: audioFile, maxSeconds: 30)
  55. }
  56. guard samples.count > fftSize * 2 else {
  57. throw KeyDetectionError.insufficientAudio
  58. }
  59. // Build chromagram
  60. let chromagram = computeChromagram(samples: samples, sampleRate: sampleRate)
  61. // Average across time
  62. let avgChroma = averageChromagram(chromagram)
  63. // Match against key profiles
  64. return matchKeyProfile(chroma: avgChroma)
  65. }.value
  66. }
  67. // MARK: - Audio Reading
  68. private static func readMonoSamples(from audioFile: AVAudioFile, maxSeconds: Double) throws -> [Float] {
  69. let sampleRate = audioFile.processingFormat.sampleRate
  70. let maxFrames = AVAudioFrameCount(min(Double(audioFile.length), sampleRate * maxSeconds))
  71. guard let format = AVAudioFormat(standardFormatWithSampleRate: sampleRate, channels: 1),
  72. let buffer = AVAudioPCMBuffer(pcmFormat: format, frameCapacity: maxFrames) else {
  73. throw KeyDetectionError.formatError
  74. }
  75. audioFile.framePosition = 0
  76. try audioFile.read(into: buffer, frameCount: maxFrames)
  77. guard let data = buffer.floatChannelData else {
  78. throw KeyDetectionError.noAudioData
  79. }
  80. return Array(UnsafeBufferPointer(start: data[0], count: Int(buffer.frameLength)))
  81. }
  82. // MARK: - Chromagram Computation
  83. private static func computeChromagram(samples: [Float], sampleRate: Double) -> [[Double]] {
  84. let halfFFT = fftSize / 2
  85. let log2n = vDSP_Length(log2(Double(fftSize)))
  86. guard let fftSetup = vDSP_create_fftsetup(log2n, FFTRadix(kFFTRadix2)) else { return [] }
  87. defer { vDSP_destroy_fftsetup(fftSetup) }
  88. let numFrames = (samples.count - fftSize) / hopSize + 1
  89. var chromagram = [[Double]]()
  90. chromagram.reserveCapacity(numFrames)
  91. var window = [Float](repeating: 0, count: fftSize)
  92. vDSP_hann_window(&window, vDSP_Length(fftSize), Int32(vDSP_HANN_NORM))
  93. // Pre-compute frequency-to-chroma mapping
  94. let chromaMap = buildChromaMap(fftSize: fftSize, sampleRate: sampleRate)
  95. var real = [Float](repeating: 0, count: halfFFT)
  96. var imag = [Float](repeating: 0, count: halfFFT)
  97. for frameIndex in 0..<numFrames {
  98. let offset = frameIndex * hopSize
  99. let end = offset + fftSize
  100. guard end <= samples.count else { break }
  101. var frame = Array(samples[offset..<end])
  102. vDSP_vmul(frame, 1, window, 1, &frame, 1, vDSP_Length(fftSize))
  103. // FFT
  104. frame.withUnsafeMutableBufferPointer { framePtr in
  105. framePtr.baseAddress!.withMemoryRebound(to: DSPComplex.self, capacity: halfFFT) { complexPtr in
  106. var splitComplex = DSPSplitComplex(realp: &real, imagp: &imag)
  107. vDSP_ctoz(complexPtr, 2, &splitComplex, 1, vDSP_Length(halfFFT))
  108. }
  109. }
  110. var splitComplex = DSPSplitComplex(realp: &real, imagp: &imag)
  111. vDSP_fft_zrip(fftSetup, &splitComplex, 1, log2n, FFTDirection(kFFTDirection_Forward))
  112. // Magnitudes
  113. var magnitudes = [Float](repeating: 0, count: halfFFT)
  114. vDSP_zvmags(&splitComplex, 1, &magnitudes, 1, vDSP_Length(halfFFT))
  115. // Map to 12 chroma bins
  116. var chroma = [Double](repeating: 0, count: 12)
  117. for bin in 1..<halfFFT {
  118. let chromaBin = chromaMap[bin]
  119. if chromaBin >= 0 {
  120. chroma[chromaBin] += Double(magnitudes[bin])
  121. }
  122. }
  123. chromagram.append(chroma)
  124. }
  125. return chromagram
  126. }
  127. /// Pre-compute which FFT bin maps to which chroma pitch class.
  128. private static func buildChromaMap(fftSize: Int, sampleRate: Double) -> [Int] {
  129. let halfFFT = fftSize / 2
  130. var map = [Int](repeating: -1, count: halfFFT)
  131. for bin in 1..<halfFFT {
  132. let frequency = Double(bin) * sampleRate / Double(fftSize)
  133. // Only consider musically relevant frequencies (30 Hz to 5000 Hz)
  134. guard frequency >= 30 && frequency <= 5000 else { continue }
  135. // Convert frequency to pitch class
  136. let semitones = 12.0 * log2(frequency / referenceFrequency)
  137. let pitchClass = ((Int(round(semitones)) % 12) + 12 + 9) % 12 // A = 9, so shift to C = 0
  138. map[bin] = pitchClass
  139. }
  140. return map
  141. }
  142. // MARK: - Average Chromagram
  143. private static func averageChromagram(_ chromagram: [[Double]]) -> [Double] {
  144. guard !chromagram.isEmpty else { return [Double](repeating: 0, count: 12) }
  145. var avg = [Double](repeating: 0, count: 12)
  146. for frame in chromagram {
  147. for i in 0..<12 {
  148. avg[i] += frame[i]
  149. }
  150. }
  151. let count = Double(chromagram.count)
  152. for i in 0..<12 {
  153. avg[i] /= count
  154. }
  155. return avg
  156. }
  157. // MARK: - Key Profile Matching
  158. private static func matchKeyProfile(chroma: [Double]) -> KeyResult {
  159. var bestCorrelation = -Double.greatestFiniteMagnitude
  160. var bestRoot = 0
  161. var bestIsMinor = false
  162. for root in 0..<12 {
  163. // Rotate chroma so 'root' aligns with index 0
  164. let rotated = rotateChroma(chroma, by: root)
  165. // Correlate with major profile
  166. let majorCorr = pearsonCorrelation(rotated, majorProfile)
  167. if majorCorr > bestCorrelation {
  168. bestCorrelation = majorCorr
  169. bestRoot = root
  170. bestIsMinor = false
  171. }
  172. // Correlate with minor profile
  173. let minorCorr = pearsonCorrelation(rotated, minorProfile)
  174. if minorCorr > bestCorrelation {
  175. bestCorrelation = minorCorr
  176. bestRoot = root
  177. bestIsMinor = true
  178. }
  179. }
  180. let confidence = max(0, min(1, (bestCorrelation + 1) / 2))
  181. let keyName = "\(noteNames[bestRoot]) \(bestIsMinor ? "Minor" : "Major")"
  182. let camelot = bestIsMinor ? camelotMinor[bestRoot] : camelotMajor[bestRoot]
  183. return KeyResult(
  184. key: keyName,
  185. camelotCode: camelot,
  186. confidence: confidence,
  187. rootNote: bestRoot,
  188. isMinor: bestIsMinor
  189. )
  190. }
  191. private static func rotateChroma(_ chroma: [Double], by amount: Int) -> [Double] {
  192. let n = chroma.count
  193. return (0..<n).map { chroma[($0 + amount) % n] }
  194. }
  195. private static func pearsonCorrelation(_ a: [Double], _ b: [Double]) -> Double {
  196. let n = Double(a.count)
  197. let sumA = a.reduce(0, +)
  198. let sumB = b.reduce(0, +)
  199. let sumAB = zip(a, b).map(*).reduce(0, +)
  200. let sumA2 = a.map { $0 * $0 }.reduce(0, +)
  201. let sumB2 = b.map { $0 * $0 }.reduce(0, +)
  202. let numerator = n * sumAB - sumA * sumB
  203. let denominator = sqrt((n * sumA2 - sumA * sumA) * (n * sumB2 - sumB * sumB))
  204. guard denominator > 0 else { return 0 }
  205. return numerator / denominator
  206. }
  207. }
  208. // MARK: - Errors
  209. enum KeyDetectionError: Error, LocalizedError {
  210. case insufficientAudio
  211. case formatError
  212. case noAudioData
  213. var errorDescription: String? {
  214. switch self {
  215. case .insufficientAudio: return "Audio file is too short for key detection"
  216. case .formatError: return "Unable to read audio format"
  217. case .noAudioData: return "No audio data found"
  218. }
  219. }
  220. }