You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

130 lines
3.8 KiB

//
// State.swift
// Diffusion
//
// Created by Pedro Cuenca on 17/1/23.
// See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE
//
import Combine
import SwiftUI
import StableDiffusion
import CoreML
let DEFAULT_MODEL = ModelInfo.v2Base
let DEFAULT_PROMPT = "Labrador in the style of Vermeer"
enum GenerationState {
case startup
case running(StableDiffusionProgress?)
case complete(String, CGImage?, UInt32, TimeInterval?)
case userCanceled
case failed(Error)
}
typealias ComputeUnits = MLComputeUnits
class GenerationContext: ObservableObject {
let scheduler = StableDiffusionScheduler.dpmSolverMultistepScheduler
@Published var pipeline: Pipeline? = nil {
didSet {
if let pipeline = pipeline {
progressSubscriber = pipeline
.progressPublisher
.receive(on: DispatchQueue.main)
.sink { progress in
guard let progress = progress else { return }
self.state = .running(progress)
}
}
}
}
@Published var state: GenerationState = .startup
@Published var positivePrompt = DEFAULT_PROMPT
@Published var negativePrompt = ""
// FIXME: Double to support the slider component
@Published var steps = 25.0
@Published var numImages = 1.0
@Published var seed = -1.0
@Published var guidanceScale = 7.5
@Published var disableSafety = false
@Published var computeUnits: ComputeUnits = Settings.shared.userSelectedComputeUnits ?? ModelInfo.defaultComputeUnits
private var progressSubscriber: Cancellable?
func generate() async throws -> GenerationResult {
guard let pipeline = pipeline else { throw "No pipeline" }
let seed = self.seed >= 0 ? UInt32(self.seed) : nil
return try pipeline.generate(
prompt: positivePrompt,
negativePrompt: negativePrompt,
scheduler: scheduler,
numInferenceSteps: Int(steps),
seed: seed,
guidanceScale: Float(guidanceScale),
disableSafety: disableSafety
)
}
func cancelGeneration() {
pipeline?.setCancelled()
}
}
class Settings {
static let shared = Settings()
let defaults = UserDefaults.standard
enum Keys: String {
case model
case safetyCheckerDisclaimer
case computeUnits
}
private init() {
defaults.register(defaults: [
Keys.model.rawValue: ModelInfo.v2Base.modelId,
Keys.safetyCheckerDisclaimer.rawValue: false,
Keys.computeUnits.rawValue: -1 // Use default
])
}
var currentModel: ModelInfo {
set {
defaults.set(newValue.modelId, forKey: Keys.model.rawValue)
}
get {
guard let modelId = defaults.string(forKey: Keys.model.rawValue) else { return DEFAULT_MODEL }
return ModelInfo.from(modelId: modelId) ?? DEFAULT_MODEL
}
}
var safetyCheckerDisclaimerShown: Bool {
set {
defaults.set(newValue, forKey: Keys.safetyCheckerDisclaimer.rawValue)
}
get {
return defaults.bool(forKey: Keys.safetyCheckerDisclaimer.rawValue)
}
}
/// Returns the option selected by the user, if overridden
/// `nil` means: guess best
var userSelectedComputeUnits: ComputeUnits? {
set {
// Any value other than the supported ones would cause `get` to return `nil`
defaults.set(newValue?.rawValue ?? -1, forKey: Keys.computeUnits.rawValue)
}
get {
let current = defaults.integer(forKey: Keys.computeUnits.rawValue)
guard current != -1 else { return nil }
return ComputeUnits(rawValue: current)
}
}
}