Co-authored-by: aseemw <aseem.elec@gmail.com> Co-authored-by: msiracusa <msiracusa+github@gmail.com>pull/31/head 0.1.0
@ -0,0 +1,144 @@
|
|||||||
|
# Swift Package
|
||||||
|
.DS_Store
|
||||||
|
/.build
|
||||||
|
/Packages
|
||||||
|
/*.xcodeproj
|
||||||
|
.swiftpm
|
||||||
|
.vscode
|
||||||
|
.*.sw?
|
||||||
|
*.docc-build
|
||||||
|
*.vs
|
||||||
|
Package.resolved
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# macOS filesystem
|
||||||
|
*.DS_Store
|
@ -0,0 +1,555 @@
|
|||||||
|
Acknowledgements
|
||||||
|
Portions of this software may utilize the following copyrighted
|
||||||
|
material, the use of which is hereby acknowledged.
|
||||||
|
|
||||||
|
_____________________
|
||||||
|
The Hugging Face team (diffusers)
|
||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
The Hugging Face team (transformers)
|
||||||
|
Copyright 2018- The Hugging Face team. All rights reserved.
|
||||||
|
|
||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
Facebook, Inc (PyTorch)
|
||||||
|
From PyTorch:
|
||||||
|
|
||||||
|
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
||||||
|
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
||||||
|
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
||||||
|
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
||||||
|
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
||||||
|
Copyright (c) 2011-2013 NYU (Clement Farabet)
|
||||||
|
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
||||||
|
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
||||||
|
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
||||||
|
|
||||||
|
From Caffe2:
|
||||||
|
|
||||||
|
Copyright (c) 2016-present, Facebook Inc. All rights reserved.
|
||||||
|
|
||||||
|
All contributions by Facebook:
|
||||||
|
Copyright (c) 2016 Facebook Inc.
|
||||||
|
|
||||||
|
All contributions by Google:
|
||||||
|
Copyright (c) 2015 Google Inc.
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
All contributions by Yangqing Jia:
|
||||||
|
Copyright (c) 2015 Yangqing Jia
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
All contributions by Kakao Brain:
|
||||||
|
Copyright 2019-2020 Kakao Brain
|
||||||
|
|
||||||
|
All contributions by Cruise LLC:
|
||||||
|
Copyright (c) 2022 Cruise LLC.
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
All contributions from Caffe:
|
||||||
|
Copyright(c) 2013, 2014, 2015, the respective contributors
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
All other contributions:
|
||||||
|
Copyright(c) 2015, 2016 the respective contributors
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
Caffe2 uses a copyright model similar to Caffe: each contributor holds
|
||||||
|
copyright over their contributions to Caffe2. The project versioning records
|
||||||
|
all such contribution and copyright details. If a contributor wants to further
|
||||||
|
mark their specific copyright on a particular contribution, they should
|
||||||
|
indicate their copyright solely in the commit message of the change when it is
|
||||||
|
committed.
|
||||||
|
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
1. Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
2. Redistributions in binary form must reproduce the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer in the
|
||||||
|
documentation and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
|
||||||
|
and IDIAP Research Institute nor the names of its contributors may be
|
||||||
|
used to endorse or promote products derived from this software without
|
||||||
|
specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||||
|
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
|
||||||
|
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||||
|
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||||
|
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||||
|
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||||
|
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||||
|
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||||
|
POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
NumPy (RandomKit 1.3)
|
||||||
|
|
||||||
|
Copyright (c) 2003-2005, Jean-Sebastien Roy (js@jeannot.org)
|
||||||
|
|
||||||
|
The rk_random and rk_seed functions algorithms and the original design of
|
||||||
|
the Mersenne Twister RNG:
|
||||||
|
|
||||||
|
Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura,
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions
|
||||||
|
are met:
|
||||||
|
|
||||||
|
1. Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
2. Redistributions in binary form must reproduce the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer in the
|
||||||
|
documentation and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
3. The names of its contributors may not be used to endorse or promote
|
||||||
|
products derived from this software without specific prior written
|
||||||
|
permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
|
||||||
|
CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
||||||
|
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
||||||
|
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
|
||||||
|
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||||
|
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||||
|
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||||
|
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
Original algorithm for the implementation of rk_interval function from
|
||||||
|
Richard J. Wagner's implementation of the Mersenne Twister RNG, optimised by
|
||||||
|
Magnus Jonsson.
|
||||||
|
|
||||||
|
Constants used in the rk_double implementation by Isaku Wada.
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
copy of this software and associated documentation files (the
|
||||||
|
"Software"), to deal in the Software without restriction, including
|
||||||
|
without limitation the rights to use, copy, modify, merge, publish,
|
||||||
|
distribute, sublicense, and/or sell copies of the Software, and to
|
||||||
|
permit persons to whom the Software is furnished to do so, subject to
|
||||||
|
the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included
|
||||||
|
in all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||||
|
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||||
|
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||||
|
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||||
|
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||||
|
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
@ -0,0 +1,71 @@
|
|||||||
|
# Code of Conduct
|
||||||
|
|
||||||
|
## Our Pledge
|
||||||
|
|
||||||
|
In the interest of fostering an open and welcoming environment, we as
|
||||||
|
contributors and maintainers pledge to making participation in our project and
|
||||||
|
our community a harassment-free experience for everyone, regardless of age, body
|
||||||
|
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
||||||
|
level of experience, education, socio-economic status, nationality, personal
|
||||||
|
appearance, race, religion, or sexual identity and orientation.
|
||||||
|
|
||||||
|
## Our Standards
|
||||||
|
|
||||||
|
Examples of behavior that contributes to creating a positive environment
|
||||||
|
include:
|
||||||
|
|
||||||
|
* Using welcoming and inclusive language
|
||||||
|
* Being respectful of differing viewpoints and experiences
|
||||||
|
* Gracefully accepting constructive criticism
|
||||||
|
* Focusing on what is best for the community
|
||||||
|
* Showing empathy towards other community members
|
||||||
|
|
||||||
|
Examples of unacceptable behavior by participants include:
|
||||||
|
|
||||||
|
* The use of sexualized language or imagery and unwelcome sexual attention or
|
||||||
|
advances
|
||||||
|
* Trolling, insulting/derogatory comments, and personal or political attacks
|
||||||
|
* Public or private harassment
|
||||||
|
* Publishing others' private information, such as a physical or electronic
|
||||||
|
address, without explicit permission
|
||||||
|
* Other conduct which could reasonably be considered inappropriate in a
|
||||||
|
professional setting
|
||||||
|
|
||||||
|
## Our Responsibilities
|
||||||
|
|
||||||
|
Project maintainers are responsible for clarifying the standards of acceptable
|
||||||
|
behavior and are expected to take appropriate and fair corrective action in
|
||||||
|
response to any instances of unacceptable behavior.
|
||||||
|
|
||||||
|
Project maintainers have the right and responsibility to remove, edit, or
|
||||||
|
reject comments, commits, code, wiki edits, issues, and other contributions
|
||||||
|
that are not aligned to this Code of Conduct, or to ban temporarily or
|
||||||
|
permanently any contributor for other behaviors that they deem inappropriate,
|
||||||
|
threatening, offensive, or harmful.
|
||||||
|
|
||||||
|
## Scope
|
||||||
|
|
||||||
|
This Code of Conduct applies within all project spaces, and it also applies when
|
||||||
|
an individual is representing the project or its community in public spaces.
|
||||||
|
Examples of representing a project or community include using an official
|
||||||
|
project e-mail address, posting via an official social media account, or acting
|
||||||
|
as an appointed representative at an online or offline event. Representation of
|
||||||
|
a project may be further defined and clarified by project maintainers.
|
||||||
|
|
||||||
|
## Enforcement
|
||||||
|
|
||||||
|
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||||
|
reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All
|
||||||
|
complaints will be reviewed and investigated and will result in a response that
|
||||||
|
is deemed necessary and appropriate to the circumstances. The project team is
|
||||||
|
obligated to maintain confidentiality with regard to the reporter of an incident.
|
||||||
|
Further details of specific enforcement policies may be posted separately.
|
||||||
|
|
||||||
|
Project maintainers who do not follow or enforce the Code of Conduct in good
|
||||||
|
faith may face temporary or permanent repercussions as determined by other
|
||||||
|
members of the project's leadership.
|
||||||
|
|
||||||
|
## Attribution
|
||||||
|
|
||||||
|
This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4,
|
||||||
|
available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html)
|
@ -0,0 +1,11 @@
|
|||||||
|
# Contribution Guide
|
||||||
|
|
||||||
|
Thanks for your interest in contributing. This project was released for system demonstration purposes and there are limited plans for future development of the repository.
|
||||||
|
|
||||||
|
While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged.
|
||||||
|
|
||||||
|
## Before you get started
|
||||||
|
|
||||||
|
By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE).
|
||||||
|
|
||||||
|
We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md).
|
@ -0,0 +1,39 @@
|
|||||||
|
Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
IMPORTANT: This Apple software is supplied to you by Apple
|
||||||
|
Inc. ("Apple") in consideration of your agreement to the following
|
||||||
|
terms, and your use, installation, modification or redistribution of
|
||||||
|
this Apple software constitutes acceptance of these terms. If you do
|
||||||
|
not agree with these terms, please do not use, install, modify or
|
||||||
|
redistribute this Apple software.
|
||||||
|
|
||||||
|
In consideration of your agreement to abide by the following terms, and
|
||||||
|
subject to these terms, Apple grants you a personal, non-exclusive
|
||||||
|
license, under Apple's copyrights in this original Apple software (the
|
||||||
|
"Apple Software"), to use, reproduce, modify and redistribute the Apple
|
||||||
|
Software, with or without modifications, in source and/or binary forms;
|
||||||
|
provided that if you redistribute the Apple Software in its entirety and
|
||||||
|
without modifications, you must retain this notice and the following
|
||||||
|
text and disclaimers in all such redistributions of the Apple Software.
|
||||||
|
Neither the name, trademarks, service marks or logos of Apple Inc. may
|
||||||
|
be used to endorse or promote products derived from the Apple Software
|
||||||
|
without specific prior written permission from Apple. Except as
|
||||||
|
expressly stated in this notice, no other rights or licenses, express or
|
||||||
|
implied, are granted by Apple herein, including but not limited to any
|
||||||
|
patent rights that may be infringed by your derivative works or by other
|
||||||
|
works in which the Apple Software may be incorporated.
|
||||||
|
|
||||||
|
The Apple Software is provided by Apple on an "AS IS" basis. APPLE
|
||||||
|
MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
|
||||||
|
THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
|
||||||
|
FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
|
||||||
|
OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
|
||||||
|
|
||||||
|
IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
|
||||||
|
OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||||
|
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||||
|
INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
|
||||||
|
MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
|
||||||
|
AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
|
||||||
|
STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
|
||||||
|
POSSIBILITY OF SUCH DAMAGE.
|
@ -0,0 +1,43 @@
|
|||||||
|
// swift-tools-version: 5.7
|
||||||
|
// The swift-tools-version declares the minimum version of Swift required to build this package.
|
||||||
|
|
||||||
|
import PackageDescription
|
||||||
|
|
||||||
|
let package = Package(
|
||||||
|
name: "stable-diffusion",
|
||||||
|
platforms: [
|
||||||
|
.macOS(.v13),
|
||||||
|
.iOS(.v16),
|
||||||
|
],
|
||||||
|
products: [
|
||||||
|
.library(
|
||||||
|
name: "StableDiffusion",
|
||||||
|
targets: ["StableDiffusion"]),
|
||||||
|
.executable(
|
||||||
|
name: "StableDiffusionSample",
|
||||||
|
targets: ["StableDiffusionCLI"])
|
||||||
|
],
|
||||||
|
dependencies: [
|
||||||
|
.package(url: "https://github.com/apple/swift-argument-parser.git", exact: "1.2.0")
|
||||||
|
],
|
||||||
|
targets: [
|
||||||
|
.target(
|
||||||
|
name: "StableDiffusion",
|
||||||
|
dependencies: [],
|
||||||
|
path: "swift/StableDiffusion"),
|
||||||
|
.executableTarget(
|
||||||
|
name: "StableDiffusionCLI",
|
||||||
|
dependencies: [
|
||||||
|
"StableDiffusion",
|
||||||
|
.product(name: "ArgumentParser", package: "swift-argument-parser")],
|
||||||
|
path: "swift/StableDiffusionCLI"),
|
||||||
|
.testTarget(
|
||||||
|
name: "StableDiffusionTests",
|
||||||
|
dependencies: ["StableDiffusion"],
|
||||||
|
path: "swift/StableDiffusionTests",
|
||||||
|
resources: [
|
||||||
|
.copy("Resources/vocab.json"),
|
||||||
|
.copy("Resources/merges.txt")
|
||||||
|
]),
|
||||||
|
]
|
||||||
|
)
|
@ -0,0 +1,305 @@
|
|||||||
|
# Core ML Stable Diffusion
|
||||||
|
|
||||||
|
Run Stable Diffusion on Apple Silicon with Core ML
|
||||||
|
|
||||||
|
<img src="assets/readme_reel.png">
|
||||||
|
|
||||||
|
This repository comprises:
|
||||||
|
|
||||||
|
- `python_coreml_stable_diffusion`, a Python package for converting PyTorch models to Core ML format and performing image generation with Hugging Face [diffusers](https://github.com/huggingface/diffusers) in Python
|
||||||
|
- `StableDiffusion`, a Swift package that developers can add to their Xcode projects as a dependency to deploy image generation capabilities in their apps. The Swift package relies on the Core ML model files generated by `python_coreml_stable_diffusion`
|
||||||
|
|
||||||
|
If you run into issues during installation or runtime, please refer to the [FAQ](#FAQ) section.
|
||||||
|
|
||||||
|
|
||||||
|
## <a name="example-results"></a> Example Results
|
||||||
|
|
||||||
|
There are numerous versions of Stable Diffusion available on the [Hugging Face Hub](https://huggingface.co/models?search=stable-diffusion). Here are example results from three of those models:
|
||||||
|
|
||||||
|
`--model-version` | [stabilityai/stable-diffusion-2-base](https://huggingface.co/stabilityai/stable-diffusion-2-base) | [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) | [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) |
|
||||||
|
:------:|:------:|:------:|:------:
|
||||||
|
Output | ![](assets/a_high_quality_photo_of_an_astronaut_riding_a_horse_in_space/randomSeed_11_computeUnit_CPU_AND_GPU_modelVersion_stabilityai_stable-diffusion-2-base.png) | ![](assets/a_high_quality_photo_of_an_astronaut_riding_a_horse_in_space/randomSeed_13_computeUnit_CPU_AND_NE_modelVersion_CompVis_stable-diffusion-v1-4.png) | ![](assets/a_high_quality_photo_of_an_astronaut_riding_a_horse_in_space/randomSeed_93_computeUnit_CPU_AND_NE_modelVersion_runwayml_stable-diffusion-v1-5.png)
|
||||||
|
M1 iPad Pro 8GB Latency (s) | 29 | 38 | 38 |
|
||||||
|
M1 MacBook Pro 16GB Latency (s) | 24 | 35 | 35 |
|
||||||
|
M2 MacBook Air 8GB Latency (s) | 18 | 23 | 23 |
|
||||||
|
|
||||||
|
Please see [Important Notes on Performance Benchmarks](#important-notes-on-performance-benchmarks) section for details.
|
||||||
|
|
||||||
|
|
||||||
|
## <a name="converting-models-to-coreml"></a> Converting Models to Core ML
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary> Click to expand </summary>
|
||||||
|
|
||||||
|
**Step 1:** Create a Python environment and install dependencies:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
conda create -n coreml_stable_diffusion python=3.8 -y
|
||||||
|
conda activate coreml_stable_diffusion
|
||||||
|
cd /path/to/cloned/ml-stable-diffusion/repository
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2:** Log in to or register for your [Hugging Face account](https://huggingface.co), generate a [User Access Token](https://huggingface.co/settings/tokens) and use this token to set up Hugging Face API access by running `huggingface-cli login` in a Terminal window.
|
||||||
|
|
||||||
|
**Step 3:** Navigate to the version of Stable Diffusion that you would like to use on [Hugging Face Hub](https://huggingface.co/models?search=stable-diffusion) and accept its Terms of Use. The default model version is [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4). The model version may be changed by the user as described in the next step.
|
||||||
|
|
||||||
|
**Step 4:** Execute the following command from the Terminal to generate Core ML model files (`.mlpackage`)
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python -m python_coreml_stable_diffusion.torch2coreml --convert-unet --convert-text-encoder --convert-vae-decoder --convert-safety-checker -o <output-mlpackages-directory>
|
||||||
|
```
|
||||||
|
|
||||||
|
**WARNING:** This command will download several GB worth of PyTorch checkpoints from Hugging Face.
|
||||||
|
|
||||||
|
This generally takes 15-20 minutes on an M1 MacBook Pro. Upon successful execution, the 4 neural network models that comprise Stable Diffusion will have been converted from PyTorch to Core ML (`.mlpackage`) and saved into the specified `<output-mlpackages-directory>`. Some additional notable arguments:
|
||||||
|
|
||||||
|
- `--model-version`: The model version defaults to [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4). Developers may specify other versions that are available on [Hugging Face Hub](https://huggingface.co/models?search=stable-diffusion), e.g. [stabilityai/stable-diffusion-2-base](https://huggingface.co/stabilityai/stable-diffusion-2-base) & [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5).
|
||||||
|
|
||||||
|
|
||||||
|
- `--bundle-resources-for-swift-cli`: Compiles all 4 models and bundles them along with necessary resources for text tokenization into `<output-mlpackages-directory>/Resources` which should provided as input to the Swift package. This flag is not necessary for the diffusers-based Python pipeline.
|
||||||
|
|
||||||
|
- `--chunk-unet`: Splits the Unet model in two approximately equal chunks (each with less than 1GB of weights) for mobile-friendly deployment. This is **required** for ANE deployment on iOS and iPadOS. This is not required for macOS. Swift CLI is able to consume both the chunked and regular versions of the Unet model but prioritizes the former. Note that chunked unet is not compatible with the Python pipeline because Python pipeline is intended for macOS only. Chunking is for on-device deployment with Swift only.
|
||||||
|
|
||||||
|
- `--attention-implementation`: Defaults to `SPLIT_EINSUM` which is the implementation described in [Deploying Transformers on the Apple Neural Engine](https://machinelearning.apple.com/research/neural-engine-transformers). `--attention-implementation ORIGINAL` will switch to an alternative that should be used for non-ANE deployment. Please refer to the [Performance Benchmark](#performance-benchmark) section for further guidance.
|
||||||
|
|
||||||
|
- `--check-output-correctness`: Compares original PyTorch model's outputs to final Core ML model's outputs. This flag increases RAM consumption significantly so it is recommended only for debugging purposes.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## <a name="image-generation-with-python"></a> Image Generation with Python
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary> Click to expand </summary>
|
||||||
|
|
||||||
|
Run text-to-image generation using the example Python pipeline based on [diffusers](https://github.com/huggingface/diffusers):
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python -m python_coreml_stable_diffusion.pipeline --prompt "a photo of an astronaut riding a horse on mars" -i <output-mlpackages-directory> -o </path/to/output/image> --compute-unit ALL --seed 93
|
||||||
|
```
|
||||||
|
Please refer to the help menu for all available arguments: `python -m python_coreml_stable_diffusion.pipeline -h`. Some notable arguments:
|
||||||
|
|
||||||
|
- `-i`: Should point to the `-o` directory from Step 4 of [Converting Models to Core ML](#converting-models-to-coreml) section from above.
|
||||||
|
- `--model-version`: If you overrode the default model version while converting models to Core ML, you will need to specify the same model version here.
|
||||||
|
- `--compute-unit`: Note that the most performant compute unit for this particular implementation may differ across different hardware. `CPU_AND_GPU` or `CPU_AND_NE` may be faster than `ALL`. Please refer to the [Performance Benchmark](#performance-benchmark) section for further guidance.
|
||||||
|
- `--scheduler`: If you would like to experiment with different schedulers, you may specify it here. For available options, please see the help menu. You may also specify a custom number of inference steps by `--num-inference-steps` which defaults to 50.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## Image Generation with Swift
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary> Click to expand </summary>
|
||||||
|
|
||||||
|
### <a name="swift-requirements"></a> System Requirements
|
||||||
|
Building the Swift projects require:
|
||||||
|
- macOS 13 or newer
|
||||||
|
- Xcode 14.1 or newer with command line tools installed. Please check [developer.apple.com](https://developer.apple.com/download/all/?q=xcode) for the latest version.
|
||||||
|
- Core ML models and tokenization resources. Please see `--bundle-resources-for-swift-cli` from the [Converting Models to Core ML](#converting-models-to-coreml) section above
|
||||||
|
|
||||||
|
If deploying this model to:
|
||||||
|
- iPhone
|
||||||
|
- iOS 16.2 or newer
|
||||||
|
- iPhone 12 or newer
|
||||||
|
- iPad
|
||||||
|
- iPadOS 16.2 or newer
|
||||||
|
- M1 or newer
|
||||||
|
- Mac
|
||||||
|
- macOS 13.1 or newer
|
||||||
|
- M1 or newer
|
||||||
|
|
||||||
|
### Example CLI Usage
|
||||||
|
```shell
|
||||||
|
swift run StableDiffusionSample "a photo of an astronaut riding a horse on mars" --resource-path <output-mlpackages-directory>/Resources/ --seed 93 --output-path </path/to/output/image>
|
||||||
|
```
|
||||||
|
The output will be named based on the prompt and random seed:
|
||||||
|
e.g. `</path/to/output/image>/a_photo_of_an_astronaut_riding_a_horse_on_mars.93.final.png`
|
||||||
|
|
||||||
|
Please use the `--help` flag to learn about batched generation and more.
|
||||||
|
|
||||||
|
### Example Library Usage
|
||||||
|
|
||||||
|
```swift
|
||||||
|
import StableDiffusion
|
||||||
|
...
|
||||||
|
let pipeline = try StableDiffusionPipeline(resourcesAt: resourceURL)
|
||||||
|
let image = try pipeline.generateImages(prompt: prompt, seed: seed).first
|
||||||
|
```
|
||||||
|
|
||||||
|
### Swift Package Details
|
||||||
|
|
||||||
|
This Swift package contains two products:
|
||||||
|
|
||||||
|
- `StableDiffusion` library
|
||||||
|
- `StableDiffusionSample` command-line tool
|
||||||
|
|
||||||
|
Both of these products require the Core ML models and tokenization resources to be supplied. When specifying resources via a directory path that directory must contain the following:
|
||||||
|
|
||||||
|
- `TextEncoder.mlmodelc` (text embedding model)
|
||||||
|
- `Unet.mlmodelc` or `UnetChunk1.mlmodelc` & `UnetChunk2.mlmodelc` (denoising autoencoder model)
|
||||||
|
- `VAEDecoder.mlmodelc` (image decoder model)
|
||||||
|
- `vocab.json` (tokenizer vocabulary file)
|
||||||
|
- `merges.text` (merges for byte pair encoding file)
|
||||||
|
|
||||||
|
Optionally, it may also include the safety checker model that some versions of Stable Diffusion include:
|
||||||
|
|
||||||
|
- `SafetyChecker.mlmodelc`
|
||||||
|
|
||||||
|
Note that the chunked version of Unet is checked for first. Only if it is not present will the full `Unet.mlmodelc` be loaded. Chunking is required for iOS and iPadOS and not necessary for macOS.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## <a name="performance-benchmark"></a> Performance Benchmark
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary> Click to expand </summary>
|
||||||
|
|
||||||
|
Standard [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) Benchmark
|
||||||
|
|
||||||
|
| Device | `--compute-unit`| `--attention-implementation` | Latency (seconds) |
|
||||||
|
| ---------------------------------- | -------------- | ---------------------------- | ----------------- |
|
||||||
|
| Mac Studio (M1 Ultra, 64-core GPU) | `CPU_AND_GPU` | `ORIGINAL` | 9 |
|
||||||
|
| Mac Studio (M1 Ultra, 48-core GPU) | `CPU_AND_GPU` | `ORIGINAL` | 13 |
|
||||||
|
| MacBook Pro (M1 Max, 32-core GPU) | `CPU_AND_GPU` | `ORIGINAL` | 18 |
|
||||||
|
| MacBook Pro (M1 Max, 24-core GPU) | `CPU_AND_GPU` | `ORIGINAL` | 20 |
|
||||||
|
| MacBook Pro (M1 Pro, 16-core GPU) | `ALL` | `SPLIT_EINSUM (default)` | 26 |
|
||||||
|
| MacBook Pro (M2) | `CPU_AND_NE` | `SPLIT_EINSUM (default)` | 23 |
|
||||||
|
| MacBook Pro (M1) | `CPU_AND_NE` | `SPLIT_EINSUM (default)` | 35 |
|
||||||
|
| iPad Pro (5th gen, M1) | `CPU_AND_NE` | `SPLIT_EINSUM (default)` | 38 |
|
||||||
|
|
||||||
|
|
||||||
|
Please see [Important Notes on Performance Benchmarks](#important-notes-on-performance-benchmarks) section for details.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## <a name="important-notes-on-performance-benchmarks"></a> Important Notes on Performance Benchmarks
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary> Click to expand </summary>
|
||||||
|
|
||||||
|
- This benchmark was conducted by Apple using public beta versions of iOS 16.2, iPadOS 16.2 and macOS 13.1 in November 2022.
|
||||||
|
- The executed program is `python_coreml_stable_diffusion.pipeline` for macOS devices and a minimal Swift test app built on the `StableDiffusion` Swift package for iOS and iPadOS devices.
|
||||||
|
- The median value across 3 end-to-end executions is reported.
|
||||||
|
- Performance may materially differ across different versions of Stable Diffusion due to architecture changes in the model itself. Each reported number is specific to the model version mentioned in that context.
|
||||||
|
- The image generation procedure follows the standard configuration: 50 inference steps, 512x512 output image resolution, 77 text token sequence length, classifier-free guidance (batch size of 2 for unet).
|
||||||
|
- The actual prompt length does not impact performance because the Core ML model is converted with a static shape that computes the forward pass for all of the 77 elements (`tokenizer.model_max_length`) in the text token sequence regardless of the actual length of the input text.
|
||||||
|
- Pipelining across the 4 models is not optimized and these performance numbers are subject to variance under increased system load from other applications. Given these factors, we do not report sub-second variance in latency.
|
||||||
|
- Weights and activations are in float16 precision for both the GPU and the ANE.
|
||||||
|
- The Swift CLI program consumes a peak memory of approximately 2.6GB (without the safety checker), 2.1GB of which is model weights in float16 precision. We applied [8-bit weight quantization](https://coremltools.readme.io/docs/compressing-ml-program-weights#use-affine-quantization) to reduce peak memory consumption by approximately 1GB. However, we observed that it had an adverse effect on generated image quality and we rolled it back. We encourage developers to experiment with other advanced weight compression techniques such as [palettization](https://coremltools.readme.io/docs/compressing-ml-program-weights#use-a-lookup-table) and/or [pruning](https://coremltools.readme.io/docs/compressing-ml-program-weights#use-sparse-representation) which may yield better results.
|
||||||
|
- In the [benchmark table](performance-benchmark), we report the best performing `--compute-unit` and `--attention-implementation` values per device. The former does not modify the Core ML model and can be applied during runtime. The latter modifies the Core ML model. Note that the best performing compute unit is model version and hardware-specific.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
## <a name="results-with-different-compute-units"></a> Results with Different Compute Units
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary> Click to expand </summary>
|
||||||
|
|
||||||
|
It is highly probable that there will be slight differences across generated images using different compute units.
|
||||||
|
|
||||||
|
The following images were generated on an M1 MacBook Pro and macOS 13.1 with the prompt *"a photo of an astronaut riding a horse on mars"* using the [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) model version. The random seed was set to 93:
|
||||||
|
|
||||||
|
CPU_AND_NE | CPU_AND_GPU | ALL |
|
||||||
|
:------------:|:-------------:|:------:
|
||||||
|
![](assets/a_high_quality_photo_of_an_astronaut_riding_a_horse_in_space/randomSeed_93_computeUnit_CPU_AND_NE_modelVersion_runwayml_stable-diffusion-v1-5.png) | ![](assets/a_high_quality_photo_of_an_astronaut_riding_a_horse_in_space/randomSeed_93_computeUnit_CPU_AND_GPU_modelVersion_runwayml_stable-diffusion-v1-5.png) | ![](assets/a_high_quality_photo_of_an_astronaut_riding_a_horse_in_space/randomSeed_93_computeUnit_ALL_modelVersion_runwayml_stable-diffusion-v1-5.png) |
|
||||||
|
|
||||||
|
Differences may be less or more pronounced for different inputs. Please see the [FAQ](#faq) Q8 for a detailed explanation.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## FAQ
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary> Click to expand </summary>
|
||||||
|
<details>
|
||||||
|
|
||||||
|
|
||||||
|
<summary> <b> Q1: </b> <code> ERROR: Failed building wheel for tokenizers or error: can't find Rust compiler </code> </summary>
|
||||||
|
|
||||||
|
<b> A1: </b> Please review this [potential solution](https://github.com/huggingface/transformers/issues/2831#issuecomment-592724471).
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary> <b> Q2: </b> <code> RuntimeError: {NSLocalizedDescription = "Error computing NN outputs." </code> </summary>
|
||||||
|
|
||||||
|
<b> A2: </b> There are many potential causes for this error. In this context, it is highly likely to be encountered when your system is under increased memory pressure from other applications. Reducing memory utilization of other applications is likely to help alleviate the issue.
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary> <b> Q3: </b> My Mac has 8GB RAM and I am converting models to Core ML using the example command. The process is geting killed because of memory issues. How do I fix this issue? </summary>
|
||||||
|
|
||||||
|
<b> A3: </b> In order to minimize the memory impact of the model conversion process, please execute the following command instead:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m python_coreml_stable_diffusion.torch2coreml --convert-vae-decoder -o <output-mlpackages-directory> && \
|
||||||
|
python -m python_coreml_stable_diffusion.torch2coreml --convert-unet -o <output-mlpackages-directory> && \
|
||||||
|
python -m python_coreml_stable_diffusion.torch2coreml --convert-text-encoder -o <output-mlpackages-directory> && \
|
||||||
|
python -m python_coreml_stable_diffusion.torch2coreml --convert-safety-checker -o <output-mlpackages-directory> &&
|
||||||
|
```
|
||||||
|
|
||||||
|
If you need `--chunk-unet`, you may do so in yet another independent command which will reuse the previously exported Unet model and simply chunk it in place:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m python_coreml_stable_diffusion.torch2coreml --convert-unet --chunk-unet -o <output-mlpackages-directory>
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary> <b> Q4: </b> My Mac has 8GB RAM, should image generation work on my machine? </summary>
|
||||||
|
|
||||||
|
<b> A4: </b> Yes! Especially the `--compute-unit CPU_AND_NE` option should work under reasonable system load from other applications. Note that part of the [Example Results](#example-results) were generated using an M2 MacBook Air with 8GB RAM.
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary> <b> Q5: </b> Every time I generate an image using the Python pipeline, loading all the Core ML models takes 2-3 minutes. Is this expected? </summary>
|
||||||
|
|
||||||
|
<b> A5: </b> Yes and using the Swift library reduces this to just a few seconds. The reason is that `coremltools` loads Core ML models (`.mlpackage`) and each model is compiled to be run on the requested compute unit during load time. Because of the size and number of operations of the unet model, it takes around 2-3 minutes to compile it for Neural Engine execution. Other models should take at most a few seconds. Note that `coremltools` does not cache the compiled model for later loads so each load takes equally long. In order to benefit from compilation caching, `StableDiffusion` Swift package by default relies on compiled Core ML models (`.mlmodelc`) which will be compiled down for the requested compute unit upon first load but then the cache will be reused on subsequent loads until it is purged due to lack of use.
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary> <b> Q6: </b> I want to deploy <code>StableDiffusion</code>, the Swift package, in my mobile app. What should I be aware of?" </summary>
|
||||||
|
|
||||||
|
<b> A6: </b> [This section](#swift-requirements) describes the minimum SDK and OS versions as well as the device models supported by this package. In addition to these requirements, for best practice, we recommend testing the package on the device with the least amount of RAM available among your deployment targets. This is due to the fact that `StableDiffusion` consumes approximately 2.6GB of peak memory during runtime while using `.cpuAndNeuralEngine` (the Swift equivalent of `coremltools.ComputeUnit.CPU_AND_NE`). Other compute units may have a higher peak memory consumption so `.cpuAndNeuralEngine` is recommended for iOS and iPadOS deployment (Please refer to this [section](#swift-requirements) for minimum device model requirements). If your app crashes during image generation, please try adding the [Increased Memory Limit](https://developer.apple.com/documentation/bundleresources/entitlements/com_apple_developer_kernel_increased-memory-limit) capability to your Xcode project which should significantly increase your app's memory limit.
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary> <b> Q7: </b> How do I generate images with different resolutions using the same Core ML models? </summary>
|
||||||
|
|
||||||
|
<b> A7: </b> The current version of `python_coreml_stable_diffusion` does not support single-model multi-resolution out of the box. However, developers may fork this project and leverage the [flexible shapes](https://coremltools.readme.io/docs/flexible-inputs) support from coremltools to extend the `torch2coreml` script by using `coremltools.EnumeratedShapes`. Note that, while the `text_encoder` is agnostic to the image resolution, the inputs and outputs of `vae_decoder` and `unet` models are dependent on the desired image resolution.
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary> <b> Q8: </b> Are the Core ML and PyTorch generated images going to be identical? </summary>
|
||||||
|
|
||||||
|
<b> A8: </b> If desired, the generated images across PyTorch and Core ML can be made approximately identical. However, it is not guaranteed by default. There are several factors that might lead to different images across PyTorch and Core ML:
|
||||||
|
|
||||||
|
|
||||||
|
<b> 1. Random Number Generator Behavior </b>
|
||||||
|
|
||||||
|
The main source of potentially different results across PyTorch and Core ML is the Random Number Generator ([RNG](https://en.wikipedia.org/wiki/Random_number_generation)) behavior. PyTorch and Numpy have different sources of randomness. `python_coreml_stable_diffusion` generally relies on Numpy for RNG (e.g. latents initialization) and `StableDiffusion` Swift Library reproduces this RNG behavior. However, PyTorch-based pipelines such as Hugging Face `diffusers` relies on PyTorch's RNG behavior.
|
||||||
|
|
||||||
|
<b> 2. PyTorch </b>
|
||||||
|
|
||||||
|
*"Completely reproducible results are not guaranteed across PyTorch releases, individual commits, or different platforms. Furthermore, results may not be reproducible between CPU and GPU executions, even when using identical seeds."* ([source](https://pytorch.org/docs/stable/notes/randomness.html#reproducibility)).
|
||||||
|
|
||||||
|
<b> 3. Model Function Drift During Conversion </b>
|
||||||
|
|
||||||
|
The difference in outputs across corresponding PyTorch and Core ML models is a potential cause. The signal integrity is tested during the conversion process (enabled via `--check-output-correctness` argument to `python_coreml_stable_diffusion.torch2coreml`) and it is verified to be above a minimum [PSNR](https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio) value as tested on random inputs. Note that this is simply a sanity check and does not guarantee this minimum PSNR across all possible inputs. Furthermore, the results are not guaranteed to be identical when executing the same Core ML models across different compute units. This is not expected to be a major source of difference as the sample visual results indicate in [this section](#results-with-different-compute-units).
|
||||||
|
|
||||||
|
<b> 4. Weights and Activations Data Type </b>
|
||||||
|
|
||||||
|
When quantizing models from float32 to lower-precision data types such as float16, the generated images are [known to vary slightly](https://lambdalabs.com/blog/inference-benchmark-stable-diffusion) in semantics even when using the same PyTorch model. Core ML models generated by coremltools have float16 weights and activations by default [unless explicitly overriden](https://github.com/apple/coremltools/blob/main/coremltools/converters/_converters_entry.py#L256). This is not expected to be a major source of difference.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary> <b> Q9: </b> The model files are very large, how do I avoid a large binary for my App? </summary>
|
||||||
|
|
||||||
|
<b> A9: </b> The recommended option is to prompt the user to download these assets upon first launch of the app. This keeps the app binary size independent of the Core ML models being deployed. Disclosing the size of the download to the user is extremely important as there could be data charges or storage impact that the user might not be comfortable with.
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
</details>
|
After Width: | Height: | Size: 395 KiB |
After Width: | Height: | Size: 430 KiB |
After Width: | Height: | Size: 444 KiB |
After Width: | Height: | Size: 428 KiB |
After Width: | Height: | Size: 444 KiB |
After Width: | Height: | Size: 507 KiB |
After Width: | Height: | Size: 520 KiB |
After Width: | Height: | Size: 507 KiB |
After Width: | Height: | Size: 423 KiB |
After Width: | Height: | Size: 427 KiB |
After Width: | Height: | Size: 467 KiB |
After Width: | Height: | Size: 446 KiB |
After Width: | Height: | Size: 468 KiB |
After Width: | Height: | Size: 460 KiB |
After Width: | Height: | Size: 456 KiB |
After Width: | Height: | Size: 461 KiB |
After Width: | Height: | Size: 1.3 MiB |
@ -0,0 +1 @@
|
|||||||
|
from ._version import __version__
|
@ -0,0 +1 @@
|
|||||||
|
__version__ = "0.1.0"
|
@ -0,0 +1,337 @@
|
|||||||
|
#
|
||||||
|
# For licensing see accompanying LICENSE.md file.
|
||||||
|
# Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
||||||
|
#
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import coremltools as ct
|
||||||
|
from coremltools.converters.mil import Block, Program, Var
|
||||||
|
from coremltools.converters.mil.frontend.milproto.load import load as _milproto_to_pymil
|
||||||
|
from coremltools.converters.mil.mil import Builder as mb
|
||||||
|
from coremltools.converters.mil.mil import Placeholder
|
||||||
|
from coremltools.converters.mil.mil import types as types
|
||||||
|
from coremltools.converters.mil.mil.passes.helper import block_context_manager
|
||||||
|
from coremltools.converters.mil.mil.passes.pass_registry import PASS_REGISTRY
|
||||||
|
from coremltools.converters.mil.testing_utils import random_gen_input_feature_type
|
||||||
|
|
||||||
|
import gc
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.basicConfig()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
from python_coreml_stable_diffusion import torch2coreml
|
||||||
|
import shutil
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
def _verify_output_correctness_of_chunks(full_model, first_chunk_model,
|
||||||
|
second_chunk_model):
|
||||||
|
""" Verifies the end-to-end output correctness of full (original) model versus chunked models
|
||||||
|
"""
|
||||||
|
# Generate inputs for first chunk and full model
|
||||||
|
input_dict = {}
|
||||||
|
for input_desc in full_model._spec.description.input:
|
||||||
|
input_dict[input_desc.name] = random_gen_input_feature_type(input_desc)
|
||||||
|
|
||||||
|
# Generate outputs for first chunk and full model
|
||||||
|
outputs_from_full_model = full_model.predict(input_dict)
|
||||||
|
outputs_from_first_chunk_model = first_chunk_model.predict(input_dict)
|
||||||
|
|
||||||
|
# Prepare inputs for second chunk model from first chunk's outputs and regular inputs
|
||||||
|
second_chunk_input_dict = {}
|
||||||
|
for input_desc in second_chunk_model._spec.description.input:
|
||||||
|
if input_desc.name in outputs_from_first_chunk_model:
|
||||||
|
second_chunk_input_dict[
|
||||||
|
input_desc.name] = outputs_from_first_chunk_model[
|
||||||
|
input_desc.name]
|
||||||
|
else:
|
||||||
|
second_chunk_input_dict[input_desc.name] = input_dict[
|
||||||
|
input_desc.name]
|
||||||
|
|
||||||
|
# Generate output for second chunk model
|
||||||
|
outputs_from_second_chunk_model = second_chunk_model.predict(
|
||||||
|
second_chunk_input_dict)
|
||||||
|
|
||||||
|
# Verify correctness across all outputs from second chunk and full model
|
||||||
|
for out_name in outputs_from_full_model.keys():
|
||||||
|
torch2coreml.report_correctness(
|
||||||
|
original_outputs=outputs_from_full_model[out_name],
|
||||||
|
final_outputs=outputs_from_second_chunk_model[out_name],
|
||||||
|
log_prefix=f"{out_name}")
|
||||||
|
|
||||||
|
|
||||||
|
def _load_prog_from_mlmodel(model):
|
||||||
|
""" Load MIL Program from an MLModel
|
||||||
|
"""
|
||||||
|
model_spec = model.get_spec()
|
||||||
|
start_ = time.time()
|
||||||
|
logger.info(
|
||||||
|
"Loading MLModel object into a MIL Program object (including the weights).."
|
||||||
|
)
|
||||||
|
prog = _milproto_to_pymil(
|
||||||
|
model_spec=model_spec,
|
||||||
|
specification_version=model_spec.specificationVersion,
|
||||||
|
file_weights_dir=model.weights_dir,
|
||||||
|
)
|
||||||
|
logger.info(f"Program loaded in {time.time() - start_:.1f} seconds")
|
||||||
|
|
||||||
|
return prog
|
||||||
|
|
||||||
|
|
||||||
|
def _get_op_idx_split_location(prog: Program):
|
||||||
|
""" Find the op that approximately bisects the graph as measure by weights size on each side
|
||||||
|
"""
|
||||||
|
main_block = prog.functions["main"]
|
||||||
|
total_size_in_mb = 0
|
||||||
|
|
||||||
|
for op in main_block.operations:
|
||||||
|
if op.op_type == "const" and isinstance(op.val.val, np.ndarray):
|
||||||
|
size_in_mb = op.val.val.size * op.val.val.itemsize / (1024 * 1024)
|
||||||
|
total_size_in_mb += size_in_mb
|
||||||
|
half_size = total_size_in_mb / 2
|
||||||
|
|
||||||
|
# Find the first non const op (single child), where the total cumulative size exceeds
|
||||||
|
# the half size for the first time
|
||||||
|
cumulative_size_in_mb = 0
|
||||||
|
for op in main_block.operations:
|
||||||
|
if op.op_type == "const" and isinstance(op.val.val, np.ndarray):
|
||||||
|
size_in_mb = op.val.val.size * op.val.val.itemsize / (1024 * 1024)
|
||||||
|
cumulative_size_in_mb += size_in_mb
|
||||||
|
|
||||||
|
if (cumulative_size_in_mb > half_size and op.op_type != "const"
|
||||||
|
and len(op.outputs) == 1
|
||||||
|
and len(op.outputs[0].child_ops) == 1):
|
||||||
|
op_idx = main_block.operations.index(op)
|
||||||
|
return op_idx, cumulative_size_in_mb, total_size_in_mb
|
||||||
|
|
||||||
|
|
||||||
|
def _get_first_chunk_outputs(block, op_idx):
|
||||||
|
# Get the list of all vars that go across from first program (all ops from 0 to op_idx (inclusive))
|
||||||
|
# to the second program (all ops from op_idx+1 till the end). These all vars need to be made the output
|
||||||
|
# of the first program and the input of the second program
|
||||||
|
boundary_vars = set()
|
||||||
|
for i in range(op_idx + 1):
|
||||||
|
op = block.operations[i]
|
||||||
|
for var in op.outputs:
|
||||||
|
if var.val is None: # only consider non const vars
|
||||||
|
for child_op in var.child_ops:
|
||||||
|
child_op_idx = block.operations.index(child_op)
|
||||||
|
if child_op_idx > op_idx:
|
||||||
|
boundary_vars.add(var)
|
||||||
|
return list(boundary_vars)
|
||||||
|
|
||||||
|
|
||||||
|
@block_context_manager
|
||||||
|
def _add_fp32_casts(block, boundary_vars):
|
||||||
|
new_boundary_vars = []
|
||||||
|
for var in boundary_vars:
|
||||||
|
if var.dtype != types.fp16:
|
||||||
|
new_boundary_vars.append(var)
|
||||||
|
else:
|
||||||
|
fp32_var = mb.cast(x=var, dtype="fp32", name=var.name)
|
||||||
|
new_boundary_vars.append(fp32_var)
|
||||||
|
return new_boundary_vars
|
||||||
|
|
||||||
|
|
||||||
|
def _make_first_chunk_prog(prog, op_idx):
|
||||||
|
""" Build first chunk by declaring early outputs and removing unused subgraph
|
||||||
|
"""
|
||||||
|
block = prog.functions["main"]
|
||||||
|
boundary_vars = _get_first_chunk_outputs(block, op_idx)
|
||||||
|
|
||||||
|
# Due to possible numerical issues, cast any fp16 var to fp32
|
||||||
|
new_boundary_vars = _add_fp32_casts(block, boundary_vars)
|
||||||
|
|
||||||
|
block.outputs.clear()
|
||||||
|
block.set_outputs(new_boundary_vars)
|
||||||
|
PASS_REGISTRY["common::dead_code_elimination"](prog)
|
||||||
|
return prog
|
||||||
|
|
||||||
|
|
||||||
|
def _make_second_chunk_prog(prog, op_idx):
|
||||||
|
""" Build second chunk by rebuilding a pristine MIL Program from MLModel
|
||||||
|
"""
|
||||||
|
block = prog.functions["main"]
|
||||||
|
block.opset_version = ct.target.iOS16
|
||||||
|
|
||||||
|
# First chunk outputs are second chunk inputs (e.g. skip connections)
|
||||||
|
boundary_vars = _get_first_chunk_outputs(block, op_idx)
|
||||||
|
|
||||||
|
# This op will not be included in this program. Its output var will be made into an input
|
||||||
|
boundary_op = block.operations[op_idx]
|
||||||
|
|
||||||
|
# Add all boundary ops as inputs
|
||||||
|
with block:
|
||||||
|
for var in boundary_vars:
|
||||||
|
new_placeholder = Placeholder(
|
||||||
|
sym_shape=var.shape,
|
||||||
|
dtype=var.dtype if var.dtype != types.fp16 else types.fp32,
|
||||||
|
name=var.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
block._input_dict[
|
||||||
|
new_placeholder.outputs[0].name] = new_placeholder.outputs[0]
|
||||||
|
|
||||||
|
block.function_inputs = tuple(block._input_dict.values())
|
||||||
|
new_var = None
|
||||||
|
if var.dtype == types.fp16:
|
||||||
|
new_var = mb.cast(x=new_placeholder.outputs[0],
|
||||||
|
dtype="fp16",
|
||||||
|
before_op=var.op)
|
||||||
|
else:
|
||||||
|
new_var = new_placeholder.outputs[0]
|
||||||
|
|
||||||
|
block.replace_uses_of_var_after_op(
|
||||||
|
anchor_op=boundary_op,
|
||||||
|
old_var=var,
|
||||||
|
new_var=new_var,
|
||||||
|
)
|
||||||
|
|
||||||
|
PASS_REGISTRY["common::dead_code_elimination"](prog)
|
||||||
|
|
||||||
|
# Remove any unused inputs
|
||||||
|
new_input_dict = OrderedDict()
|
||||||
|
for k, v in block._input_dict.items():
|
||||||
|
if len(v.child_ops) > 0:
|
||||||
|
new_input_dict[k] = v
|
||||||
|
block._input_dict = new_input_dict
|
||||||
|
block.function_inputs = tuple(block._input_dict.values())
|
||||||
|
|
||||||
|
return prog
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
os.makedirs(args.o, exist_ok=True)
|
||||||
|
|
||||||
|
# Check filename extension
|
||||||
|
mlpackage_name = os.path.basename(args.mlpackage_path)
|
||||||
|
name, ext = os.path.splitext(mlpackage_name)
|
||||||
|
assert ext == ".mlpackage", f"`--mlpackage-path` (args.mlpackage_path) is not an .mlpackage file"
|
||||||
|
|
||||||
|
# Load CoreML model
|
||||||
|
logger.info("Loading model from {}".format(args.mlpackage_path))
|
||||||
|
start_ = time.time()
|
||||||
|
model = ct.models.MLModel(
|
||||||
|
args.mlpackage_path,
|
||||||
|
compute_units=ct.ComputeUnit.CPU_ONLY,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Loading {args.mlpackage_path} took {time.time() - start_:.1f} seconds"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load the MIL Program from MLModel
|
||||||
|
prog = _load_prog_from_mlmodel(model)
|
||||||
|
|
||||||
|
# Compute the incision point by bisecting the program based on weights size
|
||||||
|
op_idx, first_chunk_weights_size, total_weights_size = _get_op_idx_split_location(
|
||||||
|
prog)
|
||||||
|
main_block = prog.functions["main"]
|
||||||
|
incision_op = main_block.operations[op_idx]
|
||||||
|
logger.info(f"{args.mlpackage_path} will chunked into two pieces.")
|
||||||
|
logger.info(
|
||||||
|
f"The incision op: name={incision_op.name}, type={incision_op.op_type}, index={op_idx}/{len(main_block.operations)}"
|
||||||
|
)
|
||||||
|
logger.info(f"First chunk size = {first_chunk_weights_size:.2f} MB")
|
||||||
|
logger.info(
|
||||||
|
f"Second chunk size = {total_weights_size - first_chunk_weights_size:.2f} MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build first chunk (in-place modifies prog by declaring early exits and removing unused subgraph)
|
||||||
|
prog_chunk1 = _make_first_chunk_prog(prog, op_idx)
|
||||||
|
|
||||||
|
# Build the second chunk
|
||||||
|
prog_chunk2 = _make_second_chunk_prog(_load_prog_from_mlmodel(model),
|
||||||
|
op_idx)
|
||||||
|
|
||||||
|
if not args.check_output_correctness:
|
||||||
|
# Original model no longer needed in memory
|
||||||
|
del model
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# Convert the MIL Program objects into MLModels
|
||||||
|
logger.info("Converting the two programs")
|
||||||
|
model_chunk1 = ct.convert(
|
||||||
|
prog_chunk1,
|
||||||
|
convert_to="mlprogram",
|
||||||
|
compute_units=ct.ComputeUnit.CPU_ONLY,
|
||||||
|
minimum_deployment_target=ct.target.iOS16,
|
||||||
|
)
|
||||||
|
del prog_chunk1
|
||||||
|
gc.collect()
|
||||||
|
logger.info("Conversion of first chunk done.")
|
||||||
|
|
||||||
|
model_chunk2 = ct.convert(
|
||||||
|
prog_chunk2,
|
||||||
|
convert_to="mlprogram",
|
||||||
|
compute_units=ct.ComputeUnit.CPU_ONLY,
|
||||||
|
minimum_deployment_target=ct.target.iOS16,
|
||||||
|
)
|
||||||
|
del prog_chunk2
|
||||||
|
gc.collect()
|
||||||
|
logger.info("Conversion of second chunk done.")
|
||||||
|
|
||||||
|
# Verify output correctness
|
||||||
|
if args.check_output_correctness:
|
||||||
|
logger.info("Verifying output correctness of chunks")
|
||||||
|
_verify_output_correctness_of_chunks(
|
||||||
|
full_model=model,
|
||||||
|
first_chunk_model=model_chunk1,
|
||||||
|
second_chunk_model=model_chunk2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove original (non-chunked) model if requested
|
||||||
|
if args.remove_original:
|
||||||
|
logger.info(
|
||||||
|
"Removing original (non-chunked) model at {args.mlpackage_path}")
|
||||||
|
shutil.rmtree(args.mlpackage_path)
|
||||||
|
logger.info("Done.")
|
||||||
|
|
||||||
|
# Save the chunked models to disk
|
||||||
|
out_path_chunk1 = os.path.join(args.o, name + "_chunk1.mlpackage")
|
||||||
|
out_path_chunk2 = os.path.join(args.o, name + "_chunk2.mlpackage")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Saved chunks in {args.o} with the suffix _chunk1.mlpackage and _chunk2.mlpackage"
|
||||||
|
)
|
||||||
|
model_chunk1.save(out_path_chunk1)
|
||||||
|
model_chunk2.save(out_path_chunk2)
|
||||||
|
logger.info("Done.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--mlpackage-path",
|
||||||
|
required=True,
|
||||||
|
help=
|
||||||
|
"Path to the mlpackage file to be split into two mlpackages of approximately same file size.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-o",
|
||||||
|
required=True,
|
||||||
|
help=
|
||||||
|
"Path to output directory where the two model chunks should be saved.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--remove-original",
|
||||||
|
action="store_true",
|
||||||
|
help=
|
||||||
|
"If specified, removes the original (non-chunked) model to avoid duplicating storage."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--check-output-correctness",
|
||||||
|
action="store_true",
|
||||||
|
help=
|
||||||
|
("If specified, compares the outputs of original Core ML model with that of pipelined CoreML model chunks and reports PSNR in dB. ",
|
||||||
|
"Enabling this feature uses more memory. Disable it if your machine runs out of memory."
|
||||||
|
))
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
@ -0,0 +1,102 @@
|
|||||||
|
#
|
||||||
|
# For licensing see accompanying LICENSE.md file.
|
||||||
|
# Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
||||||
|
#
|
||||||
|
|
||||||
|
import coremltools as ct
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.basicConfig()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class CoreMLModel:
|
||||||
|
""" Wrapper for running CoreML models using coremltools
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_path, compute_unit):
|
||||||
|
assert os.path.exists(model_path) and model_path.endswith(".mlpackage")
|
||||||
|
|
||||||
|
logger.info(f"Loading {model_path}")
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
self.model = ct.models.MLModel(
|
||||||
|
model_path, compute_units=ct.ComputeUnit[compute_unit])
|
||||||
|
load_time = time.time() - start
|
||||||
|
logger.info(f"Done. Took {load_time:.1f} seconds.")
|
||||||
|
|
||||||
|
if load_time > LOAD_TIME_INFO_MSG_TRIGGER:
|
||||||
|
logger.info(
|
||||||
|
"Loading a CoreML model through coremltools triggers compilation every time. "
|
||||||
|
"The Swift package we provide uses precompiled Core ML models (.mlmodelc) to avoid compile-on-load."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
DTYPE_MAP = {
|
||||||
|
65552: np.float16,
|
||||||
|
65568: np.float32,
|
||||||
|
131104: np.int32,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.expected_inputs = {
|
||||||
|
input_tensor.name: {
|
||||||
|
"shape": tuple(input_tensor.type.multiArrayType.shape),
|
||||||
|
"dtype": DTYPE_MAP[input_tensor.type.multiArrayType.dataType],
|
||||||
|
}
|
||||||
|
for input_tensor in self.model._spec.description.input
|
||||||
|
}
|
||||||
|
|
||||||
|
def _verify_inputs(self, **kwargs):
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if k in self.expected_inputs:
|
||||||
|
if not isinstance(v, np.ndarray):
|
||||||
|
raise TypeError(
|
||||||
|
f"Expected numpy.ndarray, got {v} for input: {k}")
|
||||||
|
|
||||||
|
expected_dtype = self.expected_inputs[k]["dtype"]
|
||||||
|
if not v.dtype == expected_dtype:
|
||||||
|
raise TypeError(
|
||||||
|
f"Expected dtype {expected_dtype}, got {v.dtype} for input: {k}"
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_shape = self.expected_inputs[k]["shape"]
|
||||||
|
if not v.shape == expected_shape:
|
||||||
|
raise TypeError(
|
||||||
|
f"Expected shape {expected_shape}, got {v.shape} for input: {k}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Received unexpected input kwarg: {k}")
|
||||||
|
|
||||||
|
def __call__(self, **kwargs):
|
||||||
|
self._verify_inputs(**kwargs)
|
||||||
|
return self.model.predict(kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
LOAD_TIME_INFO_MSG_TRIGGER = 10 # seconds
|
||||||
|
|
||||||
|
|
||||||
|
def _load_mlpackage(submodule_name, mlpackages_dir, model_version,
|
||||||
|
compute_unit):
|
||||||
|
""" Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py)
|
||||||
|
"""
|
||||||
|
logger.info(f"Loading {submodule_name} mlpackage")
|
||||||
|
|
||||||
|
fname = f"Stable_Diffusion_version_{model_version}_{submodule_name}.mlpackage".replace(
|
||||||
|
"/", "_")
|
||||||
|
mlpackage_path = os.path.join(mlpackages_dir, fname)
|
||||||
|
|
||||||
|
if not os.path.exists(mlpackage_path):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"{submodule_name} CoreML model doesn't exist at {mlpackage_path}")
|
||||||
|
|
||||||
|
return CoreMLModel(mlpackage_path, compute_unit)
|
||||||
|
|
||||||
|
def get_available_compute_units():
|
||||||
|
return tuple(cu for cu in ct.ComputeUnit._member_names_)
|
@ -0,0 +1,80 @@
|
|||||||
|
#
|
||||||
|
# For licensing see accompanying LICENSE.md file.
|
||||||
|
# Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
||||||
|
#
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
# Reference: https://github.com/apple/ml-ane-transformers/blob/main/ane_transformers/reference/layer_norm.py
|
||||||
|
class LayerNormANE(nn.Module):
|
||||||
|
""" LayerNorm optimized for Apple Neural Engine (ANE) execution
|
||||||
|
|
||||||
|
Note: This layer only supports normalization over the final dim. It expects `num_channels`
|
||||||
|
as an argument and not `normalized_shape` which is used by `torch.nn.LayerNorm`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_channels,
|
||||||
|
clip_mag=None,
|
||||||
|
eps=1e-5,
|
||||||
|
elementwise_affine=True):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
num_channels: Number of channels (C) where the expected input data format is BC1S. S stands for sequence length.
|
||||||
|
clip_mag: Optional float value to use for clamping the input range before layer norm is applied.
|
||||||
|
If specified, helps reduce risk of overflow.
|
||||||
|
eps: Small value to avoid dividing by zero
|
||||||
|
elementwise_affine: If true, adds learnable channel-wise shift (bias) and scale (weight) parameters
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
# Principle 1: Picking the Right Data Format (machinelearning.apple.com/research/apple-neural-engine)
|
||||||
|
self.expected_rank = len("BC1S")
|
||||||
|
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.eps = eps
|
||||||
|
self.clip_mag = clip_mag
|
||||||
|
self.elementwise_affine = elementwise_affine
|
||||||
|
|
||||||
|
if self.elementwise_affine:
|
||||||
|
self.weight = nn.Parameter(torch.Tensor(num_channels))
|
||||||
|
self.bias = nn.Parameter(torch.Tensor(num_channels))
|
||||||
|
|
||||||
|
self._reset_parameters()
|
||||||
|
|
||||||
|
def _reset_parameters(self):
|
||||||
|
if self.elementwise_affine:
|
||||||
|
nn.init.ones_(self.weight)
|
||||||
|
nn.init.zeros_(self.bias)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
input_rank = len(inputs.size())
|
||||||
|
|
||||||
|
# Principle 1: Picking the Right Data Format (machinelearning.apple.com/research/apple-neural-engine)
|
||||||
|
# Migrate the data format from BSC to BC1S (most conducive to ANE)
|
||||||
|
if input_rank == 3 and inputs.size(2) == self.num_channels:
|
||||||
|
inputs = inputs.transpose(1, 2).unsqueeze(2)
|
||||||
|
input_rank = len(inputs.size())
|
||||||
|
|
||||||
|
assert input_rank == self.expected_rank
|
||||||
|
assert inputs.size(1) == self.num_channels
|
||||||
|
|
||||||
|
if self.clip_mag is not None:
|
||||||
|
inputs.clamp_(-self.clip_mag, self.clip_mag)
|
||||||
|
|
||||||
|
channels_mean = inputs.mean(dim=1, keepdims=True)
|
||||||
|
|
||||||
|
zero_mean = inputs - channels_mean
|
||||||
|
|
||||||
|
zero_mean_sq = zero_mean * zero_mean
|
||||||
|
|
||||||
|
denom = (zero_mean_sq.mean(dim=1, keepdims=True) + self.eps).rsqrt()
|
||||||
|
|
||||||
|
out = zero_mean * denom
|
||||||
|
|
||||||
|
if self.elementwise_affine:
|
||||||
|
out = (out + self.bias.view(1, self.num_channels, 1, 1)
|
||||||
|
) * self.weight.view(1, self.num_channels, 1, 1)
|
||||||
|
|
||||||
|
return out
|
@ -0,0 +1,534 @@
|
|||||||
|
#
|
||||||
|
# For licensing see accompanying LICENSE.md file.
|
||||||
|
# Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
||||||
|
#
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from diffusers.pipeline_utils import DiffusionPipeline
|
||||||
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
|
from diffusers.schedulers import (
|
||||||
|
DDIMScheduler,
|
||||||
|
DPMSolverMultistepScheduler,
|
||||||
|
EulerAncestralDiscreteScheduler,
|
||||||
|
EulerDiscreteScheduler,
|
||||||
|
LMSDiscreteScheduler,
|
||||||
|
PNDMScheduler,
|
||||||
|
)
|
||||||
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.basicConfig()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
|
from python_coreml_stable_diffusion.coreml_model import (
|
||||||
|
CoreMLModel,
|
||||||
|
_load_mlpackage,
|
||||||
|
get_available_compute_units,
|
||||||
|
)
|
||||||
|
|
||||||
|
import time
|
||||||
|
import torch # Only used for `torch.from_tensor` in `pipe.scheduler.step()`
|
||||||
|
from transformers import CLIPFeatureExtractor, CLIPTokenizer
|
||||||
|
from typing import Union, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class CoreMLStableDiffusionPipeline(DiffusionPipeline):
|
||||||
|
""" Core ML version of
|
||||||
|
`diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline`
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
text_encoder: CoreMLModel,
|
||||||
|
unet: CoreMLModel,
|
||||||
|
vae_decoder: CoreMLModel,
|
||||||
|
feature_extractor: CLIPFeatureExtractor,
|
||||||
|
safety_checker: Optional[CoreMLModel],
|
||||||
|
scheduler: Union[DDIMScheduler,
|
||||||
|
DPMSolverMultistepScheduler,
|
||||||
|
EulerAncestralDiscreteScheduler,
|
||||||
|
EulerDiscreteScheduler,
|
||||||
|
LMSDiscreteScheduler,
|
||||||
|
PNDMScheduler],
|
||||||
|
tokenizer: CLIPTokenizer,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Register non-Core ML components of the pipeline similar to the original pipeline
|
||||||
|
self.register_modules(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
feature_extractor=feature_extractor,
|
||||||
|
)
|
||||||
|
|
||||||
|
if safety_checker is None:
|
||||||
|
# Reproduce original warning:
|
||||||
|
# https://github.com/huggingface/diffusers/blob/v0.9.0/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L119
|
||||||
|
logger.warning(
|
||||||
|
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||||
|
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||||
|
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||||
|
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||||
|
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||||
|
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register Core ML components of the pipeline
|
||||||
|
self.safety_checker = safety_checker
|
||||||
|
self.text_encoder = text_encoder
|
||||||
|
self.unet = unet
|
||||||
|
self.unet.in_channels = self.unet.expected_inputs["sample"]["shape"][1]
|
||||||
|
|
||||||
|
self.vae_decoder = vae_decoder
|
||||||
|
|
||||||
|
VAE_DECODER_UPSAMPLE_FACTOR = 8
|
||||||
|
|
||||||
|
# In PyTorch, users can determine the tensor shapes dynamically by default
|
||||||
|
# In CoreML, tensors have static shapes unless flexible shapes were used during export
|
||||||
|
# See https://coremltools.readme.io/docs/flexible-inputs
|
||||||
|
latent_h, latent_w = self.unet.expected_inputs["sample"]["shape"][2:]
|
||||||
|
self.height = latent_h * VAE_DECODER_UPSAMPLE_FACTOR
|
||||||
|
self.width = latent_w * VAE_DECODER_UPSAMPLE_FACTOR
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Stable Diffusion configured to generate {self.height}x{self.width} images"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _encode_prompt(self, prompt, num_images_per_prompt,
|
||||||
|
do_classifier_free_guidance, negative_prompt):
|
||||||
|
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||||
|
|
||||||
|
text_inputs = self.tokenizer(
|
||||||
|
prompt,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=self.tokenizer.model_max_length,
|
||||||
|
return_tensors="np",
|
||||||
|
)
|
||||||
|
text_input_ids = text_inputs.input_ids
|
||||||
|
|
||||||
|
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
|
||||||
|
removed_text = self.tokenizer.batch_decode(
|
||||||
|
text_input_ids[:, self.tokenizer.model_max_length:])
|
||||||
|
logger.warning(
|
||||||
|
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||||
|
f" {self.tokenizer.model_max_length} tokens: {removed_text}")
|
||||||
|
text_input_ids = text_input_ids[:, :self.tokenizer.
|
||||||
|
model_max_length]
|
||||||
|
|
||||||
|
text_embeddings = self.text_encoder(
|
||||||
|
input_ids=text_input_ids.astype(np.float32))["last_hidden_state"]
|
||||||
|
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
uncond_tokens: List[str]
|
||||||
|
if negative_prompt is None:
|
||||||
|
uncond_tokens = [""] * batch_size
|
||||||
|
elif type(prompt) is not type(negative_prompt):
|
||||||
|
raise TypeError(
|
||||||
|
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||||
|
" {type(prompt)}.")
|
||||||
|
elif isinstance(negative_prompt, str):
|
||||||
|
uncond_tokens = [negative_prompt] * batch_size
|
||||||
|
elif batch_size != len(negative_prompt):
|
||||||
|
raise ValueError(
|
||||||
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||||
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||||
|
" the batch size of `prompt`.")
|
||||||
|
else:
|
||||||
|
uncond_tokens = negative_prompt
|
||||||
|
|
||||||
|
max_length = text_input_ids.shape[-1]
|
||||||
|
uncond_input = self.tokenizer(
|
||||||
|
uncond_tokens,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="np",
|
||||||
|
)
|
||||||
|
|
||||||
|
uncond_embeddings = self.text_encoder(
|
||||||
|
input_ids=uncond_input.input_ids.astype(
|
||||||
|
np.float32))["last_hidden_state"]
|
||||||
|
|
||||||
|
# For classifier free guidance, we need to do two forward passes.
|
||||||
|
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||||
|
# to avoid doing two forward passes
|
||||||
|
text_embeddings = np.concatenate(
|
||||||
|
[uncond_embeddings, text_embeddings])
|
||||||
|
|
||||||
|
text_embeddings = text_embeddings.transpose(0, 2, 1)[:, :, None, :]
|
||||||
|
|
||||||
|
return text_embeddings
|
||||||
|
|
||||||
|
def run_safety_checker(self, image):
|
||||||
|
if self.safety_checker is not None:
|
||||||
|
safety_checker_input = self.feature_extractor(
|
||||||
|
self.numpy_to_pil(image),
|
||||||
|
return_tensors="np",
|
||||||
|
)
|
||||||
|
|
||||||
|
safety_checker_outputs = self.safety_checker(
|
||||||
|
clip_input=safety_checker_input.pixel_values.astype(
|
||||||
|
np.float16),
|
||||||
|
images=image.astype(np.float16),
|
||||||
|
adjustment=np.array([0.]).astype(
|
||||||
|
np.float16), # defaults to 0 in original pipeline
|
||||||
|
)
|
||||||
|
|
||||||
|
# Unpack dict
|
||||||
|
has_nsfw_concept = safety_checker_outputs["has_nsfw_concepts"]
|
||||||
|
image = safety_checker_outputs["filtered_images"]
|
||||||
|
concept_scores = safety_checker_outputs["concept_scores"]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Generated image has nsfw concept={has_nsfw_concept.any()}")
|
||||||
|
else:
|
||||||
|
has_nsfw_concept = None
|
||||||
|
|
||||||
|
return image, has_nsfw_concept
|
||||||
|
|
||||||
|
def decode_latents(self, latents):
|
||||||
|
latents = 1 / 0.18215 * latents
|
||||||
|
image = self.vae_decoder(z=latents.astype(np.float16))["image"]
|
||||||
|
image = np.clip(image / 2 + 0.5, 0, 1)
|
||||||
|
image = image.transpose((0, 2, 3, 1))
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def prepare_latents(self,
|
||||||
|
batch_size,
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
latents=None):
|
||||||
|
latents_shape = (batch_size, num_channels_latents, self.height // 8,
|
||||||
|
self.width // 8)
|
||||||
|
if latents is None:
|
||||||
|
latents = np.random.randn(*latents_shape).astype(np.float16)
|
||||||
|
elif latents.shape != latents_shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
latents = latents * self.scheduler.init_noise_sigma
|
||||||
|
|
||||||
|
return latents
|
||||||
|
|
||||||
|
def check_inputs(self, prompt, height, width, callback_steps):
|
||||||
|
if height != self.height or width != self.width:
|
||||||
|
logger.warning(
|
||||||
|
"`height` and `width` dimensions (of the output image tensor) are fixed when exporting the Core ML models " \
|
||||||
|
"unless flexible shapes are used during export (https://coremltools.readme.io/docs/flexible-inputs). " \
|
||||||
|
"This pipeline was provided with Core ML models that generate {self.height}x{self.width} images (user requested {height}x{width})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||||
|
raise ValueError(
|
||||||
|
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if height % 8 != 0 or width % 8 != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if (callback_steps is None) or (callback_steps is not None and
|
||||||
|
(not isinstance(callback_steps, int)
|
||||||
|
or callback_steps <= 0)):
|
||||||
|
raise ValueError(
|
||||||
|
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||||
|
f" {type(callback_steps)}.")
|
||||||
|
|
||||||
|
def prepare_extra_step_kwargs(self, eta):
|
||||||
|
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||||
|
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||||
|
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||||
|
# and should be between [0, 1]
|
||||||
|
|
||||||
|
accepts_eta = "eta" in set(
|
||||||
|
inspect.signature(self.scheduler.step).parameters.keys())
|
||||||
|
extra_step_kwargs = {}
|
||||||
|
if accepts_eta:
|
||||||
|
extra_step_kwargs["eta"] = eta
|
||||||
|
|
||||||
|
return extra_step_kwargs
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
height=512,
|
||||||
|
width=512,
|
||||||
|
num_inference_steps=50,
|
||||||
|
guidance_scale=7.5,
|
||||||
|
negative_prompt=None,
|
||||||
|
num_images_per_prompt=1,
|
||||||
|
eta=0.0,
|
||||||
|
latents=None,
|
||||||
|
output_type="pil",
|
||||||
|
return_dict=True,
|
||||||
|
callback=None,
|
||||||
|
callback_steps=1,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# 1. Check inputs. Raise error if not correct
|
||||||
|
self.check_inputs(prompt, height, width, callback_steps)
|
||||||
|
|
||||||
|
# 2. Define call parameters
|
||||||
|
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
||||||
|
if batch_size > 1 or num_images_per_prompt > 1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"For batched generation of multiple images and/or multiple prompts, please refer to the Swift package."
|
||||||
|
)
|
||||||
|
|
||||||
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||||
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||||
|
# corresponds to doing no classifier free guidance.
|
||||||
|
do_classifier_free_guidance = guidance_scale > 1.0
|
||||||
|
|
||||||
|
# 3. Encode input prompt
|
||||||
|
text_embeddings = self._encode_prompt(
|
||||||
|
prompt,
|
||||||
|
num_images_per_prompt,
|
||||||
|
do_classifier_free_guidance,
|
||||||
|
negative_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Prepare timesteps
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
|
timesteps = self.scheduler.timesteps
|
||||||
|
|
||||||
|
# 5. Prepare latent variables
|
||||||
|
num_channels_latents = self.unet.in_channels
|
||||||
|
latents = self.prepare_latents(
|
||||||
|
batch_size * num_images_per_prompt,
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
latents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Prepare extra step kwargs
|
||||||
|
extra_step_kwargs = self.prepare_extra_step_kwargs(eta)
|
||||||
|
|
||||||
|
# 7. Denoising loop
|
||||||
|
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||||
|
# expand the latents if we are doing classifier free guidance
|
||||||
|
latent_model_input = np.concatenate(
|
||||||
|
[latents] * 2) if do_classifier_free_guidance else latents
|
||||||
|
latent_model_input = self.scheduler.scale_model_input(
|
||||||
|
latent_model_input, t)
|
||||||
|
|
||||||
|
# predict the noise residual
|
||||||
|
noise_pred = self.unet(
|
||||||
|
sample=latent_model_input.astype(np.float16),
|
||||||
|
timestep=np.array([t, t], np.float16),
|
||||||
|
encoder_hidden_states=text_embeddings.astype(np.float16),
|
||||||
|
)["noise_pred"]
|
||||||
|
|
||||||
|
# perform guidance
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||||
|
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||||
|
noise_pred_text - noise_pred_uncond)
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
latents = self.scheduler.step(torch.from_numpy(noise_pred),
|
||||||
|
t,
|
||||||
|
torch.from_numpy(latents),
|
||||||
|
**extra_step_kwargs,
|
||||||
|
).prev_sample.numpy()
|
||||||
|
|
||||||
|
# call the callback, if provided
|
||||||
|
if callback is not None and i % callback_steps == 0:
|
||||||
|
callback(i, t, latents)
|
||||||
|
|
||||||
|
# 8. Post-processing
|
||||||
|
image = self.decode_latents(latents)
|
||||||
|
|
||||||
|
# 9. Run safety checker
|
||||||
|
image, has_nsfw_concept = self.run_safety_checker(image)
|
||||||
|
|
||||||
|
# 10. Convert to PIL
|
||||||
|
if output_type == "pil":
|
||||||
|
image = self.numpy_to_pil(image)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (image, has_nsfw_concept)
|
||||||
|
|
||||||
|
return StableDiffusionPipelineOutput(
|
||||||
|
images=image, nsfw_content_detected=has_nsfw_concept)
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_schedulers():
|
||||||
|
schedulers = {}
|
||||||
|
for scheduler in [DDIMScheduler,
|
||||||
|
DPMSolverMultistepScheduler,
|
||||||
|
EulerAncestralDiscreteScheduler,
|
||||||
|
EulerDiscreteScheduler,
|
||||||
|
LMSDiscreteScheduler,
|
||||||
|
PNDMScheduler]:
|
||||||
|
schedulers[scheduler().__class__.__name__.replace("Scheduler", "")] = scheduler
|
||||||
|
return schedulers
|
||||||
|
|
||||||
|
SCHEDULER_MAP = get_available_schedulers()
|
||||||
|
|
||||||
|
def get_coreml_pipe(pytorch_pipe,
|
||||||
|
mlpackages_dir,
|
||||||
|
model_version,
|
||||||
|
compute_unit,
|
||||||
|
delete_original_pipe=True,
|
||||||
|
scheduler_override=None):
|
||||||
|
""" Initializes and returns a `CoreMLStableDiffusionPipeline` from an original
|
||||||
|
diffusers PyTorch pipeline
|
||||||
|
"""
|
||||||
|
# Ensure `scheduler_override` object is of correct type if specified
|
||||||
|
if scheduler_override is not None:
|
||||||
|
assert isinstance(scheduler_override, SchedulerMixin)
|
||||||
|
logger.warning(
|
||||||
|
"Overriding scheduler in pipeline: "
|
||||||
|
f"Default={pytorch_pipe.scheduler}, Override={scheduler_override}")
|
||||||
|
|
||||||
|
# Gather configured tokenizer and scheduler attributes from the original pipe
|
||||||
|
coreml_pipe_kwargs = {
|
||||||
|
"tokenizer": pytorch_pipe.tokenizer,
|
||||||
|
"scheduler": pytorch_pipe.scheduler if scheduler_override is None else scheduler_override,
|
||||||
|
"feature_extractor": pytorch_pipe.feature_extractor,
|
||||||
|
}
|
||||||
|
|
||||||
|
model_names_to_load = ["text_encoder", "unet", "vae_decoder"]
|
||||||
|
if getattr(pytorch_pipe, "safety_checker", None) is not None:
|
||||||
|
model_names_to_load.append("safety_checker")
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Original diffusers pipeline for {model_version} does not have a safety_checker, "
|
||||||
|
"Core ML pipeline will mirror this behavior.")
|
||||||
|
coreml_pipe_kwargs["safety_checker"] = None
|
||||||
|
|
||||||
|
if delete_original_pipe:
|
||||||
|
del pytorch_pipe
|
||||||
|
gc.collect()
|
||||||
|
logger.info("Removed PyTorch pipe to reduce peak memory consumption")
|
||||||
|
|
||||||
|
# Load Core ML models
|
||||||
|
logger.info(f"Loading Core ML models in memory from {mlpackages_dir}")
|
||||||
|
coreml_pipe_kwargs.update({
|
||||||
|
model_name: _load_mlpackage(
|
||||||
|
model_name,
|
||||||
|
mlpackages_dir,
|
||||||
|
model_version,
|
||||||
|
compute_unit,
|
||||||
|
)
|
||||||
|
for model_name in model_names_to_load
|
||||||
|
})
|
||||||
|
logger.info("Done.")
|
||||||
|
|
||||||
|
logger.info("Initializing Core ML pipe for image generation")
|
||||||
|
coreml_pipe = CoreMLStableDiffusionPipeline(**coreml_pipe_kwargs)
|
||||||
|
logger.info("Done.")
|
||||||
|
|
||||||
|
return coreml_pipe
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_path(args, **override_kwargs):
|
||||||
|
""" mkdir output folder and encode metadata in the filename
|
||||||
|
"""
|
||||||
|
out_folder = os.path.join(args.o, "_".join(args.prompt.replace("/", "_").rsplit(" ")))
|
||||||
|
os.makedirs(out_folder, exist_ok=True)
|
||||||
|
|
||||||
|
out_fname = f"randomSeed_{override_kwargs.get('seed', None) or args.seed}"
|
||||||
|
out_fname += f"_computeUnit_{override_kwargs.get('compute_unit', None) or args.compute_unit}"
|
||||||
|
out_fname += f"_modelVersion_{override_kwargs.get('model_version', None) or args.model_version.replace('/', '_')}"
|
||||||
|
|
||||||
|
if args.scheduler is not None:
|
||||||
|
out_fname += f"_customScheduler_{override_kwargs.get('scheduler', None) or args.scheduler}"
|
||||||
|
out_fname += f"_numInferenceSteps{override_kwargs.get('num_inference_steps', None) or args.num_inference_steps}"
|
||||||
|
|
||||||
|
return os.path.join(out_folder, out_fname + ".png")
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
logger.info(f"Setting random seed to {args.seed}")
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
|
||||||
|
logger.info("Initializing PyTorch pipe for reference configuration")
|
||||||
|
from diffusers import StableDiffusionPipeline
|
||||||
|
pytorch_pipe = StableDiffusionPipeline.from_pretrained(args.model_version,
|
||||||
|
use_auth_token=True)
|
||||||
|
|
||||||
|
user_specified_scheduler = None
|
||||||
|
if args.scheduler is not None:
|
||||||
|
user_specified_scheduler = SCHEDULER_MAP[
|
||||||
|
args.scheduler].from_config(pytorch_pipe.scheduler.config)
|
||||||
|
|
||||||
|
coreml_pipe = get_coreml_pipe(pytorch_pipe=pytorch_pipe,
|
||||||
|
mlpackages_dir=args.i,
|
||||||
|
model_version=args.model_version,
|
||||||
|
compute_unit=args.compute_unit,
|
||||||
|
scheduler_override=user_specified_scheduler)
|
||||||
|
|
||||||
|
logger.info("Beginning image generation.")
|
||||||
|
image = coreml_pipe(
|
||||||
|
prompt=args.prompt,
|
||||||
|
height=coreml_pipe.height,
|
||||||
|
width=coreml_pipe.width,
|
||||||
|
num_inference_steps=args.num_inference_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
out_path = get_image_path(args)
|
||||||
|
logger.info(f"Saving generated image to {out_path}")
|
||||||
|
image["images"][0].save(out_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt",
|
||||||
|
required=True,
|
||||||
|
help="The text prompt to be used for text-to-image generation.")
|
||||||
|
parser.add_argument(
|
||||||
|
"-i",
|
||||||
|
required=True,
|
||||||
|
help=("Path to input directory with the .mlpackage files generated by "
|
||||||
|
"python_coreml_stable_diffusion.torch2coreml"))
|
||||||
|
parser.add_argument("-o", required=True)
|
||||||
|
parser.add_argument("--seed",
|
||||||
|
"-s",
|
||||||
|
default=93,
|
||||||
|
type=int,
|
||||||
|
help="Random seed to be able to reproduce results")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-version",
|
||||||
|
default="CompVis/stable-diffusion-v1-4",
|
||||||
|
help=
|
||||||
|
("The pre-trained model checkpoint and configuration to restore. "
|
||||||
|
"For available versions: https://huggingface.co/models?search=stable-diffusion"
|
||||||
|
))
|
||||||
|
parser.add_argument(
|
||||||
|
"--compute-unit",
|
||||||
|
choices=get_available_compute_units(),
|
||||||
|
default="ALL",
|
||||||
|
help=("The compute units to be used when executing Core ML models. "
|
||||||
|
f"Options: {get_available_compute_units()}"))
|
||||||
|
parser.add_argument(
|
||||||
|
"--scheduler",
|
||||||
|
choices=tuple(SCHEDULER_MAP.keys()),
|
||||||
|
default=None,
|
||||||
|
help=("The scheduler to use for running the reverse diffusion process. "
|
||||||
|
"If not specified, the default scheduler from the diffusers pipeline is utilized"))
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-inference-steps",
|
||||||
|
default=50,
|
||||||
|
type=int,
|
||||||
|
help="The number of iterations the unet model will be executed throughout the reverse diffusion process")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
@ -0,0 +1,5 @@
|
|||||||
|
coremltools
|
||||||
|
diffusers[torch]
|
||||||
|
torch
|
||||||
|
transformers
|
||||||
|
scipy
|
@ -0,0 +1,36 @@
|
|||||||
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
from python_coreml_stable_diffusion._version import __version__
|
||||||
|
|
||||||
|
with open('README.md') as f:
|
||||||
|
readme = f.read()
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name='python_coreml_stable_diffusion',
|
||||||
|
version=__version__,
|
||||||
|
url='https://github.com/apple/ml-stable-diffusion',
|
||||||
|
description="Run Stable Diffusion on Apple Silicon with Core ML (Python and Swift)",
|
||||||
|
long_description=readme,
|
||||||
|
long_description_content_type='text/markdown',
|
||||||
|
author='Apple Inc.',
|
||||||
|
install_requires=[
|
||||||
|
"coremltools>=6.1",
|
||||||
|
"diffusers[torch]",
|
||||||
|
"torch",
|
||||||
|
"transformers",
|
||||||
|
"scipy",
|
||||||
|
],
|
||||||
|
packages=find_packages(),
|
||||||
|
classifiers=[
|
||||||
|
"Development Status :: 4 - Beta",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"Operating System :: MacOS :: MacOS X",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.7",
|
||||||
|
"Programming Language :: Python :: 3.8",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
|
"Topic :: Artificial Intelligence",
|
||||||
|
"Topic :: Scientific/Engineering",
|
||||||
|
"Topic :: Software Development",
|
||||||
|
],
|
||||||
|
)
|
@ -0,0 +1,109 @@
|
|||||||
|
// For licensing see accompanying LICENSE.md file.
|
||||||
|
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import CoreML
|
||||||
|
import Accelerate
|
||||||
|
|
||||||
|
/// A decoder model which produces RGB images from latent samples
|
||||||
|
public struct Decoder {
|
||||||
|
|
||||||
|
/// VAE decoder model
|
||||||
|
var model: MLModel
|
||||||
|
|
||||||
|
/// Create decoder from Core ML model
|
||||||
|
///
|
||||||
|
/// - Parameters
|
||||||
|
/// - model: Core ML model for VAE decoder
|
||||||
|
public init(model: MLModel) {
|
||||||
|
self.model = model
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Prediction queue
|
||||||
|
let queue = DispatchQueue(label: "decoder.predict")
|
||||||
|
|
||||||
|
/// Batch decode latent samples into images
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - latents: Batch of latent samples to decode
|
||||||
|
/// - Returns: decoded images
|
||||||
|
public func decode(_ latents: [MLShapedArray<Float32>]) throws -> [CGImage] {
|
||||||
|
|
||||||
|
// Form batch inputs for model
|
||||||
|
let inputs: [MLFeatureProvider] = try latents.map { sample in
|
||||||
|
// Reference pipeline scales the latent samples before decoding
|
||||||
|
let sampleScaled = MLShapedArray<Float32>(
|
||||||
|
scalars: sample.scalars.map { $0 / 0.18215 },
|
||||||
|
shape: sample.shape)
|
||||||
|
|
||||||
|
let dict = [inputName: MLMultiArray(sampleScaled)]
|
||||||
|
return try MLDictionaryFeatureProvider(dictionary: dict)
|
||||||
|
}
|
||||||
|
let batch = MLArrayBatchProvider(array: inputs)
|
||||||
|
|
||||||
|
// Batch predict with model
|
||||||
|
let results = try queue.sync { try model.predictions(fromBatch: batch) }
|
||||||
|
|
||||||
|
// Transform the outputs to CGImages
|
||||||
|
let images: [CGImage] = (0..<results.count).map { i in
|
||||||
|
let result = results.features(at: i)
|
||||||
|
let outputName = result.featureNames.first!
|
||||||
|
let output = result.featureValue(for: outputName)!.multiArrayValue!
|
||||||
|
|
||||||
|
return toRGBCGImage(MLShapedArray<Float32>(output))
|
||||||
|
}
|
||||||
|
|
||||||
|
return images
|
||||||
|
}
|
||||||
|
|
||||||
|
var inputName: String {
|
||||||
|
model.modelDescription.inputDescriptionsByName.first!.key
|
||||||
|
}
|
||||||
|
|
||||||
|
typealias PixelBufferPFx1 = vImage.PixelBuffer<vImage.PlanarF>
|
||||||
|
typealias PixelBufferP8x3 = vImage.PixelBuffer<vImage.Planar8x3>
|
||||||
|
typealias PixelBufferIFx3 = vImage.PixelBuffer<vImage.InterleavedFx3>
|
||||||
|
typealias PixelBufferI8x3 = vImage.PixelBuffer<vImage.Interleaved8x3>
|
||||||
|
|
||||||
|
func toRGBCGImage(_ array: MLShapedArray<Float32>) -> CGImage {
|
||||||
|
|
||||||
|
// array is [N,C,H,W], where C==3
|
||||||
|
let channelCount = array.shape[1]
|
||||||
|
assert(channelCount == 3,
|
||||||
|
"Decoding model output has \(channelCount) channels, expected 3")
|
||||||
|
let height = array.shape[2]
|
||||||
|
let width = array.shape[3]
|
||||||
|
|
||||||
|
// Normalize each channel into a float between 0 and 1.0
|
||||||
|
let floatChannels = (0..<channelCount).map { i in
|
||||||
|
|
||||||
|
// Normalized channel output
|
||||||
|
let cOut = PixelBufferPFx1(width: width, height:height)
|
||||||
|
|
||||||
|
// Reference this channel in the array and normalize
|
||||||
|
array[0][i].withUnsafeShapedBufferPointer { ptr, _, strides in
|
||||||
|
let cIn = PixelBufferPFx1(data: .init(mutating: ptr.baseAddress!),
|
||||||
|
width: width, height: height,
|
||||||
|
byteCountPerRow: strides[0]*4)
|
||||||
|
// Map [-1.0 1.0] -> [0.0 1.0]
|
||||||
|
cIn.multiply(by: 0.5, preBias: 1.0, postBias: 0.0, destination: cOut)
|
||||||
|
}
|
||||||
|
return cOut
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to interleaved and then to UInt8
|
||||||
|
let floatImage = PixelBufferIFx3(planarBuffers: floatChannels)
|
||||||
|
let uint8Image = PixelBufferI8x3(width: width, height: height)
|
||||||
|
floatImage.convert(to:uint8Image) // maps [0.0 1.0] -> [0 255] and clips
|
||||||
|
|
||||||
|
// Convert to uint8x3 to RGB CGImage (no alpha)
|
||||||
|
let bitmapInfo = CGBitmapInfo(rawValue: CGImageAlphaInfo.none.rawValue)
|
||||||
|
let cgImage = uint8Image.makeCGImage(cgImageFormat:
|
||||||
|
.init(bitsPerComponent: 8,
|
||||||
|
bitsPerPixel: 3*8,
|
||||||
|
colorSpace: CGColorSpaceCreateDeviceRGB(),
|
||||||
|
bitmapInfo: bitmapInfo)!)!
|
||||||
|
|
||||||
|
return cgImage
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,118 @@
|
|||||||
|
// For licensing see accompanying LICENSE.md file.
|
||||||
|
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import CoreML
|
||||||
|
|
||||||
|
/// A random source consistent with NumPy
|
||||||
|
///
|
||||||
|
/// This implementation matches:
|
||||||
|
/// [NumPy's older randomkit.c](https://github.com/numpy/numpy/blob/v1.0/numpy/random/mtrand/randomkit.c)
|
||||||
|
///
|
||||||
|
struct NumPyRandomSource: RandomNumberGenerator {
|
||||||
|
|
||||||
|
struct State {
|
||||||
|
var key = [UInt32](repeating: 0, count: 624)
|
||||||
|
var pos: Int = 0
|
||||||
|
var nextGauss: Double? = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var state: State
|
||||||
|
|
||||||
|
/// Initialize with a random seed
|
||||||
|
///
|
||||||
|
/// - Parameters
|
||||||
|
/// - seed: Seed for underlying Mersenne Twister 19937 generator
|
||||||
|
/// - Returns random source
|
||||||
|
init(seed: UInt32) {
|
||||||
|
state = .init()
|
||||||
|
var s = seed & 0xffffffff
|
||||||
|
for i in 0 ..< state.key.count {
|
||||||
|
state.key[i] = s
|
||||||
|
s = UInt32((UInt64(1812433253) * UInt64(s ^ (s >> 30)) + UInt64(i) + 1) & 0xffffffff)
|
||||||
|
}
|
||||||
|
state.pos = state.key.count
|
||||||
|
state.nextGauss = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate next UInt32 using fast 32bit Mersenne Twister
|
||||||
|
mutating func nextUInt32() -> UInt32 {
|
||||||
|
let n = 624
|
||||||
|
let m = 397
|
||||||
|
let matrixA: UInt64 = 0x9908b0df
|
||||||
|
let upperMask: UInt32 = 0x80000000
|
||||||
|
let lowerMask: UInt32 = 0x7fffffff
|
||||||
|
|
||||||
|
var y: UInt32
|
||||||
|
if state.pos == state.key.count {
|
||||||
|
for i in 0 ..< (n - m) {
|
||||||
|
y = (state.key[i] & upperMask) | (state.key[i + 1] & lowerMask)
|
||||||
|
state.key[i] = state.key[i + m] ^ (y >> 1) ^ UInt32((UInt64(~(y & 1)) + 1) & matrixA)
|
||||||
|
}
|
||||||
|
for i in (n - m) ..< (n - 1) {
|
||||||
|
y = (state.key[i] & upperMask) | (state.key[i + 1] & lowerMask)
|
||||||
|
state.key[i] = state.key[i + (m - n)] ^ (y >> 1) ^ UInt32((UInt64(~(y & 1)) + 1) & matrixA)
|
||||||
|
}
|
||||||
|
y = (state.key[n - 1] & upperMask) | (state.key[0] & lowerMask)
|
||||||
|
state.key[n - 1] = state.key[m - 1] ^ (y >> 1) ^ UInt32((UInt64(~(y & 1)) + 1) & matrixA)
|
||||||
|
state.pos = 0
|
||||||
|
}
|
||||||
|
y = state.key[state.pos]
|
||||||
|
state.pos += 1
|
||||||
|
|
||||||
|
y ^= (y >> 11)
|
||||||
|
y ^= (y << 7) & 0x9d2c5680
|
||||||
|
y ^= (y << 15) & 0xefc60000
|
||||||
|
y ^= (y >> 18)
|
||||||
|
|
||||||
|
return y
|
||||||
|
}
|
||||||
|
|
||||||
|
mutating func next() -> UInt64 {
|
||||||
|
let low = nextUInt32()
|
||||||
|
let high = nextUInt32()
|
||||||
|
return (UInt64(high) << 32) | UInt64(low)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate next random double value
|
||||||
|
mutating func nextDouble() -> Double {
|
||||||
|
let a = Double(nextUInt32() >> 5)
|
||||||
|
let b = Double(nextUInt32() >> 6)
|
||||||
|
return (a * 67108864.0 + b) / 9007199254740992.0
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate next random value from a standard normal
|
||||||
|
mutating func nextGauss() -> Double {
|
||||||
|
if let nextGauss = state.nextGauss {
|
||||||
|
state.nextGauss = nil
|
||||||
|
return nextGauss
|
||||||
|
}
|
||||||
|
var x1, x2, r2: Double
|
||||||
|
repeat {
|
||||||
|
x1 = 2.0 * nextDouble() - 1.0
|
||||||
|
x2 = 2.0 * nextDouble() - 1.0
|
||||||
|
r2 = x1 * x1 + x2 * x2
|
||||||
|
} while r2 >= 1.0 || r2 == 0.0
|
||||||
|
|
||||||
|
// Box-Muller transform
|
||||||
|
let f = sqrt(-2.0 * log(r2) / r2)
|
||||||
|
state.nextGauss = f * x1
|
||||||
|
return f * x2
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a random value from a normal distribution with given mean and standard deviation.
|
||||||
|
mutating func nextNormal(mean: Double = 0.0, stdev: Double = 1.0) -> Double {
|
||||||
|
nextGauss() * stdev + mean
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates an array of random values from a normal distribution with given mean and standard deviation.
|
||||||
|
mutating func normalArray(count: Int, mean: Double = 0.0, stdev: Double = 1.0) -> [Double] {
|
||||||
|
(0 ..< count).map { _ in nextNormal(mean: mean, stdev: stdev) }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generate a shaped array with scalars from a normal distribution with given mean and standard deviation.
|
||||||
|
mutating func normalShapedArray(_ shape: [Int], mean: Double = 0.0, stdev: Double = 1.0) -> MLShapedArray<Double> {
|
||||||
|
let count = shape.reduce(1, *)
|
||||||
|
return .init(scalars: normalArray(count: count, mean: mean, stdev: stdev), shape: shape)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,154 @@
|
|||||||
|
// For licensing see accompanying LICENSE.md file.
|
||||||
|
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import CoreML
|
||||||
|
import Accelerate
|
||||||
|
|
||||||
|
/// Image safety checking model
|
||||||
|
public struct SafetyChecker {
|
||||||
|
|
||||||
|
/// Safety checking Core ML model
|
||||||
|
var model: MLModel
|
||||||
|
|
||||||
|
/// Creates safety checker
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - model: Underlying model which performs the safety check
|
||||||
|
/// - Returns: Safety checker ready from checks
|
||||||
|
public init(model: MLModel) {
|
||||||
|
self.model = model
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Prediction queue
|
||||||
|
let queue = DispatchQueue(label: "safetycheker.predict")
|
||||||
|
|
||||||
|
typealias PixelBufferPFx1 = vImage.PixelBuffer<vImage.PlanarF>
|
||||||
|
typealias PixelBufferP8x1 = vImage.PixelBuffer<vImage.Planar8>
|
||||||
|
typealias PixelBufferPFx3 = vImage.PixelBuffer<vImage.PlanarFx3>
|
||||||
|
typealias PixelBufferP8x3 = vImage.PixelBuffer<vImage.Planar8x3>
|
||||||
|
typealias PixelBufferIFx3 = vImage.PixelBuffer<vImage.InterleavedFx3>
|
||||||
|
typealias PixelBufferI8x3 = vImage.PixelBuffer<vImage.Interleaved8x3>
|
||||||
|
typealias PixelBufferI8x4 = vImage.PixelBuffer<vImage.Interleaved8x4>
|
||||||
|
|
||||||
|
enum SafetyCheckError: Error {
|
||||||
|
case imageResizeFailure
|
||||||
|
case imageToFloatFailure
|
||||||
|
case modelInputFailure
|
||||||
|
case unexpectedModelOutput
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if image is safe
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - image: Image to check
|
||||||
|
/// - Returns: Whether the model considers the image to be safe
|
||||||
|
public func isSafe(_ image: CGImage) throws -> Bool {
|
||||||
|
|
||||||
|
let inputName = "clip_input"
|
||||||
|
let adjustmentName = "adjustment"
|
||||||
|
let imagesNames = "images"
|
||||||
|
|
||||||
|
let inputInfo = model.modelDescription.inputDescriptionsByName
|
||||||
|
let inputShape = inputInfo[inputName]!.multiArrayConstraint!.shape
|
||||||
|
|
||||||
|
let width = inputShape[2].intValue
|
||||||
|
let height = inputShape[3].intValue
|
||||||
|
|
||||||
|
let resizedImage = try resizeToRGBA(image, width: width, height: height)
|
||||||
|
|
||||||
|
let bufferP8x3 = try getRGBPlanes(of: resizedImage)
|
||||||
|
|
||||||
|
let arrayPFx3 = normalizeToFloatShapedArray(bufferP8x3)
|
||||||
|
|
||||||
|
guard let input = try? MLDictionaryFeatureProvider(
|
||||||
|
dictionary:[
|
||||||
|
// Input that is analyzed for safety
|
||||||
|
inputName : MLMultiArray(arrayPFx3),
|
||||||
|
// No adjustment, use default threshold
|
||||||
|
adjustmentName : MLMultiArray(MLShapedArray<Float32>(scalars: [0], shape: [1])),
|
||||||
|
// Supplying dummy images to be filtered (will be ignored)
|
||||||
|
imagesNames : MLMultiArray(shape:[1, 512, 512, 3], dataType: .float16)
|
||||||
|
]
|
||||||
|
) else {
|
||||||
|
throw SafetyCheckError.modelInputFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = try queue.sync { try model.prediction(from: input) }
|
||||||
|
|
||||||
|
let output = result.featureValue(for: "has_nsfw_concepts")
|
||||||
|
|
||||||
|
guard let unsafe = output?.multiArrayValue?[0].boolValue else {
|
||||||
|
throw SafetyCheckError.unexpectedModelOutput
|
||||||
|
}
|
||||||
|
|
||||||
|
return !unsafe
|
||||||
|
}
|
||||||
|
|
||||||
|
func resizeToRGBA(_ image: CGImage,
|
||||||
|
width: Int, height: Int) throws -> CGImage {
|
||||||
|
|
||||||
|
guard let context = CGContext(
|
||||||
|
data: nil,
|
||||||
|
width: width,
|
||||||
|
height: height,
|
||||||
|
bitsPerComponent: 8,
|
||||||
|
bytesPerRow: width*4,
|
||||||
|
space: CGColorSpaceCreateDeviceRGB(),
|
||||||
|
bitmapInfo: CGImageAlphaInfo.noneSkipLast.rawValue) else {
|
||||||
|
throw SafetyCheckError.imageResizeFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
context.interpolationQuality = .high
|
||||||
|
context.draw(image, in: CGRect(x: 0, y: 0, width: width, height: height))
|
||||||
|
guard let resizedImage = context.makeImage() else {
|
||||||
|
throw SafetyCheckError.imageResizeFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
return resizedImage
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRGBPlanes(of rgbaImage: CGImage) throws -> PixelBufferP8x3 {
|
||||||
|
// Reference as interleaved 8 bit vImage PixelBuffer
|
||||||
|
var emptyFormat = vImage_CGImageFormat()
|
||||||
|
guard let bufferI8x4 = try? PixelBufferI8x4(
|
||||||
|
cgImage: rgbaImage,
|
||||||
|
cgImageFormat:&emptyFormat) else {
|
||||||
|
throw SafetyCheckError.imageToFloatFailure
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drop the alpha channel, keeping RGB
|
||||||
|
let bufferI8x3 = PixelBufferI8x3(width: rgbaImage.width, height:rgbaImage.height)
|
||||||
|
bufferI8x4.convert(to: bufferI8x3, channelOrdering: .RGBA)
|
||||||
|
|
||||||
|
// De-interleave into 8-bit planes
|
||||||
|
return PixelBufferP8x3(interleavedBuffer: bufferI8x3)
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeToFloatShapedArray(_ bufferP8x3: PixelBufferP8x3) -> MLShapedArray<Float32> {
|
||||||
|
let width = bufferP8x3.width
|
||||||
|
let height = bufferP8x3.height
|
||||||
|
|
||||||
|
let means = [0.485, 0.456, 0.406] as [Float]
|
||||||
|
let stds = [0.229, 0.224, 0.225] as [Float]
|
||||||
|
|
||||||
|
// Convert to normalized float 1x3xWxH input (plannar)
|
||||||
|
let arrayPFx3 = MLShapedArray<Float32>(repeating: 0.0, shape: [1, 3, width, height])
|
||||||
|
for c in 0..<3 {
|
||||||
|
arrayPFx3[0][c].withUnsafeShapedBufferPointer { ptr, _, strides in
|
||||||
|
let floatChannel = PixelBufferPFx1(data: .init(mutating: ptr.baseAddress!),
|
||||||
|
width: width, height: height,
|
||||||
|
byteCountPerRow: strides[0]*4)
|
||||||
|
|
||||||
|
bufferP8x3.withUnsafePixelBuffer(at: c) { uint8Channel in
|
||||||
|
uint8Channel.convert(to: floatChannel) // maps [0 255] -> [0 1]
|
||||||
|
floatChannel.multiply(by: 1.0/stds[c],
|
||||||
|
preBias: -means[c],
|
||||||
|
postBias: 0.0,
|
||||||
|
destination: floatChannel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return arrayPFx3
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,77 @@
|
|||||||
|
// For licensing see accompanying LICENSE.md file.
|
||||||
|
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
/// A utility for timing events and tracking time statistics
|
||||||
|
///
|
||||||
|
/// Typical usage
|
||||||
|
/// ```
|
||||||
|
/// let timer: SampleTimer
|
||||||
|
///
|
||||||
|
/// for i in 0...<iterationCount {
|
||||||
|
/// timer.start()
|
||||||
|
/// doStuff()
|
||||||
|
/// timer.stop()
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// print(String(format: "mean: %.2f, var: %.2f",
|
||||||
|
/// timer.mean, timer.variance))
|
||||||
|
/// ```
|
||||||
|
public final class SampleTimer: Codable {
|
||||||
|
var startTime: CFAbsoluteTime?
|
||||||
|
var sum: Double = 0.0
|
||||||
|
var sumOfSquares: Double = 0.0
|
||||||
|
var count = 0
|
||||||
|
var samples: [Double] = []
|
||||||
|
|
||||||
|
public init() {}
|
||||||
|
|
||||||
|
/// Start a sample, noting the current time
|
||||||
|
public func start() {
|
||||||
|
startTime = CFAbsoluteTimeGetCurrent()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop a sample and record the elapsed time
|
||||||
|
@discardableResult public func stop() -> Double {
|
||||||
|
guard let startTime = startTime else {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
let elapsed = CFAbsoluteTimeGetCurrent() - startTime
|
||||||
|
sum += elapsed
|
||||||
|
sumOfSquares += elapsed * elapsed
|
||||||
|
count += 1
|
||||||
|
samples.append(elapsed)
|
||||||
|
return elapsed
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Mean of all sampled times
|
||||||
|
public var mean: Double { sum / Double(count) }
|
||||||
|
|
||||||
|
/// Variance of all sampled times
|
||||||
|
public var variance: Double {
|
||||||
|
guard count > 1 else {
|
||||||
|
return 0.0
|
||||||
|
}
|
||||||
|
return sumOfSquares / Double(count - 1) - mean * mean
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Standard deviation of all sampled times
|
||||||
|
public var stdev: Double { variance.squareRoot() }
|
||||||
|
|
||||||
|
/// Median of all sampled times
|
||||||
|
public var median: Double {
|
||||||
|
let sorted = samples.sorted()
|
||||||
|
let (q, r) = sorted.count.quotientAndRemainder(dividingBy: 2)
|
||||||
|
if r == 0 {
|
||||||
|
return (sorted[q] + sorted[q - 1]) / 2.0
|
||||||
|
} else {
|
||||||
|
return Double(sorted[q])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public var allSamples: [Double] {
|
||||||
|
samples
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,68 @@
|
|||||||
|
// For licensing see accompanying LICENSE.md file.
|
||||||
|
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import CoreML
|
||||||
|
|
||||||
|
public extension StableDiffusionPipeline {
|
||||||
|
|
||||||
|
/// Create stable diffusion pipeline using model resources at a
|
||||||
|
/// specified URL
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - baseURL: URL pointing to directory holding all model
|
||||||
|
/// and tokenization resources
|
||||||
|
/// - configuration: The configuration to load model resources with
|
||||||
|
/// - disableSafety: Load time disable of safety to save memory
|
||||||
|
/// - Returns:
|
||||||
|
/// Pipeline ready for image generation if all necessary resources loaded
|
||||||
|
init(resourcesAt baseURL: URL,
|
||||||
|
configuration config: MLModelConfiguration = .init(),
|
||||||
|
disableSafety: Bool = false) throws {
|
||||||
|
|
||||||
|
/// Expect URL of each resource
|
||||||
|
let textEncoderURL = baseURL.appending(path: "TextEncoder.mlmodelc")
|
||||||
|
let unetURL = baseURL.appending(path: "Unet.mlmodelc")
|
||||||
|
let unetChunk1URL = baseURL.appending(path: "UnetChunk1.mlmodelc")
|
||||||
|
let unetChunk2URL = baseURL.appending(path: "UnetChunk2.mlmodelc")
|
||||||
|
let decoderURL = baseURL.appending(path: "VAEDecoder.mlmodelc")
|
||||||
|
let safetyCheckerURL = baseURL.appending(path: "SafetyChecker.mlmodelc")
|
||||||
|
let vocabURL = baseURL.appending(path: "vocab.json")
|
||||||
|
let mergesURL = baseURL.appending(path: "merges.txt")
|
||||||
|
|
||||||
|
// Text tokenizer and encoder
|
||||||
|
let tokenizer = try BPETokenizer(mergesAt: mergesURL, vocabularyAt: vocabURL)
|
||||||
|
let textEncoderModel = try MLModel(contentsOf: textEncoderURL, configuration: config)
|
||||||
|
let textEncoder = TextEncoder(tokenizer: tokenizer, model:textEncoderModel )
|
||||||
|
|
||||||
|
// Unet model
|
||||||
|
let unet: Unet
|
||||||
|
if FileManager.default.fileExists(atPath: unetChunk1URL.path) &&
|
||||||
|
FileManager.default.fileExists(atPath: unetChunk2URL.path) {
|
||||||
|
let chunk1 = try MLModel(contentsOf: unetChunk1URL, configuration: config)
|
||||||
|
let chunk2 = try MLModel(contentsOf: unetChunk2URL, configuration: config)
|
||||||
|
unet = Unet(chunks: [chunk1, chunk2])
|
||||||
|
} else {
|
||||||
|
let unetModel = try MLModel(contentsOf: unetURL, configuration: config)
|
||||||
|
unet = Unet(model: unetModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Image Decoder
|
||||||
|
let decoderModel = try MLModel(contentsOf: decoderURL, configuration: config)
|
||||||
|
let decoder = Decoder(model: decoderModel)
|
||||||
|
|
||||||
|
// Optional safety checker
|
||||||
|
var safetyChecker: SafetyChecker? = nil
|
||||||
|
if !disableSafety &&
|
||||||
|
FileManager.default.fileExists(atPath: safetyCheckerURL.path) {
|
||||||
|
let checkerModel = try MLModel(contentsOf: safetyCheckerURL, configuration: config)
|
||||||
|
safetyChecker = SafetyChecker(model: checkerModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct pipelien
|
||||||
|
self.init(textEncoder: textEncoder,
|
||||||
|
unet: unet,
|
||||||
|
decoder: decoder,
|
||||||
|
safetyChecker: safetyChecker)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,233 @@
|
|||||||
|
// For licensing see accompanying LICENSE.md file.
|
||||||
|
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import CoreML
|
||||||
|
import Accelerate
|
||||||
|
import CoreGraphics
|
||||||
|
|
||||||
|
/// A pipeline used to generate image samples from text input using stable diffusion
|
||||||
|
///
|
||||||
|
/// This implementation matches:
|
||||||
|
/// [Hugging Face Diffusers Pipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py)
|
||||||
|
public struct StableDiffusionPipeline {
|
||||||
|
|
||||||
|
/// Model to generate embeddings for tokenized input text
|
||||||
|
var textEncoder: TextEncoder
|
||||||
|
|
||||||
|
/// Model used to predict noise residuals given an input, diffusion time step, and conditional embedding
|
||||||
|
var unet: Unet
|
||||||
|
|
||||||
|
/// Model used to generate final image from latent diffusion process
|
||||||
|
var decoder: Decoder
|
||||||
|
|
||||||
|
/// Optional model for checking safety of generated image
|
||||||
|
var safetyChecker: SafetyChecker? = nil
|
||||||
|
|
||||||
|
/// Controls the influence of the text prompt on sampling process (0=random images)
|
||||||
|
var guidanceScale: Float = 7.5
|
||||||
|
|
||||||
|
/// Reports whether this pipeline can perform safety checks
|
||||||
|
public var canSafetyCheck: Bool {
|
||||||
|
safetyChecker != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a pipeline using the specified models and tokenizer
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - textEncoder: Model for encoding tokenized text
|
||||||
|
/// - unet: Model for noise prediction on latent samples
|
||||||
|
/// - decoder: Model for decoding latent sample to image
|
||||||
|
/// - safetyChecker: Optional model for checking safety of generated images
|
||||||
|
/// - guidanceScale: Influence of the text prompt on generation process
|
||||||
|
/// - Returns: Pipeline ready for image generation
|
||||||
|
public init(textEncoder: TextEncoder,
|
||||||
|
unet: Unet,
|
||||||
|
decoder: Decoder,
|
||||||
|
safetyChecker: SafetyChecker? = nil,
|
||||||
|
guidanceScale: Float = 7.5) {
|
||||||
|
self.textEncoder = textEncoder
|
||||||
|
self.unet = unet
|
||||||
|
self.decoder = decoder
|
||||||
|
self.safetyChecker = safetyChecker
|
||||||
|
self.guidanceScale = guidanceScale
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Text to image generation using stable diffusion
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - prompt: Text prompt to guide sampling
|
||||||
|
/// - stepCount: Number of inference steps to perform
|
||||||
|
/// - imageCount: Number of samples/images to generate for the input prompt
|
||||||
|
/// - seed: Random seed which
|
||||||
|
/// - disableSafety: Safety checks are only performed if `self.canSafetyCheck && !disableSafety`
|
||||||
|
/// - progressHandler: Callback to perform after each step, stops on receiving false response
|
||||||
|
/// - Returns: An array of `imageCount` optional images.
|
||||||
|
/// The images will be nil if safety checks were performed and found the result to be un-safe
|
||||||
|
public func generateImages(
|
||||||
|
prompt: String,
|
||||||
|
imageCount: Int = 1,
|
||||||
|
stepCount: Int = 50,
|
||||||
|
seed: Int = 0,
|
||||||
|
disableSafety: Bool = false,
|
||||||
|
progressHandler: (Progress) -> Bool = { _ in true }
|
||||||
|
) throws -> [CGImage?] {
|
||||||
|
|
||||||
|
// Encode the input prompt as well as a blank unconditioned input
|
||||||
|
let promptEmbedding = try textEncoder.encode(prompt)
|
||||||
|
let blankEmbedding = try textEncoder.encode("")
|
||||||
|
|
||||||
|
// Convert to Unet hidden state representation
|
||||||
|
let concatEmbedding = MLShapedArray<Float32>(
|
||||||
|
concatenating: [blankEmbedding, promptEmbedding],
|
||||||
|
alongAxis: 0
|
||||||
|
)
|
||||||
|
|
||||||
|
let hiddenStates = toHiddenStates(concatEmbedding)
|
||||||
|
|
||||||
|
/// Setup schedulers
|
||||||
|
let scheduler = (0..<imageCount).map { _ in Scheduler(stepCount: stepCount) }
|
||||||
|
let stdev = scheduler[0].initNoiseSigma
|
||||||
|
|
||||||
|
// Generate random latent samples from specified seed
|
||||||
|
var latents = generateLatentSamples(imageCount, stdev: stdev, seed: seed)
|
||||||
|
|
||||||
|
// De-noising loop
|
||||||
|
for (step,t) in scheduler[0].timeSteps.enumerated() {
|
||||||
|
|
||||||
|
// Expand the latents for classifier-free guidance
|
||||||
|
// and input to the Unet noise prediction model
|
||||||
|
let latentUnetInput = latents.map {
|
||||||
|
MLShapedArray<Float32>(concatenating: [$0, $0], alongAxis: 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Predict noise residuals from latent samples
|
||||||
|
// and current time step conditioned on hidden states
|
||||||
|
var noise = try unet.predictNoise(
|
||||||
|
latents: latentUnetInput,
|
||||||
|
timeStep: t,
|
||||||
|
hiddenStates: hiddenStates
|
||||||
|
)
|
||||||
|
|
||||||
|
noise = performGuidance(noise)
|
||||||
|
|
||||||
|
// Have the scheduler compute the previous (t-1) latent
|
||||||
|
// sample given the predicted noise and current sample
|
||||||
|
for i in 0..<imageCount {
|
||||||
|
latents[i] = scheduler[i].step(
|
||||||
|
output: noise[i],
|
||||||
|
timeStep: t,
|
||||||
|
sample: latents[i]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Report progress
|
||||||
|
let progress = Progress(
|
||||||
|
pipeline: self,
|
||||||
|
prompt: prompt,
|
||||||
|
step: step,
|
||||||
|
stepCount: stepCount,
|
||||||
|
currentLatentSamples: latents,
|
||||||
|
isSafetyEnabled: canSafetyCheck && !disableSafety
|
||||||
|
)
|
||||||
|
if !progressHandler(progress) {
|
||||||
|
// Stop if requested by handler
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode the latent samples to images
|
||||||
|
return try decodeToImages(latents, disableSafety: disableSafety)
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateLatentSamples(_ count: Int, stdev: Float, seed: Int) -> [MLShapedArray<Float32>] {
|
||||||
|
var sampleShape = unet.latentSampleShape
|
||||||
|
sampleShape[0] = 1
|
||||||
|
|
||||||
|
var random = NumPyRandomSource(seed: UInt32(seed))
|
||||||
|
let samples = (0..<count).map { _ in
|
||||||
|
MLShapedArray<Float32>(
|
||||||
|
converting: random.normalShapedArray(sampleShape, mean: 0.0, stdev: Double(stdev)))
|
||||||
|
}
|
||||||
|
return samples
|
||||||
|
}
|
||||||
|
|
||||||
|
func toHiddenStates(_ embedding: MLShapedArray<Float32>) -> MLShapedArray<Float32> {
|
||||||
|
// Unoptimized manual transpose [0, 2, None, 1]
|
||||||
|
// e.g. From [2, 77, 768] to [2, 768, 1, 77]
|
||||||
|
let fromShape = embedding.shape
|
||||||
|
let stateShape = [fromShape[0],fromShape[2], 1, fromShape[1]]
|
||||||
|
var states = MLShapedArray<Float32>(repeating: 0.0, shape: stateShape)
|
||||||
|
for i0 in 0..<fromShape[0] {
|
||||||
|
for i1 in 0..<fromShape[1] {
|
||||||
|
for i2 in 0..<fromShape[2] {
|
||||||
|
states[scalarAt:i0,i2,0,i1] = embedding[scalarAt:i0, i1, i2]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return states
|
||||||
|
}
|
||||||
|
|
||||||
|
func performGuidance(_ noise: [MLShapedArray<Float32>]) -> [MLShapedArray<Float32>] {
|
||||||
|
noise.map { performGuidance($0) }
|
||||||
|
}
|
||||||
|
|
||||||
|
func performGuidance(_ noise: MLShapedArray<Float32>) -> MLShapedArray<Float32> {
|
||||||
|
|
||||||
|
let blankNoiseScalars = noise[0].scalars
|
||||||
|
let textNoiseScalars = noise[1].scalars
|
||||||
|
|
||||||
|
var resultScalars = blankNoiseScalars
|
||||||
|
|
||||||
|
for i in 0..<resultScalars.count {
|
||||||
|
// unconditioned + guidance*(text - unconditioned)
|
||||||
|
resultScalars[i] += guidanceScale*(textNoiseScalars[i]-blankNoiseScalars[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
var shape = noise.shape
|
||||||
|
shape[0] = 1
|
||||||
|
return MLShapedArray<Float32>(scalars: resultScalars, shape: shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeToImages(_ latents: [MLShapedArray<Float32>],
|
||||||
|
disableSafety: Bool) throws -> [CGImage?] {
|
||||||
|
|
||||||
|
|
||||||
|
let images = try decoder.decode(latents)
|
||||||
|
|
||||||
|
// If safety is disabled return what was decoded
|
||||||
|
if disableSafety {
|
||||||
|
return images
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there is no safety checker return what was decoded
|
||||||
|
guard let safetyChecker = safetyChecker else {
|
||||||
|
return images
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise change images which are not safe to nil
|
||||||
|
let safeImages = try images.map { image in
|
||||||
|
try safetyChecker.isSafe(image) ? image : nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return safeImages
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
extension StableDiffusionPipeline {
|
||||||
|
/// Sampling progress details
|
||||||
|
public struct Progress {
|
||||||
|
public let pipeline: StableDiffusionPipeline
|
||||||
|
public let prompt: String
|
||||||
|
public let step: Int
|
||||||
|
public let stepCount: Int
|
||||||
|
public let currentLatentSamples: [MLShapedArray<Float32>]
|
||||||
|
public let isSafetyEnabled: Bool
|
||||||
|
public var currentImages: [CGImage?] {
|
||||||
|
try! pipeline.decodeToImages(
|
||||||
|
currentLatentSamples,
|
||||||
|
disableSafety: !isSafetyEnabled)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,76 @@
|
|||||||
|
// For licensing see accompanying LICENSE.md file.
|
||||||
|
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import CoreML
|
||||||
|
|
||||||
|
/// A model for encoding text
|
||||||
|
public struct TextEncoder {
|
||||||
|
|
||||||
|
/// Text tokenizer
|
||||||
|
var tokenizer: BPETokenizer
|
||||||
|
|
||||||
|
/// Embedding model
|
||||||
|
var model: MLModel
|
||||||
|
|
||||||
|
/// Creates text encoder which embeds a tokenized string
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - tokenizer: Tokenizer for input text
|
||||||
|
/// - model: Model for encoding tokenized text
|
||||||
|
public init(tokenizer: BPETokenizer, model: MLModel) {
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.model = model
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Encode input text/string
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - text: Input text to be tokenized and then embedded
|
||||||
|
/// - Returns: Embedding representing the input text
|
||||||
|
public func encode(_ text: String) throws -> MLShapedArray<Float32> {
|
||||||
|
|
||||||
|
// Get models expected input length
|
||||||
|
let inputLength = inputShape.last!
|
||||||
|
|
||||||
|
// Tokenize, padding to the expected length
|
||||||
|
var (tokens, ids) = tokenizer.tokenize(input: text, minCount: inputLength)
|
||||||
|
|
||||||
|
// Truncate if necessary
|
||||||
|
if ids.count > inputLength {
|
||||||
|
tokens = tokens.dropLast(tokens.count - inputLength)
|
||||||
|
ids = ids.dropLast(ids.count - inputLength)
|
||||||
|
let truncated = tokenizer.decode(tokens: tokens)
|
||||||
|
print("Needed to truncate input '\(text)' to '\(truncated)'")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the model to generate the embedding
|
||||||
|
return try encode(ids: ids)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Prediction queue
|
||||||
|
let queue = DispatchQueue(label: "textencoder.predict")
|
||||||
|
|
||||||
|
func encode(ids: [Int]) throws -> MLShapedArray<Float32> {
|
||||||
|
let inputName = inputDescription.name
|
||||||
|
let inputShape = inputShape
|
||||||
|
|
||||||
|
let floatIds = ids.map { Float32($0) }
|
||||||
|
let inputArray = MLShapedArray<Float32>(scalars: floatIds, shape: inputShape)
|
||||||
|
let inputFeatures = try! MLDictionaryFeatureProvider(
|
||||||
|
dictionary: [inputName: MLMultiArray(inputArray)])
|
||||||
|
|
||||||
|
let result = try queue.sync { try model.prediction(from: inputFeatures) }
|
||||||
|
let embeddingFeature = result.featureValue(for: "last_hidden_state")
|
||||||
|
return MLShapedArray<Float32>(converting: embeddingFeature!.multiArrayValue!)
|
||||||
|
}
|
||||||
|
|
||||||
|
var inputDescription: MLFeatureDescription {
|
||||||
|
model.modelDescription.inputDescriptionsByName.first!.value
|
||||||
|
}
|
||||||
|
|
||||||
|
var inputShape: [Int] {
|
||||||
|
inputDescription.multiArrayConstraint!.shape.map { $0.intValue }
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,143 @@
|
|||||||
|
// For licensing see accompanying LICENSE.md file.
|
||||||
|
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import CoreML
|
||||||
|
|
||||||
|
/// U-Net noise prediction model for stable diffusion
|
||||||
|
public struct Unet {
|
||||||
|
|
||||||
|
/// Model used to predict noise residuals given an input, diffusion time step, and conditional embedding
|
||||||
|
///
|
||||||
|
/// It can be in the form of a single model or multiple stages
|
||||||
|
var models: [MLModel]
|
||||||
|
|
||||||
|
/// Creates a U-Net noise prediction model
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - model: U-Net held in single Core ML model
|
||||||
|
/// - Returns: Ready for prediction
|
||||||
|
public init(model: MLModel) {
|
||||||
|
self.models = [model]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a U-Net noise prediction model
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - chunks: U-Net held chunked into multiple Core ML models
|
||||||
|
/// - Returns: Ready for prediction
|
||||||
|
public init(chunks: [MLModel]) {
|
||||||
|
self.models = chunks
|
||||||
|
}
|
||||||
|
|
||||||
|
var latentSampleDescription: MLFeatureDescription {
|
||||||
|
models.first!.modelDescription.inputDescriptionsByName["sample"]!
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The expected shape of the models latent sample input
|
||||||
|
public var latentSampleShape: [Int] {
|
||||||
|
latentSampleDescription.multiArrayConstraint!.shape.map { $0.intValue }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Batch prediction noise from latent samples
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - latents: Batch of latent samples in an array
|
||||||
|
/// - timeStep: Current diffusion timestep
|
||||||
|
/// - hiddenStates: Hidden state to condition on
|
||||||
|
/// - Returns: Array of predicted noise residuals
|
||||||
|
func predictNoise(
|
||||||
|
latents: [MLShapedArray<Float32>],
|
||||||
|
timeStep: Int,
|
||||||
|
hiddenStates: MLShapedArray<Float32>
|
||||||
|
) throws -> [MLShapedArray<Float32>] {
|
||||||
|
|
||||||
|
// Match time step batch dimension to the model / latent samples
|
||||||
|
let t = MLShapedArray<Float32>(scalars:[Float(timeStep), Float(timeStep)],shape:[2])
|
||||||
|
|
||||||
|
// Form batch input to model
|
||||||
|
let inputs = try latents.map {
|
||||||
|
let dict: [String: Any] = [
|
||||||
|
"sample" : MLMultiArray($0),
|
||||||
|
"timestep" : MLMultiArray(t),
|
||||||
|
"encoder_hidden_states": MLMultiArray(hiddenStates)
|
||||||
|
]
|
||||||
|
return try MLDictionaryFeatureProvider(dictionary: dict)
|
||||||
|
}
|
||||||
|
let batch = MLArrayBatchProvider(array: inputs)
|
||||||
|
|
||||||
|
// Make predictions
|
||||||
|
let results = try predictions(from: batch)
|
||||||
|
|
||||||
|
// Pull out the results in Float32 format
|
||||||
|
let noise = (0..<results.count).map { i in
|
||||||
|
|
||||||
|
let result = results.features(at: i)
|
||||||
|
let outputName = result.featureNames.first!
|
||||||
|
|
||||||
|
let outputNoise = result.featureValue(for: outputName)!.multiArrayValue!
|
||||||
|
|
||||||
|
// To conform to this func return type make sure we return float32
|
||||||
|
// Use the fact that the concatenating constructor for MLMultiArray
|
||||||
|
// can do type conversion:
|
||||||
|
let fp32Noise = MLMultiArray(
|
||||||
|
concatenating: [outputNoise],
|
||||||
|
axis: 0,
|
||||||
|
dataType: .float32
|
||||||
|
)
|
||||||
|
return MLShapedArray<Float32>(fp32Noise)
|
||||||
|
}
|
||||||
|
|
||||||
|
return noise
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Prediction queue
|
||||||
|
let queue = DispatchQueue(label: "unet.predict")
|
||||||
|
|
||||||
|
func predictions(from batch: MLBatchProvider) throws -> MLBatchProvider {
|
||||||
|
|
||||||
|
var results = try queue.sync {
|
||||||
|
try models.first!.predictions(fromBatch: batch)
|
||||||
|
}
|
||||||
|
|
||||||
|
if models.count == 1 {
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manual pipeline batch prediction
|
||||||
|
let inputs = batch.arrayOfFeatureValueDictionaries
|
||||||
|
for stage in models.dropFirst() {
|
||||||
|
|
||||||
|
// Combine the original inputs with the outputs of the last stage
|
||||||
|
let next = try results.arrayOfFeatureValueDictionaries
|
||||||
|
.enumerated().map { (index, dict) in
|
||||||
|
let nextDict = dict.merging(inputs[index]) { (out, _) in out }
|
||||||
|
return try MLDictionaryFeatureProvider(dictionary: nextDict)
|
||||||
|
}
|
||||||
|
let nextBatch = MLArrayBatchProvider(array: next)
|
||||||
|
|
||||||
|
// Predict
|
||||||
|
results = try queue.sync {
|
||||||
|
try stage.predictions(fromBatch: nextBatch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extension MLFeatureProvider {
|
||||||
|
var featureValueDictionary: [String : MLFeatureValue] {
|
||||||
|
self.featureNames.reduce(into: [String : MLFeatureValue]()) { result, name in
|
||||||
|
result[name] = self.featureValue(for: name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extension MLBatchProvider {
|
||||||
|
var arrayOfFeatureValueDictionaries: [[String : MLFeatureValue]] {
|
||||||
|
(0..<self.count).map {
|
||||||
|
self.features(at: $0).featureValueDictionary
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,34 @@
|
|||||||
|
// For licensing see accompanying LICENSE.md file.
|
||||||
|
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
extension BPETokenizer {
|
||||||
|
enum FileReadError: Error {
|
||||||
|
case invalidMergeFileLine(Int)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Read vocab.json file at URL into a dictionary mapping a String to its Int token id
|
||||||
|
static func readVocabulary(url: URL) throws -> [String: Int] {
|
||||||
|
let content = try Data(contentsOf: url)
|
||||||
|
return try JSONDecoder().decode([String: Int].self, from: content)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Read merges.txt file at URL into a dictionary mapping bigrams to the line number/rank/priority
|
||||||
|
static func readMerges(url: URL) throws -> [TokenPair: Int] {
|
||||||
|
let content = try String(contentsOf: url)
|
||||||
|
let lines = content.split(separator: "\n")
|
||||||
|
|
||||||
|
let merges: [(TokenPair, Int)] = try lines.enumerated().compactMap { (index, line) in
|
||||||
|
if line.hasPrefix("#") {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
let pair = line.split(separator: " ")
|
||||||
|
if pair.count != 2 {
|
||||||
|
throw FileReadError.invalidMergeFileLine(index+1)
|
||||||
|
}
|
||||||
|
return (TokenPair(String(pair[0]), String(pair[1])),index)
|
||||||
|
}
|
||||||
|
return [TokenPair : Int](uniqueKeysWithValues: merges)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,181 @@
|
|||||||
|
// For licensing see accompanying LICENSE.md file.
|
||||||
|
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
/// A tokenizer based on byte pair encoding.
|
||||||
|
public struct BPETokenizer {
|
||||||
|
/// A dictionary that maps pairs of tokens to the rank/order of the merge.
|
||||||
|
let merges: [TokenPair : Int]
|
||||||
|
|
||||||
|
/// A dictionary from of tokens to identifiers.
|
||||||
|
let vocabulary: [String: Int]
|
||||||
|
|
||||||
|
/// The start token.
|
||||||
|
let startToken: String = "<|startoftext|>"
|
||||||
|
|
||||||
|
/// The end token.
|
||||||
|
let endToken: String = "<|endoftext|>"
|
||||||
|
|
||||||
|
/// The token used for padding
|
||||||
|
let padToken: String = "<|endoftext|>"
|
||||||
|
|
||||||
|
/// The unknown token.
|
||||||
|
let unknownToken: String = "<|endoftext|>"
|
||||||
|
|
||||||
|
var unknownTokenID: Int {
|
||||||
|
vocabulary[unknownToken, default: 0]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a tokenizer.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - merges: A dictionary that maps pairs of tokens to the rank/order of the merge.
|
||||||
|
/// - vocabulary: A dictionary from of tokens to identifiers.
|
||||||
|
public init(merges: [TokenPair: Int], vocabulary: [String: Int]) {
|
||||||
|
self.merges = merges
|
||||||
|
self.vocabulary = vocabulary
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a tokenizer by loading merges and vocabulary from URLs.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - mergesURL: The URL of a text file containing merges.
|
||||||
|
/// - vocabularyURL: The URL of a JSON file containing the vocabulary.
|
||||||
|
public init(mergesAt mergesURL: URL, vocabularyAt vocabularyURL: URL) throws {
|
||||||
|
self.merges = try Self.readMerges(url: mergesURL)
|
||||||
|
self.vocabulary = try! Self.readVocabulary(url: vocabularyURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tokenizes an input string.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - input: A string.
|
||||||
|
/// - minCount: The minimum number of tokens to return.
|
||||||
|
/// - Returns: An array of tokens and an array of token identifiers.
|
||||||
|
public func tokenize(input: String, minCount: Int? = nil) -> (tokens: [String], tokenIDs: [Int]) {
|
||||||
|
var tokens: [String] = []
|
||||||
|
|
||||||
|
tokens.append(startToken)
|
||||||
|
tokens.append(contentsOf: encode(input: input))
|
||||||
|
tokens.append(endToken)
|
||||||
|
|
||||||
|
// Pad if there was a min length specified
|
||||||
|
if let minLen = minCount, minLen > tokens.count {
|
||||||
|
tokens.append(contentsOf: repeatElement(padToken, count: minLen - tokens.count))
|
||||||
|
}
|
||||||
|
|
||||||
|
let ids = tokens.map({ vocabulary[$0, default: unknownTokenID] })
|
||||||
|
return (tokens: tokens, tokenIDs: ids)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the token identifier for a token.
|
||||||
|
public func tokenID(for token: String) -> Int? {
|
||||||
|
vocabulary[token]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the token for a token identifier.
|
||||||
|
public func token(id: Int) -> String? {
|
||||||
|
vocabulary.first(where: { $0.value == id })?.key
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decodes a sequence of tokens into a fully formed string
|
||||||
|
public func decode(tokens: [String]) -> String {
|
||||||
|
String(tokens.joined())
|
||||||
|
.replacingOccurrences(of: "</w>", with: " ")
|
||||||
|
.replacingOccurrences(of: startToken, with: "")
|
||||||
|
.replacingOccurrences(of: endToken, with: "")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Encode an input string to a sequence of tokens
|
||||||
|
func encode(input: String) -> [String] {
|
||||||
|
let normalized = input.trimmingCharacters(in: .whitespacesAndNewlines).lowercased()
|
||||||
|
let words = normalized.split(separator: " ")
|
||||||
|
return words.flatMap({ encode(word: $0) })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Encode a single word into a sequence of tokens
|
||||||
|
func encode(word: Substring) -> [String] {
|
||||||
|
var tokens = word.map { String($0) }
|
||||||
|
if let last = tokens.indices.last {
|
||||||
|
tokens[last] = tokens[last] + "</w>"
|
||||||
|
}
|
||||||
|
|
||||||
|
while true {
|
||||||
|
let pairs = pairs(for: tokens)
|
||||||
|
let canMerge = pairs.filter { merges[$0] != nil }
|
||||||
|
|
||||||
|
if canMerge.isEmpty {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// If multiple merges are found, use the one with the lowest rank
|
||||||
|
let shouldMerge = canMerge.min { merges[$0]! < merges[$1]! }!
|
||||||
|
tokens = update(tokens, merging: shouldMerge)
|
||||||
|
}
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the set of adjacent pairs / bigrams from a sequence of tokens
|
||||||
|
func pairs(for tokens: [String]) -> Set<TokenPair> {
|
||||||
|
guard tokens.count > 1 else {
|
||||||
|
return Set()
|
||||||
|
}
|
||||||
|
|
||||||
|
var pairs = Set<TokenPair>(minimumCapacity: tokens.count - 1)
|
||||||
|
var prev = tokens.first!
|
||||||
|
for current in tokens.dropFirst() {
|
||||||
|
pairs.insert(TokenPair(prev, current))
|
||||||
|
prev = current
|
||||||
|
}
|
||||||
|
return pairs
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Update the sequence of tokens by greedily merging instance of a specific bigram
|
||||||
|
func update(_ tokens: [String], merging bigram: TokenPair) -> [String] {
|
||||||
|
guard tokens.count > 1 else {
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
|
||||||
|
var newTokens = [String]()
|
||||||
|
newTokens.reserveCapacity(tokens.count - 1)
|
||||||
|
|
||||||
|
var index = 0
|
||||||
|
while index < tokens.count {
|
||||||
|
let remainingTokens = tokens[index...]
|
||||||
|
if let startMatchIndex = remainingTokens.firstIndex(of: bigram.first) {
|
||||||
|
// Found a possible match, append everything before it
|
||||||
|
newTokens.append(contentsOf: tokens[index..<startMatchIndex])
|
||||||
|
|
||||||
|
if index < tokens.count - 1 && tokens[startMatchIndex + 1] == bigram.second {
|
||||||
|
// Full match, merge
|
||||||
|
newTokens.append(bigram.first + bigram.second)
|
||||||
|
index = startMatchIndex + 2
|
||||||
|
} else {
|
||||||
|
// Only matched the first, no merge
|
||||||
|
newTokens.append(bigram.first)
|
||||||
|
index = startMatchIndex + 1
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Didn't find any more matches, append the rest unmerged
|
||||||
|
newTokens.append(contentsOf: remainingTokens)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return newTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extension BPETokenizer {
|
||||||
|
|
||||||
|
/// A hashable tuple of strings
|
||||||
|
public struct TokenPair: Hashable {
|
||||||
|
let first: String
|
||||||
|
let second: String
|
||||||
|
|
||||||
|
init(_ first: String, _ second: String) {
|
||||||
|
self.first = first
|
||||||
|
self.second = second
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,186 @@
|
|||||||
|
// For licensing see accompanying LICENSE.md file.
|
||||||
|
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
import ArgumentParser
|
||||||
|
import CoreGraphics
|
||||||
|
import CoreML
|
||||||
|
import Foundation
|
||||||
|
import StableDiffusion
|
||||||
|
import UniformTypeIdentifiers
|
||||||
|
|
||||||
|
struct StableDiffusionSample: ParsableCommand {
|
||||||
|
|
||||||
|
static let configuration = CommandConfiguration(
|
||||||
|
abstract: "Run stable diffusion to generate images guided by a text prompt",
|
||||||
|
version: "0.1"
|
||||||
|
)
|
||||||
|
|
||||||
|
@Argument(help: "Input string prompt")
|
||||||
|
var prompt: String
|
||||||
|
|
||||||
|
@Option(
|
||||||
|
help: ArgumentHelp(
|
||||||
|
"Path to stable diffusion resources.",
|
||||||
|
discussion: "The resource directory should contain\n" +
|
||||||
|
" - *compiled* models: {TextEncoder,Unet,VAEDecoder}.mlmodelc\n" +
|
||||||
|
" - tokenizer info: vocab.json, merges.txt",
|
||||||
|
valueName: "directory-path"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
var resourcePath: String = "./"
|
||||||
|
|
||||||
|
@Option(help: "Number of images to sample / generate")
|
||||||
|
var imageCount: Int = 1
|
||||||
|
|
||||||
|
@Option(help: "Number of diffusion steps to perform")
|
||||||
|
var stepCount: Int = 50
|
||||||
|
|
||||||
|
@Option(
|
||||||
|
help: ArgumentHelp(
|
||||||
|
"How often to save samples at intermediate steps",
|
||||||
|
discussion: "Set to 0 to only save the final sample"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
var saveEvery: Int = 0
|
||||||
|
|
||||||
|
@Option(help: "Output path")
|
||||||
|
var outputPath: String = "./"
|
||||||
|
|
||||||
|
@Option(help: "Random seed")
|
||||||
|
var seed: Int = 93
|
||||||
|
|
||||||
|
@Option(help: "Compute units to load model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine}")
|
||||||
|
var computeUnits: ComputeUnits = .all
|
||||||
|
|
||||||
|
@Flag(help: "Disable safety checking")
|
||||||
|
var disableSafety: Bool = false
|
||||||
|
|
||||||
|
mutating func run() throws {
|
||||||
|
guard FileManager.default.fileExists(atPath: resourcePath) else {
|
||||||
|
throw RunError.resources("Resource path does not exist \(resourcePath)")
|
||||||
|
}
|
||||||
|
|
||||||
|
let config = MLModelConfiguration()
|
||||||
|
config.computeUnits = computeUnits.asMLComputeUnits
|
||||||
|
let resourceURL = URL(filePath: resourcePath)
|
||||||
|
|
||||||
|
log("Loading resources and creating pipeline\n")
|
||||||
|
log("(Note: This can take a while the first time using these resources)\n")
|
||||||
|
let pipeline = try StableDiffusionPipeline(resourcesAt: resourceURL,
|
||||||
|
configuration: config,
|
||||||
|
disableSafety: disableSafety)
|
||||||
|
|
||||||
|
log("Sampling ...\n")
|
||||||
|
let sampleTimer = SampleTimer()
|
||||||
|
sampleTimer.start()
|
||||||
|
|
||||||
|
let images = try pipeline.generateImages(
|
||||||
|
prompt: prompt,
|
||||||
|
imageCount: imageCount,
|
||||||
|
stepCount: stepCount,
|
||||||
|
seed: seed
|
||||||
|
) { progress in
|
||||||
|
sampleTimer.stop()
|
||||||
|
handleProgress(progress,sampleTimer)
|
||||||
|
if progress.stepCount != progress.step {
|
||||||
|
sampleTimer.start()
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = try saveImages(images, logNames: true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleProgress(
|
||||||
|
_ progress: StableDiffusionPipeline.Progress,
|
||||||
|
_ sampleTimer: SampleTimer
|
||||||
|
) {
|
||||||
|
log("\u{1B}[1A\u{1B}[K")
|
||||||
|
log("Step \(progress.step) of \(progress.stepCount) ")
|
||||||
|
log(" [")
|
||||||
|
log(String(format: "mean: %.2f, ", 1.0/sampleTimer.mean))
|
||||||
|
log(String(format: "median: %.2f, ", 1.0/sampleTimer.median))
|
||||||
|
log(String(format: "last %.2f", 1.0/sampleTimer.allSamples.last!))
|
||||||
|
log("] step/sec")
|
||||||
|
|
||||||
|
if saveEvery > 0, progress.step % saveEvery == 0 {
|
||||||
|
let saveCount = (try? saveImages(progress.currentImages, step: progress.step)) ?? 0
|
||||||
|
log(" saved \(saveCount) image\(saveCount != 1 ? "s" : "")")
|
||||||
|
}
|
||||||
|
log("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func saveImages(
|
||||||
|
_ images: [CGImage?],
|
||||||
|
step: Int? = nil,
|
||||||
|
logNames: Bool = false
|
||||||
|
) throws -> Int {
|
||||||
|
let url = URL(filePath: outputPath)
|
||||||
|
var saved = 0
|
||||||
|
for i in 0 ..< images.count {
|
||||||
|
|
||||||
|
guard let image = images[i] else {
|
||||||
|
if logNames {
|
||||||
|
log("Image \(i) failed safety check and was not saved")
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
let name = imageName(i, step: step)
|
||||||
|
let fileURL = url.appending(path:name)
|
||||||
|
|
||||||
|
guard let dest = CGImageDestinationCreateWithURL(fileURL as CFURL, UTType.png.identifier as CFString, 1, nil) else {
|
||||||
|
throw RunError.saving("Failed to create destination for \(fileURL)")
|
||||||
|
}
|
||||||
|
CGImageDestinationAddImage(dest, image, nil)
|
||||||
|
if !CGImageDestinationFinalize(dest) {
|
||||||
|
throw RunError.saving("Failed to save \(fileURL)")
|
||||||
|
}
|
||||||
|
if logNames {
|
||||||
|
log("Saved \(name)\n")
|
||||||
|
}
|
||||||
|
saved += 1
|
||||||
|
}
|
||||||
|
return saved
|
||||||
|
}
|
||||||
|
|
||||||
|
func imageName(_ sample: Int, step: Int? = nil) -> String {
|
||||||
|
var name = prompt.replacingOccurrences(of: " ", with: "_")
|
||||||
|
if imageCount != 1 {
|
||||||
|
name += ".\(sample)"
|
||||||
|
}
|
||||||
|
|
||||||
|
name += ".\(seed)"
|
||||||
|
|
||||||
|
if let step = step {
|
||||||
|
name += ".\(step)"
|
||||||
|
} else {
|
||||||
|
name += ".final"
|
||||||
|
}
|
||||||
|
name += ".png"
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
func log(_ str: String, term: String = "") {
|
||||||
|
print(str, terminator: term)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
enum RunError: Error {
|
||||||
|
case resources(String)
|
||||||
|
case saving(String)
|
||||||
|
}
|
||||||
|
|
||||||
|
enum ComputeUnits: String, ExpressibleByArgument, CaseIterable {
|
||||||
|
case all, cpuAndGPU, cpuOnly, cpuAndNeuralEngine
|
||||||
|
var asMLComputeUnits: MLComputeUnits {
|
||||||
|
switch self {
|
||||||
|
case .all: return .all
|
||||||
|
case .cpuAndGPU: return .cpuAndGPU
|
||||||
|
case .cpuOnly: return .cpuOnly
|
||||||
|
case .cpuAndNeuralEngine: return .cpuAndNeuralEngine
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
StableDiffusionSample.main()
|
@ -0,0 +1,62 @@
|
|||||||
|
// For licensing see accompanying LICENSE.md file.
|
||||||
|
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
import XCTest
|
||||||
|
import CoreML
|
||||||
|
@testable import StableDiffusion
|
||||||
|
|
||||||
|
final class StableDiffusionTests: XCTestCase {
|
||||||
|
|
||||||
|
var vocabFileInBundleURL: URL {
|
||||||
|
let fileName = "vocab"
|
||||||
|
guard let url = Bundle.module.url(forResource: fileName, withExtension: "json") else {
|
||||||
|
fatalError("BPE tokenizer vocabulary file is missing from bundle")
|
||||||
|
}
|
||||||
|
return url
|
||||||
|
}
|
||||||
|
|
||||||
|
var mergesFileInBundleURL: URL {
|
||||||
|
let fileName = "merges"
|
||||||
|
guard let url = Bundle.module.url(forResource: fileName, withExtension: "txt") else {
|
||||||
|
fatalError("BPE tokenizer merges file is missing from bundle")
|
||||||
|
}
|
||||||
|
return url
|
||||||
|
}
|
||||||
|
|
||||||
|
func testBPETokenizer() throws {
|
||||||
|
|
||||||
|
let tokenizer = try BPETokenizer(mergesAt: mergesFileInBundleURL, vocabularyAt: vocabFileInBundleURL)
|
||||||
|
|
||||||
|
func testPrompt(prompt: String, expectedIds: [Int]) {
|
||||||
|
|
||||||
|
let (tokens, ids) = tokenizer.tokenize(input: prompt)
|
||||||
|
|
||||||
|
print("Tokens = \(tokens)\n")
|
||||||
|
print("Expected tokens = \(expectedIds.map({ tokenizer.token(id: $0) }))")
|
||||||
|
print("ids = \(ids)\n")
|
||||||
|
print("Expected Ids = \(expectedIds)\n")
|
||||||
|
|
||||||
|
XCTAssertEqual(ids,expectedIds)
|
||||||
|
}
|
||||||
|
|
||||||
|
testPrompt(prompt: "a photo of an astronaut riding a horse on mars",
|
||||||
|
expectedIds: [49406, 320, 1125, 539, 550, 18376, 6765, 320, 4558, 525, 7496, 49407])
|
||||||
|
|
||||||
|
testPrompt(prompt: "Apple CoreML developer tools on a Macbook Air are fast",
|
||||||
|
expectedIds: [49406, 3055, 19622, 5780, 10929, 5771, 525, 320, 20617,
|
||||||
|
1922, 631, 1953, 49407])
|
||||||
|
}
|
||||||
|
|
||||||
|
func test_randomNormalValues_matchNumPyRandom() {
|
||||||
|
var random = NumPyRandomSource(seed: 12345)
|
||||||
|
let samples = random.normalArray(count: 10_000)
|
||||||
|
let last5 = samples.suffix(5)
|
||||||
|
|
||||||
|
// numpy.random.seed(12345); print(numpy.random.randn(10000)[-5:])
|
||||||
|
let expected = [-0.86285345, 2.15229409, -0.00670556, -1.21472309, 0.65498866]
|
||||||
|
|
||||||
|
for (value, expected) in zip(last5, expected) {
|
||||||
|
XCTAssertEqual(value, expected, accuracy: .ulpOfOne.squareRoot())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,410 @@
|
|||||||
|
#
|
||||||
|
# For licensing see accompanying LICENSE.md file.
|
||||||
|
# Copyright (C) 2022 Apple Inc. All Rights Reserved.
|
||||||
|
#
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import contextlib
|
||||||
|
import coremltools as ct
|
||||||
|
from diffusers import StableDiffusionPipeline
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from PIL import Image
|
||||||
|
from statistics import median
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
torch.set_grad_enabled(False)
|
||||||
|
|
||||||
|
from python_coreml_stable_diffusion import torch2coreml, pipeline, coreml_model
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel("INFO")
|
||||||
|
|
||||||
|
# Testing configuration
|
||||||
|
TEST_SEED = 93
|
||||||
|
TEST_PROMPT = "a high quality photo of an astronaut riding a horse in space"
|
||||||
|
TEST_COMPUTE_UNIT = ["CPU_AND_GPU", "ALL", "CPU_AND_NE"]
|
||||||
|
TEST_PSNR_THRESHOLD = 35 # dB
|
||||||
|
TEST_ABSOLUTE_MAX_LATENCY = 90 # seconds
|
||||||
|
TEST_WARMUP_INFERENCE_STEPS = 3
|
||||||
|
TEST_TEXT_TO_IMAGE_SPEED_REPEATS = 3
|
||||||
|
TEST_MINIMUM_PROMPT_TO_IMAGE_CLIP_COSINE_SIMILARITY = 0.3 # in range [0.,1.]
|
||||||
|
|
||||||
|
|
||||||
|
class TestStableDiffusionForTextToImage(unittest.TestCase):
|
||||||
|
""" Test Stable Diffusion text-to-image pipeline for:
|
||||||
|
|
||||||
|
- PyTorch to CoreML conversion via coremltools
|
||||||
|
- Speed of CoreML runtime across several compute units
|
||||||
|
- Integration with `diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.py`
|
||||||
|
- Efficacy of the safety_checker
|
||||||
|
- Affinity of the generated image with the original prompt via CLIP score
|
||||||
|
- The bridge between Python and Swift CLI
|
||||||
|
- The signal parity of Swift CLI generated image with that of Python CLI
|
||||||
|
"""
|
||||||
|
cli_args = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.pytorch_pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
cls.cli_args.model_version,
|
||||||
|
use_auth_token=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# To be initialized after test_torch_to_coreml_conversion is run
|
||||||
|
cls.coreml_pipe = None
|
||||||
|
cls.active_compute_unit = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
cls.pytorch_pipe = None
|
||||||
|
cls.coreml_pipe = None
|
||||||
|
cls.active_compute_unit = None
|
||||||
|
|
||||||
|
def test_torch_to_coreml_conversion(self):
|
||||||
|
""" Tests:
|
||||||
|
- PyTorch to CoreML conversion via coremltools
|
||||||
|
"""
|
||||||
|
with self.subTest(model="vae_decoder"):
|
||||||
|
logger.info("Converting vae_decoder")
|
||||||
|
torch2coreml.convert_vae_decoder(self.pytorch_pipe, self.cli_args)
|
||||||
|
logger.info("Successfuly converted vae_decoder")
|
||||||
|
|
||||||
|
with self.subTest(model="unet"):
|
||||||
|
logger.info("Converting unet")
|
||||||
|
torch2coreml.convert_unet(self.pytorch_pipe, self.cli_args)
|
||||||
|
logger.info("Successfuly converted unet")
|
||||||
|
|
||||||
|
with self.subTest(model="text_encoder"):
|
||||||
|
logger.info("Converting text_encoder")
|
||||||
|
torch2coreml.convert_text_encoder(self.pytorch_pipe, self.cli_args)
|
||||||
|
logger.info("Successfuly converted text_encoder")
|
||||||
|
|
||||||
|
with self.subTest(model="safety_checker"):
|
||||||
|
logger.info("Converting safety_checker")
|
||||||
|
torch2coreml.convert_safety_checker(self.pytorch_pipe,
|
||||||
|
self.cli_args)
|
||||||
|
logger.info("Successfuly converted safety_checker")
|
||||||
|
|
||||||
|
def test_end_to_end_image_generation_speed(self):
|
||||||
|
""" Tests:
|
||||||
|
- Speed of CoreML runtime across several compute units
|
||||||
|
- Integration with `diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.py`
|
||||||
|
"""
|
||||||
|
latency = {
|
||||||
|
compute_unit:
|
||||||
|
self._coreml_text_to_image_with_compute_unit(compute_unit)
|
||||||
|
for compute_unit in TEST_COMPUTE_UNIT
|
||||||
|
}
|
||||||
|
latency["num_repeats_for_median"] = TEST_TEXT_TO_IMAGE_SPEED_REPEATS
|
||||||
|
|
||||||
|
json_path = os.path.join(self.cli_args.o, "benchmark.json")
|
||||||
|
logger.info(f"Saving inference benchmark results to {json_path}")
|
||||||
|
with open(json_path, "w") as f:
|
||||||
|
json.dump(latency, f)
|
||||||
|
|
||||||
|
for compute_unit in TEST_COMPUTE_UNIT:
|
||||||
|
with self.subTest(compute_unit=compute_unit):
|
||||||
|
self.assertGreater(TEST_ABSOLUTE_MAX_LATENCY,
|
||||||
|
latency[compute_unit])
|
||||||
|
|
||||||
|
def test_image_to_prompt_clip_score(self):
|
||||||
|
""" Tests:
|
||||||
|
Affinity of the generated image with the original prompt via CLIP score
|
||||||
|
"""
|
||||||
|
logger.warning(
|
||||||
|
"This test will download the CLIP ViT-B/16 model (approximately 600 MB) from Hugging Face"
|
||||||
|
)
|
||||||
|
|
||||||
|
from transformers import CLIPProcessor, CLIPModel
|
||||||
|
|
||||||
|
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
|
||||||
|
processor = CLIPProcessor.from_pretrained(
|
||||||
|
"openai/clip-vit-base-patch16")
|
||||||
|
|
||||||
|
for compute_unit in TEST_COMPUTE_UNIT:
|
||||||
|
with self.subTest(compute_unit=compute_unit):
|
||||||
|
image_path = pipeline.get_image_path(self.cli_args,
|
||||||
|
prompt=TEST_PROMPT,
|
||||||
|
compute_unit=compute_unit)
|
||||||
|
image = Image.open(image_path)
|
||||||
|
|
||||||
|
# Preprocess images and text for inference with CLIP
|
||||||
|
inputs = processor(text=[TEST_PROMPT],
|
||||||
|
images=image,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True)
|
||||||
|
outputs = model(**inputs)
|
||||||
|
|
||||||
|
# Compute cosine similarity between image and text embeddings
|
||||||
|
image_text_cosine_similarity = outputs.image_embeds @ outputs.text_embeds.T
|
||||||
|
logger.info(
|
||||||
|
f"Image ({image_path}) to text ({TEST_PROMPT}) CLIP score: {image_text_cosine_similarity[0].item():.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure that the minimum cosine similarity threshold is achieved
|
||||||
|
self.assertGreater(
|
||||||
|
image_text_cosine_similarity,
|
||||||
|
TEST_MINIMUM_PROMPT_TO_IMAGE_CLIP_COSINE_SIMILARITY,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_safety_checker_efficacy(self):
|
||||||
|
""" Tests:
|
||||||
|
- Efficacy of the safety_checker
|
||||||
|
"""
|
||||||
|
self._init_coreml_pipe(compute_unit=self.active_compute_unit)
|
||||||
|
|
||||||
|
safety_checker_test_prompt = "NSFW"
|
||||||
|
image = self.coreml_pipe(safety_checker_test_prompt)
|
||||||
|
|
||||||
|
# Image must have been erased by the safety checker
|
||||||
|
self.assertEqual(np.array(image["images"][0]).sum(), 0.)
|
||||||
|
self.assertTrue(image["nsfw_content_detected"].any())
|
||||||
|
|
||||||
|
def test_swift_cli_image_generation(self):
|
||||||
|
""" Tests:
|
||||||
|
- The bridge between Python and Swift CLI
|
||||||
|
- The signal parity of Swift CLI generated image with that of Python CLI
|
||||||
|
"""
|
||||||
|
# coremltools to Core ML compute unit mapping
|
||||||
|
compute_unit_map = {
|
||||||
|
"ALL": "all",
|
||||||
|
"CPU_AND_GPU": "cpuAndGPU",
|
||||||
|
"CPU_AND_NE": "cpuAndNeuralEngine"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Prepare resources for Swift CLI
|
||||||
|
resources_dir = torch2coreml.bundle_resources_for_swift_cli(
|
||||||
|
self.cli_args)
|
||||||
|
logger.info("Bundled resources for Swift CLI")
|
||||||
|
|
||||||
|
# Execute image generation with Swift CLI
|
||||||
|
# Note: First time takes ~5 minutes due to project building and so on
|
||||||
|
cmd = " ".join([
|
||||||
|
f"swift run StableDiffusionSample \"{TEST_PROMPT}\"",
|
||||||
|
f"--resource-path {resources_dir}",
|
||||||
|
f"--seed {TEST_SEED}",
|
||||||
|
f"--output-path {self.cli_args.o}",
|
||||||
|
f"--compute-units {compute_unit_map[TEST_COMPUTE_UNIT[-1]]}"
|
||||||
|
])
|
||||||
|
logger.info(f"Executing `{cmd}`")
|
||||||
|
os.system(cmd)
|
||||||
|
logger.info(f"Image generation with Swift CLI is complete")
|
||||||
|
|
||||||
|
# Load Swift CLI generated image
|
||||||
|
swift_cli_image = Image.open(
|
||||||
|
os.path.join(
|
||||||
|
self.cli_args.o, "_".join(TEST_PROMPT.rsplit(" ")) + "." +
|
||||||
|
str(TEST_SEED) + ".final.png"))
|
||||||
|
|
||||||
|
# Load Python CLI (pipeline.py) generated image
|
||||||
|
python_cli_image = Image.open(pipeline.get_image_path(self.cli_args,
|
||||||
|
prompt=TEST_PROMPT,
|
||||||
|
compute_unit=TEST_COMPUTE_UNIT[-1]))
|
||||||
|
|
||||||
|
# Compute signal parity
|
||||||
|
swift2torch_psnr = torch2coreml.report_correctness(
|
||||||
|
np.array(swift_cli_image.convert("RGB")),
|
||||||
|
np.array(python_cli_image.convert("RGB")),
|
||||||
|
"Swift CLI and Python CLI generated images")
|
||||||
|
self.assertGreater(swift2torch_psnr, torch2coreml.ABSOLUTE_MIN_PSNR)
|
||||||
|
|
||||||
|
def _init_coreml_pipe(self, compute_unit):
|
||||||
|
""" Initializes CoreML pipe for the requested compute_unit
|
||||||
|
"""
|
||||||
|
assert compute_unit in ct.ComputeUnit._member_names_, f"Not a valid coremltools.ComputeUnit: {compute_unit}"
|
||||||
|
|
||||||
|
if self.active_compute_unit == compute_unit:
|
||||||
|
logger.info(
|
||||||
|
"self.coreml_pipe matches requested compute_unit, skipping reinitialization"
|
||||||
|
)
|
||||||
|
assert \
|
||||||
|
isinstance(self.coreml_pipe, pipeline.CoreMLStableDiffusionPipeline), \
|
||||||
|
type(self.coreml_pipe)
|
||||||
|
else:
|
||||||
|
self.active_compute_unit = compute_unit
|
||||||
|
self.coreml_pipe = pipeline.get_coreml_pipe(
|
||||||
|
pytorch_pipe=self.pytorch_pipe,
|
||||||
|
mlpackages_dir=self.cli_args.o,
|
||||||
|
model_version=self.cli_args.model_version,
|
||||||
|
compute_unit=self.active_compute_unit,)
|
||||||
|
|
||||||
|
|
||||||
|
def _coreml_text_to_image_with_compute_unit(self, compute_unit):
|
||||||
|
""" Benchmark end-to-end text-to-image generation with the requested compute_unit
|
||||||
|
"""
|
||||||
|
self._init_coreml_pipe(compute_unit)
|
||||||
|
|
||||||
|
# Warm up (not necessary in all settings but improves consistency for benchmarking)
|
||||||
|
logger.info(
|
||||||
|
f"Warmup image generation with {TEST_WARMUP_INFERENCE_STEPS} inference steps"
|
||||||
|
)
|
||||||
|
image = self.coreml_pipe(
|
||||||
|
TEST_PROMPT, num_inference_steps=TEST_WARMUP_INFERENCE_STEPS)
|
||||||
|
|
||||||
|
# Test end-to-end speed
|
||||||
|
logger.info(
|
||||||
|
f"Run full image generation {TEST_TEXT_TO_IMAGE_SPEED_REPEATS} times and report median"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_coreml_text_to_image_speed():
|
||||||
|
""" Execute Core ML based image generation
|
||||||
|
"""
|
||||||
|
_reset_seed()
|
||||||
|
image = self.coreml_pipe(TEST_PROMPT)["images"][0]
|
||||||
|
out_path = pipeline.get_image_path(self.cli_args,
|
||||||
|
prompt=TEST_PROMPT,
|
||||||
|
compute_unit=compute_unit)
|
||||||
|
logger.info(f"Saving generated image to {out_path}")
|
||||||
|
image.save(out_path)
|
||||||
|
|
||||||
|
def collect_timings(callable, n):
|
||||||
|
""" Collect user latency for callable
|
||||||
|
"""
|
||||||
|
user_latencies = []
|
||||||
|
for _ in range(n):
|
||||||
|
s = time.time()
|
||||||
|
callable()
|
||||||
|
user_latencies.append(float(f"{time.time() - s:.2f}"))
|
||||||
|
return user_latencies
|
||||||
|
|
||||||
|
coreml_latencies = collect_timings(
|
||||||
|
callable=test_coreml_text_to_image_speed,
|
||||||
|
n=TEST_TEXT_TO_IMAGE_SPEED_REPEATS)
|
||||||
|
coreml_median_latency = median(coreml_latencies)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"End-to-end latencies with coremltools.ComputeUnit.{compute_unit}: median={coreml_median_latency:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return coreml_median_latency
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_seed():
|
||||||
|
""" Reset RNG state in order to reproduce the results across multiple runs
|
||||||
|
"""
|
||||||
|
torch.manual_seed(TEST_SEED)
|
||||||
|
np.random.seed(TEST_SEED)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_test_artifacts_dir(args):
|
||||||
|
if cli_args.persistent_test_artifacts_dir is not None:
|
||||||
|
os.makedirs(cli_args.persistent_test_artifacts_dir, exist_ok=True)
|
||||||
|
return contextlib.nullcontext(
|
||||||
|
enter_result=cli_args.persistent_test_artifacts_dir)
|
||||||
|
else:
|
||||||
|
return tempfile.TemporaryDirectory(
|
||||||
|
prefix="python_coreml_stable_diffusion_tests")
|
||||||
|
|
||||||
|
|
||||||
|
def _extend_parser(parser):
|
||||||
|
parser.add_argument(
|
||||||
|
"--persistent-test-artifacts-dir",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=
|
||||||
|
("If specified, test artifacts such as Core ML models and generated images are saved in this directory. ",
|
||||||
|
"Otherwise, all artifacts are erased after the test program terminates."
|
||||||
|
))
|
||||||
|
parser.add_argument(
|
||||||
|
"--fast",
|
||||||
|
action="store_true",
|
||||||
|
help=
|
||||||
|
"If specified, runs fewer repeats for `test_end_to_end_image_generation_speed`"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--test-image-to-prompt-clip-score-opt-in",
|
||||||
|
action="store_true",
|
||||||
|
help=
|
||||||
|
("If specified, enables `test_image_to_prompt_clip_score` to verify the relevance of the "
|
||||||
|
"generated image content to the original text prompt. This test is an opt-in "
|
||||||
|
"test because it involves an additional one time 600MB model download."
|
||||||
|
))
|
||||||
|
parser.add_argument(
|
||||||
|
"--test-swift-cli-opt-in",
|
||||||
|
action="store_true",
|
||||||
|
help=
|
||||||
|
("If specified, compiles all models and builds the Swift CLI to run image generation and compares "
|
||||||
|
"results across Python and Swift runtime"))
|
||||||
|
parser.add_argument(
|
||||||
|
"--test-safety-checker-efficacy-opt-in",
|
||||||
|
action="store_true",
|
||||||
|
help=
|
||||||
|
("If specified, generates a potentially NSFW image to check whether the `safety_checker` "
|
||||||
|
"accurately detects and removes the content"))
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Reproduce the CLI of the original pipeline
|
||||||
|
parser = torch2coreml.parser_spec()
|
||||||
|
parser = _extend_parser(parser)
|
||||||
|
cli_args = parser.parse_args()
|
||||||
|
|
||||||
|
cli_args.check_output_correctness = True
|
||||||
|
cli_args.prompt = TEST_PROMPT
|
||||||
|
cli_args.seed = TEST_SEED
|
||||||
|
cli_args.compute_unit = TEST_COMPUTE_UNIT[0]
|
||||||
|
cli_args.scheduler = None # use default
|
||||||
|
torch2coreml.ABSOLUTE_MIN_PSNR = TEST_PSNR_THRESHOLD
|
||||||
|
|
||||||
|
if cli_args.fast:
|
||||||
|
logger.info(
|
||||||
|
"`--fast` detected: Image generation will be run once " \
|
||||||
|
f"(instead of {TEST_TEXT_TO_IMAGE_SPEED_REPEATS } times) " \
|
||||||
|
"with ComputeUnit.ALL (other compute units are skipped)" \
|
||||||
|
" (median can not be reported)")
|
||||||
|
TEST_TEXT_TO_IMAGE_SPEED_REPEATS = 1
|
||||||
|
TEST_COMPUTE_UNIT = ["ALL"]
|
||||||
|
|
||||||
|
logger.info("`--fast` detected: Skipping `--check-output-correctness` tests")
|
||||||
|
cli_args.check_output_correctness = False
|
||||||
|
elif cli_args.attention_implementation == "ORIGINAL":
|
||||||
|
TEST_COMPUTE_UNIT = ["CPU_AND_GPU", "ALL"]
|
||||||
|
elif cli_args.attention_implementation == "SPLIT_EINSUM":
|
||||||
|
TEST_COMPUTE_UNIT = ["ALL", "CPU_AND_NE"]
|
||||||
|
|
||||||
|
logger.info(f"Testing compute units: {TEST_COMPUTE_UNIT}")
|
||||||
|
|
||||||
|
|
||||||
|
# Save CoreML model files and generated images into the artifacts dir
|
||||||
|
with _get_test_artifacts_dir(cli_args) as test_artifacts_dir:
|
||||||
|
cli_args.o = test_artifacts_dir
|
||||||
|
logger.info(f"Test artifacts will be saved under {test_artifacts_dir}")
|
||||||
|
|
||||||
|
TestStableDiffusionForTextToImage.cli_args = cli_args
|
||||||
|
|
||||||
|
# Run the following tests in sequential order
|
||||||
|
suite = unittest.TestSuite()
|
||||||
|
suite.addTest(
|
||||||
|
TestStableDiffusionForTextToImage(
|
||||||
|
"test_torch_to_coreml_conversion"))
|
||||||
|
suite.addTest(
|
||||||
|
TestStableDiffusionForTextToImage(
|
||||||
|
"test_end_to_end_image_generation_speed"))
|
||||||
|
|
||||||
|
if cli_args.test_safety_checker_efficacy_opt_in:
|
||||||
|
suite.addTest(
|
||||||
|
TestStableDiffusionForTextToImage("test_safety_checker_efficacy"))
|
||||||
|
|
||||||
|
if cli_args.test_image_to_prompt_clip_score_opt_in:
|
||||||
|
suite.addTest(
|
||||||
|
TestStableDiffusionForTextToImage(
|
||||||
|
"test_image_to_prompt_clip_score"))
|
||||||
|
|
||||||
|
if cli_args.test_swift_cli_opt_in:
|
||||||
|
suite.addTest(
|
||||||
|
TestStableDiffusionForTextToImage(
|
||||||
|
"test_swift_cli_image_generation"))
|
||||||
|
|
||||||
|
if os.getenv("DEBUG", False):
|
||||||
|
suite.debug()
|
||||||
|
else:
|
||||||
|
runner = unittest.TextTestRunner()
|
||||||
|
runner.run(suite)
|