macOS: download model on startup.

pull/12/head
Pedro Cuenca 1 year ago
parent a51e3d8cc5
commit 97935c125a

@ -2,9 +2,11 @@
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>com.apple.security.app-sandbox</key>
<true/>
<key>com.apple.security.files.user-selected.read-only</key>
<true/>
<key>com.apple.security.app-sandbox</key>
<true/>
<key>com.apple.security.files.user-selected.read-only</key>
<true/>
<key>com.apple.security.network.client</key>
<true/>
</dict>
</plist>

@ -5,22 +5,58 @@
// Created by Cyril Zakka on 1/12/23.
//
import Combine
import SwiftUI
import CompactSlider
struct PromptView: View {
@StateObject var context = DiffusionGlobals()
static let models = ModelInfo.MODELS
static let modelNames = models.map { $0.modelVersion }
@State private var model = ModelInfo.v2Base.modelVersion
@State private var model = Settings.shared.currentModel.modelVersion
@State private var positivePrompt = ""
@State private var negativePrompt = ""
@State private var steps = 50.0
@State private var numImages = 1.0
@State private var seed = 386.0
// TODO: refactor download with similar code in Loading.swift (iOS)
@State private var preparationPhase = "Downloading…"
@State private var downloadProgress: Double = 0
@State private var stateSubscriber: Cancellable?
func modelDidChange(model: ModelInfo) {
Task.init {
let loader = PipelineLoader(model: model)
stateSubscriber = loader.statePublisher.sink { state in
DispatchQueue.main.async {
switch state {
case .downloading(let progress):
preparationPhase = "Downloading"
downloadProgress = progress
case .uncompressing:
preparationPhase = "Uncompressing"
downloadProgress = 1
case .readyOnDisk:
preparationPhase = "Loading"
downloadProgress = 1
default:
break
}
}
}
do {
context.pipeline = try await loader.prepare()
} catch {
// TODO: expose to user
print("Could not load model, error: \(error)")
}
}
}
var body: some View {
VStack(alignment: .leading) {
@ -37,6 +73,9 @@ struct PromptView: View {
Text($0)
}
}
.onChange(of: model) { theModel in
print("Model changed to \(theModel) model: \(model)")
}
} label: {
Label("Model", systemImage: "cpu").foregroundColor(.secondary)
}
@ -58,7 +97,7 @@ struct PromptView: View {
}
Divider()
DisclosureGroup {
CompactSlider(value: $steps, in: 0...250, step: 5) {
Text("Steps")
@ -68,7 +107,7 @@ struct PromptView: View {
} label: {
Label("Step count", systemImage: "square.3.layers.3d.down.left").foregroundColor(.secondary)
}
Divider()
DisclosureGroup() {
CompactSlider(value: $numImages, in: 0...10, step: 1) {
@ -101,7 +140,11 @@ struct PromptView: View {
}
.buttonStyle(.borderedProminent)
// StatusView()
}.padding()
}
.padding()
.onAppear {
modelDidChange(model: ModelInfo.from(modelVersion: model) ?? ModelInfo.v2Base)
}
}
}

@ -21,6 +21,8 @@
EBDD7DB42973200200C1C4B2 /* Utils.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBDD7DB22973200200C1C4B2 /* Utils.swift */; };
EBDD7DB52973201800C1C4B2 /* ModelInfo.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE3FF4B295E1EFE00E921AA /* ModelInfo.swift */; };
EBDD7DB62973206600C1C4B2 /* Downloader.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBB5BA5929426E06003A2A5B /* Downloader.swift */; };
EBDD7DB82976AAFE00C1C4B2 /* State.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBDD7DB72976AAFE00C1C4B2 /* State.swift */; };
EBDD7DB92976AAFE00C1C4B2 /* State.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBDD7DB72976AAFE00C1C4B2 /* State.swift */; };
EBE3FF4C295E1EFE00E921AA /* ModelInfo.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE3FF4B295E1EFE00E921AA /* ModelInfo.swift */; };
EBE755C9293E37DD00806B32 /* DiffusionApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE755C8293E37DD00806B32 /* DiffusionApp.swift */; };
EBE755CB293E37DD00806B32 /* TextToImage.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE755CA293E37DD00806B32 /* TextToImage.swift */; };
@ -63,6 +65,7 @@
EBB5BA5229425BEE003A2A5B /* PipelineLoader.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PipelineLoader.swift; sourceTree = "<group>"; };
EBB5BA5929426E06003A2A5B /* Downloader.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Downloader.swift; sourceTree = "<group>"; };
EBDD7DB22973200200C1C4B2 /* Utils.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Utils.swift; sourceTree = "<group>"; };
EBDD7DB72976AAFE00C1C4B2 /* State.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = State.swift; sourceTree = "<group>"; };
EBE3FF4A295DFE2400E921AA /* common.xcconfig */ = {isa = PBXFileReference; lastKnownFileType = text.xcconfig; name = common.xcconfig; path = config/common.xcconfig; sourceTree = "<group>"; };
EBE3FF4B295E1EFE00E921AA /* ModelInfo.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelInfo.swift; sourceTree = "<group>"; };
EBE4438729488DCA00CDA605 /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = "<group>"; };
@ -175,6 +178,7 @@
EBE755C8293E37DD00806B32 /* DiffusionApp.swift */,
EBDD7DB22973200200C1C4B2 /* Utils.swift */,
EBE3FF4B295E1EFE00E921AA /* ModelInfo.swift */,
EBDD7DB72976AAFE00C1C4B2 /* State.swift */,
EBE755CC293E37DD00806B32 /* Assets.xcassets */,
EBE755CE293E37DD00806B32 /* Diffusion.entitlements */,
EBE755CF293E37DD00806B32 /* Preview Content */,
@ -435,6 +439,7 @@
EBB5BA5A29426E06003A2A5B /* Downloader.swift in Sources */,
EBE3FF4C295E1EFE00E921AA /* ModelInfo.swift in Sources */,
EBE756092941178600806B32 /* Loading.swift in Sources */,
EBDD7DB82976AAFE00C1C4B2 /* State.swift in Sources */,
EBB5BA5329425BEE003A2A5B /* PipelineLoader.swift in Sources */,
EBE755C9293E37DD00806B32 /* DiffusionApp.swift in Sources */,
EBDD7DB32973200200C1C4B2 /* Utils.swift in Sources */,
@ -465,6 +470,7 @@
EBDD7DAB29731F7500C1C4B2 /* PipelineLoader.swift in Sources */,
EBDD7DAA29731F6C00C1C4B2 /* Pipeline.swift in Sources */,
F15520262971093300DC009B /* ContentView.swift in Sources */,
EBDD7DB92976AAFE00C1C4B2 /* State.swift in Sources */,
EBDD7DB42973200200C1C4B2 /* Utils.swift in Sources */,
F1552031297109C300DC009B /* PromptView.swift in Sources */,
EBDD7DB62973206600C1C4B2 /* Downloader.swift in Sources */,

@ -85,4 +85,12 @@ extension ModelInfo {
ModelInfo.v2Base,
ModelInfo.v21Base
]
static func from(modelVersion: String) -> ModelInfo? {
ModelInfo.MODELS.first(where: {$0.modelVersion == modelVersion})
}
static func from(modelId: String) -> ModelInfo? {
ModelInfo.MODELS.first(where: {$0.modelId == modelId})
}
}

@ -0,0 +1,40 @@
//
// State.swift
// Diffusion
//
// Created by Pedro Cuenca on 17/1/23.
//
import SwiftUI
let DEFAULT_MODEL = ModelInfo.v2Base
class DiffusionGlobals: ObservableObject {
@Published var pipeline: Pipeline? = nil
}
class Settings {
static let shared = Settings()
let defaults = UserDefaults.standard
enum Keys: String {
case model
}
private init() {
defaults.register(defaults: [
Keys.model.rawValue: ModelInfo.v2Base.modelId
])
}
var currentModel: ModelInfo {
set {
defaults.set(newValue.modelId, forKey: Keys.model.rawValue)
}
get {
guard let modelId = defaults.string(forKey: Keys.model.rawValue) else { return DEFAULT_MODEL }
return ModelInfo.from(modelId: modelId) ?? DEFAULT_MODEL
}
}
}

@ -11,10 +11,6 @@ import Combine
let model = ModelInfo.v2Base
class DiffusionGlobals: ObservableObject {
@Published var pipeline: Pipeline? = nil
}
struct LoadingView: View {
@StateObject var context = DiffusionGlobals()

Loading…
Cancel
Save