SVGenius项目,配套资料

SVGenius-main\src\generation\style_trans
score_rubrics.json

[
  {
    "criteria_description": "Style Consistency",
    "score_1": "The transformed image completely fails to adopt the target style. Elements are rendered in the original style with no visible stylistic transformation.",
    "score_2": "The transformed image shows minimal attempt at the target style, but most elements remain in the original style. Only superficial changes are visible.",
    "score_3": "The transformed image partially adopts the target style. Some elements show clear stylistic transformation while others remain inconsistent.",
    "score_4": "The transformed image mostly achieves the target style with good consistency. Minor elements may not fully align with the target style.",
    "score_5": "The transformed image perfectly embodies the target style. All elements consistently reflect the stylistic characteristics with professional quality."
  },
  {
    "criteria_description": "Content Preservation",
    "score_1": "The transformed image completely loses the original content. Key elements, structure, or semantic meaning are unrecognizable.",
    "score_2": "The transformed image preserves very little of the original content. Major structural elements are missing or severely distorted.",
    "score_3": "The transformed image preserves some original content, but important details or structural elements are altered or lost.",
    "score_4": "The transformed image preserves most original content with good fidelity. Minor details may differ from the original.",
    "score_5": "The transformed image perfectly preserves all original content, structure, and semantic meaning while applying the target style."
  },
  {
    "criteria_description": "Visual Quality",
    "score_1": "The transformed image has severe visual artifacts, distortions, or rendering issues. The overall visual quality is unacceptable.",
    "score_2": "The transformed image has noticeable visual problems such as artifacts, poor rendering, or unappealing aesthetics.",
    "score_3": "The transformed image has acceptable visual quality but shows some imperfections in rendering or overall aesthetics.",
    "score_4": "The transformed image has good visual quality with clean rendering and appealing aesthetics. Minor imperfections may exist.",
    "score_5": "The transformed image has excellent visual quality with professional-grade rendering, smooth aesthetics, and no visible artifacts."
  }
]

svg_to_png.py

"""
辅助脚本:将 style_trans.py 生成的 SVG 文件转换为 PNG,
以便 evaluation.py 使用(evaluation.py 需要 PNG 路径进行 base64 编码)。

用法:
    cd src
    python -m generation.style_trans.svg_to_png results/style_trans/easy_gen.json

输入:style_trans.py 的输出 JSON(含 output_paths.gen_svg 等字段)
输出:同目录下 _png.json 文件,其中 transferred_image_path 已替换为 PNG 路径
"""

import cairosvg
import json
import os
import sys


def main():
    if len(sys.argv) < 2:
        print("用法: python svg_to_png.py <style_trans_output.json>")
        sys.exit(1)

    gen_json = sys.argv[1]
    with open(gen_json, 'r', encoding='utf-8') as f:
        data = json.load(f)

    converted = 0
    for item in data.get('results', []):
        paths = item.get('output_paths', {})
        for key in ['gt_svg', 'gen_svg']:
            svg_path = paths.get(key, '')
            if svg_path and os.path.exists(svg_path):
                png_path = svg_path.replace('.svg', '.png')
                try:
                    cairosvg.svg2png(url=svg_path, write_to=png_path)
                    item[f'{key}_png'] = png_path
                    converted += 1
                    print(f"  [OK] {svg_path} -> {png_path}")
                except Exception as e:
                    print(f"  [FAIL] {svg_path}: {e}")

        # 更新 transferred_image_path 为 PNG 路径
        if 'output_paths' in item:
            gen_svg = item['output_paths'].get('gen_svg', '')
            if gen_svg:
                item['transferred_image_path'] = gen_svg.replace('.svg', '.png')

    output_json = gen_json.replace('.json', '_png.json')
    with open(output_json, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)

    print(f"\n转换完成: {converted} 个 SVG -> PNG")
    print(f"输出文件: {output_json}")
    print(f"请将此文件作为 evaluation.py 的 --input 参数")


if __name__ == '__main__':
    main()

./style_trans/easy_icons/svg_captions.json

[
  {
    "image_path": "../../../data/easy/process/page_27_033-线面常用_47315_icon_7.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/easy/process/page_381_交通线图标_20160_icon_82.png",
    "question": "3D-style"
  },
  {
    "image_path": "../../../data/easy/process/page_43_zpy_icon_excel工作站_44295_icon_10.png",
    "question": "Line-art"
  },
  {
    "image_path": "../../../data/easy/process/page_397_多色施工类icon_6158_icon_21.png",
    "question": "Line-art"
  },
  {
    "image_path": "../../../data/easy/process/page_328_饰品icon_10850_icon_14.png",
    "question": "Line-art"
  },
  {
    "image_path": "../../../data/easy/process/page_42_一张图标签_44528_icon_46.png",
    "question": "3D-style"
  },
  {
    "image_path": "../../../data/easy/process/page_67_交通工具_40671_icon_19.png",
    "question": "Cartoon Style"
  },
  {
    "image_path": "../../../data/easy/process/page_32_酷趣创易_电商_icon031_46250_icon_2.png",
    "question": "3D-style"
  },
  {
    "image_path": "../../../data/easy/process/page_171_天气_22770_icon_23.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/easy/process/page_202_特步_20790_icon_12.png",
    "question": "3D-style"
  },
  {
    "image_path": "../../../data/easy/process/page_204_全棉时代_20750_icon_6.png",
    "question": "3D-style"
  },
  {
    "image_path": "../../../data/easy/process/page_36_008-美食_45398_icon_18.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/easy/process/page_307_双色线性ICON_12087_icon_4.png",
    "question": "Cartoon Style"
  },
  {
    "image_path": "../../../data/easy/process/page_78_互联网常用_38294_icon_5.png",
    "question": "Cartoon Style"
  },
  {
    "image_path": "../../../data/easy/process/page_42_面性双色图标_44434_icon_14.png",
    "question": "Line-art"
  },
  {
    "image_path": "../../../data/easy/process/page_72_工作台_39755_icon_36.png",
    "question": "Cartoon Style"
  },
  {
    "image_path": "../../../data/easy/process/服务logo_46449_icon_16.png",
    "question": "Cartoon Style"
  },
  {
    "image_path": "../../../data/easy/process/page_79_易魔方常用面性图标库_38236_icon_40.png",
    "question": "Cartoon Style"
  },
  {
    "image_path": "../../../data/easy/process/page_50_035-常用线面_47484_icon_25.png",
    "question": "Line-art"
  },
  {
    "image_path": "../../../data/easy/process/page_284_蜗牛小程序_25391_icon_104.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/easy/process/page_175_美妆图标_35309_icon_14.png",
    "question": "3D-style"
  }
]

./style_trans/complex_icons/svg_captions.json

[
  {
    "image_path": "../../../data/medium/process/page_51_小假哥_多彩_icon104_42838_icon_3.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/medium/process/page_372_春节系列图标_7969_icon_7.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/medium/process/page_332_花店通用icon_10772_icon_9.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/medium/process/page_397_多色施工类icon_6158_icon_9.png",
    "question": "Line-art"
  },
  {
    "image_path": "../../../data/medium/process/page_89_线性旅行图标_36037_icon_19.png",
    "question": "3D-style"
  },
  {
    "image_path": "../../../data/medium/process/page_72_32位像素风emoji_39823_icon_13.png",
    "question": "3D-style"
  },
  {
    "image_path": "../../../data/medium/process/page_12_教育学习_48779_icon_17.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/medium/process/page_10_旅行旅游景点_48951_icon_5.png",
    "question": "3D-style"
  },
  {
    "image_path": "../../../data/medium/process/page_392_扁平化科学领域图标_7084_icon_84.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/medium/process/page_207_衣服饰品图标_20559_icon_72.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/medium/process/page_54_小假哥_多彩_icon070_42746_icon_25.png",
    "question": "Line-art"
  },
  {
    "image_path": "../../../data/medium/process/page_268_美食icon_14970_icon_6.png",
    "question": "Cartoon Style"
  },
  {
    "image_path": "../../../data/medium/process/page_392_扁平化科学领域图标_7084_icon_41.png",
    "question": "3D-style"
  },
  {
    "image_path": "../../../data/medium/process/page_394_礼物多色小图标_6623_icon_18.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/medium/process/page_309_水果_11820_icon_21.png",
    "question": "3D-style"
  },
  {
    "image_path": "../../../data/medium/process/page_33_旅游_46080_icon_59.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/medium/process/page_61_小假哥_多彩_icon022_42149_icon_14.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/medium/process/page_208_食物多色图标_20440_icon_22.png",
    "question": "Cartoon Style"
  },
  {
    "image_path": "../../../data/medium/process/page_208_常购商品_20444_icon_4.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/medium/process/page_701_小icon_905_icon_37.png",
    "question": "Line-art"
  },
  {
    "image_path": "../../../data/medium/process/page_412_美味的食物_2134_icon_68.png",
    "question": "Line-art"
  },
  {
    "image_path": "../../../data/medium/process/page_393_面性图标_6996_icon_83.png",
    "question": "Line-art"
  },
  {
    "image_path": "../../../data/medium/process/page_676_中间页前链_1653_icon_5.png",
    "question": "Line-art"
  },
  {
    "image_path": "../../../data/medium/process/page_208_商家活动页设计大赛图标_20505_icon_15.png",
    "question": "Line-art"
  },
  {
    "image_path": "../../../data/medium/process/page_53_小假哥_多彩_icon083_42776_icon_262.png",
    "question": "Line-art"
  },
  {
    "image_path": "../../../data/medium/process/page_74_政务信息_39417_icon_48.png",
    "question": "Cartoon Style"
  },
  {
    "image_path": "../../../data/medium/process/page_59_天气tianqi_42524_icon_12.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/medium/process/page_228_工作圈_18217_icon_12.png",
    "question": "Line-art"
  },
  {
    "image_path": "../../../data/medium/process/page_19_赛博图标_49623_icon_14.png",
    "question": "Line-art"
  },
  {
    "image_path": "../../../data/medium/process/page_92_美妆图标_35312_icon_6.png",
    "question": "Pixel-art"
  }
]

./style_trans/illustrations/svg_captions.json

[
  {
    "image_path": "../../../data/hard/process/page_37_教育图标_45273_icon_0.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/hard/process/page_92_线性游戏兴趣爱好_35314_icon_5.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/hard/process/page_92_教育_35436_icon_17.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/hard/process/page_34_树木相关森林图标_45903_icon_11.png",
    "question": "3D-style"
  },
  {
    "image_path": "../../../data/hard/process/page_413_互联网系列_2001_icon_15.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/hard/process/page_393_食物2_6971_icon_17.png",
    "question": "Cartoon Style"
  },
  {
    "image_path": "../../../data/hard/process/page_95_手绘人物_34981_icon_5.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/hard/process/page_297_研究生系统_24310_icon_11.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/hard/process/page_70_甜甜圈_40334_icon_4.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/hard/process/page_220_IT软件开发职业_30298_icon_29.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/hard/process/page_300_马上创业网_12940_icon_61.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/hard/process/page_11_肖像_48949_icon_2.png",
    "question": "Line-art"
  },
  {
    "image_path": "../../../data/hard/process/page_59_婴幼玩具_42427_icon_2.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/hard/process/page_11_酷趣创易_美食_icon086_48856_icon_8.png",
    "question": "Line-art"
  },
  {
    "image_path": "../../../data/hard/process/page_70_像素鞋子图标_40194_icon_12.png",
    "question": "Cartoon Style"
  },
  {
    "image_path": "../../../data/hard/process/page_98_牙齿类彩色图标_34256_icon_45.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/hard/process/page_3_海边商业休闲元素_50207_icon_24.png",
    "question": "3D-style"
  },
  {
    "image_path": "../../../data/hard/process/page_14_商务_48561_icon_44.png",
    "question": "Line-art"
  },
  {
    "image_path": "../../../data/hard/process/page_33_美食大全_46021_icon_0.png",
    "question": "Line-art"
  },
  {
    "image_path": "../../../data/hard/process/page_412_食物图标_2108_icon_43.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/hard/process/page_7_花_49471_icon_9.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/hard/process/page_10_06-植物icon_49104_icon_16.png",
    "question": "3D-style"
  },
  {
    "image_path": "../../../data/hard/process/page_3_健身器材_50284_icon_19.png",
    "question": "Cartoon Style"
  },
  {
    "image_path": "../../../data/hard/process/page_251_费哲软件IWMS_27516_icon_60.png",
    "question": "Cartoon Style"
  },
  {
    "image_path": "../../../data/hard/process/page_79_漫威英雄_38147_icon_15.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/hard/process/page_54_小假哥_债务_icon077_42753_icon_11.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/hard/process/page_30_线性3d图标_46725_icon_56.png",
    "question": "Line-art"
  },
  {
    "image_path": "../../../data/hard/process/page_309_水果_11820_icon_15.png",
    "question": "Pixel-art"
  },
  {
    "image_path": "../../../data/hard/process/page_395_设计器图标_6517_icon_5.png",
    "question": "Cartoon Style"
  }
]

download_models_offline.py

"""
SVGenius 离线模型预下载脚本
============================
用途:在有网络的环境中运行此脚本,将评估所需的 6 个 CV 模型提前下载到本地。
离线运行评估时,模型将从本地缓存加载,无需联网。

下载的模型:
  1. CLIP ViT-L/14          → ~/.cache/clip/          (~890MB)
  2. aesthetic-predictor-v2-5 → ~/.cache/huggingface/  (~1.5GB)
  3. HPSv2 v2.1             → 通过 hpsv2 库自动管理     (~3GB)
  4. DINOv2-base            → ./models/dinov2-base/    (~350MB)
  5. LPIPS (VGG)            → ~/.cache/torch/hub/      (~50MB)
  6. InceptionV3 (FID)      → ~/.cache/torch/hub/      (~100MB)

使用方法(在有网络的机器上):
  python download_models_offline.py [--target-dir ./models]

然后复制整个 SVGenius 目录到离线机器即可。
"""
import os
import sys
import argparse


def download_clip():
    """下载 CLIP ViT-L/14 模型"""
    print("\n" + "=" * 60)
    print("[1/6] 下载 CLIP ViT-L/14 模型 (~890MB)")
    print("=" * 60)
    import clip
    import torch
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model, preprocess = clip.load("ViT-L/14", device=device, jit=False)
    print("  ✓ CLIP ViT-L/14 下载完成,缓存于 ~/.cache/clip/")


def download_aesthetic():
    """下载 aesthetic-predictor-v2-5 美学评分模型"""
    print("\n" + "=" * 60)
    print("[2/6] 下载 aesthetic-predictor-v2-5 美学评分模型 (~1.5GB)")
    print("=" * 60)
    import torch
    from aesthetic_predictor_v2_5 import convert_v2_5_from_siglip
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    model, preprocessor = convert_v2_5_from_siglip(
        low_cpu_mem_usage=True,
        trust_remote_code=True,
    )
    print("  ✓ aesthetic-predictor-v2-5 下载完成,缓存于 ~/.cache/huggingface/")


def download_hpsv2():
    """下载 HPSv2 人类偏好评分模型"""
    print("\n" + "=" * 60)
    print("[3/6] 下载 HPSv2 v2.1 人类偏好模型 (~3GB)")
    print("=" * 60)
    import hpsv2
    from PIL import Image
    # 用一张简单图片触发模型下载
    img = Image.new('RGB', (224, 224), color='white')
    try:
        score = hpsv2.score(img, "a simple white image", hps_version="v2.1")
        print(f"  ✓ HPSv2 v2.1 下载完成,测试得分: {score}")
    except Exception as e:
        print(f"  ⚠ HPSv2 测试评分失败(可能是正常的): {e}")
        print("  ✓ 模型已触发下载,缓存于 ~/.cache/huggingface/")


def download_dinov2(target_dir):
    """下载 DINOv2-base 模型"""
    print("\n" + "=" * 60)
    print("[4/6] 下载 DINOv2-base 模型 (~350MB)")
    print("=" * 60)
    from transformers import AutoModel, AutoImageProcessor
    import torch

    dinov2_dir = os.path.join(target_dir, "dinov2-base")
    os.makedirs(dinov2_dir, exist_ok=True)

    print(f"  目标目录: {dinov2_dir}")
    model = AutoModel.from_pretrained("facebook/dinov2-base")
    processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")

    model.save_pretrained(dinov2_dir)
    processor.save_pretrained(dinov2_dir)
    print(f"  ✓ DINOv2-base 已保存到 {dinov2_dir}")


def download_lpips():
    """下载 LPIPS VGG 模型"""
    print("\n" + "=" * 60)
    print("[5/6] 下载 LPIPS (VGG backbone) 模型 (~50MB)")
    print("=" * 60)
    import lpips
    import torch
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = lpips.LPIPS(net='vgg').to(device)
    print("  ✓ LPIPS (VGG) 下载完成,缓存于 ~/.cache/torch/hub/")


def download_inception(target_dir):
    """下载 InceptionV3 FID 权重"""
    print("\n" + "=" * 60)
    print("[6/6] 下载 InceptionV3 FID 权重 (~100MB)")
    print("=" * 60)
    import torch
    from torchvision import models
    from torch.hub import load_state_dict_from_url

    # 1. 下载标准 InceptionV3 权重(torchvision 自动缓存)
    _ = models.inception_v3(weights='DEFAULT')
    print("  ✓ torchvision InceptionV3 权重下载完成")

    # 2. 下载 FID 专用 Inception 权重
    FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
    fid_dir = os.path.join(target_dir, "fid")
    os.makedirs(fid_dir, exist_ok=True)
    fid_path = os.path.join(fid_dir, "pt_inception-2015-12-05-6726825d.pth")

    if not os.path.exists(fid_path):
        print(f"  下载 FID Inception 权重到 {fid_path}...")
        state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
        torch.save(state_dict, fid_path)
        print(f"  ✓ FID 权重已保存到 {fid_path}")
    else:
        print(f"  ✓ FID 权重已存在: {fid_path}")


def print_offline_guide(target_dir):
    """打印离线部署指南"""
    print("\n" + "=" * 60)
    print(" 离线部署指南")
    print("=" * 60)
    print(f"""
要将模型部署到离线机器,需要复制以下目录:

1. CLIP ViT-L/14 (~/.cache/clip/)
   源路径: {os.path.expanduser('~/.cache/clip/')}
   目标路径: 离线机器的相同路径 (~/.cache/clip/)

2. HuggingFace 模型缓存 (~/.cache/huggingface/)
   包含: aesthetic-predictor-v2-5, HPSv2
   源路径: {os.path.expanduser('~/.cache/huggingface/hub/')}
   目标路径: 离线机器的相同路径

3. DINOv2-base ({target_dir}/dinov2-base/)
   源路径: {os.path.abspath(os.path.join(target_dir, 'dinov2-base'))}
   目标路径: SVGenius/src/metrics/models/dinov2-base/(相对项目根目录)

4. PyTorch Hub 缓存 (~/.cache/torch/hub/)
   包含: LPIPS (VGG), InceptionV3
   源路径: {os.path.expanduser('~/.cache/torch/hub/')}
   目标路径: 离线机器的相同路径

5. FID 专用权重 ({target_dir}/fid/)
   源路径: {os.path.abspath(os.path.join(target_dir, 'fid'))}
   目标路径: SVGenius/src/metrics/models/fid/

====================================================================
 重要:离线机器上需设置环境变量以禁用网络访问
====================================================================
  export HF_HUB_OFFLINE=1           # 禁用 HuggingFace 在线下载
  export TRANSFORMERS_OFFLINE=1      # 禁用 Transformers 在线下载
  export HF_DATASETS_OFFLINE=1       # 禁用 Datasets 在线下载

 或者直接在 Python 代码中设置:
  import os
  os.environ['HF_HUB_OFFLINE'] = '1'
  os.environ['TRANSFORMERS_OFFLINE'] = '1'
====================================================================
""")


def main():
    parser = argparse.ArgumentParser(description="SVGenius 离线模型预下载脚本")
    parser.add_argument(
        "--target-dir", type=str, default="./models",
        help="模型保存目录 (默认: ./models,相对于 SVGenius 项目根目录)"
    )
    parser.add_argument(
        "--skip", type=str, nargs="*", default=[],
        choices=["clip", "aesthetic", "hpsv2", "dinov2", "lpips", "inception"],
        help="跳过的模型 (可选: clip, aesthetic, hpsv2, dinov2, lpips, inception)"
    )
    args = parser.parse_args()

    target_dir = os.path.abspath(args.target_dir)
    os.makedirs(target_dir, exist_ok=True)

    print("=" * 60)
    print(" SVGenius 离线模型预下载")
    print(f" 保存目录: {target_dir}")
    print("=" * 60)
    print("\n需要下载 6 个模型,总大小约 6GB,请确保磁盘空间充足。\n")

    models_to_download = {
        "clip": download_clip,
        "aesthetic": download_aesthetic,
        "hpsv2": download_hpsv2,
        "dinov2": lambda: download_dinov2(target_dir),
        "lpips": download_lpips,
        "inception": lambda: download_inception(target_dir),
    }

    for name, func in models_to_download.items():
        if name in args.skip:
            print(f"\n  ⏭ 跳过 {name}")
            continue
        try:
            func()
        except Exception as e:
            print(f"\n  ✗ {name} 下载失败: {e}")
            print(f"    请检查网络连接后重试,或使用 --skip {name} 跳过此模型")

    print_offline_guide(target_dir)


if __name__ == "__main__":
    main()

\SVGenius-main\src\metrics\compute_clip_score.py

import os
class CLIPScoreCalculator(BaseMetric):
    def __init__(self, download_root=None):
        super().__init__()
        self.class_name = self.__class__.__name__
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
        # Set download directory
        if download_root is None:
            # 优先使用项目本地模型目录,其次使用默认缓存目录
            local_model_dir = os.path.join(os.path.dirname(__file__), 'models', 'clip')
            if os.path.exists(os.path.join(local_model_dir, 'ViT-L-14.pt')):
                download_root = local_model_dir
            else:
                download_root = os.path.expanduser("~/.cache/clip")
        
        # Load CLIP model
        self.clip_model, self.preprocess = clip.load("ViT-L/14", device=self.device, jit=False, 
                                                    download_root=download_root)
        
        if self.device == "cpu":
            self.clip_model.float()
        else:
            clip.model.convert_weights(self.clip_model)
        
        # Freeze logit_scale parameter
        self.clip_model.logit_scale.requires_grad_(False)

SVGenius-main\src\metrics\compute_dino_score.py

class DINOScoreCalculator(BaseMetric): 
    def __init__(self, config=None, device='cuda'):
        super().__init__()
        self.class_name = self.__class__.__name__
        self.config = config
        self.model, self.processor = self.get_DINOv2_model("base")
        self.model = self.model.to(device)
        self.device = device

        self.metric = self.calculate_DINOv2_similarity_score

    def get_DINOv2_model(self, model_size):
        # 优先使用本地模型路径(离线兼容)
        local_dinov2 = os.path.join(os.path.dirname(__file__), 'models', 'dinov2-base')
        if os.path.exists(local_dinov2) and os.path.isdir(local_dinov2):
            model_path = local_dinov2
        elif os.path.exists("./dinov2-base"):
            model_path = "./dinov2-base"
        elif model_size == "small":
            model_path = "./dinov2-base"
        elif model_size == "base":
            model_path = "./dinov2-base"
        elif model_size == "large":
            model_path = "./dinov2-base"
        else:
            raise ValueError(f"model_size should be either 'small', 'base' or 'large', got {model_size}")
        return AutoModel.from_pretrained(model_path, local_files_only=True), AutoImageProcessor.from_pretrained(model_path, local_files_only=True)

SVGenius-main\src\metrics\compute_aesthetic_score.py

class AestheticScoreMetric(BaseMetric):
    def __init__(self, batch_size=1):
        super().__init__()
        self.class_name = self.__class__.__name__
        
        # Use single device instead of device_map="auto"
        self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {self.device}")
        
        # Load model and preprocessor (离线兼容:设置 local_files_only=True)
        print("Loading model...")
        try:
            self.model, self.preprocessor = convert_v2_5_from_siglip(
                low_cpu_mem_usage=True,
                trust_remote_code=True,
            )
        except Exception:
            # 离线模式回退
            import os
            os.environ.setdefault('HF_HUB_OFFLINE', '1')
            self.model, self.preprocessor = convert_v2_5_from_siglip(
                low_cpu_mem_usage=True,
                trust_remote_code=True,
                local_files_only=True,
            )
        
        # Explicitly move entire model to same device
        self.model = self.model.float().to(self.device)
        print("Model loading complete!")
        
        # Ensure model is in evaluation mode
        self.model.eval()
        
        self.batch_size = batch_size

SVGenius-main\src\metrics\inception.py

import os
# 220
    # 优先从本地文件加载 FID 权重(离线兼容)
    local_fid_path = os.path.join(os.path.dirname(__file__), 'models', 'fid', 'pt_inception-2015-12-05-6726825d.pth')
    if os.path.exists(local_fid_path):
        state_dict = torch.load(local_fid_path, map_location='cpu', weights_only=True)
    else:
        state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
    inception.load_state_dict(state_dict)
    return inception
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值