@ -57,7 +57,6 @@ struct ControlsView: View {
@ State private var disclosedSteps = false
@ State private var disclosedSeed = false
@ State private var disclosedAdvanced = false
@ State private var useANE = ( Settings . shared . userSelectedAttentionVariant ? ? ModelInfo . defaultAttention ) = = . splitEinsum
// TODO: r e f a c t o r d o w n l o a d w i t h s i m i l a r c o d e i n L o a d i n g . s w i f t ( i O S )
@ State private var stateSubscriber : Cancellable ?
@ -82,14 +81,18 @@ struct ControlsView: View {
mustShowSafetyCheckerDisclaimer = generation . disableSafety && ! Settings . shared . safetyCheckerDisclaimerShown
}
func update ANE State( ) {
Settings . shared . userSelected AttentionVariant = useANE ? . splitEinsum : . original
func update ComputeUnits State( ) {
Settings . shared . userSelected ComputeUnits = generation . computeUnits
modelDidChange ( model : Settings . shared . currentModel )
}
func resetComputeUnitsState ( ) {
generation . computeUnits = Settings . shared . userSelectedComputeUnits ? ? ModelInfo . defaultComputeUnits
}
func modelDidChange ( model : ModelInfo ) {
guard pipelineLoader ? . model != model || pipelineLoader ? . variant != Settings . shared . userSelectedAttentionVariant else {
print ( " Reusing same model \( model ) with attention \( String ( describing : Settings . shared . userSelectedAttentionVariant ) ) " )
guard pipelineLoader ? . model != model || pipelineLoader ? . computeUnits != generation . computeUnits else {
print ( " Reusing same model \( model ) with units \( generation . computeUnits ) " )
return
}
@ -99,7 +102,7 @@ struct ControlsView: View {
pipelineLoader ? . cancel ( )
pipelineState = . downloading ( 0 )
Task . init {
let loader = PipelineLoader ( model : model , variant: Settings . shared . userSelectedAttentionVariant , maxSeed : maxSeed )
let loader = PipelineLoader ( model : model , computeUnits: generation . computeUnits , maxSeed : maxSeed )
self . pipelineLoader = loader
stateSubscriber = loader . statePublisher . sink { state in
DispatchQueue . main . async {
@ -128,8 +131,8 @@ struct ControlsView: View {
}
}
func isModelDownloaded ( _ model : ModelInfo , variant: AttentionVariant ? = nil ) -> Bool {
PipelineLoader ( model : model , variant: variant ? ? Settings . shared . userSelectedAttentionVariant ) . ready
func isModelDownloaded ( _ model : ModelInfo , computeUnits: ComputeUnits ? = nil ) -> Bool {
PipelineLoader ( model : model , computeUnits: computeUnits ? ? generation . computeUnits ) . ready
}
func modelLabel ( _ model : ModelInfo ) -> Text {
@ -301,21 +304,25 @@ struct ControlsView: View {
Divider ( )
DisclosureGroup ( isExpanded : $ disclosedAdvanced ) {
HStack {
Toggle ( " Use Neural Engine " , isOn : $ useANE ) . onChange ( of : useANE ) { value in
guard let currentModel = ModelInfo . from ( modelVersion : model ) else { return }
let variantDownloaded = isModelDownloaded ( currentModel , variant : useANE ? . splitEinsum : . original )
if variantDownloaded {
updateANEState ( )
} else {
mustShowModelDownloadDisclaimer . toggle ( )
}
}
. padding ( . leading , 10 )
Picker ( selection : $ generation . computeUnits , label : Text ( " Use " ) ) {
Text ( " GPU " ) . tag ( ComputeUnits . cpuAndGPU )
Text ( " Neural Engine " ) . tag ( ComputeUnits . cpuAndNeuralEngine )
Text ( " GPU and Neural Engine " ) . tag ( ComputeUnits . all )
} . pickerStyle ( . radioGroup ) . padding ( . leading )
Spacer ( )
}
. onChange ( of : generation . computeUnits ) { units in
guard let currentModel = ModelInfo . from ( modelVersion : model ) else { return }
let variantDownloaded = isModelDownloaded ( currentModel , computeUnits : units )
if variantDownloaded {
updateComputeUnitsState ( )
} else {
mustShowModelDownloadDisclaimer . toggle ( )
}
}
. alert ( " Download Required " , isPresented : $ mustShowModelDownloadDisclaimer , actions : {
Button ( " Cancel " , role : . destructive ) { useANE . toggle ( ) }
Button ( " Download " , role : . cancel ) { updateANEState ( ) }
Button ( " Cancel " , role : . destructive ) { resetComputeUnitsStat e( ) }
Button ( " Download " , role : . cancel ) { update ComputeUnits State( ) }
} , message : {
Text ( " This setting requires a new version of the selected model. " )
} )