You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

124 lines
4.0 KiB

//
// ModelInfo.swift
// Diffusion
//
// Created by Pedro Cuenca on 29/12/22.
// See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE
//
import CoreML
enum AttentionVariant: String {
case original
case splitEinsum
}
struct ModelInfo {
/// Hugging Face model Id that contains .zip archives with compiled Core ML models
let modelId: String
/// Arbitrary string for presentation purposes. Something like "2.1-base"
let modelVersion: String
/// Suffix of the archive containing the ORIGINAL attention variant. Usually something like "original_compiled"
let originalAttentionSuffix: String
/// Suffix of the archive containing the SPLIT_EINSUM attention variant. Usually something like "split_einsum_compiled"
let splitAttentionSuffix: String
/// Whether the archive contains the VAE Encoder (for image to image tasks). Not yet in use.
let supportsEncoder: Bool
init(modelId: String, modelVersion: String, originalAttentionSuffix: String = "original_compiled", splitAttentionSuffix: String = "split_einsum_compiled", supportsEncoder: Bool = false) {
self.modelId = modelId
self.modelVersion = modelVersion
self.originalAttentionSuffix = originalAttentionSuffix
self.splitAttentionSuffix = splitAttentionSuffix
self.supportsEncoder = supportsEncoder
}
}
extension ModelInfo {
static var defaultAttention: AttentionVariant {
guard runningOnMac else { return .splitEinsum }
#if os(macOS)
guard Capabilities.hasANE else { return .original }
return Capabilities.performanceCores >= 8 ? .original : .splitEinsum
#else
return .splitEinsum
#endif
}
var bestAttention: AttentionVariant {
return ModelInfo.defaultAttention
}
func modelURL(for variant: AttentionVariant) -> URL {
// Pattern: https://huggingface.co/pcuenq/coreml-stable-diffusion/resolve/main/coreml-stable-diffusion-v1-5_original_compiled.zip
let suffix: String
switch variant {
case .original: suffix = originalAttentionSuffix
case .splitEinsum: suffix = splitAttentionSuffix
}
let repo = modelId.split(separator: "/").last!
return URL(string: "https://huggingface.co/\(modelId)/resolve/main/\(repo)_\(suffix).zip")!
}
/// Best variant for the current platform.
/// Currently using `split_einsum` for iOS and simple performance heuristics for macOS.
var bestURL: URL { modelURL(for: bestAttention) }
var reduceMemory: Bool {
return !runningOnMac
}
}
extension ModelInfo {
// TODO: repo does not exist yet
static let v14Base = ModelInfo(
modelId: "pcuenq/coreml-stable-diffusion-1-4",
modelVersion: "CompVis/stable-diffusion-v1-4"
)
static let v15Base = ModelInfo(
modelId: "pcuenq/coreml-stable-diffusion-v1-5",
modelVersion: "runwayml/stable-diffusion-v1-5"
)
static let v2Base = ModelInfo(
modelId: "pcuenq/coreml-stable-diffusion-2-base",
modelVersion: "stabilityai/stable-diffusion-2-base"
)
static let v21Base = ModelInfo(
modelId: "pcuenq/coreml-stable-diffusion-2-1-base",
modelVersion: "stabilityai/stable-diffusion-2-1-base",
supportsEncoder: true
)
static let ofaSmall = ModelInfo(
modelId: "pcuenq/coreml-small-stable-diffusion-v0",
modelVersion: "OFA-Sys/small-stable-diffusion-v0"
)
static let MODELS = [
ModelInfo.v14Base,
ModelInfo.v15Base,
ModelInfo.v2Base,
ModelInfo.v21Base,
ModelInfo.ofaSmall
]
static func from(modelVersion: String) -> ModelInfo? {
ModelInfo.MODELS.first(where: {$0.modelVersion == modelVersion})
}
static func from(modelId: String) -> ModelInfo? {
ModelInfo.MODELS.first(where: {$0.modelId == modelId})
}
}
extension ModelInfo : Equatable {
static func ==(lhs: ModelInfo, rhs: ModelInfo) -> Bool { lhs.modelId == rhs.modelId }
}