Allow all compute units to be selected (#30)

* Allow all compute units to be selected by the user.

* Remove commented code.

* Simplify labels.

* Remove warning

* Align picker left

* Apply suggestions from code review
main
Pedro Cuenca 1 year ago committed by GitHub
parent d69e554b46
commit 3b75e75118
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -57,7 +57,6 @@ struct ControlsView: View {
@State private var disclosedSteps = false @State private var disclosedSteps = false
@State private var disclosedSeed = false @State private var disclosedSeed = false
@State private var disclosedAdvanced = false @State private var disclosedAdvanced = false
@State private var useANE = (Settings.shared.userSelectedAttentionVariant ?? ModelInfo.defaultAttention) == .splitEinsum
// TODO: refactor download with similar code in Loading.swift (iOS) // TODO: refactor download with similar code in Loading.swift (iOS)
@State private var stateSubscriber: Cancellable? @State private var stateSubscriber: Cancellable?
@ -82,14 +81,18 @@ struct ControlsView: View {
mustShowSafetyCheckerDisclaimer = generation.disableSafety && !Settings.shared.safetyCheckerDisclaimerShown mustShowSafetyCheckerDisclaimer = generation.disableSafety && !Settings.shared.safetyCheckerDisclaimerShown
} }
func updateANEState() { func updateComputeUnitsState() {
Settings.shared.userSelectedAttentionVariant = useANE ? .splitEinsum : .original Settings.shared.userSelectedComputeUnits = generation.computeUnits
modelDidChange(model: Settings.shared.currentModel) modelDidChange(model: Settings.shared.currentModel)
} }
func resetComputeUnitsState() {
generation.computeUnits = Settings.shared.userSelectedComputeUnits ?? ModelInfo.defaultComputeUnits
}
func modelDidChange(model: ModelInfo) { func modelDidChange(model: ModelInfo) {
guard pipelineLoader?.model != model || pipelineLoader?.variant != Settings.shared.userSelectedAttentionVariant else { guard pipelineLoader?.model != model || pipelineLoader?.computeUnits != generation.computeUnits else {
print("Reusing same model \(model) with attention \(String(describing: Settings.shared.userSelectedAttentionVariant))") print("Reusing same model \(model) with units \(generation.computeUnits)")
return return
} }
@ -99,7 +102,7 @@ struct ControlsView: View {
pipelineLoader?.cancel() pipelineLoader?.cancel()
pipelineState = .downloading(0) pipelineState = .downloading(0)
Task.init { Task.init {
let loader = PipelineLoader(model: model, variant: Settings.shared.userSelectedAttentionVariant, maxSeed: maxSeed) let loader = PipelineLoader(model: model, computeUnits: generation.computeUnits, maxSeed: maxSeed)
self.pipelineLoader = loader self.pipelineLoader = loader
stateSubscriber = loader.statePublisher.sink { state in stateSubscriber = loader.statePublisher.sink { state in
DispatchQueue.main.async { DispatchQueue.main.async {
@ -128,8 +131,8 @@ struct ControlsView: View {
} }
} }
func isModelDownloaded(_ model: ModelInfo, variant: AttentionVariant? = nil) -> Bool { func isModelDownloaded(_ model: ModelInfo, computeUnits: ComputeUnits? = nil) -> Bool {
PipelineLoader(model: model, variant: variant ?? Settings.shared.userSelectedAttentionVariant).ready PipelineLoader(model: model, computeUnits: computeUnits ?? generation.computeUnits).ready
} }
func modelLabel(_ model: ModelInfo) -> Text { func modelLabel(_ model: ModelInfo) -> Text {
@ -301,21 +304,25 @@ struct ControlsView: View {
Divider() Divider()
DisclosureGroup(isExpanded: $disclosedAdvanced) { DisclosureGroup(isExpanded: $disclosedAdvanced) {
HStack { HStack {
Toggle("Use Neural Engine", isOn: $useANE).onChange(of: useANE) { value in Picker(selection: $generation.computeUnits, label: Text("Use")) {
guard let currentModel = ModelInfo.from(modelVersion: model) else { return } Text("GPU").tag(ComputeUnits.cpuAndGPU)
let variantDownloaded = isModelDownloaded(currentModel, variant: useANE ? .splitEinsum : .original) Text("Neural Engine").tag(ComputeUnits.cpuAndNeuralEngine)
if variantDownloaded { Text("GPU and Neural Engine").tag(ComputeUnits.all)
updateANEState() }.pickerStyle(.radioGroup).padding(.leading)
} else {
mustShowModelDownloadDisclaimer.toggle()
}
}
.padding(.leading, 10)
Spacer() Spacer()
} }
.onChange(of: generation.computeUnits) { units in
guard let currentModel = ModelInfo.from(modelVersion: model) else { return }
let variantDownloaded = isModelDownloaded(currentModel, computeUnits: units)
if variantDownloaded {
updateComputeUnitsState()
} else {
mustShowModelDownloadDisclaimer.toggle()
}
}
.alert("Download Required", isPresented: $mustShowModelDownloadDisclaimer, actions: { .alert("Download Required", isPresented: $mustShowModelDownloadDisclaimer, actions: {
Button("Cancel", role: .destructive) { useANE.toggle() } Button("Cancel", role: .destructive) { resetComputeUnitsState() }
Button("Download", role: .cancel) { updateANEState() } Button("Download", role: .cancel) { updateComputeUnitsState() }
}, message: { }, message: {
Text("This setting requires a new version of the selected model.") Text("This setting requires a new version of the selected model.")
}) })

@ -13,6 +13,10 @@ enum AttentionVariant: String {
case splitEinsum case splitEinsum
} }
extension AttentionVariant {
var defaultComputeUnits: MLComputeUnits { self == .original ? .cpuAndGPU : .cpuAndNeuralEngine }
}
struct ModelInfo { struct ModelInfo {
/// Hugging Face model Id that contains .zip archives with compiled Core ML models /// Hugging Face model Id that contains .zip archives with compiled Core ML models
let modelId: String let modelId: String
@ -39,6 +43,7 @@ struct ModelInfo {
} }
extension ModelInfo { extension ModelInfo {
//TODO: set compute units instead and derive variant from it
static var defaultAttention: AttentionVariant { static var defaultAttention: AttentionVariant {
guard runningOnMac else { return .splitEinsum } guard runningOnMac else { return .splitEinsum }
#if os(macOS) #if os(macOS)
@ -49,9 +54,10 @@ extension ModelInfo {
#endif #endif
} }
var bestAttention: AttentionVariant { static var defaultComputeUnits: MLComputeUnits { defaultAttention.defaultComputeUnits }
return ModelInfo.defaultAttention
} var bestAttention: AttentionVariant { ModelInfo.defaultAttention }
var defaultComputeUnits: MLComputeUnits { bestAttention.defaultComputeUnits }
func modelURL(for variant: AttentionVariant) -> URL { func modelURL(for variant: AttentionVariant) -> URL {
// Pattern: https://huggingface.co/pcuenq/coreml-stable-diffusion/resolve/main/coreml-stable-diffusion-v1-5_original_compiled.zip // Pattern: https://huggingface.co/pcuenq/coreml-stable-diffusion/resolve/main/coreml-stable-diffusion-v1-5_original_compiled.zip

@ -18,14 +18,14 @@ class PipelineLoader {
static let models = Path.applicationSupport / "hf-diffusion-models" static let models = Path.applicationSupport / "hf-diffusion-models"
let model: ModelInfo let model: ModelInfo
let variant: AttentionVariant let computeUnits: ComputeUnits
let maxSeed: UInt32 let maxSeed: UInt32
private var downloadSubscriber: Cancellable? private var downloadSubscriber: Cancellable?
init(model: ModelInfo, variant: AttentionVariant? = nil, maxSeed: UInt32 = UInt32.max) { init(model: ModelInfo, computeUnits: ComputeUnits? = nil, maxSeed: UInt32 = UInt32.max) {
self.model = model self.model = model
self.variant = variant ?? model.bestAttention self.computeUnits = computeUnits ?? model.defaultComputeUnits
self.maxSeed = maxSeed self.maxSeed = maxSeed
state = .undetermined state = .undetermined
setInitialState() setInitialState()
@ -98,11 +98,17 @@ extension PipelineLoader {
return compiledPath.exists return compiledPath.exists
} }
// TODO: measure performance on different devices, disassociate from variant var variant: AttentionVariant {
var computeUnits: MLComputeUnits { switch computeUnits {
variant == .original ? .cpuAndGPU : .cpuAndNeuralEngine case .cpuOnly : return .original // Not supported yet
case .cpuAndGPU : return .original
case .cpuAndNeuralEngine: return .splitEinsum
case .all : return .splitEinsum
@unknown default:
fatalError("Unknown MLComputeUnits")
}
} }
// TODO: maybe receive Progress to add another progress as child // TODO: maybe receive Progress to add another progress as child
func prepare() async throws -> Pipeline { func prepare() async throws -> Pipeline {
do { do {

@ -9,6 +9,7 @@
import Combine import Combine
import SwiftUI import SwiftUI
import StableDiffusion import StableDiffusion
import CoreML
let DEFAULT_MODEL = ModelInfo.v2Base let DEFAULT_MODEL = ModelInfo.v2Base
let DEFAULT_PROMPT = "Labrador in the style of Vermeer" let DEFAULT_PROMPT = "Labrador in the style of Vermeer"
@ -21,6 +22,8 @@ enum GenerationState {
case failed(Error) case failed(Error)
} }
typealias ComputeUnits = MLComputeUnits
class GenerationContext: ObservableObject { class GenerationContext: ObservableObject {
let scheduler = StableDiffusionScheduler.dpmSolverMultistepScheduler let scheduler = StableDiffusionScheduler.dpmSolverMultistepScheduler
@ -48,6 +51,8 @@ class GenerationContext: ObservableObject {
@Published var seed = -1.0 @Published var seed = -1.0
@Published var guidanceScale = 7.5 @Published var guidanceScale = 7.5
@Published var disableSafety = false @Published var disableSafety = false
@Published var computeUnits: ComputeUnits = Settings.shared.userSelectedComputeUnits ?? ModelInfo.defaultComputeUnits
private var progressSubscriber: Cancellable? private var progressSubscriber: Cancellable?
@ -78,14 +83,14 @@ class Settings {
enum Keys: String { enum Keys: String {
case model case model
case safetyCheckerDisclaimer case safetyCheckerDisclaimer
case variant case computeUnits
} }
private init() { private init() {
defaults.register(defaults: [ defaults.register(defaults: [
Keys.model.rawValue: ModelInfo.v2Base.modelId, Keys.model.rawValue: ModelInfo.v2Base.modelId,
Keys.safetyCheckerDisclaimer.rawValue: false, Keys.safetyCheckerDisclaimer.rawValue: false,
Keys.variant.rawValue: "- default -" Keys.computeUnits.rawValue: -1 // Use default
]) ])
} }
@ -109,15 +114,16 @@ class Settings {
} }
/// Returns the option selected by the user, if overridden /// Returns the option selected by the user, if overridden
/// `nil` means: guess best for this {model, device} /// `nil` means: guess best
var userSelectedAttentionVariant: AttentionVariant? { var userSelectedComputeUnits: ComputeUnits? {
set { set {
// Any String other than the supported ones would cause `get` to return `nil` // Any value other than the supported ones would cause `get` to return `nil`
defaults.set(newValue?.rawValue ?? "- default -", forKey: Keys.variant.rawValue) defaults.set(newValue?.rawValue ?? -1, forKey: Keys.computeUnits.rawValue)
} }
get { get {
let current = defaults.string(forKey: Keys.variant.rawValue) let current = defaults.integer(forKey: Keys.computeUnits.rawValue)
return AttentionVariant(rawValue: current ?? "") guard current != -1 else { return nil }
return ComputeUnits(rawValue: current)
} }
} }
} }

Loading…
Cancel
Save