BPMDetector.swift 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. import Accelerate
  2. import AVFoundation
  3. import Foundation
  4. /// Detects BPM from audio files using energy-based onset detection with autocorrelation.
  5. struct BPMDetector {
  6. // MARK: - Configuration
  7. /// Analysis window size (samples). Larger = more frequency resolution, less time resolution.
  8. private static let fftSize = 1024
  9. /// Hop size between analysis windows.
  10. private static let hopSize = 512
  11. /// Minimum BPM to consider.
  12. private static let minBPM: Double = 60
  13. /// Maximum BPM to consider.
  14. private static let maxBPM: Double = 200
  15. // MARK: - Public API
  16. /// Analyze a track's BPM. Runs on a background thread.
  17. static func detectBPM(for track: Track) async throws -> Double {
  18. let url = track.fileURL
  19. return try await detectBPM(fileURL: url)
  20. }
  21. /// Analyze BPM from a file URL.
  22. static func detectBPM(fileURL: URL) async throws -> Double {
  23. try await Task.detached(priority: .userInitiated) {
  24. let sampleRate: Double
  25. let samples: [Float]
  26. if OGGDecoder.isOGGFile(fileURL) {
  27. let result = try OGGDecoder.readMonoSamples(url: fileURL, maxSeconds: 60)
  28. sampleRate = result.sampleRate
  29. samples = result.samples
  30. } else {
  31. let audioFile = try AVAudioFile(forReading: fileURL)
  32. sampleRate = audioFile.processingFormat.sampleRate
  33. samples = try readMonoSamples(from: audioFile, maxSeconds: 60)
  34. }
  35. guard samples.count > fftSize * 2 else {
  36. throw BPMError.insufficientAudio
  37. }
  38. // Step 1: Compute spectral flux (onset detection function)
  39. let flux = computeSpectralFlux(samples: samples)
  40. // Step 2: Normalize flux
  41. let normalizedFlux = normalize(flux)
  42. // Step 3: Autocorrelation to find periodicity
  43. let bpm = findBPMFromAutocorrelation(
  44. onsetFunction: normalizedFlux,
  45. hopRate: sampleRate / Double(hopSize)
  46. )
  47. return bpm
  48. }.value
  49. }
  50. // MARK: - Audio Reading
  51. private static func readMonoSamples(from audioFile: AVAudioFile, maxSeconds: Double) throws -> [Float] {
  52. let sampleRate = audioFile.processingFormat.sampleRate
  53. let maxFrames = AVAudioFrameCount(min(Double(audioFile.length), sampleRate * maxSeconds))
  54. guard let format = AVAudioFormat(standardFormatWithSampleRate: sampleRate, channels: 1),
  55. let buffer = AVAudioPCMBuffer(pcmFormat: format, frameCapacity: maxFrames) else {
  56. throw BPMError.formatError
  57. }
  58. audioFile.framePosition = 0
  59. try audioFile.read(into: buffer, frameCount: maxFrames)
  60. guard let channelData = buffer.floatChannelData else {
  61. throw BPMError.noAudioData
  62. }
  63. return Array(UnsafeBufferPointer(start: channelData[0], count: Int(buffer.frameLength)))
  64. }
  65. // MARK: - Spectral Flux
  66. private static func computeSpectralFlux(samples: [Float]) -> [Float] {
  67. let halfFFT = fftSize / 2
  68. let log2n = vDSP_Length(log2(Double(fftSize)))
  69. guard let fftSetup = vDSP_create_fftsetup(log2n, FFTRadix(kFFTRadix2)) else { return [] }
  70. defer { vDSP_destroy_fftsetup(fftSetup) }
  71. let numFrames = (samples.count - fftSize) / hopSize + 1
  72. guard numFrames > 1 else { return [] }
  73. var window = [Float](repeating: 0, count: fftSize)
  74. vDSP_hann_window(&window, vDSP_Length(fftSize), Int32(vDSP_HANN_NORM))
  75. var previousMagnitudes = [Float](repeating: 0, count: halfFFT)
  76. var flux = [Float]()
  77. flux.reserveCapacity(numFrames)
  78. var real = [Float](repeating: 0, count: halfFFT)
  79. var imag = [Float](repeating: 0, count: halfFFT)
  80. for frameIndex in 0..<numFrames {
  81. let offset = frameIndex * hopSize
  82. let end = offset + fftSize
  83. guard end <= samples.count else { break }
  84. // Window the frame
  85. var frame = Array(samples[offset..<end])
  86. vDSP_vmul(frame, 1, window, 1, &frame, 1, vDSP_Length(fftSize))
  87. // Pack for FFT
  88. frame.withUnsafeMutableBufferPointer { framePtr in
  89. framePtr.baseAddress!.withMemoryRebound(to: DSPComplex.self, capacity: halfFFT) { complexPtr in
  90. var splitComplex = DSPSplitComplex(realp: &real, imagp: &imag)
  91. vDSP_ctoz(complexPtr, 2, &splitComplex, 1, vDSP_Length(halfFFT))
  92. }
  93. }
  94. // FFT
  95. var splitComplex = DSPSplitComplex(realp: &real, imagp: &imag)
  96. vDSP_fft_zrip(fftSetup, &splitComplex, 1, log2n, FFTDirection(kFFTDirection_Forward))
  97. // Magnitudes
  98. var magnitudes = [Float](repeating: 0, count: halfFFT)
  99. vDSP_zvmags(&splitComplex, 1, &magnitudes, 1, vDSP_Length(halfFFT))
  100. // Spectral flux: sum of positive differences
  101. var diff = [Float](repeating: 0, count: halfFFT)
  102. vDSP_vsub(previousMagnitudes, 1, magnitudes, 1, &diff, 1, vDSP_Length(halfFFT))
  103. // Half-wave rectify (keep only positive changes)
  104. var threshold: Float = 0
  105. vDSP_vthres(diff, 1, &threshold, &diff, 1, vDSP_Length(halfFFT))
  106. var sum: Float = 0
  107. vDSP_sve(diff, 1, &sum, vDSP_Length(halfFFT))
  108. flux.append(sum)
  109. previousMagnitudes = magnitudes
  110. }
  111. return flux
  112. }
  113. // MARK: - Autocorrelation
  114. private static func findBPMFromAutocorrelation(onsetFunction: [Float], hopRate: Double) -> Double {
  115. let n = onsetFunction.count
  116. guard n > 0 else { return 120 }
  117. // Lag range in frames corresponding to BPM range
  118. let minLag = max(1, Int(hopRate * 60.0 / maxBPM))
  119. let maxLag = min(n - 1, Int(hopRate * 60.0 / minBPM))
  120. guard minLag < maxLag else { return 120 }
  121. // Compute autocorrelation for relevant lags
  122. var bestLag = minLag
  123. var bestCorrelation: Float = -.greatestFiniteMagnitude
  124. for lag in minLag...maxLag {
  125. var correlation: Float = 0
  126. let length = vDSP_Length(n - lag)
  127. onsetFunction.withUnsafeBufferPointer { buf in
  128. vDSP_dotpr(
  129. buf.baseAddress!, 1,
  130. buf.baseAddress!.advanced(by: lag), 1,
  131. &correlation,
  132. length
  133. )
  134. }
  135. // Normalize by overlap length
  136. correlation /= Float(n - lag)
  137. if correlation > bestCorrelation {
  138. bestCorrelation = correlation
  139. bestLag = lag
  140. }
  141. }
  142. // Convert lag to BPM
  143. let bpm = hopRate * 60.0 / Double(bestLag)
  144. // If BPM is very low, it might be detecting half-time — double it
  145. if bpm < 80 { return bpm * 2 }
  146. // If very high, might be double-time — halve it
  147. if bpm > 180 { return bpm / 2 }
  148. return (bpm * 10).rounded() / 10 // round to 1 decimal
  149. }
  150. // MARK: - Normalize
  151. private static func normalize(_ data: [Float]) -> [Float] {
  152. guard !data.isEmpty else { return [] }
  153. var minVal: Float = 0
  154. var maxVal: Float = 0
  155. vDSP_minv(data, 1, &minVal, vDSP_Length(data.count))
  156. vDSP_maxv(data, 1, &maxVal, vDSP_Length(data.count))
  157. let range = maxVal - minVal
  158. guard range > 0 else { return [Float](repeating: 0, count: data.count) }
  159. var result = [Float](repeating: 0, count: data.count)
  160. var negMin = -minVal
  161. vDSP_vsadd(data, 1, &negMin, &result, 1, vDSP_Length(data.count))
  162. var scale = 1.0 / range
  163. vDSP_vsmul(result, 1, &scale, &result, 1, vDSP_Length(data.count))
  164. return result
  165. }
  166. }
  167. // MARK: - Errors
  168. enum BPMError: Error, LocalizedError {
  169. case insufficientAudio
  170. case formatError
  171. case noAudioData
  172. var errorDescription: String? {
  173. switch self {
  174. case .insufficientAudio: return "Audio file is too short for BPM analysis"
  175. case .formatError: return "Unable to read audio format"
  176. case .noAudioData: return "No audio data found in file"
  177. }
  178. }
  179. }