图像生成:PyTorch从零开始实现一个简单的扩散模型

在这里插入图片描述

在这里插入图片描述

前言

环境要求

Package                               Version             Editable project location
------------------------------------- ------------------- -------------------------
absl-py                               1.4.0
absolufy-imports                      0.3.1
accelerate                            1.9.0
aiofiles                              22.1.0
aiohappyeyeballs                      2.6.1
aiohttp                               3.12.15
aiosignal                             1.4.0
aiosqlite                             0.21.0
alabaster                             1.0.0
albucore                              0.0.24
albumentations                        2.0.8
ale-py                                0.11.2
alembic                               1.16.5
altair                                5.5.0
annotated-types                       0.7.0
annoy                                 1.17.3
ansicolors                            1.1.8
antlr4-python3-runtime                4.9.3
anyio                                 4.11.0
anywidget                             0.9.18
argon2-cffi                           25.1.0
argon2-cffi-bindings                  21.2.0
args                                  0.1.0
array_record                          0.7.2
arrow                                 1.3.0
arviz                                 0.21.0
astropy                               7.1.0
astropy-iers-data                     0.2025.7.21.0.41.39
asttokens                             3.0.0
astunparse                            1.6.3
atpublic                              5.1
attrs                                 25.3.0
audioread                             3.0.1
Authlib                               1.6.4
autograd                              1.8.0
babel                                 2.17.0
backcall                              0.2.0
backports.tarfile                     1.2.0
bayesian-optimization                 3.1.0
beartype                              0.21.0
beautifulsoup4                        4.13.4
betterproto                           2.0.0b7
bigframes                             2.12.0
bigquery-magics                       0.10.1
black                                 25.9.0
bleach                                6.2.0
blinker                               1.9.0
blis                                  1.3.0
blobfile                              3.0.0
blosc2                                3.6.1
bokeh                                 3.7.3
Boruta                                0.4.3
boto3                                 1.40.39
botocore                              1.40.39
Bottleneck                            1.4.2
bq_helper                             0.4.1               /root/src/BigQuery_Helper
bqplot                                0.12.45
branca                                0.8.1
Brotli                                1.1.0
build                                 1.2.2.post1
CacheControl                          0.14.3
cachetools                            5.5.2
Cartopy                               0.24.1
catalogue                             2.0.10
catboost                              1.2.8
category_encoders                     2.7.0
certifi                               2025.8.3
cesium                                0.12.4
cffi                                  2.0.0
chardet                               5.2.0
charset-normalizer                    3.4.3
Chessnut                              0.4.1
chex                                  0.1.90
clarabel                              0.11.1
click                                 8.3.0
click-plugins                         1.1.1.2
cligj                                 0.7.2
clint                                 0.5.1
cloudpathlib                          0.21.1
cloudpickle                           3.1.1
cmake                                 3.31.6
cmdstanpy                             1.2.5
colorama                              0.4.6
colorcet                              3.1.0
colorlog                              6.9.0
colorlover                            0.3.0
colour                                0.1.5
comm                                  0.2.3
community                             1.0.0b1
confection                            0.1.5
cons                                  0.4.7
contourpy                             1.3.2
coverage                              7.10.7
cramjam                               2.10.0
cryptography                          46.0.1
cuda-bindings                         12.9.2
cuda-pathfinder                       1.2.3
cuda-python                           12.9.2
cudf-cu12                             25.2.2
cudf-polars-cu12                      25.6.0
cufflinks                             0.17.3
cuml-cu12                             25.2.1
cupy-cuda12x                          13.6.0
curl_cffi                             0.12.0
cuvs-cu12                             25.2.1
cvxopt                                1.3.2
cvxpy                                 1.6.7
cycler                                0.12.1
cyipopt                               1.5.0
cymem                                 2.0.11
Cython                                3.0.12
cytoolz                               1.0.1
daal                                  2025.8.0
dacite                                1.9.2
dask                                  2024.12.1
dask-cuda                             25.2.0
dask-cudf-cu12                        25.2.2
dask-expr                             1.1.21
dataclasses-json                      0.6.7
dataproc-spark-connect                0.8.3
datasets                              4.1.1
db-dtypes                             1.4.3
dbus-python                           1.2.18
deap                                  1.4.3
debugpy                               1.8.15
decorator                             4.4.2
deepdiff                              8.6.1
defusedxml                            0.7.1
Deprecated                            1.2.18
diffusers                             0.34.0
dill                                  0.4.0
dipy                                  1.11.0
distributed                           2024.12.1
distributed-ucxx-cu12                 0.42.0
distro                                1.9.0
dlib                                  19.24.6
dm-tree                               0.1.9
dnspython                             2.8.0
docker                                7.1.0
docstring_parser                      0.17.0
docstring-to-markdown                 0.17
docutils                              0.21.2
dopamine_rl                           4.1.2
duckdb                                1.3.2
earthengine-api                       1.5.24
easydict                              1.13
easyocr                               1.7.2
editdistance                          0.8.1
eerepr                                0.1.2
einops                                0.8.1
eli5                                  0.13.0
email-validator                       2.3.0
emoji                                 2.15.0
en_core_web_sm                        3.8.0
entrypoints                           0.4
et_xmlfile                            2.0.0
etils                                 1.13.0
etuples                               0.3.10
execnb                                0.1.14
Farama-Notifications                  0.0.4
fastai                                2.8.4
fastapi                               0.116.1
fastcore                              1.8.11
fastdownload                          0.0.7
fastjsonschema                        2.21.1
fastprogress                          1.0.3
fastrlock                             0.8.3
fasttext                              0.9.3
fasttransform                         0.0.2
featuretools                          1.31.0
ffmpy                                 0.6.1
filelock                              3.19.1
filetype                              1.2.0
fiona                                 1.10.1
firebase-admin                        6.9.0
Flask                                 3.1.1
flatbuffers                           25.2.10
flax                                  0.10.6
folium                                0.20.0
fonttools                             4.59.0
fqdn                                  1.5.1
frozendict                            2.4.6
frozenlist                            1.7.0
fsspec                                2025.9.0
funcy                                 2.0
fury                                  0.12.0
future                                1.0.0
fuzzywuzzy                            0.18.0
gast                                  0.6.0
gatspy                                0.3
gcsfs                                 2025.3.0
GDAL                                  3.8.4
gdown                                 5.2.0
geemap                                0.35.3
gensim                                4.3.3
geocoder                              1.38.1
geographiclib                         2.0
geojson                               3.2.0
geopandas                             0.14.4
geopy                                 2.4.1
ghapi                                 1.0.8
gin-config                            0.5.0
gitdb                                 4.0.12
GitPython                             3.1.45
glob2                                 0.7
google                                2.0.3
google-adk                            1.14.1
google-ai-generativelanguage          0.6.15
google-api-core                       1.34.1
google-api-python-client              2.177.0
google-auth                           2.40.3
google-auth-httplib2                  0.2.0
google-auth-oauthlib                  1.2.2
google-cloud-aiplatform               1.105.0
google-cloud-appengine-logging        1.6.2
google-cloud-audit-log                0.3.2
google-cloud-automl                   1.0.1
google-cloud-bigquery                 3.25.0
google-cloud-bigquery-connection      1.18.3
google-cloud-bigtable                 2.32.0
google-cloud-core                     2.4.3
google-cloud-dataproc                 5.21.0
google-cloud-datastore                2.21.0
google-cloud-firestore                2.21.0
google-cloud-functions                1.20.4
google-cloud-iam                      2.19.1
google-cloud-language                 2.17.2
google-cloud-logging                  3.12.1
google-cloud-resource-manager         1.14.2
google-cloud-secret-manager           2.24.0
google-cloud-spanner                  3.56.0
google-cloud-speech                   2.33.0
google-cloud-storage                  2.19.0
google-cloud-trace                    1.16.2
google-cloud-translate                3.12.1
google-cloud-videointelligence        2.16.2
google-cloud-vision                   3.10.2
google-colab                          1.0.0
google-crc32c                         1.7.1
google-genai                          1.27.0
google-generativeai                   0.8.5
google-pasta                          0.2.0
google-resumable-media                2.7.2
googleapis-common-protos              1.70.0
googledrivedownloader                 1.1.0
gpxpy                                 1.6.2
gradio                                5.38.1
gradio_client                         1.11.0
graphviz                              0.21
greenlet                              3.2.3
groovy                                0.1.2
grpc-google-iam-v1                    0.14.2
grpc-interceptor                      0.15.4
grpcio                                1.75.1
grpcio-status                         1.49.0rc1
grpclib                               0.4.8
gspread                               6.2.1
gspread-dataframe                     4.0.0
gym                                   0.25.2
gym-notices                           0.0.8
gymnasium                             0.29.0
h11                                   0.16.0
h2                                    4.3.0
h2o                                   3.46.0.7
h5netcdf                              1.6.3
h5py                                  3.14.0
haversine                             2.9.0
hdbscan                               0.8.40
hep_ml                                0.8.0
hf_transfer                           0.1.9
hf-xet                                1.1.10
highspy                               1.11.0
holidays                              0.77
holoviews                             1.21.0
hpack                                 4.1.0
html5lib                              1.1
httpcore                              1.0.9
httpimport                            1.4.1
httplib2                              0.22.0
httpx                                 0.28.1
httpx-sse                             0.4.1
huggingface-hub                       1.0.0rc2
humanize                              4.12.3
hyperframe                            6.1.0
hyperopt                              0.2.7
ibis-framework                        9.5.0
id                                    1.5.0
idna                                  3.10
igraph                                0.11.9
ImageHash                             4.3.1
imageio                               2.37.0
imageio-ffmpeg                        0.6.0
imagesize                             1.4.1
imbalanced-learn                      0.13.0
immutabledict                         4.2.1
importlib_metadata                    8.7.0
importlib_resources                   6.5.2
imutils                               0.5.4
in-toto-attestation                   0.9.3
inflect                               7.5.0
iniconfig                             2.1.0
intel-cmplr-lib-rt                    2024.2.0
intel-cmplr-lib-ur                    2024.2.0
intel-openmp                          2024.2.0
ipyevents                             2.0.2
ipyfilechooser                        0.6.0
ipykernel                             6.17.1
ipyleaflet                            0.20.0
ipympl                                0.9.7
ipyparallel                           8.8.0
ipython                               7.34.0
ipython-genutils                      0.2.0
ipython_pygments_lexers               1.1.1
ipython-sql                           0.5.0
ipytree                               0.2.2
ipywidgets                            8.1.5
isoduration                           20.11.0
isoweek                               1.3.3
itsdangerous                          2.2.0
Janome                                0.5.0
jaraco.classes                        3.4.0
jaraco.context                        6.0.1
jaraco.functools                      4.2.1
jax                                   0.5.2
jax-cuda12-pjrt                       0.5.1
jax-cuda12-plugin                     0.5.1
jaxlib                                0.5.1
jedi                                  0.19.2
jeepney                               0.9.0
jieba                                 0.42.1
Jinja2                                3.1.6
jiter                                 0.10.0
jmespath                              1.0.1
joblib                                1.5.2
json5                                 0.12.1
jsonpatch                             1.33
jsonpickle                            4.1.1
jsonpointer                           3.0.0
jsonschema                            4.25.0
jsonschema-specifications             2025.4.1
jupyter_client                        8.6.3
jupyter-console                       6.1.0
jupyter_core                          5.8.1
jupyter-events                        0.12.0
jupyter_kernel_gateway                2.5.2
jupyter-leaflet                       0.20.0
jupyter-lsp                           1.5.1
jupyter_server                        2.12.5
jupyter_server_fileid                 0.9.3
jupyter_server_terminals              0.5.3
jupyter_server_ydoc                   0.8.0
jupyter-ydoc                          0.2.5
jupyterlab                            3.6.8
jupyterlab-lsp                        3.10.2
jupyterlab_pygments                   0.3.0
jupyterlab_server                     2.27.3
jupyterlab_widgets                    3.0.15
jupytext                              1.17.2
kaggle                                1.7.4.5
kaggle-environments                   1.18.0
kagglehub                             0.3.13
keras                                 3.8.0
keras-core                            0.1.7
keras-cv                              0.9.0
keras-hub                             0.18.1
keras-nlp                             0.18.1
keras-tuner                           1.4.7
keyring                               25.6.0
keyrings.google-artifactregistry-auth 1.1.2
kiwisolver                            1.4.8
kornia                                0.8.1
kornia_rs                             0.1.9
kt-legacy                             1.0.5
langchain                             0.3.27
langchain-core                        0.3.72
langchain-text-splitters              0.3.9
langcodes                             3.5.0
langid                                1.1.6
langsmith                             0.4.8
language_data                         1.3.0
lark                                  1.3.0
launchpadlib                          1.10.16
lazr.restfulclient                    0.14.4
lazr.uri                              1.0.6
lazy_loader                           0.4
learntools                            0.3.5
libclang                              18.1.1
libcudf-cu12                          25.2.2
libcugraph-cu12                       25.6.0
libcuml-cu12                          25.2.1
libcuvs-cu12                          25.2.1
libkvikio-cu12                        25.2.1
libpysal                              4.9.2
libraft-cu12                          25.2.0
librmm-cu12                           25.6.0
librosa                               0.11.0
libucx-cu12                           1.18.1
libucxx-cu12                          0.42.0
lightgbm                              4.6.0
lightning-utilities                   0.15.2
lime                                  0.2.0.1
line_profiler                         5.0.0
linkify-it-py                         2.0.3
llvmlite                              0.43.0
lml                                   0.2.0
locket                                1.0.0
logical-unification                   0.4.6
lxml                                  5.4.0
Mako                                  1.3.10
mamba                                 0.11.3
marisa-trie                           1.2.1
Markdown                              3.8.2
markdown-it-py                        4.0.0
MarkupSafe                            3.0.2
marshmallow                           3.26.1
matplotlib                            3.7.2
matplotlib-inline                     0.1.7
matplotlib-venn                       1.1.2
mcp                                   1.15.0
mdit-py-plugins                       0.4.2
mdurl                                 0.1.2
minify_html                           0.16.4
miniKanren                            1.0.5
missingno                             0.5.2
mistune                               0.8.4
mizani                                0.13.5
mkl                                   2025.2.0
mkl-fft                               1.3.8
mkl-random                            1.2.4
mkl-service                           2.4.1
mkl-umath                             0.1.1
ml_collections                        1.1.0
ml-dtypes                             0.4.1
mlcrate                               0.2.0
mlxtend                               0.23.4
mne                                   1.10.1
model-signing                         1.0.1
more-itertools                        10.7.0
moviepy                               1.0.3
mpld3                                 0.5.11
mpmath                                1.3.0
msgpack                               1.1.1
multidict                             6.6.4
multimethod                           1.12
multipledispatch                      1.0.0
multiprocess                          0.70.16
multitasking                          0.0.12
murmurhash                            1.0.13
music21                               9.3.0
mypy_extensions                       1.1.0
namex                                 0.1.0
narwhals                              1.48.1
natsort                               8.4.0
nbclassic                             1.3.1
nbclient                              0.5.13
nbconvert                             6.4.5
nbdev                                 2.4.5
nbformat                              5.10.4
ndindex                               1.10.0
nest-asyncio                          1.6.0
networkx                              3.5
nibabel                               5.3.2
nilearn                               0.10.4
ninja                                 1.13.0
nltk                                  3.9.1
notebook                              6.5.4
notebook_shim                         0.2.4
numba                                 0.60.0
numba-cuda                            0.2.0
numexpr                               2.11.0
numpy                                 1.26.4
nvidia-cublas-cu12                    12.5.3.2
nvidia-cuda-cupti-cu12                12.5.82
nvidia-cuda-nvcc-cu12                 12.5.82
nvidia-cuda-nvrtc-cu12                12.5.82
nvidia-cuda-runtime-cu12              12.5.82
nvidia-cudnn-cu12                     9.3.0.75
nvidia-cufft-cu12                     11.2.3.61
nvidia-curand-cu12                    10.3.6.82
nvidia-cusolver-cu12                  11.6.3.83
nvidia-cusparse-cu12                  12.5.1.3
nvidia-cusparselt-cu12                0.6.2
nvidia-ml-py                          12.575.51
nvidia-nccl-cu12                      2.21.5
nvidia-nvcomp-cu12                    4.2.0.11
nvidia-nvjitlink-cu12                 12.5.82
nvidia-nvtx-cu12                      12.4.127
nvtx                                  0.2.13
nx-cugraph-cu12                       25.6.0
oauth2client                          4.1.3
oauthlib                              3.3.1
odfpy                                 1.4.1
olefile                               0.47
omegaconf                             2.3.0
onnx                                  1.18.0
open_spiel                            1.6.1
openai                                1.97.1
opencv-contrib-python                 4.12.0.88
opencv-python                         4.12.0.88
opencv-python-headless                4.12.0.88
openpyxl                              3.1.5
openslide-bin                         4.0.0.8
openslide-python                      1.4.2
opentelemetry-api                     1.37.0
opentelemetry-exporter-gcp-trace      1.9.0
opentelemetry-resourcedetector-gcp    1.9.0a0
opentelemetry-sdk                     1.37.0
opentelemetry-semantic-conventions    0.58b0
opt_einsum                            3.4.0
optax                                 0.2.5
optree                                0.16.0
optuna                                4.5.0
orbax-checkpoint                      0.11.19
orderly-set                           5.5.0
orjson                                3.11.0
osqp                                  1.0.4
overrides                             7.7.0
packaging                             25.0
pandas                                2.2.3
pandas-datareader                     0.10.0
pandas-gbq                            0.29.2
pandas-profiling                      3.6.6
pandas-stubs                          2.2.2.240909
pandasql                              0.7.3
pandocfilters                         1.5.1
panel                                 1.7.5
papermill                             2.6.0
param                                 2.2.1
parso                                 0.8.4
parsy                                 2.1
partd                                 1.4.2
path                                  17.1.1
path.py                               12.5.0
pathos                                0.3.2
pathspec                              0.12.1
patsy                                 1.0.1
pdf2image                             1.17.0
peewee                                3.18.2
peft                                  0.16.0
pettingzoo                            1.24.0
pexpect                               4.9.0
phik                                  0.12.5
pickleshare                           0.7.5
pillow                                11.3.0
pip                                   24.1.2
platformdirs                          4.4.0
plotly                                5.24.1
plotly-express                        0.4.1
plotnine                              0.14.5
pluggy                                1.6.0
plum-dispatch                         2.5.7
ply                                   3.11
polars                                1.25.0
pooch                                 1.8.2
portpicker                            1.5.2
pox                                   0.3.6
ppft                                  1.7.7
preprocessing                         0.1.13
preshed                               3.0.10
prettytable                           3.16.0
proglog                               0.1.12
progressbar2                          4.5.0
prometheus_client                     0.22.1
promise                               2.3
prompt_toolkit                        3.0.51
propcache                             0.3.2
prophet                               1.1.7
proto-plus                            1.26.1
protobuf                              3.20.3
psutil                                7.1.0
psycopg2                              2.9.10
psygnal                               0.14.0
ptyprocess                            0.7.0
pudb                                  2025.1.1
puremagic                             1.30
py-cpuinfo                            9.0.0
py4j                                  0.10.9.7
pyaml                                 25.7.0
PyArabic                              0.6.15
pyarrow                               19.0.1
pyasn1                                0.6.1
pyasn1_modules                        0.4.2
pybind11                              3.0.1
pycairo                               1.28.0
pyclipper                             1.3.0.post6
pycocotools                           2.0.10
pycparser                             2.23
pycryptodome                          3.23.0
pycryptodomex                         3.23.0
pycuda                                2025.1.2
pydantic                              2.12.0a1
pydantic_core                         2.37.2
pydantic-settings                     2.11.0
pydata-google-auth                    1.9.1
pydegensac                            0.1.2
pydicom                               3.0.1
pydot                                 3.0.4
pydotplus                             2.0.2
PyDrive                               1.3.1
PyDrive2                              1.21.3
pydub                                 0.25.1
pyemd                                 1.0.0
pyerfa                                2.0.1.5
pyexcel-io                            0.6.7
pyexcel-ods                           0.6.0
pygame                                2.6.1
pygit2                                1.18.0
pygltflib                             1.16.5
Pygments                              2.19.2
PyGObject                             3.42.0
PyJWT                                 2.10.1
pyLDAvis                              3.4.1
pylibcudf-cu12                        25.2.2
pylibcugraph-cu12                     25.6.0
pylibraft-cu12                        25.2.0
pymc                                  5.25.1
pymc3                                 3.11.4
pymongo                               4.15.1
Pympler                               1.1
pynndescent                           0.5.13
pynvjitlink-cu12                      0.5.2
pynvml                                12.0.0
pyogrio                               0.11.0
pyomo                                 6.9.2
PyOpenGL                              3.1.9
pyOpenSSL                             25.3.0
pyparsing                             3.0.9
pypdf                                 6.1.0
pyperclip                             1.9.0
pyproj                                3.7.1
pyproject_hooks                       1.2.0
pyshp                                 2.3.1
PySocks                               1.7.1
pyspark                               3.5.1
pytensor                              2.31.7
pytesseract                           0.3.13
pytest                                8.4.1
python-apt                            0.0.0
python-bidi                           0.6.6
python-box                            7.3.2
python-dateutil                       2.9.0.post0
python-dotenv                         1.1.1
python-json-logger                    3.3.0
python-louvain                        0.16
python-lsp-jsonrpc                    1.1.2
python-lsp-server                     1.13.1
python-multipart                      0.0.20
python-slugify                        8.0.4
python-snappy                         0.7.3
python-utils                          3.9.1
pytokens                              0.1.10
pytools                               2025.2.4
pytorch-ignite                        0.5.2
pytorch-lightning                     2.5.5
pytz                                  2025.2
PyUpSet                               0.1.1.post7
pyviz_comms                           3.0.6
PyWavelets                            1.8.0
PyYAML                                6.0.3
pyzmq                                 26.2.1
qgrid                                 1.3.1
qtconsole                             5.7.0
QtPy                                  2.4.3
raft-dask-cu12                        25.2.0
rapids-dask-dependency                25.2.0
rapids-logger                         0.1.1
ratelim                               0.1.6
ray                                   2.49.2
referencing                           0.36.2
regex                                 2025.9.18
requests                              2.32.5
requests-oauthlib                     2.0.0
requests-toolbelt                     1.0.0
requirements-parser                   0.9.0
rfc3161-client                        1.0.5
rfc3339-validator                     0.1.4
rfc3986-validator                     0.1.1
rfc3987-syntax                        1.1.0
rfc8785                               0.1.4
rgf-python                            3.12.0
rich                                  14.1.0
rmm-cu12                              25.2.0
roman-numerals-py                     3.1.0
rpds-py                               0.26.0
rpy2                                  3.5.17
rsa                                   4.9.1
rtree                                 1.4.1
ruff                                  0.12.5
s3fs                                  0.4.2
s3transfer                            0.14.0
safehttpx                             0.1.6
safetensors                           0.5.3
scikit-image                          0.25.2
scikit-learn                          1.2.2
scikit-learn-intelex                  2025.8.0
scikit-multilearn                     0.2.0
scikit-optimize                       0.10.2
scikit-plot                           0.3.7
scikit-surprise                       1.1.4
scipy                                 1.15.3
scooby                                0.10.1
scs                                   3.2.7.post2
seaborn                               0.12.2
SecretStorage                         3.3.3
securesystemslib                      1.3.1
segment_anything                      1.0
semantic-version                      2.10.0
semver                                3.0.4
Send2Trash                            1.8.3
sentence-transformers                 4.1.0
sentencepiece                         0.2.0
sentry-sdk                            2.33.2
setuptools                            75.2.0
setuptools-scm                        9.2.0
shap                                  0.44.1
shapely                               2.1.2
shellingham                           1.5.4
Shimmy                                1.3.0
sigstore                              4.0.0
sigstore-models                       0.0.5
sigstore-rekor-types                  0.0.18
simple-parsing                        0.1.7
simpleitk                             2.5.2
simplejson                            3.20.1
simsimd                               6.5.0
siphash24                             1.8
six                                   1.17.0
sklearn-compat                        0.1.3
sklearn-pandas                        2.2.0
slicer                                0.0.7
smart_open                            7.3.0.post1
smmap                                 5.0.2
sniffio                               1.3.1
snowballstemmer                       3.0.1
sortedcontainers                      2.4.0
soundfile                             0.13.1
soupsieve                             2.7
soxr                                  0.5.0.post1
spacy                                 3.8.7
spacy-legacy                          3.0.12
spacy-loggers                         1.0.5
spanner-graph-notebook                1.1.7
Sphinx                                8.2.3
sphinx-rtd-theme                      0.2.4
sphinxcontrib-applehelp               2.0.0
sphinxcontrib-devhelp                 2.0.0
sphinxcontrib-htmlhelp                2.1.0
sphinxcontrib-jsmath                  1.0.1
sphinxcontrib-qthelp                  2.0.0
sphinxcontrib-serializinghtml         2.0.0
SQLAlchemy                            2.0.41
sqlalchemy-spanner                    1.16.0
sqlglot                               25.20.2
sqlparse                              0.5.3
squarify                              0.4.4
srsly                                 2.5.1
sse-starlette                         3.0.2
stable-baselines3                     2.1.0
stanio                                0.5.1
starlette                             0.47.2
statsmodels                           0.14.5
stopit                                1.1.2
stringzilla                           3.12.5
stumpy                                1.13.0
sympy                                 1.13.1
tables                                3.10.2
tabulate                              0.9.0
tbb                                   2022.2.0
tbb4py                                2022.2.0
tblib                                 3.1.0
tcmlib                                1.4.0
tenacity                              8.5.0
tensorboard                           2.18.0
tensorboard-data-server               0.7.2
tensorflow                            2.18.0
tensorflow-cloud                      0.1.5
tensorflow-datasets                   4.9.9
tensorflow_decision_forests           1.11.0
tensorflow-hub                        0.16.1
tensorflow-io                         0.37.1
tensorflow-io-gcs-filesystem          0.37.1
tensorflow-metadata                   1.17.2
tensorflow-probability                0.25.0
tensorflow-text                       2.18.1
tensorstore                           0.1.74
termcolor                             3.1.0
terminado                             0.18.1
testpath                              0.6.0
text-unidecode                        1.3
textblob                              0.19.0
texttable                             1.7.0
tf_keras                              2.18.0
tf-slim                               1.1.0
Theano                                1.0.5
Theano-PyMC                           1.1.2
thinc                                 8.3.6
threadpoolctl                         3.6.0
tifffile                              2025.6.11
tiktoken                              0.9.0
timm                                  1.0.19
tinycss2                              1.4.0
tokenizers                            0.21.2
toml                                  0.10.2
tomlkit                               0.13.3
toolz                                 1.0.0
torch                                 2.6.0+cu124
torchao                               0.10.0
torchaudio                            2.6.0+cu124
torchdata                             0.11.0
torchinfo                             1.8.0
torchmetrics                          1.8.2
torchsummary                          1.5.1
torchtune                             0.6.1
torchvision                           0.21.0+cu124
tornado                               6.5.2
TPOT                                  0.12.1
tqdm                                  4.67.1
traitlets                             5.7.1
traittypes                            0.2.1
transformers                          4.53.3
treelite                              4.4.1
treescope                             0.1.9
triton                                3.2.0
trx-python                            0.3
tsfresh                               0.21.0
tuf                                   6.0.0
tweepy                                4.16.0
typeguard                             4.4.4
typer                                 0.16.0
typer-slim                            0.19.2
types-python-dateutil                 2.9.0.20250822
types-pytz                            2025.2.0.20250516
types-setuptools                      80.9.0.20250529
typing_extensions                     4.15.0
typing-inspect                        0.9.0
typing-inspection                     0.4.1
tzdata                                2025.2
tzlocal                               5.3.1
uc-micro-py                           1.0.3
ucx-py-cu12                           0.42.0
ucxx-cu12                             0.42.0
ujson                                 5.11.0
umap-learn                            0.5.9.post2
umf                                   0.11.0
update-checker                        0.18.0
uri-template                          1.3.0
uritemplate                           4.2.0
urllib3                               2.5.0
urwid                                 3.0.3
urwid_readline                        0.15.1
uvicorn                               0.35.0
vega-datasets                         0.9.0
visions                               0.8.1
vtk                                   9.3.1
wadllib                               1.3.6
Wand                                  0.6.13
wandb                                 0.21.0
wasabi                                1.1.3
watchdog                              6.0.0
wavio                                 0.0.9
wcwidth                               0.2.13
weasel                                0.4.1
webcolors                             24.11.1
webencodings                          0.5.1
websocket-client                      1.8.0
websockets                            15.0.1
Werkzeug                              3.1.3
wheel                                 0.45.1
widgetsnbextension                    4.0.14
woodwork                              0.31.0
wordcloud                             1.9.4
wrapt                                 1.17.2
wurlitzer                             3.1.1
xarray                                2025.7.1
xarray-einstats                       0.9.1
xgboost                               2.0.3
xlrd                                  2.0.2
xvfbwrapper                           0.2.14
xxhash                                3.5.0
xyzservices                           2025.4.0
y-py                                  0.6.2
yarl                                  1.20.1
ydata-profiling                       4.17.0
ydf                                   0.9.0
yellowbrick                           1.5
yfinance                              0.2.65
ypy-websocket                         0.8.4
zict                                  3.0.0
zipp                                  3.23.0
zstandard                             0.23.0

相关介绍

  • Python是一种跨平台的计算机程序设计语言。是一个高层次的结合了解释性、编译性、互动性和面向对象的脚本语言。最初被设计用于编写自动化脚本(shell),随着版本的不断更新和语言新功能的添加,越多被用于独立的、大型项目的开发。
  • PyTorch 是一个深度学习框架,封装好了很多网络和深度学习相关的工具方便我们调用,而不用我们一个个去单独写了。它分为 CPU 和 GPU 版本,其他框架还有 TensorFlow、Caffe 等。PyTorch 是由 Facebook 人工智能研究院(FAIR)基于 Torch 推出的,它是一个基于 Python 的可续计算包,提供两个高级功能:1、具有强大的 GPU 加速的张量计算(如 NumPy);2、构建深度神经网络时的自动微分机制。
  • 扩散模型(Diffusion Models)是一类强大的生成模型,近年来在图像、音频、文本等生成任务中取得了突破性成果。它们通过模拟一个逐步加噪(前向过程)和逐步去噪(反向过程)的机制,学习如何从纯噪声中重建出真实数据。
  • 核心思想
    扩散模型的灵感来源于非平衡热力学:
    • 前向过程(Forward Process):将真实数据(如一张图像)逐步加入高斯噪声,经过若干步后,数据最终变成完全的随机噪声。
    • 反向过程(Reverse Process):训练一个神经网络,学习如何从噪声中一步步“去噪”,最终还原出类似原始数据的新样本。
  • 这个过程类似于“破坏-重建”:先慢慢把一张清晰的图片弄模糊直至完全看不清,再教会模型如何从模糊中恢复清晰图像。
    在这里插入图片描述

具体实现

导入相关库

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

准备数据集

dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys');
Input shape: torch.Size([8, 1, 28, 28])
Labels: tensor([2, 2, 0, 8, 8, 6, 2, 2])

在这里插入图片描述

定义加噪函数

  • 假设你从未读过任何扩散模型论文,但你知道该过程涉及添加噪声。你会如何实现?
  • 通过一个简单的方法来控制加噪的程度。那么,如果我们引入一个参数来指定要添加的噪声量,然后执行以下操作:
noise = torch.rand like(x)
noisyx=(1-amount)*x+ amount*noise
  • 如果amount=0,我们就会原封不动地返回输入值。如果amount达到1,我们就会返回与输入值x毫无关联的噪声。
  • 通过这种方式将输入值与噪声混合,我们可以保持输出值在相同的范围内(0到1)。
def corrupt(x, amount):
  """Corrupt the input `x` by mixing it with noise according to `amount`"""
  noise = torch.rand_like(x)
  amount = amount.view(-1, 1, 1, 1) # Sort shape so broadcasting works
  return x*(1-amount) + noise*amount 

定义网络模型

在这里插入图片描述

class BasicUNet(nn.Module):
    """A minimal UNet implementation."""
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()
        self.down_layers = torch.nn.ModuleList([ 
            nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
            nn.Conv2d(32, 64, kernel_size=5, padding=2),
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
        ])
        self.up_layers = torch.nn.ModuleList([
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
            nn.Conv2d(64, 32, kernel_size=5, padding=2),
            nn.Conv2d(32, out_channels, kernel_size=5, padding=2), 
        ])
        self.act = nn.SiLU() # The activation function
        self.downscale = nn.MaxPool2d(2)
        self.upscale = nn.Upsample(scale_factor=2)

    def forward(self, x):
        h = []
        for i, l in enumerate(self.down_layers):
            x = self.act(l(x)) # Through the layer and the activation function
            if i < 2: # For all but the third (final) down layer:
              h.append(x) # Storing output for skip connection
              x = self.downscale(x) # Downscale ready for the next layer
              
        for i, l in enumerate(self.up_layers):
            if i > 0: # For all except the first up layer
              x = self.upscale(x) # Upscale
              x += h.pop() # Fetching stored output (skip connection)
            x = self.act(l(x)) # Through the layer and the activation function
            
        return x

net = BasicUNet()
x = torch.rand(8, 1, 28, 28)
net(x).shape
sum([p.numel() for p in net.parameters()])
torch.Size([8, 1, 28, 28])
309057

训练模型

# Dataloader (you can mess with batch size)
batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# How many runs through the data should we do?
n_epochs = 3

# Create the network
net = BasicUNet()
net.to(device)

# Our loss function
loss_fn = nn.MSELoss()

# The optimizer
opt = torch.optim.Adam(net.parameters(), lr=1e-3) 

# Keeping a record of the losses for later viewing
losses = []

# The training loop
for epoch in range(n_epochs):

    for x, y in train_dataloader:

        # Get some data and prepare the corrupted version
        x = x.to(device) # Data on the GPU
        noise_amount = torch.rand(x.shape[0]).to(device) # Pick random noise amounts
        noisy_x = corrupt(x, noise_amount) # Create our noisy x
        
        # # 修复:创建子图并添加 plt.show()
        # fig, axs = plt.subplots(1, 2, figsize=(10, 5))
        
        # # 显示原始图像
        # axs[0].imshow(x[0].cpu().squeeze(), cmap='Greys')
        # axs[0].set_title(f'Original Image{str(x[0].shape)}')
        # axs[0].axis('off')
        
        # # 显示噪声图像
        # axs[1].imshow(noisy_x[0].cpu().squeeze(), cmap='Greys')
        # axs[1].set_title(f'Noisy Image{str(noisy_x[0].shape)}')
        # axs[1].axis('off')

        # Get the model prediction
        pred = net(noisy_x)

        # Calculate the loss
        loss = loss_fn(pred, x) # How close is the output to the true 'clean' x?

        # Backprop and update the params:
        opt.zero_grad()
        loss.backward()
        opt.step()

        # Store the loss for later
        losses.append(loss.item())

    # Print our the average of the loss values for this epoch:
    avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)
    print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}')

# View the loss curve
plt.plot(losses)
plt.ylim(0, 0.1);
Finished epoch 0. Average loss for this epoch: 0.025291
Finished epoch 1. Average loss for this epoch: 0.019563
Finished epoch 2. Average loss for this epoch: 0.017892

在这里插入图片描述

扩散生成

#@markdown Visualizing model predictions on noisy inputs:

# Fetch some data
x, y = next(iter(train_dataloader))
x = x[:8] # Only using the first 8 for easy plotting

# Corrupt with a range of amounts
amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption
noised_x = corrupt(x, amount)

# Get the model predictions
with torch.no_grad():
  preds = net(noised_x.to(device)).detach().cpu()

# Plot
fig, axs = plt.subplots(3, 1, figsize=(12, 7))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0, 1), cmap='Greys')
axs[1].set_title('Corrupted data')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0, 1), cmap='Greys')
axs[2].set_title('Network Predictions')
axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0, 1), cmap='Greys');

输出结果

在这里插入图片描述

对于较低的数据量,预测结果相当不错!但是,当amount变得很高时,模型可以利用的东西就少了,当amount=1时,它就会输出一个接近数据集平均值的模糊结果。

#@markdown Sampling strategy: Break the process into 5 steps and move 1/5'th of the way there each time:
n_steps = 5
x = torch.rand(8, 1, 28, 28).to(device) # Start from random
step_history = [x.detach().cpu()]
pred_output_history = []

for i in range(n_steps):
    with torch.no_grad(): # No need to track gradients during inference
        pred = net(x) # Predict the denoised x0
    pred_output_history.append(pred.detach().cpu()) # Store model output for plotting
    mix_factor = 1/(n_steps - i) # How much we move towards the prediction
    x = x*(1-mix_factor) + pred*mix_factor # Move part of the way there
    step_history.append(x.detach().cpu()) # Store step for plotting

fig, axs = plt.subplots(n_steps, 2, figsize=(9, 4), sharex=True)
axs[0,0].set_title('x (model input)')
axs[0,1].set_title('model prediction')
for i in range(n_steps):
    axs[i, 0].imshow(torchvision.utils.make_grid(step_history[i])[0].clip(0, 1), cmap='Greys')
    axs[i, 1].imshow(torchvision.utils.make_grid(pred_output_history[i])[0].clip(0, 1), cmap='Greys')

输出结果

在这里插入图片描述

#@markdown Showing more results, using 40 sampling steps
n_steps = 40
x = torch.rand(64, 1, 28, 28).to(device)
for i in range(n_steps):
  noise_amount = torch.ones((x.shape[0], )).to(device) * (1-(i/n_steps)) # Starting high going low
  with torch.no_grad():
    pred = net(x)
  mix_factor = 1/(n_steps - i)
  x = x*(1-mix_factor) + pred*mix_factor
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap='Greys')

输出结果

在这里插入图片描述

虽然不是很好,但也有一些可识别的数字!你可以尝试延长训练时间(比如 10 或 20 个epoch),并调整模型配置、学习率、优化器等。此外,如果你想尝试难度稍高的数据集,可以使用 fashionMNIST 是数据集。

参考

[1] https://huggingface.co/learn/diffusion-course/

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

FriendshipT

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值