You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

124 lines
3.6 KiB

//
// State.swift
// Diffusion
//
// Created by Pedro Cuenca on 17/1/23.
// See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE
//
import Combine
import SwiftUI
import StableDiffusion
let DEFAULT_MODEL = ModelInfo.v2Base
let DEFAULT_PROMPT = "Labrador in the style of Vermeer"
enum GenerationState {
case startup
case running(StableDiffusionProgress?)
case complete(String, CGImage?, UInt32, TimeInterval?)
case userCanceled
case failed(Error)
}
class GenerationContext: ObservableObject {
let scheduler = StableDiffusionScheduler.dpmSolverMultistepScheduler
@Published var pipeline: Pipeline? = nil {
didSet {
if let pipeline = pipeline {
progressSubscriber = pipeline
.progressPublisher
.receive(on: DispatchQueue.main)
.sink { progress in
guard let progress = progress else { return }
self.state = .running(progress)
}
}
}
}
@Published var state: GenerationState = .startup
@Published var positivePrompt = DEFAULT_PROMPT
@Published var negativePrompt = ""
// FIXME: Double to support the slider component
@Published var steps = 25.0
@Published var numImages = 1.0
@Published var seed = -1.0
@Published var guidanceScale = 7.5
@Published var disableSafety = false
private var progressSubscriber: Cancellable?
func generate() async throws -> GenerationResult {
guard let pipeline = pipeline else { throw "No pipeline" }
let seed = self.seed >= 0 ? UInt32(self.seed) : nil
return try pipeline.generate(
prompt: positivePrompt,
negativePrompt: negativePrompt,
scheduler: scheduler,
numInferenceSteps: Int(steps),
seed: seed,
guidanceScale: Float(guidanceScale),
disableSafety: disableSafety
)
}
func cancelGeneration() {
pipeline?.setCancelled()
}
}
class Settings {
static let shared = Settings()
let defaults = UserDefaults.standard
enum Keys: String {
case model
case safetyCheckerDisclaimer
case variant
}
private init() {
defaults.register(defaults: [
Keys.model.rawValue: ModelInfo.v2Base.modelId,
Keys.safetyCheckerDisclaimer.rawValue: false,
Keys.variant.rawValue: "- default -"
])
}
var currentModel: ModelInfo {
set {
defaults.set(newValue.modelId, forKey: Keys.model.rawValue)
}
get {
guard let modelId = defaults.string(forKey: Keys.model.rawValue) else { return DEFAULT_MODEL }
return ModelInfo.from(modelId: modelId) ?? DEFAULT_MODEL
}
}
var safetyCheckerDisclaimerShown: Bool {
set {
defaults.set(newValue, forKey: Keys.safetyCheckerDisclaimer.rawValue)
}
get {
return defaults.bool(forKey: Keys.safetyCheckerDisclaimer.rawValue)
}
}
/// Returns the option selected by the user, if overridden
/// `nil` means: guess best for this {model, device}
var userSelectedAttentionVariant: AttentionVariant? {
set {
// Any String other than the supported ones would cause `get` to return `nil`
defaults.set(newValue?.rawValue ?? "- default -", forKey: Keys.variant.rawValue)
}
get {
let current = defaults.string(forKey: Keys.variant.rawValue)
return AttentionVariant(rawValue: current ?? "")
}
}
}