|
|
@ -49,25 +49,19 @@ final class ImageGenerator: ObservableObject {
|
|
|
|
|
|
|
|
|
|
|
|
@Published var generationState: GenerationState = .idle
|
|
|
|
@Published var generationState: GenerationState = .idle
|
|
|
|
@Published var generatedImages: GeneratedImages?
|
|
|
|
@Published var generatedImages: GeneratedImages?
|
|
|
|
private let sdpipeline: StableDiffusionPipeline
|
|
|
|
private var sdPipeline: StableDiffusionPipeline?
|
|
|
|
|
|
|
|
|
|
|
|
init() {
|
|
|
|
init() {
|
|
|
|
guard let path = Bundle.main.path(forResource: "CoreMLModels", ofType: nil, inDirectory: nil) else {
|
|
|
|
|
|
|
|
fatalError("Fatal error: failed to find the CoreML models.")
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
let resourceURL = URL(fileURLWithPath: path)
|
|
|
|
|
|
|
|
// TODO: move the pipeline creation to background task because it's heavy
|
|
|
|
|
|
|
|
if let pipeline = try? StableDiffusionPipeline(resourcesAt: resourceURL) {
|
|
|
|
|
|
|
|
sdpipeline = pipeline
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
fatalError("Fatal error: failed to create the Stable-Diffusion-Pipeline.")
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func setState(_ state: GenerationState) { // for actor isolation
|
|
|
|
func setState(_ state: GenerationState) { // for actor isolation
|
|
|
|
generationState = state
|
|
|
|
generationState = state
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func setPipeline(_ pipeline: StableDiffusionPipeline) { // for actor isolation
|
|
|
|
|
|
|
|
sdPipeline = pipeline
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func setGeneratedImages(_ images: GeneratedImages) { // for actor isolation
|
|
|
|
func setGeneratedImages(_ images: GeneratedImages) { // for actor isolation
|
|
|
|
generatedImages = images
|
|
|
|
generatedImages = images
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -76,12 +70,29 @@ final class ImageGenerator: ObservableObject {
|
|
|
|
guard generationState == .idle else { return }
|
|
|
|
guard generationState == .idle else { return }
|
|
|
|
Task.detached(priority: .high) {
|
|
|
|
Task.detached(priority: .high) {
|
|
|
|
await self.setState(.generating(progressStep: 0))
|
|
|
|
await self.setState(.generating(progressStep: 0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if await self.sdPipeline == nil {
|
|
|
|
|
|
|
|
guard let path = Bundle.main.path(forResource: "CoreMLModels", ofType: nil, inDirectory: nil) else {
|
|
|
|
|
|
|
|
fatalError("Fatal error: failed to find the CoreML models.")
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
let resourceURL = URL(fileURLWithPath: path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if let pipeline = try? StableDiffusionPipeline(resourcesAt: resourceURL) {
|
|
|
|
|
|
|
|
await self.setPipeline(pipeline)
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
fatalError("Fatal error: failed to create the Stable-Diffusion-Pipeline.")
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if let sdPipeline = await self.sdPipeline {
|
|
|
|
do {
|
|
|
|
do {
|
|
|
|
|
|
|
|
// API:
|
|
|
|
// generateImages(prompt: String, imageCount: Int = 1, stepCount: Int = 50, seed: Int = 0,
|
|
|
|
// generateImages(prompt: String, imageCount: Int = 1, stepCount: Int = 50, seed: Int = 0,
|
|
|
|
// disableSafety: Bool = false,
|
|
|
|
// disableSafety: Bool = false,
|
|
|
|
// progressHandler: (StableDiffusionPipeline.Progress) -> Bool = { _ in true }) throws -> [CGImage?]
|
|
|
|
// progressHandler: (StableDiffusionPipeline.Progress) -> Bool = { _ in true })
|
|
|
|
|
|
|
|
// throws -> [CGImage?]
|
|
|
|
// TODO: use the progressHandler
|
|
|
|
// TODO: use the progressHandler
|
|
|
|
let cgImages = try self.sdpipeline.generateImages(prompt: parameter.prompt,
|
|
|
|
let cgImages = try sdPipeline.generateImages(prompt: parameter.prompt,
|
|
|
|
imageCount: parameter.imageCount,
|
|
|
|
imageCount: parameter.imageCount,
|
|
|
|
stepCount: parameter.stepCount,
|
|
|
|
stepCount: parameter.stepCount,
|
|
|
|
seed: parameter.seed,
|
|
|
|
seed: parameter.seed,
|
|
|
@ -98,8 +109,10 @@ final class ImageGenerator: ObservableObject {
|
|
|
|
disableSafety: parameter.disableSafety,
|
|
|
|
disableSafety: parameter.disableSafety,
|
|
|
|
images: uiImages.map { uiImage in GeneratedImage(uiImage: uiImage) }))
|
|
|
|
images: uiImages.map { uiImage in GeneratedImage(uiImage: uiImage) }))
|
|
|
|
} catch {
|
|
|
|
} catch {
|
|
|
|
print("failed.")
|
|
|
|
print("failed to generate images.")
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
await self.setState(.idle)
|
|
|
|
await self.setState(.idle)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|