In [None]:
import h5py
import torch
import torch.nn.functional as F
from torchvision.models import inception_v3
from torchvision.transforms import functional as TF
from scipy.linalg import sqrtm
import numpy as np
from collections import defaultdict
from tqdm import tqdm

In [None]:
DATASET_PATH = "./datasets/Ra_128_indexed.h5"
DATASET_FAKE_PATH = "./datasets/Ra_128_indexed.h5"

device = device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
with h5py.File(DATASET_PATH, "r") as f:
    images = f["images"][:]
    labels = f["labels"][:]

with h5py.File(DATASET_FAKE_PATH, "r") as f:
    images_fake = f["images"][:]
    labels_fake = f["labels"][:]

# Reshape
images = np.transpose(images, (0, 3, 1, 2))
images_fake = np.transpose(images_fake, (0, 3, 1, 2))

# Print
print("Real Images")
print(images.shape)
print(labels.shape)
print("")
print("Fake Images")
print(images_fake.shape)
print(labels_fake.shape)

In [None]:
# 把 numpy 图像变成 torch tensor，并做 resize 和 normalize
def preprocess_batch_numpy_images(images_np, size=(299, 299)):
    images_tensor = torch.tensor(images_np, dtype=torch.float32) / 255.0  # [N, C, H, W]
    resized = F.interpolate(
        images_tensor, size=size, mode="bilinear", align_corners=False
    )
    # normalize to InceptionV3 style: mean=0.5, std=0.5 → [-1, 1]
    normalized = (resized - 0.5) / 0.5
    return normalized


# 提取 features，返回 numpy 数组
def get_features(images_np, model, batch_size=32):
    model.eval()
    device = next(model.parameters()).device
    all_feats = []
    with torch.no_grad():
        for i in range(0, len(images_np), batch_size):
            batch_np = images_np[i : i + batch_size]
            batch_tensor = preprocess_batch_numpy_images(batch_np).to(device)
            feats = model(batch_tensor)
            all_feats.append(feats.cpu())
    return torch.cat(all_feats, dim=0).numpy()


# Fréchet distance
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    covmean = sqrtm(sigma1 @ sigma2)
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    diff = mu1 - mu2
    return diff @ diff + np.trace(sigma1 + sigma2 - 2 * covmean)


# 分 label 分组特征
def group_features_by_label(features, labels):
    label_to_features = defaultdict(list)
    for feat, label in zip(features, labels):
        label_to_features[label].append(feat)
    return {label: np.stack(feats) for label, feats in label_to_features.items()}


# 最终主函数
def compute_sfid(images1_np, labels1_np, images2_np, labels2_np, batch_size=32):
    assert images1_np.shape[1:] == images2_np.shape[1:], "Image size mismatch"
    assert images1_np.shape[0] == labels1_np.shape[0]
    assert images2_np.shape[0] == labels2_np.shape[0]

    model = inception_v3(pretrained=True, transform_input=False)
    model.fc = torch.nn.Identity()
    model.eval().to(device)

    features1 = get_features(images1_np, model, batch_size)
    features2 = get_features(images2_np, model, batch_size)

    label_feats1 = group_features_by_label(features1, labels1_np)
    label_feats2 = group_features_by_label(features2, labels2_np)

    common_labels = set(label_feats1.keys()).intersection(label_feats2.keys())
    sfid_total = 0.0
    count = 0

    for label in common_labels:
        f1, f2 = label_feats1[label], label_feats2[label]
        if len(f1) < 2 or len(f2) < 2:
            continue
        mu1, sigma1 = np.mean(f1, axis=0), np.cov(f1, rowvar=False)
        mu2, sigma2 = np.mean(f2, axis=0), np.cov(f2, rowvar=False)
        fid = calculate_frechet_distance(mu1, sigma1, mu2, sigma2)
        sfid_total += fid
        count += 1

    return sfid_total / count if count > 0 else None

In [None]:
# images1_np: shape [N, 3, H, W], range [0, 255], dtype=np.uint8 or float
# labels1_np: shape [N], dtype=int or float

sfid = compute_sfid(images, labels, images_fake, labels_fake)
print(f"SFID: {sfid:.4f}")