You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
91 lines
2.9 KiB
91 lines
2.9 KiB
2 years ago
|
//
|
||
|
// StatusView.swift
|
||
|
// Diffusion-macOS
|
||
|
//
|
||
|
// Created by Cyril Zakka on 1/12/23.
|
||
|
// See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE
|
||
|
//
|
||
|
|
||
|
import SwiftUI
|
||
|
|
||
|
struct StatusView: View {
|
||
|
@EnvironmentObject var generation: GenerationContext
|
||
|
var pipelineState: Binding<PipelineState>
|
||
|
|
||
|
func submit() {
|
||
|
if case .running = generation.state { return }
|
||
|
Task {
|
||
|
generation.state = .running(nil)
|
||
|
let interval: TimeInterval?
|
||
|
let image: CGImage?
|
||
|
(image, interval) = await generation.generate() ?? (nil, nil)
|
||
|
generation.state = .complete(generation.positivePrompt, image, interval)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func generationStatusView() -> any View {
|
||
|
switch generation.state {
|
||
|
case .startup: return EmptyView()
|
||
|
case .running(let progress):
|
||
|
guard let progress = progress, progress.stepCount > 0 else {
|
||
|
// The first time it takes a little bit before generation starts
|
||
|
return HStack {
|
||
|
Text("Preparing model…")
|
||
|
Spacer()
|
||
|
}
|
||
|
}
|
||
|
let step = Int(progress.step) + 1
|
||
|
let fraction = Double(step) / Double(progress.stepCount)
|
||
|
return HStack {
|
||
|
Text("Generating \(Int(round(100*fraction)))%")
|
||
|
Spacer()
|
||
|
}
|
||
|
case .complete(_, let image, let interval):
|
||
|
guard let _ = image else {
|
||
|
return HStack {
|
||
|
Text("Safety checker triggered, please try a different prompt or seed")
|
||
|
Spacer()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return HStack {
|
||
|
let intervalString = String(format: "Time: %.1fs", interval ?? 0)
|
||
|
Text(intervalString)
|
||
|
Spacer()
|
||
|
}.frame(maxHeight: 25)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var body: some View {
|
||
|
switch pipelineState.wrappedValue {
|
||
|
case .downloading(let progress):
|
||
|
ProgressView("Downloading…", value: progress*100, total: 110).padding()
|
||
|
case .uncompressing:
|
||
|
ProgressView("Uncompressing…", value: 100, total: 110).padding()
|
||
|
case .loading:
|
||
|
ProgressView("Loading…", value: 105, total: 110).padding()
|
||
|
case .ready:
|
||
|
VStack {
|
||
|
Button {
|
||
|
submit()
|
||
|
} label: {
|
||
|
Text("Generate")
|
||
|
.frame(maxWidth: .infinity)
|
||
|
.frame(height: 50)
|
||
|
}
|
||
|
.buttonStyle(.borderedProminent)
|
||
|
|
||
|
AnyView(generationStatusView())
|
||
|
}
|
||
|
case .failed:
|
||
|
Text("Pipeline loading error")
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
struct StatusView_Previews: PreviewProvider {
|
||
|
static var previews: some View {
|
||
|
StatusView(pipelineState: .constant(.downloading(0.2)))
|
||
|
}
|
||
|
}
|