// // Qwen3CoreML.swift // Qwen3 CoreML Example // // Swift Example for Qwen3 CoreML Integration // Qwen3-0.6B CoreML integration with Stateful KV-Cache and Int4 quantization // // Requirements: // - iOS 18.0+ / macOS 15.0+ (Apple Neural Engine support) // - 400-500MB RAM for both models // - swift-transformers package // // Usage: // let qwen3 = Qwen3CoreML() // await qwen3.loadModels() // let response = await qwen3.generate("Hello, world!") // import Foundation import CoreML import Tokenizers /// Qwen3-0.6B CoreML model wrapper with Stateful KV-Cache @MainActor public final class Qwen3CoreML { // MARK: - Configuration public struct Config { public static let maxContextLength = 1024 public static let maxTokens = 512 public static let temperature: Float = 0.7 public static let topK = 40 public static let topP: Float = 0.9 // Model paths (relative to app bundle or absolute) public static let prefillModelName = "Qwen3-0.6B-Prefill-Int4" public static let decodeModelName = "Qwen3-0.6B-Decode-Int4" public static let tokenizerModelId = "Qwen/Qwen3-0.6B" } // MARK: - State private var prefillModel: MLModel? private var decodeModel: MLModel? private var tokenizer: Tokenizer? private var decodeState: MLState? private(set) var isModelsLoaded = false private(set) var isGenerating = false // Qwen3 special tokens private let eosTokenIds: Set = [151643, 151645] // <|endoftext|>, <|im_end|> private let bosTokenId = 151643 private let chatTemplate = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n%@<|im_end|>\n<|im_start|>assistant\n" // Performance tracking private(set) var tokensPerSecond: Double = 0 private(set) var currentPosition = 0 // MARK: - Initialization public init() { print("🤖 Qwen3CoreML initialized") } // MARK: - Model Loading /// Load both Prefill and Decode CoreML models and tokenizer public func loadModels() async throws { guard !isModelsLoaded else { print("🤖 Qwen3: Models already loaded") return } print("🤖 Qwen3: Loading CoreML models and tokenizer...") do { // Load Prefill model try await loadModel(named: Config.prefillModelName, into: &prefillModel) print("✅ Prefill model loaded") // Load Decode model with state try await loadModel(named: Config.decodeModelName, into: &decodeModel, withState: true) print("✅ Decode model loaded") // Load tokenizer via Tokenizers framework tokenizer = try await AutoTokenizer.from(pretrained: Config.tokenizerModelId) print("✅ Tokenizer loaded") isModelsLoaded = true print("🎉 Qwen3 models loaded successfully") } catch { print("❌ Failed to load Qwen3 models: \(error.localizedDescription)") throw Qwen3Error.modelLoadingFailed(error.localizedDescription) } } /// Load a single CoreML model private func loadModel(named modelName: String, into model: inout MLModel?, withState: Bool = false) async throws { let config = MLModelConfiguration() config.computeUnits = .cpuAndNeuralEngine // Use ANE when available // Try Bundle first, then local paths var modelURL: URL? // Check main bundle if let url = Bundle.main.url(forResource: modelName, withExtension: "mlpackage") { modelURL = url } // Check app support directory else if let appSupport = FileManager.default.urls(for: .applicationSupportDirectory, in: .userDomainMask).first { let appDir = appSupport.appendingPathComponent("Qwen3CoreML") let modelsDir = appDir.appendingPathComponent("Models") let modelPath = modelsDir.appendingPathComponent("\(modelName).mlpackage") if FileManager.default.fileExists(atPath: modelPath.path) { modelURL = modelPath } } guard let modelURL = modelURL else { throw Qwen3Error.modelNotFound(modelName) } // Compile and load model let compiledURL = try await MLModel.compileModel(at: modelURL) model = try MLModel(contentsOf: compiledURL, configuration: config) // Create state for decode model only if withState { decodeState = model?.makeState() } } // MARK: - Text Generation /// Generate text response for user message (streaming) public func generate( userMessage: String, systemPrompt: String = "You are a helpful assistant.", maxTokens: Int = Config.maxTokens, temperature: Float = Config.temperature, enableThinking: Bool = false ) -> AsyncStream { AsyncStream { continuation in Task { await generateInternal( userMessage: userMessage, systemPrompt: systemPrompt, maxTokens: maxTokens, temperature: temperature, enableThinking: enableThinking, continuation: continuation ) } } } /// Generate text response for user message (non-streaming) public func generateSync( userMessage: String, systemPrompt: String = "You are a helpful assistant.", maxTokens: Int = Config.maxTokens, temperature: Float = Config.temperature, enableThinking: Bool = false ) async throws -> String { guard isModelsLoaded, let tokenizer = tokenizer else { throw Qwen3Error.modelNotLoaded } var result = "" for await chunk in generate( userMessage: userMessage, systemPrompt: systemPrompt, maxTokens: maxTokens, temperature: temperature, enableThinking: enableThinking ) { result += chunk } return result } /// Reset conversation and KV-Cache state public func resetConversation() { decodeState = decodeModel?.makeState() currentPosition = 0 print("🔄 Qwen3 conversation reset") } // MARK: - Private Generation private func generateInternal( userMessage: String, systemPrompt: String, maxTokens: Int, temperature: Float, enableThinking: Bool, continuation: AsyncStream.Continuation ) async { guard isModelsLoaded, let prefillModel = prefillModel, let decodeModel = decodeModel, let tokenizer = tokenizer, var decodeState = decodeState else { continuation.finish() return } isGenerating = true let startTime = Date() defer { isGenerating = false continuation.finish() } do { // Format chat prompt let chatPrompt = formatChatPrompt( userMessage: userMessage, systemPrompt: systemPrompt, enableThinking: enableThinking ) // Tokenize prompt let inputTokens = tokenizer.encode(text: chatPrompt) // Check context length guard inputTokens.count + maxTokens <= Config.maxContextLength else { print("⚠️ Prompt too long, truncating...") // Truncate if needed let truncatedTokens = Array(inputTokens.suffix(Config.maxContextLength - maxTokens)) // Add BOS token if missing let tokensToProcess = truncatedTokens.first == bosTokenId ? truncatedTokens : [bosTokenId] + truncatedTokens try await processTokens(tokensToProcess, model: prefillModel) } // Process initial tokens with Prefill model try await processTokens(inputTokens, model: prefillModel) // Generate new tokens with Decode model var generatedTokens: [Int] = [] var isInThinkingBlock = false for _ in 0.. isInThinkingBlock = true } else if nextToken == 151668 { // isInThinkingBlock = false if !enableThinking { continue } } // Decode token to text let tokenText = tokenizer.decode(tokens: [nextToken]) // Stream token if not in thinking block or thinking enabled if !isInThinkingBlock || enableThinking { continuation.yield(tokenText) } } // Calculate performance let elapsed = Date().timeIntervalSince(startTime) tokensPerSecond = Double(generatedTokens.count) / elapsed print("📊 Generation: \(generatedTokens.count) tokens in \(String(format: "%.2f", elapsed))s (\(String(format: "%.1f", tokensPerSecond)) tok/s)") } catch { print("❌ Generation failed: \(error.localizedDescription)") // Note: We don't throw here since continuation is already finished } } /// Process initial tokens using Prefill model private func processTokens(_ tokens: [Int], model: MLModel) async throws { let seqLen = tokens.count // Create causal mask for all tokens let causalMask = createCausalMask(seqLen: seqLen, totalLen: seqLen) let inputIdsTensor = MLTensor( shape: [1, seqLen], scalars: tokens.map { Int32($0) }, scalarType: Int32.self ) let inputs = try MLDictionaryFeatureProvider(dictionary: [ "inputIds": MLFeatureValue(tensor: inputIdsTensor), "causalMask": MLFeatureValue(tensor: causalMask) ]) // Run prefill inference _ = try await model.prediction(from: inputs) currentPosition = seqLen } /// Generate next token using Decode model private func generateNextToken( temperature: Float, decodeModel: MLModel, decodeState: inout MLState ) async throws -> Int { // Current position as input let positionIds = [Int32(currentPosition)] let positionTensor = MLTensor( shape: [1, 1], scalars: positionIds, scalarType: Int32.self ) // We need a dummy input ID, actual logit generation uses past KV cache let dummyInputTensor = MLTensor( shape: [1, 1], scalars: [Int32(0)], // Will be ignored in decode model scalarType: Int32.self ) let inputs = try MLDictionaryFeatureProvider(dictionary: [ "inputIds": MLFeatureValue(tensor: dummyInputTensor), "positionIds": MLFeatureValue(tensor: positionTensor), ]) let output = try await decodeModel.prediction(from: inputs, using: decodeState) guard let logitsTensor = output.featureValue(for: "logits")?.tensorValue(of: Float16.self) else { throw Qwen3Error.inferenceError("No logits in model output") } // Sample from logits let nextToken = sampleToken(from: logitsTensor, temperature: temperature) // Update position for next step currentPosition += 1 return nextToken } /// Sample next token from logits private func sampleToken(from logitsTensor: MLTensor, temperature: Float) -> Int { // Extract logits for the last token [1, 1, vocab_size] -> [vocab_size] let vocabSize = logitsTensor.shape[2] var logitsArray = [Float](repeating: 0, count: vocabSize) logitsTensor.withUnsafeBufferPointer(of: Float16.self) { buffer in for i in 0.. MLTensor { var maskData = [Float16](repeating: Float16(-Float.infinity), count: seqLen * totalLen) for i in 0.. String { let chatTemplate = "<|im_start|>system\n\(systemPrompt)<|im_end|>\n<|im_start|>user\n\(userMessage)<|im_end|>\n<|im_start|>assistant\n" if enableThinking { return chatTemplate } else { return chatTemplate + "/no_think\n" } } } // MARK: - Errors public enum Qwen3Error: LocalizedError { case modelNotFound(String) case modelNotLoaded case modelLoadingFailed(String) case inferenceError(String) case tokenizationError public var errorDescription: String? { switch self { case .modelNotFound(let modelName): return "Model '\(modelName)' not found. Place it in app bundle or ~/Library/Application Support/Qwen3CoreML/Models/" case .modelNotLoaded: return "Models are not loaded. Call loadModels() first." case .inferenceError(let message): return "Inference error: \(message)" case .tokenizationError: return "Tokenization error" } } } // MARK: - Helper Methods /// Extension with utility methods for text processing extension Qwen3CoreML { /// Correct text using Qwen3 (compatible with LLMRunner.correct()) public func correct(text: String) async throws -> String { return try await generateSync( userMessage: """ Please correct the following text by fixing punctuation, capitalization, and grammatical errors. Keep the original language. Only output the corrected text, nothing else. Text: \(text) Corrected: """, systemPrompt: "You are a professional proofreader and text editor.", maxTokens: 256, temperature: 0.1 // Low temperature for consistent corrections ).trimmingCharacters(in: .whitespacesAndNewlines) } }