From a0297f8c3a74b90bd2b401c3710bd5391f5ca9d4 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 7 Feb 2023 15:00:32 +0100 Subject: [PATCH] Cancel downloads when selecting new model. I had to make Downloader explicitly cancellable because it waits forever for the semaphore to toggle, so Task cancellation does not work here. Cancellation is therefore exposed through PipelineLoader. --- Diffusion-macOS/ControlsView.swift | 4 ++++ Diffusion/Downloader.swift | 12 +++++++++--- Diffusion/Pipeline/PipelineLoader.swift | 9 +++++++-- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/Diffusion-macOS/ControlsView.swift b/Diffusion-macOS/ControlsView.swift index 46509fa..f56b1fc 100644 --- a/Diffusion-macOS/ControlsView.swift +++ b/Diffusion-macOS/ControlsView.swift @@ -61,6 +61,7 @@ struct ControlsView: View { // TODO: refactor download with similar code in Loading.swift (iOS) @State private var stateSubscriber: Cancellable? @State private var pipelineState: PipelineState = .downloading(0) + @State private var pipelineLoader: PipelineLoader? = nil // TODO: make this computed, and observable, and easy to read @State private var mustShowSafetyCheckerDisclaimer = false @@ -73,13 +74,16 @@ struct ControlsView: View { print("Loading model \(model)") Settings.shared.currentModel = model + pipelineLoader?.cancel() pipelineState = .downloading(0) Task.init { let loader = PipelineLoader(model: model) + self.pipelineLoader = loader stateSubscriber = loader.statePublisher.sink { state in DispatchQueue.main.async { switch state { case .downloading(let progress): + print("\(loader.model.modelVersion): \(progress)") pipelineState = .downloading(progress) case .uncompressing: pipelineState = .uncompressing diff --git a/Diffusion/Downloader.swift b/Diffusion/Downloader.swift index 24a85a5..847e34a 100644 --- a/Diffusion/Downloader.swift +++ b/Diffusion/Downloader.swift @@ -23,19 +23,21 @@ class Downloader: NSObject, ObservableObject { private(set) lazy var downloadState: CurrentValueSubject = CurrentValueSubject(.notStarted) private var stateSubscriber: Cancellable? + private var urlSession: URLSession? = nil + init(from url: URL, to destination: URL) { self.destination = destination super.init() // .background allows downloads to proceed in the background let config = URLSessionConfiguration.background(withIdentifier: "net.pcuenca.diffusion.download") - let urlSession = URLSession(configuration: config, delegate: self, delegateQueue: OperationQueue()) + urlSession = URLSession(configuration: config, delegate: self, delegateQueue: OperationQueue()) downloadState.value = .downloading(0) - urlSession.getAllTasks { tasks in + urlSession?.getAllTasks { tasks in // If there's an existing pending background task, let it proceed, otherwise start a new one. // TODO: check URL when we support downloading more models. if tasks.first == nil { - urlSession.downloadTask(with: url).resume() + self.urlSession?.downloadTask(with: url).resume() } } } @@ -59,6 +61,10 @@ class Downloader: NSObject, ObservableObject { default: throw("Should never happen, lol") } } + + func cancel() { + urlSession?.invalidateAndCancel() + } } extension Downloader: URLSessionDelegate, URLSessionDownloadDelegate { diff --git a/Diffusion/Pipeline/PipelineLoader.swift b/Diffusion/Pipeline/PipelineLoader.swift index 85fb517..b0dfb50 100644 --- a/Diffusion/Pipeline/PipelineLoader.swift +++ b/Diffusion/Pipeline/PipelineLoader.swift @@ -43,7 +43,8 @@ class PipelineLoader { } } private(set) lazy var statePublisher: CurrentValueSubject = CurrentValueSubject(state) - + private(set) var downloader: Downloader? = nil + func setInitialState() { if ready { state = .readyOnDisk @@ -63,6 +64,10 @@ extension PipelineLoader { } } +extension PipelineLoader { + func cancel() { downloader?.cancel() } +} + extension PipelineLoader { var url: URL { return model.bestURL @@ -107,9 +112,9 @@ extension PipelineLoader { if ready || downloaded { return downloadedURL } let downloader = Downloader(from: url, to: downloadedURL) + self.downloader = downloader downloadSubscriber = downloader.downloadState.sink { state in if case .downloading(let progress) = state { - print(progress) self.state = .downloading(progress) } }