From 3b75e751183fa677a852d5f4b07233195525c461 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 22 Feb 2023 14:07:31 +0100 Subject: [PATCH] Allow all compute units to be selected (#30) * Allow all compute units to be selected by the user. * Remove commented code. * Simplify labels. * Remove warning * Align picker left * Apply suggestions from code review --- Diffusion-macOS/ControlsView.swift | 47 ++++++++++++++----------- Diffusion/ModelInfo.swift | 12 +++++-- Diffusion/Pipeline/PipelineLoader.swift | 20 +++++++---- Diffusion/State.swift | 22 +++++++----- 4 files changed, 63 insertions(+), 38 deletions(-) diff --git a/Diffusion-macOS/ControlsView.swift b/Diffusion-macOS/ControlsView.swift index 23a4ac6..f9a519e 100644 --- a/Diffusion-macOS/ControlsView.swift +++ b/Diffusion-macOS/ControlsView.swift @@ -57,7 +57,6 @@ struct ControlsView: View { @State private var disclosedSteps = false @State private var disclosedSeed = false @State private var disclosedAdvanced = false - @State private var useANE = (Settings.shared.userSelectedAttentionVariant ?? ModelInfo.defaultAttention) == .splitEinsum // TODO: refactor download with similar code in Loading.swift (iOS) @State private var stateSubscriber: Cancellable? @@ -82,14 +81,18 @@ struct ControlsView: View { mustShowSafetyCheckerDisclaimer = generation.disableSafety && !Settings.shared.safetyCheckerDisclaimerShown } - func updateANEState() { - Settings.shared.userSelectedAttentionVariant = useANE ? .splitEinsum : .original + func updateComputeUnitsState() { + Settings.shared.userSelectedComputeUnits = generation.computeUnits modelDidChange(model: Settings.shared.currentModel) } + func resetComputeUnitsState() { + generation.computeUnits = Settings.shared.userSelectedComputeUnits ?? ModelInfo.defaultComputeUnits + } + func modelDidChange(model: ModelInfo) { - guard pipelineLoader?.model != model || pipelineLoader?.variant != Settings.shared.userSelectedAttentionVariant else { - print("Reusing same model \(model) with attention \(String(describing: Settings.shared.userSelectedAttentionVariant))") + guard pipelineLoader?.model != model || pipelineLoader?.computeUnits != generation.computeUnits else { + print("Reusing same model \(model) with units \(generation.computeUnits)") return } @@ -99,7 +102,7 @@ struct ControlsView: View { pipelineLoader?.cancel() pipelineState = .downloading(0) Task.init { - let loader = PipelineLoader(model: model, variant: Settings.shared.userSelectedAttentionVariant, maxSeed: maxSeed) + let loader = PipelineLoader(model: model, computeUnits: generation.computeUnits, maxSeed: maxSeed) self.pipelineLoader = loader stateSubscriber = loader.statePublisher.sink { state in DispatchQueue.main.async { @@ -128,8 +131,8 @@ struct ControlsView: View { } } - func isModelDownloaded(_ model: ModelInfo, variant: AttentionVariant? = nil) -> Bool { - PipelineLoader(model: model, variant: variant ?? Settings.shared.userSelectedAttentionVariant).ready + func isModelDownloaded(_ model: ModelInfo, computeUnits: ComputeUnits? = nil) -> Bool { + PipelineLoader(model: model, computeUnits: computeUnits ?? generation.computeUnits).ready } func modelLabel(_ model: ModelInfo) -> Text { @@ -301,21 +304,25 @@ struct ControlsView: View { Divider() DisclosureGroup(isExpanded: $disclosedAdvanced) { HStack { - Toggle("Use Neural Engine", isOn: $useANE).onChange(of: useANE) { value in - guard let currentModel = ModelInfo.from(modelVersion: model) else { return } - let variantDownloaded = isModelDownloaded(currentModel, variant: useANE ? .splitEinsum : .original) - if variantDownloaded { - updateANEState() - } else { - mustShowModelDownloadDisclaimer.toggle() - } - } - .padding(.leading, 10) + Picker(selection: $generation.computeUnits, label: Text("Use")) { + Text("GPU").tag(ComputeUnits.cpuAndGPU) + Text("Neural Engine").tag(ComputeUnits.cpuAndNeuralEngine) + Text("GPU and Neural Engine").tag(ComputeUnits.all) + }.pickerStyle(.radioGroup).padding(.leading) Spacer() } + .onChange(of: generation.computeUnits) { units in + guard let currentModel = ModelInfo.from(modelVersion: model) else { return } + let variantDownloaded = isModelDownloaded(currentModel, computeUnits: units) + if variantDownloaded { + updateComputeUnitsState() + } else { + mustShowModelDownloadDisclaimer.toggle() + } + } .alert("Download Required", isPresented: $mustShowModelDownloadDisclaimer, actions: { - Button("Cancel", role: .destructive) { useANE.toggle() } - Button("Download", role: .cancel) { updateANEState() } + Button("Cancel", role: .destructive) { resetComputeUnitsState() } + Button("Download", role: .cancel) { updateComputeUnitsState() } }, message: { Text("This setting requires a new version of the selected model.") }) diff --git a/Diffusion/ModelInfo.swift b/Diffusion/ModelInfo.swift index eab1ee4..843a853 100644 --- a/Diffusion/ModelInfo.swift +++ b/Diffusion/ModelInfo.swift @@ -13,6 +13,10 @@ enum AttentionVariant: String { case splitEinsum } +extension AttentionVariant { + var defaultComputeUnits: MLComputeUnits { self == .original ? .cpuAndGPU : .cpuAndNeuralEngine } +} + struct ModelInfo { /// Hugging Face model Id that contains .zip archives with compiled Core ML models let modelId: String @@ -39,6 +43,7 @@ struct ModelInfo { } extension ModelInfo { + //TODO: set compute units instead and derive variant from it static var defaultAttention: AttentionVariant { guard runningOnMac else { return .splitEinsum } #if os(macOS) @@ -49,9 +54,10 @@ extension ModelInfo { #endif } - var bestAttention: AttentionVariant { - return ModelInfo.defaultAttention - } + static var defaultComputeUnits: MLComputeUnits { defaultAttention.defaultComputeUnits } + + var bestAttention: AttentionVariant { ModelInfo.defaultAttention } + var defaultComputeUnits: MLComputeUnits { bestAttention.defaultComputeUnits } func modelURL(for variant: AttentionVariant) -> URL { // Pattern: https://huggingface.co/pcuenq/coreml-stable-diffusion/resolve/main/coreml-stable-diffusion-v1-5_original_compiled.zip diff --git a/Diffusion/Pipeline/PipelineLoader.swift b/Diffusion/Pipeline/PipelineLoader.swift index 3eafff2..34b51c2 100644 --- a/Diffusion/Pipeline/PipelineLoader.swift +++ b/Diffusion/Pipeline/PipelineLoader.swift @@ -18,14 +18,14 @@ class PipelineLoader { static let models = Path.applicationSupport / "hf-diffusion-models" let model: ModelInfo - let variant: AttentionVariant + let computeUnits: ComputeUnits let maxSeed: UInt32 private var downloadSubscriber: Cancellable? - init(model: ModelInfo, variant: AttentionVariant? = nil, maxSeed: UInt32 = UInt32.max) { + init(model: ModelInfo, computeUnits: ComputeUnits? = nil, maxSeed: UInt32 = UInt32.max) { self.model = model - self.variant = variant ?? model.bestAttention + self.computeUnits = computeUnits ?? model.defaultComputeUnits self.maxSeed = maxSeed state = .undetermined setInitialState() @@ -98,11 +98,17 @@ extension PipelineLoader { return compiledPath.exists } - // TODO: measure performance on different devices, disassociate from variant - var computeUnits: MLComputeUnits { - variant == .original ? .cpuAndGPU : .cpuAndNeuralEngine + var variant: AttentionVariant { + switch computeUnits { + case .cpuOnly : return .original // Not supported yet + case .cpuAndGPU : return .original + case .cpuAndNeuralEngine: return .splitEinsum + case .all : return .splitEinsum + @unknown default: + fatalError("Unknown MLComputeUnits") + } } - + // TODO: maybe receive Progress to add another progress as child func prepare() async throws -> Pipeline { do { diff --git a/Diffusion/State.swift b/Diffusion/State.swift index 12381c5..c9b2dfd 100644 --- a/Diffusion/State.swift +++ b/Diffusion/State.swift @@ -9,6 +9,7 @@ import Combine import SwiftUI import StableDiffusion +import CoreML let DEFAULT_MODEL = ModelInfo.v2Base let DEFAULT_PROMPT = "Labrador in the style of Vermeer" @@ -21,6 +22,8 @@ enum GenerationState { case failed(Error) } +typealias ComputeUnits = MLComputeUnits + class GenerationContext: ObservableObject { let scheduler = StableDiffusionScheduler.dpmSolverMultistepScheduler @@ -48,6 +51,8 @@ class GenerationContext: ObservableObject { @Published var seed = -1.0 @Published var guidanceScale = 7.5 @Published var disableSafety = false + + @Published var computeUnits: ComputeUnits = Settings.shared.userSelectedComputeUnits ?? ModelInfo.defaultComputeUnits private var progressSubscriber: Cancellable? @@ -78,14 +83,14 @@ class Settings { enum Keys: String { case model case safetyCheckerDisclaimer - case variant + case computeUnits } private init() { defaults.register(defaults: [ Keys.model.rawValue: ModelInfo.v2Base.modelId, Keys.safetyCheckerDisclaimer.rawValue: false, - Keys.variant.rawValue: "- default -" + Keys.computeUnits.rawValue: -1 // Use default ]) } @@ -109,15 +114,16 @@ class Settings { } /// Returns the option selected by the user, if overridden - /// `nil` means: guess best for this {model, device} - var userSelectedAttentionVariant: AttentionVariant? { + /// `nil` means: guess best + var userSelectedComputeUnits: ComputeUnits? { set { - // Any String other than the supported ones would cause `get` to return `nil` - defaults.set(newValue?.rawValue ?? "- default -", forKey: Keys.variant.rawValue) + // Any value other than the supported ones would cause `get` to return `nil` + defaults.set(newValue?.rawValue ?? -1, forKey: Keys.computeUnits.rawValue) } get { - let current = defaults.string(forKey: Keys.variant.rawValue) - return AttentionVariant(rawValue: current ?? "") + let current = defaults.integer(forKey: Keys.computeUnits.rawValue) + guard current != -1 else { return nil } + return ComputeUnits(rawValue: current) } } }