Cancel downloads when selecting new model.

I had to make Downloader explicitly cancellable because it waits forever
for the semaphore to toggle, so Task cancellation does not work here.
Cancellation is therefore exposed through PipelineLoader.
pull/28/head
Pedro Cuenca 1 year ago
parent 7fe3c03448
commit a0297f8c3a

@ -61,6 +61,7 @@ struct ControlsView: View {
// TODO: refactor download with similar code in Loading.swift (iOS)
@State private var stateSubscriber: Cancellable?
@State private var pipelineState: PipelineState = .downloading(0)
@State private var pipelineLoader: PipelineLoader? = nil
// TODO: make this computed, and observable, and easy to read
@State private var mustShowSafetyCheckerDisclaimer = false
@ -73,13 +74,16 @@ struct ControlsView: View {
print("Loading model \(model)")
Settings.shared.currentModel = model
pipelineLoader?.cancel()
pipelineState = .downloading(0)
Task.init {
let loader = PipelineLoader(model: model)
self.pipelineLoader = loader
stateSubscriber = loader.statePublisher.sink { state in
DispatchQueue.main.async {
switch state {
case .downloading(let progress):
print("\(loader.model.modelVersion): \(progress)")
pipelineState = .downloading(progress)
case .uncompressing:
pipelineState = .uncompressing

@ -23,19 +23,21 @@ class Downloader: NSObject, ObservableObject {
private(set) lazy var downloadState: CurrentValueSubject<DownloadState, Never> = CurrentValueSubject(.notStarted)
private var stateSubscriber: Cancellable?
private var urlSession: URLSession? = nil
init(from url: URL, to destination: URL) {
self.destination = destination
super.init()
// .background allows downloads to proceed in the background
let config = URLSessionConfiguration.background(withIdentifier: "net.pcuenca.diffusion.download")
let urlSession = URLSession(configuration: config, delegate: self, delegateQueue: OperationQueue())
urlSession = URLSession(configuration: config, delegate: self, delegateQueue: OperationQueue())
downloadState.value = .downloading(0)
urlSession.getAllTasks { tasks in
urlSession?.getAllTasks { tasks in
// If there's an existing pending background task, let it proceed, otherwise start a new one.
// TODO: check URL when we support downloading more models.
if tasks.first == nil {
urlSession.downloadTask(with: url).resume()
self.urlSession?.downloadTask(with: url).resume()
}
}
}
@ -59,6 +61,10 @@ class Downloader: NSObject, ObservableObject {
default: throw("Should never happen, lol")
}
}
func cancel() {
urlSession?.invalidateAndCancel()
}
}
extension Downloader: URLSessionDelegate, URLSessionDownloadDelegate {

@ -43,7 +43,8 @@ class PipelineLoader {
}
}
private(set) lazy var statePublisher: CurrentValueSubject<PipelinePreparationPhase, Never> = CurrentValueSubject(state)
private(set) var downloader: Downloader? = nil
func setInitialState() {
if ready {
state = .readyOnDisk
@ -63,6 +64,10 @@ extension PipelineLoader {
}
}
extension PipelineLoader {
func cancel() { downloader?.cancel() }
}
extension PipelineLoader {
var url: URL {
return model.bestURL
@ -107,9 +112,9 @@ extension PipelineLoader {
if ready || downloaded { return downloadedURL }
let downloader = Downloader(from: url, to: downloadedURL)
self.downloader = downloader
downloadSubscriber = downloader.downloadState.sink { state in
if case .downloading(let progress) = state {
print(progress)
self.state = .downloading(progress)
}
}

Loading…
Cancel
Save