diff --git a/Diffusion-macOS/Diffusion_macOS.entitlements b/Diffusion-macOS/Diffusion_macOS.entitlements
index f2ef3ae..625af03 100644
--- a/Diffusion-macOS/Diffusion_macOS.entitlements
+++ b/Diffusion-macOS/Diffusion_macOS.entitlements
@@ -2,9 +2,11 @@
- com.apple.security.app-sandbox
-
- com.apple.security.files.user-selected.read-only
-
+ com.apple.security.app-sandbox
+
+ com.apple.security.files.user-selected.read-only
+
+ com.apple.security.network.client
+
diff --git a/Diffusion-macOS/PromptView.swift b/Diffusion-macOS/PromptView.swift
index 0de2107..02a5d73 100644
--- a/Diffusion-macOS/PromptView.swift
+++ b/Diffusion-macOS/PromptView.swift
@@ -5,22 +5,58 @@
// Created by Cyril Zakka on 1/12/23.
//
+import Combine
import SwiftUI
import CompactSlider
struct PromptView: View {
-
+ @StateObject var context = DiffusionGlobals()
+
static let models = ModelInfo.MODELS
static let modelNames = models.map { $0.modelVersion }
- @State private var model = ModelInfo.v2Base.modelVersion
+ @State private var model = Settings.shared.currentModel.modelVersion
@State private var positivePrompt = ""
@State private var negativePrompt = ""
@State private var steps = 50.0
@State private var numImages = 1.0
@State private var seed = 386.0
+ // TODO: refactor download with similar code in Loading.swift (iOS)
+ @State private var preparationPhase = "Downloading…"
+ @State private var downloadProgress: Double = 0
+ @State private var stateSubscriber: Cancellable?
+
+ func modelDidChange(model: ModelInfo) {
+ Task.init {
+ let loader = PipelineLoader(model: model)
+ stateSubscriber = loader.statePublisher.sink { state in
+ DispatchQueue.main.async {
+ switch state {
+ case .downloading(let progress):
+ preparationPhase = "Downloading"
+ downloadProgress = progress
+ case .uncompressing:
+ preparationPhase = "Uncompressing"
+ downloadProgress = 1
+ case .readyOnDisk:
+ preparationPhase = "Loading"
+ downloadProgress = 1
+ default:
+ break
+ }
+ }
+ }
+ do {
+ context.pipeline = try await loader.prepare()
+ } catch {
+ // TODO: expose to user
+ print("Could not load model, error: \(error)")
+ }
+ }
+ }
+
var body: some View {
VStack(alignment: .leading) {
@@ -37,6 +73,9 @@ struct PromptView: View {
Text($0)
}
}
+ .onChange(of: model) { theModel in
+ print("Model changed to \(theModel) model: \(model)")
+ }
} label: {
Label("Model", systemImage: "cpu").foregroundColor(.secondary)
}
@@ -58,7 +97,7 @@ struct PromptView: View {
}
Divider()
-
+
DisclosureGroup {
CompactSlider(value: $steps, in: 0...250, step: 5) {
Text("Steps")
@@ -68,7 +107,7 @@ struct PromptView: View {
} label: {
Label("Step count", systemImage: "square.3.layers.3d.down.left").foregroundColor(.secondary)
}
-
+
Divider()
DisclosureGroup() {
CompactSlider(value: $numImages, in: 0...10, step: 1) {
@@ -101,7 +140,11 @@ struct PromptView: View {
}
.buttonStyle(.borderedProminent)
// StatusView()
- }.padding()
+ }
+ .padding()
+ .onAppear {
+ modelDidChange(model: ModelInfo.from(modelVersion: model) ?? ModelInfo.v2Base)
+ }
}
}
diff --git a/Diffusion.xcodeproj/project.pbxproj b/Diffusion.xcodeproj/project.pbxproj
index 7c136a3..d91a674 100644
--- a/Diffusion.xcodeproj/project.pbxproj
+++ b/Diffusion.xcodeproj/project.pbxproj
@@ -21,6 +21,8 @@
EBDD7DB42973200200C1C4B2 /* Utils.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBDD7DB22973200200C1C4B2 /* Utils.swift */; };
EBDD7DB52973201800C1C4B2 /* ModelInfo.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE3FF4B295E1EFE00E921AA /* ModelInfo.swift */; };
EBDD7DB62973206600C1C4B2 /* Downloader.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBB5BA5929426E06003A2A5B /* Downloader.swift */; };
+ EBDD7DB82976AAFE00C1C4B2 /* State.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBDD7DB72976AAFE00C1C4B2 /* State.swift */; };
+ EBDD7DB92976AAFE00C1C4B2 /* State.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBDD7DB72976AAFE00C1C4B2 /* State.swift */; };
EBE3FF4C295E1EFE00E921AA /* ModelInfo.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE3FF4B295E1EFE00E921AA /* ModelInfo.swift */; };
EBE755C9293E37DD00806B32 /* DiffusionApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE755C8293E37DD00806B32 /* DiffusionApp.swift */; };
EBE755CB293E37DD00806B32 /* TextToImage.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE755CA293E37DD00806B32 /* TextToImage.swift */; };
@@ -63,6 +65,7 @@
EBB5BA5229425BEE003A2A5B /* PipelineLoader.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PipelineLoader.swift; sourceTree = ""; };
EBB5BA5929426E06003A2A5B /* Downloader.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Downloader.swift; sourceTree = ""; };
EBDD7DB22973200200C1C4B2 /* Utils.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Utils.swift; sourceTree = ""; };
+ EBDD7DB72976AAFE00C1C4B2 /* State.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = State.swift; sourceTree = ""; };
EBE3FF4A295DFE2400E921AA /* common.xcconfig */ = {isa = PBXFileReference; lastKnownFileType = text.xcconfig; name = common.xcconfig; path = config/common.xcconfig; sourceTree = ""; };
EBE3FF4B295E1EFE00E921AA /* ModelInfo.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelInfo.swift; sourceTree = ""; };
EBE4438729488DCA00CDA605 /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; };
@@ -175,6 +178,7 @@
EBE755C8293E37DD00806B32 /* DiffusionApp.swift */,
EBDD7DB22973200200C1C4B2 /* Utils.swift */,
EBE3FF4B295E1EFE00E921AA /* ModelInfo.swift */,
+ EBDD7DB72976AAFE00C1C4B2 /* State.swift */,
EBE755CC293E37DD00806B32 /* Assets.xcassets */,
EBE755CE293E37DD00806B32 /* Diffusion.entitlements */,
EBE755CF293E37DD00806B32 /* Preview Content */,
@@ -435,6 +439,7 @@
EBB5BA5A29426E06003A2A5B /* Downloader.swift in Sources */,
EBE3FF4C295E1EFE00E921AA /* ModelInfo.swift in Sources */,
EBE756092941178600806B32 /* Loading.swift in Sources */,
+ EBDD7DB82976AAFE00C1C4B2 /* State.swift in Sources */,
EBB5BA5329425BEE003A2A5B /* PipelineLoader.swift in Sources */,
EBE755C9293E37DD00806B32 /* DiffusionApp.swift in Sources */,
EBDD7DB32973200200C1C4B2 /* Utils.swift in Sources */,
@@ -465,6 +470,7 @@
EBDD7DAB29731F7500C1C4B2 /* PipelineLoader.swift in Sources */,
EBDD7DAA29731F6C00C1C4B2 /* Pipeline.swift in Sources */,
F15520262971093300DC009B /* ContentView.swift in Sources */,
+ EBDD7DB92976AAFE00C1C4B2 /* State.swift in Sources */,
EBDD7DB42973200200C1C4B2 /* Utils.swift in Sources */,
F1552031297109C300DC009B /* PromptView.swift in Sources */,
EBDD7DB62973206600C1C4B2 /* Downloader.swift in Sources */,
diff --git a/Diffusion/ModelInfo.swift b/Diffusion/ModelInfo.swift
index 0171892..6bd20ee 100644
--- a/Diffusion/ModelInfo.swift
+++ b/Diffusion/ModelInfo.swift
@@ -85,4 +85,12 @@ extension ModelInfo {
ModelInfo.v2Base,
ModelInfo.v21Base
]
+
+ static func from(modelVersion: String) -> ModelInfo? {
+ ModelInfo.MODELS.first(where: {$0.modelVersion == modelVersion})
+ }
+
+ static func from(modelId: String) -> ModelInfo? {
+ ModelInfo.MODELS.first(where: {$0.modelId == modelId})
+ }
}
diff --git a/Diffusion/State.swift b/Diffusion/State.swift
new file mode 100644
index 0000000..b0a389b
--- /dev/null
+++ b/Diffusion/State.swift
@@ -0,0 +1,40 @@
+//
+// State.swift
+// Diffusion
+//
+// Created by Pedro Cuenca on 17/1/23.
+//
+
+import SwiftUI
+
+let DEFAULT_MODEL = ModelInfo.v2Base
+
+class DiffusionGlobals: ObservableObject {
+ @Published var pipeline: Pipeline? = nil
+}
+
+class Settings {
+ static let shared = Settings()
+
+ let defaults = UserDefaults.standard
+
+ enum Keys: String {
+ case model
+ }
+
+ private init() {
+ defaults.register(defaults: [
+ Keys.model.rawValue: ModelInfo.v2Base.modelId
+ ])
+ }
+
+ var currentModel: ModelInfo {
+ set {
+ defaults.set(newValue.modelId, forKey: Keys.model.rawValue)
+ }
+ get {
+ guard let modelId = defaults.string(forKey: Keys.model.rawValue) else { return DEFAULT_MODEL }
+ return ModelInfo.from(modelId: modelId) ?? DEFAULT_MODEL
+ }
+ }
+}
diff --git a/Diffusion/Views/Loading.swift b/Diffusion/Views/Loading.swift
index 58e997c..d077f40 100644
--- a/Diffusion/Views/Loading.swift
+++ b/Diffusion/Views/Loading.swift
@@ -11,10 +11,6 @@ import Combine
let model = ModelInfo.v2Base
-class DiffusionGlobals: ObservableObject {
- @Published var pipeline: Pipeline? = nil
-}
-
struct LoadingView: View {
@StateObject var context = DiffusionGlobals()