macOS: download model on startup.

pull/12/head
Pedro Cuenca 2 years ago
parent a51e3d8cc5
commit 97935c125a

@ -6,5 +6,7 @@
<true/> <true/>
<key>com.apple.security.files.user-selected.read-only</key> <key>com.apple.security.files.user-selected.read-only</key>
<true/> <true/>
<key>com.apple.security.network.client</key>
<true/>
</dict> </dict>
</plist> </plist>

@ -5,22 +5,58 @@
// Created by Cyril Zakka on 1/12/23. // Created by Cyril Zakka on 1/12/23.
// //
import Combine
import SwiftUI import SwiftUI
import CompactSlider import CompactSlider
struct PromptView: View { struct PromptView: View {
@StateObject var context = DiffusionGlobals()
static let models = ModelInfo.MODELS static let models = ModelInfo.MODELS
static let modelNames = models.map { $0.modelVersion } static let modelNames = models.map { $0.modelVersion }
@State private var model = ModelInfo.v2Base.modelVersion @State private var model = Settings.shared.currentModel.modelVersion
@State private var positivePrompt = "" @State private var positivePrompt = ""
@State private var negativePrompt = "" @State private var negativePrompt = ""
@State private var steps = 50.0 @State private var steps = 50.0
@State private var numImages = 1.0 @State private var numImages = 1.0
@State private var seed = 386.0 @State private var seed = 386.0
// TODO: refactor download with similar code in Loading.swift (iOS)
@State private var preparationPhase = "Downloading…"
@State private var downloadProgress: Double = 0
@State private var stateSubscriber: Cancellable?
func modelDidChange(model: ModelInfo) {
Task.init {
let loader = PipelineLoader(model: model)
stateSubscriber = loader.statePublisher.sink { state in
DispatchQueue.main.async {
switch state {
case .downloading(let progress):
preparationPhase = "Downloading"
downloadProgress = progress
case .uncompressing:
preparationPhase = "Uncompressing"
downloadProgress = 1
case .readyOnDisk:
preparationPhase = "Loading"
downloadProgress = 1
default:
break
}
}
}
do {
context.pipeline = try await loader.prepare()
} catch {
// TODO: expose to user
print("Could not load model, error: \(error)")
}
}
}
var body: some View { var body: some View {
VStack(alignment: .leading) { VStack(alignment: .leading) {
@ -37,6 +73,9 @@ struct PromptView: View {
Text($0) Text($0)
} }
} }
.onChange(of: model) { theModel in
print("Model changed to \(theModel) model: \(model)")
}
} label: { } label: {
Label("Model", systemImage: "cpu").foregroundColor(.secondary) Label("Model", systemImage: "cpu").foregroundColor(.secondary)
} }
@ -101,7 +140,11 @@ struct PromptView: View {
} }
.buttonStyle(.borderedProminent) .buttonStyle(.borderedProminent)
// StatusView() // StatusView()
}.padding() }
.padding()
.onAppear {
modelDidChange(model: ModelInfo.from(modelVersion: model) ?? ModelInfo.v2Base)
}
} }
} }

@ -21,6 +21,8 @@
EBDD7DB42973200200C1C4B2 /* Utils.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBDD7DB22973200200C1C4B2 /* Utils.swift */; }; EBDD7DB42973200200C1C4B2 /* Utils.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBDD7DB22973200200C1C4B2 /* Utils.swift */; };
EBDD7DB52973201800C1C4B2 /* ModelInfo.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE3FF4B295E1EFE00E921AA /* ModelInfo.swift */; }; EBDD7DB52973201800C1C4B2 /* ModelInfo.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE3FF4B295E1EFE00E921AA /* ModelInfo.swift */; };
EBDD7DB62973206600C1C4B2 /* Downloader.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBB5BA5929426E06003A2A5B /* Downloader.swift */; }; EBDD7DB62973206600C1C4B2 /* Downloader.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBB5BA5929426E06003A2A5B /* Downloader.swift */; };
EBDD7DB82976AAFE00C1C4B2 /* State.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBDD7DB72976AAFE00C1C4B2 /* State.swift */; };
EBDD7DB92976AAFE00C1C4B2 /* State.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBDD7DB72976AAFE00C1C4B2 /* State.swift */; };
EBE3FF4C295E1EFE00E921AA /* ModelInfo.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE3FF4B295E1EFE00E921AA /* ModelInfo.swift */; }; EBE3FF4C295E1EFE00E921AA /* ModelInfo.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE3FF4B295E1EFE00E921AA /* ModelInfo.swift */; };
EBE755C9293E37DD00806B32 /* DiffusionApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE755C8293E37DD00806B32 /* DiffusionApp.swift */; }; EBE755C9293E37DD00806B32 /* DiffusionApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE755C8293E37DD00806B32 /* DiffusionApp.swift */; };
EBE755CB293E37DD00806B32 /* TextToImage.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE755CA293E37DD00806B32 /* TextToImage.swift */; }; EBE755CB293E37DD00806B32 /* TextToImage.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE755CA293E37DD00806B32 /* TextToImage.swift */; };
@ -63,6 +65,7 @@
EBB5BA5229425BEE003A2A5B /* PipelineLoader.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PipelineLoader.swift; sourceTree = "<group>"; }; EBB5BA5229425BEE003A2A5B /* PipelineLoader.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PipelineLoader.swift; sourceTree = "<group>"; };
EBB5BA5929426E06003A2A5B /* Downloader.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Downloader.swift; sourceTree = "<group>"; }; EBB5BA5929426E06003A2A5B /* Downloader.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Downloader.swift; sourceTree = "<group>"; };
EBDD7DB22973200200C1C4B2 /* Utils.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Utils.swift; sourceTree = "<group>"; }; EBDD7DB22973200200C1C4B2 /* Utils.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Utils.swift; sourceTree = "<group>"; };
EBDD7DB72976AAFE00C1C4B2 /* State.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = State.swift; sourceTree = "<group>"; };
EBE3FF4A295DFE2400E921AA /* common.xcconfig */ = {isa = PBXFileReference; lastKnownFileType = text.xcconfig; name = common.xcconfig; path = config/common.xcconfig; sourceTree = "<group>"; }; EBE3FF4A295DFE2400E921AA /* common.xcconfig */ = {isa = PBXFileReference; lastKnownFileType = text.xcconfig; name = common.xcconfig; path = config/common.xcconfig; sourceTree = "<group>"; };
EBE3FF4B295E1EFE00E921AA /* ModelInfo.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelInfo.swift; sourceTree = "<group>"; }; EBE3FF4B295E1EFE00E921AA /* ModelInfo.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelInfo.swift; sourceTree = "<group>"; };
EBE4438729488DCA00CDA605 /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = "<group>"; }; EBE4438729488DCA00CDA605 /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = "<group>"; };
@ -175,6 +178,7 @@
EBE755C8293E37DD00806B32 /* DiffusionApp.swift */, EBE755C8293E37DD00806B32 /* DiffusionApp.swift */,
EBDD7DB22973200200C1C4B2 /* Utils.swift */, EBDD7DB22973200200C1C4B2 /* Utils.swift */,
EBE3FF4B295E1EFE00E921AA /* ModelInfo.swift */, EBE3FF4B295E1EFE00E921AA /* ModelInfo.swift */,
EBDD7DB72976AAFE00C1C4B2 /* State.swift */,
EBE755CC293E37DD00806B32 /* Assets.xcassets */, EBE755CC293E37DD00806B32 /* Assets.xcassets */,
EBE755CE293E37DD00806B32 /* Diffusion.entitlements */, EBE755CE293E37DD00806B32 /* Diffusion.entitlements */,
EBE755CF293E37DD00806B32 /* Preview Content */, EBE755CF293E37DD00806B32 /* Preview Content */,
@ -435,6 +439,7 @@
EBB5BA5A29426E06003A2A5B /* Downloader.swift in Sources */, EBB5BA5A29426E06003A2A5B /* Downloader.swift in Sources */,
EBE3FF4C295E1EFE00E921AA /* ModelInfo.swift in Sources */, EBE3FF4C295E1EFE00E921AA /* ModelInfo.swift in Sources */,
EBE756092941178600806B32 /* Loading.swift in Sources */, EBE756092941178600806B32 /* Loading.swift in Sources */,
EBDD7DB82976AAFE00C1C4B2 /* State.swift in Sources */,
EBB5BA5329425BEE003A2A5B /* PipelineLoader.swift in Sources */, EBB5BA5329425BEE003A2A5B /* PipelineLoader.swift in Sources */,
EBE755C9293E37DD00806B32 /* DiffusionApp.swift in Sources */, EBE755C9293E37DD00806B32 /* DiffusionApp.swift in Sources */,
EBDD7DB32973200200C1C4B2 /* Utils.swift in Sources */, EBDD7DB32973200200C1C4B2 /* Utils.swift in Sources */,
@ -465,6 +470,7 @@
EBDD7DAB29731F7500C1C4B2 /* PipelineLoader.swift in Sources */, EBDD7DAB29731F7500C1C4B2 /* PipelineLoader.swift in Sources */,
EBDD7DAA29731F6C00C1C4B2 /* Pipeline.swift in Sources */, EBDD7DAA29731F6C00C1C4B2 /* Pipeline.swift in Sources */,
F15520262971093300DC009B /* ContentView.swift in Sources */, F15520262971093300DC009B /* ContentView.swift in Sources */,
EBDD7DB92976AAFE00C1C4B2 /* State.swift in Sources */,
EBDD7DB42973200200C1C4B2 /* Utils.swift in Sources */, EBDD7DB42973200200C1C4B2 /* Utils.swift in Sources */,
F1552031297109C300DC009B /* PromptView.swift in Sources */, F1552031297109C300DC009B /* PromptView.swift in Sources */,
EBDD7DB62973206600C1C4B2 /* Downloader.swift in Sources */, EBDD7DB62973206600C1C4B2 /* Downloader.swift in Sources */,

@ -85,4 +85,12 @@ extension ModelInfo {
ModelInfo.v2Base, ModelInfo.v2Base,
ModelInfo.v21Base ModelInfo.v21Base
] ]
static func from(modelVersion: String) -> ModelInfo? {
ModelInfo.MODELS.first(where: {$0.modelVersion == modelVersion})
}
static func from(modelId: String) -> ModelInfo? {
ModelInfo.MODELS.first(where: {$0.modelId == modelId})
}
} }

@ -0,0 +1,40 @@
//
// State.swift
// Diffusion
//
// Created by Pedro Cuenca on 17/1/23.
//
import SwiftUI
let DEFAULT_MODEL = ModelInfo.v2Base
class DiffusionGlobals: ObservableObject {
@Published var pipeline: Pipeline? = nil
}
class Settings {
static let shared = Settings()
let defaults = UserDefaults.standard
enum Keys: String {
case model
}
private init() {
defaults.register(defaults: [
Keys.model.rawValue: ModelInfo.v2Base.modelId
])
}
var currentModel: ModelInfo {
set {
defaults.set(newValue.modelId, forKey: Keys.model.rawValue)
}
get {
guard let modelId = defaults.string(forKey: Keys.model.rawValue) else { return DEFAULT_MODEL }
return ModelInfo.from(modelId: modelId) ?? DEFAULT_MODEL
}
}
}

@ -11,10 +11,6 @@ import Combine
let model = ModelInfo.v2Base let model = ModelInfo.v2Base
class DiffusionGlobals: ObservableObject {
@Published var pipeline: Pipeline? = nil
}
struct LoadingView: View { struct LoadingView: View {
@StateObject var context = DiffusionGlobals() @StateObject var context = DiffusionGlobals()

Loading…
Cancel
Save