Download: display progress, hide Generate button.

pull/12/head
Pedro Cuenca 2 years ago
parent 97935c125a
commit d1c63dafbc

@ -9,6 +9,13 @@ import Combine
import SwiftUI import SwiftUI
import CompactSlider import CompactSlider
enum PipelineState {
case downloading(Double)
case uncompressing
case loading
case ready
case failed(Error)
}
struct PromptView: View { struct PromptView: View {
@StateObject var context = DiffusionGlobals() @StateObject var context = DiffusionGlobals()
@ -24,9 +31,8 @@ struct PromptView: View {
@State private var seed = 386.0 @State private var seed = 386.0
// TODO: refactor download with similar code in Loading.swift (iOS) // TODO: refactor download with similar code in Loading.swift (iOS)
@State private var preparationPhase = "Downloading…"
@State private var downloadProgress: Double = 0
@State private var stateSubscriber: Cancellable? @State private var stateSubscriber: Cancellable?
@State private var pipelineState: PipelineState = .downloading(0)
func modelDidChange(model: ModelInfo) { func modelDidChange(model: ModelInfo) {
Task.init { Task.init {
@ -35,14 +41,11 @@ struct PromptView: View {
DispatchQueue.main.async { DispatchQueue.main.async {
switch state { switch state {
case .downloading(let progress): case .downloading(let progress):
preparationPhase = "Downloading" pipelineState = .downloading(progress)
downloadProgress = progress
case .uncompressing: case .uncompressing:
preparationPhase = "Uncompressing" pipelineState = .uncompressing
downloadProgress = 1
case .readyOnDisk: case .readyOnDisk:
preparationPhase = "Loading" pipelineState = .loading
downloadProgress = 1
default: default:
break break
} }
@ -50,9 +53,10 @@ struct PromptView: View {
} }
do { do {
context.pipeline = try await loader.prepare() context.pipeline = try await loader.prepare()
pipelineState = .ready
} catch { } catch {
// TODO: expose to user
print("Could not load model, error: \(error)") print("Could not load model, error: \(error)")
pipelineState = .failed(error)
} }
} }
} }
@ -129,17 +133,9 @@ struct PromptView: View {
Label("Random Seed", systemImage: "leaf").foregroundColor(.secondary) Label("Random Seed", systemImage: "leaf").foregroundColor(.secondary)
} }
} }
} }
Button {
// Generate image here StatusView(pipelineState: $pipelineState)
} label: {
Text("Generate")
.frame(maxWidth: .infinity)
.frame(height: 50)
}
.buttonStyle(.borderedProminent)
// StatusView()
} }
.padding() .padding()
.onAppear { .onAppear {

@ -8,13 +8,33 @@
import SwiftUI import SwiftUI
struct StatusView: View { struct StatusView: View {
var pipelineState: Binding<PipelineState>
var body: some View { var body: some View {
Text(/*@START_MENU_TOKEN@*/"Hello, World!"/*@END_MENU_TOKEN@*/) switch pipelineState.wrappedValue {
case .downloading(let progress):
ProgressView("Downloading…", value: progress*100, total: 110).padding()
case .uncompressing:
ProgressView("Uncompressing…", value: 100, total: 110).padding()
case .loading:
ProgressView("Loading…", value: 105, total: 110).padding()
case .ready:
Button {
// Generate image here
} label: {
Text("Generate")
.frame(maxWidth: .infinity)
.frame(height: 50)
}
.buttonStyle(.borderedProminent)
case .failed:
Text("Pipeline loading error")
}
} }
} }
struct StatusView_Previews: PreviewProvider { struct StatusView_Previews: PreviewProvider {
static var previews: some View { static var previews: some View {
StatusView() StatusView(pipelineState: .constant(.downloading(0.2)))
} }
} }

Loading…
Cancel
Save