You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

386 lines
18 KiB

//
// PromptView.swift
// Diffusion-macOS
//
// Created by Cyril Zakka on 1/12/23.
// See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE
//
import Combine
import SwiftUI
import CompactSlider
enum PipelineState {
case downloading(Double)
case uncompressing
case loading
case ready
case failed(Error)
}
/// Mimics the native appearance, but labels are clickable.
/// To be removed (adding gestures to all labels) if we observe any UI shenanigans.
struct LabelToggleDisclosureGroupStyle: DisclosureGroupStyle {
func makeBody(configuration: Configuration) -> some View {
VStack {
HStack {
Button {
withAnimation {
configuration.isExpanded.toggle()
}
} label: {
Image(systemName: configuration.isExpanded ? "chevron.down" : "chevron.right").frame(width:8, height: 8)
}.buttonStyle(.plain).font(.footnote).fontWeight(.semibold).foregroundColor(.gray)
configuration.label.onTapGesture {
withAnimation {
configuration.isExpanded.toggle()
}
}
Spacer()
}
if configuration.isExpanded {
configuration.content
}
}
}
}
struct ControlsView: View {
@EnvironmentObject var generation: GenerationContext
static let models = ModelInfo.MODELS
@State private var model = Settings.shared.currentModel.modelVersion
@State private var disclosedModel = true
@State private var disclosedPrompt = true
@State private var disclosedGuidance = false
@State private var disclosedSteps = false
@State private var disclosedSeed = false
@State private var disclosedAdvanced = false
// TODO: refactor download with similar code in Loading.swift (iOS)
@State private var stateSubscriber: Cancellable?
@State private var pipelineState: PipelineState = .downloading(0)
@State private var pipelineLoader: PipelineLoader? = nil
// TODO: make this computed, and observable, and easy to read
@State private var mustShowSafetyCheckerDisclaimer = false
@State private var mustShowModelDownloadDisclaimer = false // When changing advanced settings
@State private var showModelsHelp = false
@State private var showPromptsHelp = false
@State private var showGuidanceHelp = false
@State private var showStepsHelp = false
@State private var showSeedHelp = false
@State private var showAdvancedHelp = false
// Reasonable range for the slider
let maxSeed: UInt32 = 1000
func updateSafetyCheckerState() {
mustShowSafetyCheckerDisclaimer = generation.disableSafety && !Settings.shared.safetyCheckerDisclaimerShown
}
func updateComputeUnitsState() {
Settings.shared.userSelectedComputeUnits = generation.computeUnits
modelDidChange(model: Settings.shared.currentModel)
}
func resetComputeUnitsState() {
generation.computeUnits = Settings.shared.userSelectedComputeUnits ?? ModelInfo.defaultComputeUnits
}
func modelDidChange(model: ModelInfo) {
guard pipelineLoader?.model != model || pipelineLoader?.computeUnits != generation.computeUnits else {
print("Reusing same model \(model) with units \(generation.computeUnits)")
return
}
print("Loading model \(model)")
Settings.shared.currentModel = model
pipelineLoader?.cancel()
pipelineState = .downloading(0)
Task.init {
let loader = PipelineLoader(model: model, computeUnits: generation.computeUnits, maxSeed: maxSeed)
self.pipelineLoader = loader
stateSubscriber = loader.statePublisher.sink { state in
DispatchQueue.main.async {
switch state {
case .downloading(let progress):
print("\(loader.model.modelVersion): \(progress)")
pipelineState = .downloading(progress)
case .uncompressing:
pipelineState = .uncompressing
case .readyOnDisk:
pipelineState = .loading
case .failed(let error):
pipelineState = .failed(error)
default:
break
}
}
}
do {
generation.pipeline = try await loader.prepare()
pipelineState = .ready
} catch {
print("Could not load model, error: \(error)")
pipelineState = .failed(error)
}
}
}
func isModelDownloaded(_ model: ModelInfo, computeUnits: ComputeUnits? = nil) -> Bool {
PipelineLoader(model: model, computeUnits: computeUnits ?? generation.computeUnits).ready
}
func modelLabel(_ model: ModelInfo) -> Text {
let downloaded = isModelDownloaded(model)
let prefix = downloaded ? "" : "" //" "
return Text(prefix).foregroundColor(downloaded ? .accentColor : .secondary) + Text(model.modelVersion)
}
var body: some View {
VStack(alignment: .leading) {
Label("Generation Options", systemImage: "gearshape.2")
.font(.headline)
.fontWeight(.bold)
Divider()
ScrollView {
Group {
DisclosureGroup(isExpanded: $disclosedModel) {
let revealOption = "-- reveal --"
Picker("", selection: $model) {
ForEach(Self.models, id: \.modelVersion) {
modelLabel($0)
}
Text("Reveal in Finder…").tag(revealOption)
}
.onChange(of: model) { selection in
guard selection != revealOption else {
let selected = pipelineLoader?.compiledPath
NSWorkspace.shared.selectFile(selected?.string, inFileViewerRootedAtPath: PipelineLoader.models.string)
model = Settings.shared.currentModel.modelVersion
return
}
guard let model = ModelInfo.from(modelVersion: selection) else { return }
modelDidChange(model: model)
}
} label: {
HStack {
Label("Model from Hub", systemImage: "cpu").foregroundColor(.secondary)
Spacer()
if disclosedModel {
Button {
showModelsHelp.toggle()
} label: {
Image(systemName: "info.circle")
}
.buttonStyle(.plain)
// Or maybe use .sheet instead
.sheet(isPresented: $showModelsHelp) {
modelsHelp($showModelsHelp)
}
}
}.foregroundColor(.secondary)
}
Divider()
DisclosureGroup(isExpanded: $disclosedPrompt) {
Group {
TextField("Positive prompt", text: $generation.positivePrompt,
axis: .vertical).lineLimit(5)
.textFieldStyle(.squareBorder)
.listRowInsets(EdgeInsets(top: 0, leading: -20, bottom: 0, trailing: 20))
TextField("Negative prompt", text: $generation.negativePrompt,
axis: .vertical).lineLimit(5)
.textFieldStyle(.squareBorder)
}.padding(.leading, 10)
} label: {
HStack {
Label("Prompts", systemImage: "text.quote").foregroundColor(.secondary)
Spacer()
if disclosedPrompt {
Button {
showPromptsHelp.toggle()
} label: {
Image(systemName: "info.circle")
}
.buttonStyle(.plain)
// Or maybe use .sheet instead
.popover(isPresented: $showPromptsHelp, arrowEdge: .trailing) {
promptsHelp($showPromptsHelp)
}
}
}.foregroundColor(.secondary)
}
Divider()
let guidanceScaleValue = generation.guidanceScale.formatted("%.1f")
DisclosureGroup(isExpanded: $disclosedGuidance) {
CompactSlider(value: $generation.guidanceScale, in: 0...20, step: 0.5) {
Text("Guidance Scale")
Spacer()
Text(guidanceScaleValue)
}.padding(.leading, 10)
} label: {
HStack {
Label("Guidance Scale", systemImage: "scalemass").foregroundColor(.secondary)
Spacer()
if disclosedGuidance {
Button {
showGuidanceHelp.toggle()
} label: {
Image(systemName: "info.circle")
}
.buttonStyle(.plain)
// Or maybe use .sheet instead
.popover(isPresented: $showGuidanceHelp, arrowEdge: .trailing) {
guidanceHelp($showGuidanceHelp)
}
} else {
Text(guidanceScaleValue)
}
}.foregroundColor(.secondary)
}
DisclosureGroup(isExpanded: $disclosedSteps) {
CompactSlider(value: $generation.steps, in: 0...150, step: 5) {
Text("Steps")
Spacer()
Text("\(Int(generation.steps))")
}.padding(.leading, 10)
} label: {
HStack {
Label("Step count", systemImage: "square.3.layers.3d.down.left").foregroundColor(.secondary)
Spacer()
if disclosedSteps {
Button {
showStepsHelp.toggle()
} label: {
Image(systemName: "info.circle")
}
.buttonStyle(.plain)
.popover(isPresented: $showStepsHelp, arrowEdge: .trailing) {
stepsHelp($showStepsHelp)
}
} else {
Text("\(Int(generation.steps))")
}
}.foregroundColor(.secondary)
}
DisclosureGroup(isExpanded: $disclosedSeed) {
let sliderLabel = generation.seed < 0 ? "Random Seed" : "Seed"
CompactSlider(value: $generation.seed, in: -1...Double(maxSeed), step: 1) {
Text(sliderLabel)
Spacer()
Text("\(Int(generation.seed))")
}.padding(.leading, 10)
} label: {
HStack {
Label("Seed", systemImage: "leaf").foregroundColor(.secondary)
Spacer()
if disclosedSeed {
Button {
showSeedHelp.toggle()
} label: {
Image(systemName: "info.circle")
}
.buttonStyle(.plain)
.popover(isPresented: $showSeedHelp, arrowEdge: .trailing) {
seedHelp($showSeedHelp)
}
} else {
Text("\(Int(generation.seed))")
}
}.foregroundColor(.secondary)
}
if Capabilities.hasANE {
Divider()
DisclosureGroup(isExpanded: $disclosedAdvanced) {
HStack {
Picker(selection: $generation.computeUnits, label: Text("Use")) {
Text("GPU").tag(ComputeUnits.cpuAndGPU)
Text("Neural Engine").tag(ComputeUnits.cpuAndNeuralEngine)
Text("GPU and Neural Engine").tag(ComputeUnits.all)
}.pickerStyle(.radioGroup).padding(.leading)
Spacer()
}
.onChange(of: generation.computeUnits) { units in
guard let currentModel = ModelInfo.from(modelVersion: model) else { return }
let variantDownloaded = isModelDownloaded(currentModel, computeUnits: units)
if variantDownloaded {
updateComputeUnitsState()
} else {
mustShowModelDownloadDisclaimer.toggle()
}
}
.alert("Download Required", isPresented: $mustShowModelDownloadDisclaimer, actions: {
Button("Cancel", role: .destructive) { resetComputeUnitsState() }
Button("Download", role: .cancel) { updateComputeUnitsState() }
}, message: {
Text("This setting requires a new version of the selected model.")
})
} label: {
HStack {
Label("Advanced", systemImage: "terminal").foregroundColor(.secondary)
Spacer()
if disclosedAdvanced {
Button {
showAdvancedHelp.toggle()
} label: {
Image(systemName: "info.circle")
}
.buttonStyle(.plain)
.popover(isPresented: $showAdvancedHelp, arrowEdge: .trailing) {
advancedHelp($showAdvancedHelp)
}
}
}.foregroundColor(.secondary)
}
}
}
}
.disclosureGroupStyle(LabelToggleDisclosureGroupStyle())
Toggle("Disable Safety Checker", isOn: $generation.disableSafety).onChange(of: generation.disableSafety) { value in
updateSafetyCheckerState()
}
.popover(isPresented: $mustShowSafetyCheckerDisclaimer) {
VStack {
Text("You have disabled the safety checker").font(.title).padding(.top)
Text("""
Please, ensure that you abide \
by the conditions of the Stable Diffusion license and do not expose \
unfiltered results to the public.
""")
.lineLimit(nil)
.padding(.all, 5)
Button {
Settings.shared.safetyCheckerDisclaimerShown = true
updateSafetyCheckerState()
} label: {
Text("I Accept").frame(maxWidth: 200)
}
.padding(.bottom)
}
.frame(minWidth: 400, idealWidth: 400, maxWidth: 400)
.fixedSize()
}
Divider()
StatusView(pipelineState: $pipelineState)
}
.padding()
.onAppear {
print(PipelineLoader.models)
modelDidChange(model: ModelInfo.from(modelVersion: model) ?? ModelInfo.v2Base)
}
}
}