SVGenius-main\src\generation\style_trans
score_rubrics.json
[
{
"criteria_description": "Style Consistency",
"score_1": "The transformed image completely fails to adopt the target style. Elements are rendered in the original style with no visible stylistic transformation.",
"score_2": "The transformed image shows minimal attempt at the target style, but most elements remain in the original style. Only superficial changes are visible.",
"score_3": "The transformed image partially adopts the target style. Some elements show clear stylistic transformation while others remain inconsistent.",
"score_4": "The transformed image mostly achieves the target style with good consistency. Minor elements may not fully align with the target style.",
"score_5": "The transformed image perfectly embodies the target style. All elements consistently reflect the stylistic characteristics with professional quality."
},
{
"criteria_description": "Content Preservation",
"score_1": "The transformed image completely loses the original content. Key elements, structure, or semantic meaning are unrecognizable.",
"score_2": "The transformed image preserves very little of the original content. Major structural elements are missing or severely distorted.",
"score_3": "The transformed image preserves some original content, but important details or structural elements are altered or lost.",
"score_4": "The transformed image preserves most original content with good fidelity. Minor details may differ from the original.",
"score_5": "The transformed image perfectly preserves all original content, structure, and semantic meaning while applying the target style."
},
{
"criteria_description": "Visual Quality",
"score_1": "The transformed image has severe visual artifacts, distortions, or rendering issues. The overall visual quality is unacceptable.",
"score_2": "The transformed image has noticeable visual problems such as artifacts, poor rendering, or unappealing aesthetics.",
"score_3": "The transformed image has acceptable visual quality but shows some imperfections in rendering or overall aesthetics.",
"score_4": "The transformed image has good visual quality with clean rendering and appealing aesthetics. Minor imperfections may exist.",
"score_5": "The transformed image has excellent visual quality with professional-grade rendering, smooth aesthetics, and no visible artifacts."
}
]
svg_to_png.py
"""
辅助脚本:将 style_trans.py 生成的 SVG 文件转换为 PNG,
以便 evaluation.py 使用(evaluation.py 需要 PNG 路径进行 base64 编码)。
用法:
cd src
python -m generation.style_trans.svg_to_png results/style_trans/easy_gen.json
输入:style_trans.py 的输出 JSON(含 output_paths.gen_svg 等字段)
输出:同目录下 _png.json 文件,其中 transferred_image_path 已替换为 PNG 路径
"""
import cairosvg
import json
import os
import sys
def main():
if len(sys.argv) < 2:
print("用法: python svg_to_png.py <style_trans_output.json>")
sys.exit(1)
gen_json = sys.argv[1]
with open(gen_json, 'r', encoding='utf-8') as f:
data = json.load(f)
converted = 0
for item in data.get('results', []):
paths = item.get('output_paths', {})
for key in ['gt_svg', 'gen_svg']:
svg_path = paths.get(key, '')
if svg_path and os.path.exists(svg_path):
png_path = svg_path.replace('.svg', '.png')
try:
cairosvg.svg2png(url=svg_path, write_to=png_path)
item[f'{key}_png'] = png_path
converted += 1
print(f" [OK] {svg_path} -> {png_path}")
except Exception as e:
print(f" [FAIL] {svg_path}: {e}")
# 更新 transferred_image_path 为 PNG 路径
if 'output_paths' in item:
gen_svg = item['output_paths'].get('gen_svg', '')
if gen_svg:
item['transferred_image_path'] = gen_svg.replace('.svg', '.png')
output_json = gen_json.replace('.json', '_png.json')
with open(output_json, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
print(f"\n转换完成: {converted} 个 SVG -> PNG")
print(f"输出文件: {output_json}")
print(f"请将此文件作为 evaluation.py 的 --input 参数")
if __name__ == '__main__':
main()
./style_trans/easy_icons/svg_captions.json
[
{
"image_path": "../../../data/easy/process/page_27_033-线面常用_47315_icon_7.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/easy/process/page_381_交通线图标_20160_icon_82.png",
"question": "3D-style"
},
{
"image_path": "../../../data/easy/process/page_43_zpy_icon_excel工作站_44295_icon_10.png",
"question": "Line-art"
},
{
"image_path": "../../../data/easy/process/page_397_多色施工类icon_6158_icon_21.png",
"question": "Line-art"
},
{
"image_path": "../../../data/easy/process/page_328_饰品icon_10850_icon_14.png",
"question": "Line-art"
},
{
"image_path": "../../../data/easy/process/page_42_一张图标签_44528_icon_46.png",
"question": "3D-style"
},
{
"image_path": "../../../data/easy/process/page_67_交通工具_40671_icon_19.png",
"question": "Cartoon Style"
},
{
"image_path": "../../../data/easy/process/page_32_酷趣创易_电商_icon031_46250_icon_2.png",
"question": "3D-style"
},
{
"image_path": "../../../data/easy/process/page_171_天气_22770_icon_23.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/easy/process/page_202_特步_20790_icon_12.png",
"question": "3D-style"
},
{
"image_path": "../../../data/easy/process/page_204_全棉时代_20750_icon_6.png",
"question": "3D-style"
},
{
"image_path": "../../../data/easy/process/page_36_008-美食_45398_icon_18.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/easy/process/page_307_双色线性ICON_12087_icon_4.png",
"question": "Cartoon Style"
},
{
"image_path": "../../../data/easy/process/page_78_互联网常用_38294_icon_5.png",
"question": "Cartoon Style"
},
{
"image_path": "../../../data/easy/process/page_42_面性双色图标_44434_icon_14.png",
"question": "Line-art"
},
{
"image_path": "../../../data/easy/process/page_72_工作台_39755_icon_36.png",
"question": "Cartoon Style"
},
{
"image_path": "../../../data/easy/process/服务logo_46449_icon_16.png",
"question": "Cartoon Style"
},
{
"image_path": "../../../data/easy/process/page_79_易魔方常用面性图标库_38236_icon_40.png",
"question": "Cartoon Style"
},
{
"image_path": "../../../data/easy/process/page_50_035-常用线面_47484_icon_25.png",
"question": "Line-art"
},
{
"image_path": "../../../data/easy/process/page_284_蜗牛小程序_25391_icon_104.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/easy/process/page_175_美妆图标_35309_icon_14.png",
"question": "3D-style"
}
]
./style_trans/complex_icons/svg_captions.json
[
{
"image_path": "../../../data/medium/process/page_51_小假哥_多彩_icon104_42838_icon_3.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/medium/process/page_372_春节系列图标_7969_icon_7.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/medium/process/page_332_花店通用icon_10772_icon_9.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/medium/process/page_397_多色施工类icon_6158_icon_9.png",
"question": "Line-art"
},
{
"image_path": "../../../data/medium/process/page_89_线性旅行图标_36037_icon_19.png",
"question": "3D-style"
},
{
"image_path": "../../../data/medium/process/page_72_32位像素风emoji_39823_icon_13.png",
"question": "3D-style"
},
{
"image_path": "../../../data/medium/process/page_12_教育学习_48779_icon_17.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/medium/process/page_10_旅行旅游景点_48951_icon_5.png",
"question": "3D-style"
},
{
"image_path": "../../../data/medium/process/page_392_扁平化科学领域图标_7084_icon_84.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/medium/process/page_207_衣服饰品图标_20559_icon_72.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/medium/process/page_54_小假哥_多彩_icon070_42746_icon_25.png",
"question": "Line-art"
},
{
"image_path": "../../../data/medium/process/page_268_美食icon_14970_icon_6.png",
"question": "Cartoon Style"
},
{
"image_path": "../../../data/medium/process/page_392_扁平化科学领域图标_7084_icon_41.png",
"question": "3D-style"
},
{
"image_path": "../../../data/medium/process/page_394_礼物多色小图标_6623_icon_18.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/medium/process/page_309_水果_11820_icon_21.png",
"question": "3D-style"
},
{
"image_path": "../../../data/medium/process/page_33_旅游_46080_icon_59.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/medium/process/page_61_小假哥_多彩_icon022_42149_icon_14.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/medium/process/page_208_食物多色图标_20440_icon_22.png",
"question": "Cartoon Style"
},
{
"image_path": "../../../data/medium/process/page_208_常购商品_20444_icon_4.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/medium/process/page_701_小icon_905_icon_37.png",
"question": "Line-art"
},
{
"image_path": "../../../data/medium/process/page_412_美味的食物_2134_icon_68.png",
"question": "Line-art"
},
{
"image_path": "../../../data/medium/process/page_393_面性图标_6996_icon_83.png",
"question": "Line-art"
},
{
"image_path": "../../../data/medium/process/page_676_中间页前链_1653_icon_5.png",
"question": "Line-art"
},
{
"image_path": "../../../data/medium/process/page_208_商家活动页设计大赛图标_20505_icon_15.png",
"question": "Line-art"
},
{
"image_path": "../../../data/medium/process/page_53_小假哥_多彩_icon083_42776_icon_262.png",
"question": "Line-art"
},
{
"image_path": "../../../data/medium/process/page_74_政务信息_39417_icon_48.png",
"question": "Cartoon Style"
},
{
"image_path": "../../../data/medium/process/page_59_天气tianqi_42524_icon_12.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/medium/process/page_228_工作圈_18217_icon_12.png",
"question": "Line-art"
},
{
"image_path": "../../../data/medium/process/page_19_赛博图标_49623_icon_14.png",
"question": "Line-art"
},
{
"image_path": "../../../data/medium/process/page_92_美妆图标_35312_icon_6.png",
"question": "Pixel-art"
}
]
./style_trans/illustrations/svg_captions.json
[
{
"image_path": "../../../data/hard/process/page_37_教育图标_45273_icon_0.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/hard/process/page_92_线性游戏兴趣爱好_35314_icon_5.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/hard/process/page_92_教育_35436_icon_17.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/hard/process/page_34_树木相关森林图标_45903_icon_11.png",
"question": "3D-style"
},
{
"image_path": "../../../data/hard/process/page_413_互联网系列_2001_icon_15.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/hard/process/page_393_食物2_6971_icon_17.png",
"question": "Cartoon Style"
},
{
"image_path": "../../../data/hard/process/page_95_手绘人物_34981_icon_5.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/hard/process/page_297_研究生系统_24310_icon_11.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/hard/process/page_70_甜甜圈_40334_icon_4.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/hard/process/page_220_IT软件开发职业_30298_icon_29.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/hard/process/page_300_马上创业网_12940_icon_61.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/hard/process/page_11_肖像_48949_icon_2.png",
"question": "Line-art"
},
{
"image_path": "../../../data/hard/process/page_59_婴幼玩具_42427_icon_2.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/hard/process/page_11_酷趣创易_美食_icon086_48856_icon_8.png",
"question": "Line-art"
},
{
"image_path": "../../../data/hard/process/page_70_像素鞋子图标_40194_icon_12.png",
"question": "Cartoon Style"
},
{
"image_path": "../../../data/hard/process/page_98_牙齿类彩色图标_34256_icon_45.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/hard/process/page_3_海边商业休闲元素_50207_icon_24.png",
"question": "3D-style"
},
{
"image_path": "../../../data/hard/process/page_14_商务_48561_icon_44.png",
"question": "Line-art"
},
{
"image_path": "../../../data/hard/process/page_33_美食大全_46021_icon_0.png",
"question": "Line-art"
},
{
"image_path": "../../../data/hard/process/page_412_食物图标_2108_icon_43.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/hard/process/page_7_花_49471_icon_9.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/hard/process/page_10_06-植物icon_49104_icon_16.png",
"question": "3D-style"
},
{
"image_path": "../../../data/hard/process/page_3_健身器材_50284_icon_19.png",
"question": "Cartoon Style"
},
{
"image_path": "../../../data/hard/process/page_251_费哲软件IWMS_27516_icon_60.png",
"question": "Cartoon Style"
},
{
"image_path": "../../../data/hard/process/page_79_漫威英雄_38147_icon_15.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/hard/process/page_54_小假哥_债务_icon077_42753_icon_11.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/hard/process/page_30_线性3d图标_46725_icon_56.png",
"question": "Line-art"
},
{
"image_path": "../../../data/hard/process/page_309_水果_11820_icon_15.png",
"question": "Pixel-art"
},
{
"image_path": "../../../data/hard/process/page_395_设计器图标_6517_icon_5.png",
"question": "Cartoon Style"
}
]
download_models_offline.py
"""
SVGenius 离线模型预下载脚本
============================
用途:在有网络的环境中运行此脚本,将评估所需的 6 个 CV 模型提前下载到本地。
离线运行评估时,模型将从本地缓存加载,无需联网。
下载的模型:
1. CLIP ViT-L/14 → ~/.cache/clip/ (~890MB)
2. aesthetic-predictor-v2-5 → ~/.cache/huggingface/ (~1.5GB)
3. HPSv2 v2.1 → 通过 hpsv2 库自动管理 (~3GB)
4. DINOv2-base → ./models/dinov2-base/ (~350MB)
5. LPIPS (VGG) → ~/.cache/torch/hub/ (~50MB)
6. InceptionV3 (FID) → ~/.cache/torch/hub/ (~100MB)
使用方法(在有网络的机器上):
python download_models_offline.py [--target-dir ./models]
然后复制整个 SVGenius 目录到离线机器即可。
"""
import os
import sys
import argparse
def download_clip():
"""下载 CLIP ViT-L/14 模型"""
print("\n" + "=" * 60)
print("[1/6] 下载 CLIP ViT-L/14 模型 (~890MB)")
print("=" * 60)
import clip
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model, preprocess = clip.load("ViT-L/14", device=device, jit=False)
print(" ✓ CLIP ViT-L/14 下载完成,缓存于 ~/.cache/clip/")
def download_aesthetic():
"""下载 aesthetic-predictor-v2-5 美学评分模型"""
print("\n" + "=" * 60)
print("[2/6] 下载 aesthetic-predictor-v2-5 美学评分模型 (~1.5GB)")
print("=" * 60)
import torch
from aesthetic_predictor_v2_5 import convert_v2_5_from_siglip
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model, preprocessor = convert_v2_5_from_siglip(
low_cpu_mem_usage=True,
trust_remote_code=True,
)
print(" ✓ aesthetic-predictor-v2-5 下载完成,缓存于 ~/.cache/huggingface/")
def download_hpsv2():
"""下载 HPSv2 人类偏好评分模型"""
print("\n" + "=" * 60)
print("[3/6] 下载 HPSv2 v2.1 人类偏好模型 (~3GB)")
print("=" * 60)
import hpsv2
from PIL import Image
# 用一张简单图片触发模型下载
img = Image.new('RGB', (224, 224), color='white')
try:
score = hpsv2.score(img, "a simple white image", hps_version="v2.1")
print(f" ✓ HPSv2 v2.1 下载完成,测试得分: {score}")
except Exception as e:
print(f" ⚠ HPSv2 测试评分失败(可能是正常的): {e}")
print(" ✓ 模型已触发下载,缓存于 ~/.cache/huggingface/")
def download_dinov2(target_dir):
"""下载 DINOv2-base 模型"""
print("\n" + "=" * 60)
print("[4/6] 下载 DINOv2-base 模型 (~350MB)")
print("=" * 60)
from transformers import AutoModel, AutoImageProcessor
import torch
dinov2_dir = os.path.join(target_dir, "dinov2-base")
os.makedirs(dinov2_dir, exist_ok=True)
print(f" 目标目录: {dinov2_dir}")
model = AutoModel.from_pretrained("facebook/dinov2-base")
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
model.save_pretrained(dinov2_dir)
processor.save_pretrained(dinov2_dir)
print(f" ✓ DINOv2-base 已保存到 {dinov2_dir}")
def download_lpips():
"""下载 LPIPS VGG 模型"""
print("\n" + "=" * 60)
print("[5/6] 下载 LPIPS (VGG backbone) 模型 (~50MB)")
print("=" * 60)
import lpips
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = lpips.LPIPS(net='vgg').to(device)
print(" ✓ LPIPS (VGG) 下载完成,缓存于 ~/.cache/torch/hub/")
def download_inception(target_dir):
"""下载 InceptionV3 FID 权重"""
print("\n" + "=" * 60)
print("[6/6] 下载 InceptionV3 FID 权重 (~100MB)")
print("=" * 60)
import torch
from torchvision import models
from torch.hub import load_state_dict_from_url
# 1. 下载标准 InceptionV3 权重(torchvision 自动缓存)
_ = models.inception_v3(weights='DEFAULT')
print(" ✓ torchvision InceptionV3 权重下载完成")
# 2. 下载 FID 专用 Inception 权重
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
fid_dir = os.path.join(target_dir, "fid")
os.makedirs(fid_dir, exist_ok=True)
fid_path = os.path.join(fid_dir, "pt_inception-2015-12-05-6726825d.pth")
if not os.path.exists(fid_path):
print(f" 下载 FID Inception 权重到 {fid_path}...")
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
torch.save(state_dict, fid_path)
print(f" ✓ FID 权重已保存到 {fid_path}")
else:
print(f" ✓ FID 权重已存在: {fid_path}")
def print_offline_guide(target_dir):
"""打印离线部署指南"""
print("\n" + "=" * 60)
print(" 离线部署指南")
print("=" * 60)
print(f"""
要将模型部署到离线机器,需要复制以下目录:
1. CLIP ViT-L/14 (~/.cache/clip/)
源路径: {os.path.expanduser('~/.cache/clip/')}
目标路径: 离线机器的相同路径 (~/.cache/clip/)
2. HuggingFace 模型缓存 (~/.cache/huggingface/)
包含: aesthetic-predictor-v2-5, HPSv2
源路径: {os.path.expanduser('~/.cache/huggingface/hub/')}
目标路径: 离线机器的相同路径
3. DINOv2-base ({target_dir}/dinov2-base/)
源路径: {os.path.abspath(os.path.join(target_dir, 'dinov2-base'))}
目标路径: SVGenius/src/metrics/models/dinov2-base/(相对项目根目录)
4. PyTorch Hub 缓存 (~/.cache/torch/hub/)
包含: LPIPS (VGG), InceptionV3
源路径: {os.path.expanduser('~/.cache/torch/hub/')}
目标路径: 离线机器的相同路径
5. FID 专用权重 ({target_dir}/fid/)
源路径: {os.path.abspath(os.path.join(target_dir, 'fid'))}
目标路径: SVGenius/src/metrics/models/fid/
====================================================================
重要:离线机器上需设置环境变量以禁用网络访问
====================================================================
export HF_HUB_OFFLINE=1 # 禁用 HuggingFace 在线下载
export TRANSFORMERS_OFFLINE=1 # 禁用 Transformers 在线下载
export HF_DATASETS_OFFLINE=1 # 禁用 Datasets 在线下载
或者直接在 Python 代码中设置:
import os
os.environ['HF_HUB_OFFLINE'] = '1'
os.environ['TRANSFORMERS_OFFLINE'] = '1'
====================================================================
""")
def main():
parser = argparse.ArgumentParser(description="SVGenius 离线模型预下载脚本")
parser.add_argument(
"--target-dir", type=str, default="./models",
help="模型保存目录 (默认: ./models,相对于 SVGenius 项目根目录)"
)
parser.add_argument(
"--skip", type=str, nargs="*", default=[],
choices=["clip", "aesthetic", "hpsv2", "dinov2", "lpips", "inception"],
help="跳过的模型 (可选: clip, aesthetic, hpsv2, dinov2, lpips, inception)"
)
args = parser.parse_args()
target_dir = os.path.abspath(args.target_dir)
os.makedirs(target_dir, exist_ok=True)
print("=" * 60)
print(" SVGenius 离线模型预下载")
print(f" 保存目录: {target_dir}")
print("=" * 60)
print("\n需要下载 6 个模型,总大小约 6GB,请确保磁盘空间充足。\n")
models_to_download = {
"clip": download_clip,
"aesthetic": download_aesthetic,
"hpsv2": download_hpsv2,
"dinov2": lambda: download_dinov2(target_dir),
"lpips": download_lpips,
"inception": lambda: download_inception(target_dir),
}
for name, func in models_to_download.items():
if name in args.skip:
print(f"\n ⏭ 跳过 {name}")
continue
try:
func()
except Exception as e:
print(f"\n ✗ {name} 下载失败: {e}")
print(f" 请检查网络连接后重试,或使用 --skip {name} 跳过此模型")
print_offline_guide(target_dir)
if __name__ == "__main__":
main()
\SVGenius-main\src\metrics\compute_clip_score.py
import os
class CLIPScoreCalculator(BaseMetric):
def __init__(self, download_root=None):
super().__init__()
self.class_name = self.__class__.__name__
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Set download directory
if download_root is None:
# 优先使用项目本地模型目录,其次使用默认缓存目录
local_model_dir = os.path.join(os.path.dirname(__file__), 'models', 'clip')
if os.path.exists(os.path.join(local_model_dir, 'ViT-L-14.pt')):
download_root = local_model_dir
else:
download_root = os.path.expanduser("~/.cache/clip")
# Load CLIP model
self.clip_model, self.preprocess = clip.load("ViT-L/14", device=self.device, jit=False,
download_root=download_root)
if self.device == "cpu":
self.clip_model.float()
else:
clip.model.convert_weights(self.clip_model)
# Freeze logit_scale parameter
self.clip_model.logit_scale.requires_grad_(False)
SVGenius-main\src\metrics\compute_dino_score.py
class DINOScoreCalculator(BaseMetric):
def __init__(self, config=None, device='cuda'):
super().__init__()
self.class_name = self.__class__.__name__
self.config = config
self.model, self.processor = self.get_DINOv2_model("base")
self.model = self.model.to(device)
self.device = device
self.metric = self.calculate_DINOv2_similarity_score
def get_DINOv2_model(self, model_size):
# 优先使用本地模型路径(离线兼容)
local_dinov2 = os.path.join(os.path.dirname(__file__), 'models', 'dinov2-base')
if os.path.exists(local_dinov2) and os.path.isdir(local_dinov2):
model_path = local_dinov2
elif os.path.exists("./dinov2-base"):
model_path = "./dinov2-base"
elif model_size == "small":
model_path = "./dinov2-base"
elif model_size == "base":
model_path = "./dinov2-base"
elif model_size == "large":
model_path = "./dinov2-base"
else:
raise ValueError(f"model_size should be either 'small', 'base' or 'large', got {model_size}")
return AutoModel.from_pretrained(model_path, local_files_only=True), AutoImageProcessor.from_pretrained(model_path, local_files_only=True)
SVGenius-main\src\metrics\compute_aesthetic_score.py
class AestheticScoreMetric(BaseMetric):
def __init__(self, batch_size=1):
super().__init__()
self.class_name = self.__class__.__name__
# Use single device instead of device_map="auto"
self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {self.device}")
# Load model and preprocessor (离线兼容:设置 local_files_only=True)
print("Loading model...")
try:
self.model, self.preprocessor = convert_v2_5_from_siglip(
low_cpu_mem_usage=True,
trust_remote_code=True,
)
except Exception:
# 离线模式回退
import os
os.environ.setdefault('HF_HUB_OFFLINE', '1')
self.model, self.preprocessor = convert_v2_5_from_siglip(
low_cpu_mem_usage=True,
trust_remote_code=True,
local_files_only=True,
)
# Explicitly move entire model to same device
self.model = self.model.float().to(self.device)
print("Model loading complete!")
# Ensure model is in evaluation mode
self.model.eval()
self.batch_size = batch_size
SVGenius-main\src\metrics\inception.py
import os
# 220
# 优先从本地文件加载 FID 权重(离线兼容)
local_fid_path = os.path.join(os.path.dirname(__file__), 'models', 'fid', 'pt_inception-2015-12-05-6726825d.pth')
if os.path.exists(local_fid_path):
state_dict = torch.load(local_fid_path, map_location='cpu', weights_only=True)
else:
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
inception.load_state_dict(state_dict)
return inception

2457

被折叠的 条评论
为什么被折叠?



