Allow all compute units to be selected (#30)

* Allow all compute units to be selected by the user.

* Remove commented code.

* Simplify labels.

* Remove warning

* Align picker left

* Apply suggestions from code review
main
Pedro Cuenca 2 years ago committed by GitHub
parent d69e554b46
commit 3b75e75118
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -57,7 +57,6 @@ struct ControlsView: View {
@State private var disclosedSteps = false
@State private var disclosedSeed = false
@State private var disclosedAdvanced = false
@State private var useANE = (Settings.shared.userSelectedAttentionVariant ?? ModelInfo.defaultAttention) == .splitEinsum
// TODO: refactor download with similar code in Loading.swift (iOS)
@State private var stateSubscriber: Cancellable?
@ -82,14 +81,18 @@ struct ControlsView: View {
mustShowSafetyCheckerDisclaimer = generation.disableSafety && !Settings.shared.safetyCheckerDisclaimerShown
}
func updateANEState() {
Settings.shared.userSelectedAttentionVariant = useANE ? .splitEinsum : .original
func updateComputeUnitsState() {
Settings.shared.userSelectedComputeUnits = generation.computeUnits
modelDidChange(model: Settings.shared.currentModel)
}
func resetComputeUnitsState() {
generation.computeUnits = Settings.shared.userSelectedComputeUnits ?? ModelInfo.defaultComputeUnits
}
func modelDidChange(model: ModelInfo) {
guard pipelineLoader?.model != model || pipelineLoader?.variant != Settings.shared.userSelectedAttentionVariant else {
print("Reusing same model \(model) with attention \(String(describing: Settings.shared.userSelectedAttentionVariant))")
guard pipelineLoader?.model != model || pipelineLoader?.computeUnits != generation.computeUnits else {
print("Reusing same model \(model) with units \(generation.computeUnits)")
return
}
@ -99,7 +102,7 @@ struct ControlsView: View {
pipelineLoader?.cancel()
pipelineState = .downloading(0)
Task.init {
let loader = PipelineLoader(model: model, variant: Settings.shared.userSelectedAttentionVariant, maxSeed: maxSeed)
let loader = PipelineLoader(model: model, computeUnits: generation.computeUnits, maxSeed: maxSeed)
self.pipelineLoader = loader
stateSubscriber = loader.statePublisher.sink { state in
DispatchQueue.main.async {
@ -128,8 +131,8 @@ struct ControlsView: View {
}
}
func isModelDownloaded(_ model: ModelInfo, variant: AttentionVariant? = nil) -> Bool {
PipelineLoader(model: model, variant: variant ?? Settings.shared.userSelectedAttentionVariant).ready
func isModelDownloaded(_ model: ModelInfo, computeUnits: ComputeUnits? = nil) -> Bool {
PipelineLoader(model: model, computeUnits: computeUnits ?? generation.computeUnits).ready
}
func modelLabel(_ model: ModelInfo) -> Text {
@ -301,21 +304,25 @@ struct ControlsView: View {
Divider()
DisclosureGroup(isExpanded: $disclosedAdvanced) {
HStack {
Toggle("Use Neural Engine", isOn: $useANE).onChange(of: useANE) { value in
Picker(selection: $generation.computeUnits, label: Text("Use")) {
Text("GPU").tag(ComputeUnits.cpuAndGPU)
Text("Neural Engine").tag(ComputeUnits.cpuAndNeuralEngine)
Text("GPU and Neural Engine").tag(ComputeUnits.all)
}.pickerStyle(.radioGroup).padding(.leading)
Spacer()
}
.onChange(of: generation.computeUnits) { units in
guard let currentModel = ModelInfo.from(modelVersion: model) else { return }
let variantDownloaded = isModelDownloaded(currentModel, variant: useANE ? .splitEinsum : .original)
let variantDownloaded = isModelDownloaded(currentModel, computeUnits: units)
if variantDownloaded {
updateANEState()
updateComputeUnitsState()
} else {
mustShowModelDownloadDisclaimer.toggle()
}
}
.padding(.leading, 10)
Spacer()
}
.alert("Download Required", isPresented: $mustShowModelDownloadDisclaimer, actions: {
Button("Cancel", role: .destructive) { useANE.toggle() }
Button("Download", role: .cancel) { updateANEState() }
Button("Cancel", role: .destructive) { resetComputeUnitsState() }
Button("Download", role: .cancel) { updateComputeUnitsState() }
}, message: {
Text("This setting requires a new version of the selected model.")
})

@ -13,6 +13,10 @@ enum AttentionVariant: String {
case splitEinsum
}
extension AttentionVariant {
var defaultComputeUnits: MLComputeUnits { self == .original ? .cpuAndGPU : .cpuAndNeuralEngine }
}
struct ModelInfo {
/// Hugging Face model Id that contains .zip archives with compiled Core ML models
let modelId: String
@ -39,6 +43,7 @@ struct ModelInfo {
}
extension ModelInfo {
//TODO: set compute units instead and derive variant from it
static var defaultAttention: AttentionVariant {
guard runningOnMac else { return .splitEinsum }
#if os(macOS)
@ -49,9 +54,10 @@ extension ModelInfo {
#endif
}
var bestAttention: AttentionVariant {
return ModelInfo.defaultAttention
}
static var defaultComputeUnits: MLComputeUnits { defaultAttention.defaultComputeUnits }
var bestAttention: AttentionVariant { ModelInfo.defaultAttention }
var defaultComputeUnits: MLComputeUnits { bestAttention.defaultComputeUnits }
func modelURL(for variant: AttentionVariant) -> URL {
// Pattern: https://huggingface.co/pcuenq/coreml-stable-diffusion/resolve/main/coreml-stable-diffusion-v1-5_original_compiled.zip

@ -18,14 +18,14 @@ class PipelineLoader {
static let models = Path.applicationSupport / "hf-diffusion-models"
let model: ModelInfo
let variant: AttentionVariant
let computeUnits: ComputeUnits
let maxSeed: UInt32
private var downloadSubscriber: Cancellable?
init(model: ModelInfo, variant: AttentionVariant? = nil, maxSeed: UInt32 = UInt32.max) {
init(model: ModelInfo, computeUnits: ComputeUnits? = nil, maxSeed: UInt32 = UInt32.max) {
self.model = model
self.variant = variant ?? model.bestAttention
self.computeUnits = computeUnits ?? model.defaultComputeUnits
self.maxSeed = maxSeed
state = .undetermined
setInitialState()
@ -98,9 +98,15 @@ extension PipelineLoader {
return compiledPath.exists
}
// TODO: measure performance on different devices, disassociate from variant
var computeUnits: MLComputeUnits {
variant == .original ? .cpuAndGPU : .cpuAndNeuralEngine
var variant: AttentionVariant {
switch computeUnits {
case .cpuOnly : return .original // Not supported yet
case .cpuAndGPU : return .original
case .cpuAndNeuralEngine: return .splitEinsum
case .all : return .splitEinsum
@unknown default:
fatalError("Unknown MLComputeUnits")
}
}
// TODO: maybe receive Progress to add another progress as child

@ -9,6 +9,7 @@
import Combine
import SwiftUI
import StableDiffusion
import CoreML
let DEFAULT_MODEL = ModelInfo.v2Base
let DEFAULT_PROMPT = "Labrador in the style of Vermeer"
@ -21,6 +22,8 @@ enum GenerationState {
case failed(Error)
}
typealias ComputeUnits = MLComputeUnits
class GenerationContext: ObservableObject {
let scheduler = StableDiffusionScheduler.dpmSolverMultistepScheduler
@ -49,6 +52,8 @@ class GenerationContext: ObservableObject {
@Published var guidanceScale = 7.5
@Published var disableSafety = false
@Published var computeUnits: ComputeUnits = Settings.shared.userSelectedComputeUnits ?? ModelInfo.defaultComputeUnits
private var progressSubscriber: Cancellable?
func generate() async throws -> GenerationResult {
@ -78,14 +83,14 @@ class Settings {
enum Keys: String {
case model
case safetyCheckerDisclaimer
case variant
case computeUnits
}
private init() {
defaults.register(defaults: [
Keys.model.rawValue: ModelInfo.v2Base.modelId,
Keys.safetyCheckerDisclaimer.rawValue: false,
Keys.variant.rawValue: "- default -"
Keys.computeUnits.rawValue: -1 // Use default
])
}
@ -109,15 +114,16 @@ class Settings {
}
/// Returns the option selected by the user, if overridden
/// `nil` means: guess best for this {model, device}
var userSelectedAttentionVariant: AttentionVariant? {
/// `nil` means: guess best
var userSelectedComputeUnits: ComputeUnits? {
set {
// Any String other than the supported ones would cause `get` to return `nil`
defaults.set(newValue?.rawValue ?? "- default -", forKey: Keys.variant.rawValue)
// Any value other than the supported ones would cause `get` to return `nil`
defaults.set(newValue?.rawValue ?? -1, forKey: Keys.computeUnits.rawValue)
}
get {
let current = defaults.string(forKey: Keys.variant.rawValue)
return AttentionVariant(rawValue: current ?? "")
let current = defaults.integer(forKey: Keys.computeUnits.rawValue)
guard current != -1 else { return nil }
return ComputeUnits(rawValue: current)
}
}
}

Loading…
Cancel
Save