diff --git a/Diffusion-macOS/Diffusion_macOS.entitlements b/Diffusion-macOS/Diffusion_macOS.entitlements index f2ef3ae..625af03 100644 --- a/Diffusion-macOS/Diffusion_macOS.entitlements +++ b/Diffusion-macOS/Diffusion_macOS.entitlements @@ -2,9 +2,11 @@ - com.apple.security.app-sandbox - - com.apple.security.files.user-selected.read-only - + com.apple.security.app-sandbox + + com.apple.security.files.user-selected.read-only + + com.apple.security.network.client + diff --git a/Diffusion-macOS/PromptView.swift b/Diffusion-macOS/PromptView.swift index 0de2107..02a5d73 100644 --- a/Diffusion-macOS/PromptView.swift +++ b/Diffusion-macOS/PromptView.swift @@ -5,22 +5,58 @@ // Created by Cyril Zakka on 1/12/23. // +import Combine import SwiftUI import CompactSlider struct PromptView: View { - + @StateObject var context = DiffusionGlobals() + static let models = ModelInfo.MODELS static let modelNames = models.map { $0.modelVersion } - @State private var model = ModelInfo.v2Base.modelVersion + @State private var model = Settings.shared.currentModel.modelVersion @State private var positivePrompt = "" @State private var negativePrompt = "" @State private var steps = 50.0 @State private var numImages = 1.0 @State private var seed = 386.0 + // TODO: refactor download with similar code in Loading.swift (iOS) + @State private var preparationPhase = "Downloading…" + @State private var downloadProgress: Double = 0 + @State private var stateSubscriber: Cancellable? + + func modelDidChange(model: ModelInfo) { + Task.init { + let loader = PipelineLoader(model: model) + stateSubscriber = loader.statePublisher.sink { state in + DispatchQueue.main.async { + switch state { + case .downloading(let progress): + preparationPhase = "Downloading" + downloadProgress = progress + case .uncompressing: + preparationPhase = "Uncompressing" + downloadProgress = 1 + case .readyOnDisk: + preparationPhase = "Loading" + downloadProgress = 1 + default: + break + } + } + } + do { + context.pipeline = try await loader.prepare() + } catch { + // TODO: expose to user + print("Could not load model, error: \(error)") + } + } + } + var body: some View { VStack(alignment: .leading) { @@ -37,6 +73,9 @@ struct PromptView: View { Text($0) } } + .onChange(of: model) { theModel in + print("Model changed to \(theModel) model: \(model)") + } } label: { Label("Model", systemImage: "cpu").foregroundColor(.secondary) } @@ -58,7 +97,7 @@ struct PromptView: View { } Divider() - + DisclosureGroup { CompactSlider(value: $steps, in: 0...250, step: 5) { Text("Steps") @@ -68,7 +107,7 @@ struct PromptView: View { } label: { Label("Step count", systemImage: "square.3.layers.3d.down.left").foregroundColor(.secondary) } - + Divider() DisclosureGroup() { CompactSlider(value: $numImages, in: 0...10, step: 1) { @@ -101,7 +140,11 @@ struct PromptView: View { } .buttonStyle(.borderedProminent) // StatusView() - }.padding() + } + .padding() + .onAppear { + modelDidChange(model: ModelInfo.from(modelVersion: model) ?? ModelInfo.v2Base) + } } } diff --git a/Diffusion.xcodeproj/project.pbxproj b/Diffusion.xcodeproj/project.pbxproj index 7c136a3..d91a674 100644 --- a/Diffusion.xcodeproj/project.pbxproj +++ b/Diffusion.xcodeproj/project.pbxproj @@ -21,6 +21,8 @@ EBDD7DB42973200200C1C4B2 /* Utils.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBDD7DB22973200200C1C4B2 /* Utils.swift */; }; EBDD7DB52973201800C1C4B2 /* ModelInfo.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE3FF4B295E1EFE00E921AA /* ModelInfo.swift */; }; EBDD7DB62973206600C1C4B2 /* Downloader.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBB5BA5929426E06003A2A5B /* Downloader.swift */; }; + EBDD7DB82976AAFE00C1C4B2 /* State.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBDD7DB72976AAFE00C1C4B2 /* State.swift */; }; + EBDD7DB92976AAFE00C1C4B2 /* State.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBDD7DB72976AAFE00C1C4B2 /* State.swift */; }; EBE3FF4C295E1EFE00E921AA /* ModelInfo.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE3FF4B295E1EFE00E921AA /* ModelInfo.swift */; }; EBE755C9293E37DD00806B32 /* DiffusionApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE755C8293E37DD00806B32 /* DiffusionApp.swift */; }; EBE755CB293E37DD00806B32 /* TextToImage.swift in Sources */ = {isa = PBXBuildFile; fileRef = EBE755CA293E37DD00806B32 /* TextToImage.swift */; }; @@ -63,6 +65,7 @@ EBB5BA5229425BEE003A2A5B /* PipelineLoader.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PipelineLoader.swift; sourceTree = ""; }; EBB5BA5929426E06003A2A5B /* Downloader.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Downloader.swift; sourceTree = ""; }; EBDD7DB22973200200C1C4B2 /* Utils.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Utils.swift; sourceTree = ""; }; + EBDD7DB72976AAFE00C1C4B2 /* State.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = State.swift; sourceTree = ""; }; EBE3FF4A295DFE2400E921AA /* common.xcconfig */ = {isa = PBXFileReference; lastKnownFileType = text.xcconfig; name = common.xcconfig; path = config/common.xcconfig; sourceTree = ""; }; EBE3FF4B295E1EFE00E921AA /* ModelInfo.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ModelInfo.swift; sourceTree = ""; }; EBE4438729488DCA00CDA605 /* README.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = ""; }; @@ -175,6 +178,7 @@ EBE755C8293E37DD00806B32 /* DiffusionApp.swift */, EBDD7DB22973200200C1C4B2 /* Utils.swift */, EBE3FF4B295E1EFE00E921AA /* ModelInfo.swift */, + EBDD7DB72976AAFE00C1C4B2 /* State.swift */, EBE755CC293E37DD00806B32 /* Assets.xcassets */, EBE755CE293E37DD00806B32 /* Diffusion.entitlements */, EBE755CF293E37DD00806B32 /* Preview Content */, @@ -435,6 +439,7 @@ EBB5BA5A29426E06003A2A5B /* Downloader.swift in Sources */, EBE3FF4C295E1EFE00E921AA /* ModelInfo.swift in Sources */, EBE756092941178600806B32 /* Loading.swift in Sources */, + EBDD7DB82976AAFE00C1C4B2 /* State.swift in Sources */, EBB5BA5329425BEE003A2A5B /* PipelineLoader.swift in Sources */, EBE755C9293E37DD00806B32 /* DiffusionApp.swift in Sources */, EBDD7DB32973200200C1C4B2 /* Utils.swift in Sources */, @@ -465,6 +470,7 @@ EBDD7DAB29731F7500C1C4B2 /* PipelineLoader.swift in Sources */, EBDD7DAA29731F6C00C1C4B2 /* Pipeline.swift in Sources */, F15520262971093300DC009B /* ContentView.swift in Sources */, + EBDD7DB92976AAFE00C1C4B2 /* State.swift in Sources */, EBDD7DB42973200200C1C4B2 /* Utils.swift in Sources */, F1552031297109C300DC009B /* PromptView.swift in Sources */, EBDD7DB62973206600C1C4B2 /* Downloader.swift in Sources */, diff --git a/Diffusion/ModelInfo.swift b/Diffusion/ModelInfo.swift index 0171892..6bd20ee 100644 --- a/Diffusion/ModelInfo.swift +++ b/Diffusion/ModelInfo.swift @@ -85,4 +85,12 @@ extension ModelInfo { ModelInfo.v2Base, ModelInfo.v21Base ] + + static func from(modelVersion: String) -> ModelInfo? { + ModelInfo.MODELS.first(where: {$0.modelVersion == modelVersion}) + } + + static func from(modelId: String) -> ModelInfo? { + ModelInfo.MODELS.first(where: {$0.modelId == modelId}) + } } diff --git a/Diffusion/State.swift b/Diffusion/State.swift new file mode 100644 index 0000000..b0a389b --- /dev/null +++ b/Diffusion/State.swift @@ -0,0 +1,40 @@ +// +// State.swift +// Diffusion +// +// Created by Pedro Cuenca on 17/1/23. +// + +import SwiftUI + +let DEFAULT_MODEL = ModelInfo.v2Base + +class DiffusionGlobals: ObservableObject { + @Published var pipeline: Pipeline? = nil +} + +class Settings { + static let shared = Settings() + + let defaults = UserDefaults.standard + + enum Keys: String { + case model + } + + private init() { + defaults.register(defaults: [ + Keys.model.rawValue: ModelInfo.v2Base.modelId + ]) + } + + var currentModel: ModelInfo { + set { + defaults.set(newValue.modelId, forKey: Keys.model.rawValue) + } + get { + guard let modelId = defaults.string(forKey: Keys.model.rawValue) else { return DEFAULT_MODEL } + return ModelInfo.from(modelId: modelId) ?? DEFAULT_MODEL + } + } +} diff --git a/Diffusion/Views/Loading.swift b/Diffusion/Views/Loading.swift index 58e997c..d077f40 100644 --- a/Diffusion/Views/Loading.swift +++ b/Diffusion/Views/Loading.swift @@ -11,10 +11,6 @@ import Combine let model = ModelInfo.v2Base -class DiffusionGlobals: ObservableObject { - @Published var pipeline: Pipeline? = nil -} - struct LoadingView: View { @StateObject var context = DiffusionGlobals()