From 04b918ae49c9f6c760aaa6851528df670a20a6f3 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 22 Feb 2023 12:16:33 +0100 Subject: [PATCH] Allow all compute units to be selected by the user. --- Diffusion-macOS/ControlsView.swift | 59 ++++++++++++++++--------- Diffusion/ModelInfo.swift | 12 +++-- Diffusion/Pipeline/PipelineLoader.swift | 18 +++++--- Diffusion/State.swift | 22 +++++---- 4 files changed, 73 insertions(+), 38 deletions(-) diff --git a/Diffusion-macOS/ControlsView.swift b/Diffusion-macOS/ControlsView.swift index 04d26f1..c6295f0 100644 --- a/Diffusion-macOS/ControlsView.swift +++ b/Diffusion-macOS/ControlsView.swift @@ -57,7 +57,6 @@ struct ControlsView: View { @State private var disclosedSteps = false @State private var disclosedSeed = false @State private var disclosedAdvanced = false - @State private var useANE = (Settings.shared.userSelectedAttentionVariant ?? ModelInfo.defaultAttention) == .splitEinsum // TODO: refactor download with similar code in Loading.swift (iOS) @State private var stateSubscriber: Cancellable? @@ -82,13 +81,17 @@ struct ControlsView: View { mustShowSafetyCheckerDisclaimer = generation.disableSafety && !Settings.shared.safetyCheckerDisclaimerShown } - func updateANEState() { - Settings.shared.userSelectedAttentionVariant = useANE ? .splitEinsum : .original + func updateComputeUnitsState() { + Settings.shared.userSelectedComputeUnits = generation.computeUnits modelDidChange(model: Settings.shared.currentModel) } + func resetComputeUnitsState() { + generation.computeUnits = Settings.shared.userSelectedComputeUnits ?? ModelInfo.defaultComputeUnits + } + func modelDidChange(model: ModelInfo) { - guard pipelineLoader?.model != model else { return } + guard pipelineLoader?.model != model && pipelineLoader?.computeUnits != generation.computeUnits else { return } print("Loading model \(model)") Settings.shared.currentModel = model @@ -96,7 +99,7 @@ struct ControlsView: View { pipelineLoader?.cancel() pipelineState = .downloading(0) Task.init { - let loader = PipelineLoader(model: model, variant: Settings.shared.userSelectedAttentionVariant, maxSeed: maxSeed) + let loader = PipelineLoader(model: model, computeUnits: generation.computeUnits /* Or read from settings */, maxSeed: maxSeed) self.pipelineLoader = loader stateSubscriber = loader.statePublisher.sink { state in DispatchQueue.main.async { @@ -125,8 +128,8 @@ struct ControlsView: View { } } - func isModelDownloaded(_ model: ModelInfo, variant: AttentionVariant? = nil) -> Bool { - PipelineLoader(model: model, variant: variant ?? Settings.shared.userSelectedAttentionVariant).ready + func isModelDownloaded(_ model: ModelInfo, computeUnits: ComputeUnits? = nil) -> Bool { + PipelineLoader(model: model, computeUnits: computeUnits ?? generation.computeUnits /*Settings.shared.userSelectedAttentionVariant*/).ready } func modelLabel(_ model: ModelInfo) -> Text { @@ -297,22 +300,38 @@ struct ControlsView: View { if Capabilities.hasANE { Divider() DisclosureGroup(isExpanded: $disclosedAdvanced) { - HStack { - Toggle("Use Neural Engine", isOn: $useANE).onChange(of: useANE) { value in - guard let currentModel = ModelInfo.from(modelVersion: model) else { return } - let variantDownloaded = isModelDownloaded(currentModel, variant: useANE ? .splitEinsum : .original) - if variantDownloaded { - updateANEState() - } else { - mustShowModelDownloadDisclaimer.toggle() - } + Picker(selection: $generation.computeUnits, label: Text("Use")) { + Text("CPU and GPU").tag(ComputeUnits.cpuAndGPU) + Text("CPU and Neural Engine").tag(ComputeUnits.cpuAndNeuralEngine) + Text("CPU, GPU and NE").tag(ComputeUnits.all) + } + .pickerStyle(.radioGroup) + .onChange(of: generation.computeUnits) { units in + guard let currentModel = ModelInfo.from(modelVersion: model) else { return } + let variantDownloaded = isModelDownloaded(currentModel, computeUnits: units) + if variantDownloaded { + updateComputeUnitsState() + } else { + mustShowModelDownloadDisclaimer.toggle() } - .padding(.leading, 10) - Spacer() } +// HStack { +// Toggle("Use GPU and Neural Engine", isOn: $useANE).onChange(of: useANE) { value in +// guard let currentModel = ModelInfo.from(modelVersion: model) else { return } +// let variantDownloaded = isModelDownloaded(currentModel, variant: useANE ? .splitEinsum : .original) +// if variantDownloaded { +// updateANEState() +// } else { +// mustShowModelDownloadDisclaimer.toggle() +// } +// } +// .padding(.leading, 10) +// Spacer() +// } .alert("Download Required", isPresented: $mustShowModelDownloadDisclaimer, actions: { - Button("Cancel", role: .destructive) { useANE.toggle() } - Button("Download", role: .cancel) { updateANEState() } + //TODO: FIXME + Button("Cancel", role: .destructive) { resetComputeUnitsState() } + Button("Download", role: .cancel) { updateComputeUnitsState() } }, message: { Text("This setting requires a new version of the selected model.") }) diff --git a/Diffusion/ModelInfo.swift b/Diffusion/ModelInfo.swift index eab1ee4..843a853 100644 --- a/Diffusion/ModelInfo.swift +++ b/Diffusion/ModelInfo.swift @@ -13,6 +13,10 @@ enum AttentionVariant: String { case splitEinsum } +extension AttentionVariant { + var defaultComputeUnits: MLComputeUnits { self == .original ? .cpuAndGPU : .cpuAndNeuralEngine } +} + struct ModelInfo { /// Hugging Face model Id that contains .zip archives with compiled Core ML models let modelId: String @@ -39,6 +43,7 @@ struct ModelInfo { } extension ModelInfo { + //TODO: set compute units instead and derive variant from it static var defaultAttention: AttentionVariant { guard runningOnMac else { return .splitEinsum } #if os(macOS) @@ -49,9 +54,10 @@ extension ModelInfo { #endif } - var bestAttention: AttentionVariant { - return ModelInfo.defaultAttention - } + static var defaultComputeUnits: MLComputeUnits { defaultAttention.defaultComputeUnits } + + var bestAttention: AttentionVariant { ModelInfo.defaultAttention } + var defaultComputeUnits: MLComputeUnits { bestAttention.defaultComputeUnits } func modelURL(for variant: AttentionVariant) -> URL { // Pattern: https://huggingface.co/pcuenq/coreml-stable-diffusion/resolve/main/coreml-stable-diffusion-v1-5_original_compiled.zip diff --git a/Diffusion/Pipeline/PipelineLoader.swift b/Diffusion/Pipeline/PipelineLoader.swift index 3eafff2..8d1c190 100644 --- a/Diffusion/Pipeline/PipelineLoader.swift +++ b/Diffusion/Pipeline/PipelineLoader.swift @@ -18,14 +18,14 @@ class PipelineLoader { static let models = Path.applicationSupport / "hf-diffusion-models" let model: ModelInfo - let variant: AttentionVariant + let computeUnits: ComputeUnits let maxSeed: UInt32 private var downloadSubscriber: Cancellable? - init(model: ModelInfo, variant: AttentionVariant? = nil, maxSeed: UInt32 = UInt32.max) { + init(model: ModelInfo, computeUnits: ComputeUnits? = nil, maxSeed: UInt32 = UInt32.max) { self.model = model - self.variant = variant ?? model.bestAttention + self.computeUnits = computeUnits ?? model.defaultComputeUnits self.maxSeed = maxSeed state = .undetermined setInitialState() @@ -98,11 +98,15 @@ extension PipelineLoader { return compiledPath.exists } - // TODO: measure performance on different devices, disassociate from variant - var computeUnits: MLComputeUnits { - variant == .original ? .cpuAndGPU : .cpuAndNeuralEngine + var variant: AttentionVariant { + switch computeUnits { + case .cpuOnly : return .original // Not supported yet + case .cpuAndGPU : return .original + case .cpuAndNeuralEngine: return .splitEinsum + case .all : return .splitEinsum + } } - + // TODO: maybe receive Progress to add another progress as child func prepare() async throws -> Pipeline { do { diff --git a/Diffusion/State.swift b/Diffusion/State.swift index 12381c5..c9b2dfd 100644 --- a/Diffusion/State.swift +++ b/Diffusion/State.swift @@ -9,6 +9,7 @@ import Combine import SwiftUI import StableDiffusion +import CoreML let DEFAULT_MODEL = ModelInfo.v2Base let DEFAULT_PROMPT = "Labrador in the style of Vermeer" @@ -21,6 +22,8 @@ enum GenerationState { case failed(Error) } +typealias ComputeUnits = MLComputeUnits + class GenerationContext: ObservableObject { let scheduler = StableDiffusionScheduler.dpmSolverMultistepScheduler @@ -48,6 +51,8 @@ class GenerationContext: ObservableObject { @Published var seed = -1.0 @Published var guidanceScale = 7.5 @Published var disableSafety = false + + @Published var computeUnits: ComputeUnits = Settings.shared.userSelectedComputeUnits ?? ModelInfo.defaultComputeUnits private var progressSubscriber: Cancellable? @@ -78,14 +83,14 @@ class Settings { enum Keys: String { case model case safetyCheckerDisclaimer - case variant + case computeUnits } private init() { defaults.register(defaults: [ Keys.model.rawValue: ModelInfo.v2Base.modelId, Keys.safetyCheckerDisclaimer.rawValue: false, - Keys.variant.rawValue: "- default -" + Keys.computeUnits.rawValue: -1 // Use default ]) } @@ -109,15 +114,16 @@ class Settings { } /// Returns the option selected by the user, if overridden - /// `nil` means: guess best for this {model, device} - var userSelectedAttentionVariant: AttentionVariant? { + /// `nil` means: guess best + var userSelectedComputeUnits: ComputeUnits? { set { - // Any String other than the supported ones would cause `get` to return `nil` - defaults.set(newValue?.rawValue ?? "- default -", forKey: Keys.variant.rawValue) + // Any value other than the supported ones would cause `get` to return `nil` + defaults.set(newValue?.rawValue ?? -1, forKey: Keys.computeUnits.rawValue) } get { - let current = defaults.string(forKey: Keys.variant.rawValue) - return AttentionVariant(rawValue: current ?? "") + let current = defaults.integer(forKey: Keys.computeUnits.rawValue) + guard current != -1 else { return nil } + return ComputeUnits(rawValue: current) } } }