ShieldGemma-2: Non-Reproducible Outputs Bug & Fix
Bug
ShieldGemma2ForImageClassification gives different outputs every time you load the model, even with the exact same image and settings.
Root Cause
The checkpoint has text_config.tie_word_embeddings = True, meaning lm_head.weight should be shared with embed_tokens.weight (and is NOT stored separately in the checkpoint). However, the outer wrapper class (ShieldGemma2ForImageClassification) does not propagate _tied_weights_keys, so lm_head.weight is randomly initialized on every load.
You can confirm this from the load report:
ShieldGemma2ForImageClassification LOAD REPORT from: google/shieldgemma-2-4b-it
Key | Status |
---------------------+---------+-
model.lm_head.weight | MISSING |
Fix (one line)
After loading, manually tie lm_head.weight to embed_tokens.weight:
model.model.lm_head.weight = model.model.get_input_embeddings().weight
Affected: transformers >= 4.49 (when ShieldGemma2ForImageClassification was added)
Model: google/shieldgemma-2-4b-it
Setup
# pip install transformers torch pillow requests
import torch, gc
from transformers import AutoProcessor, ShieldGemma2ForImageClassification
from PIL import Image
import requests
model_id = "google/shieldgemma-2-4b-it"
processor = AutoProcessor.from_pretrained(model_id)
# Load a test image (a bee -- should be SAFE for all policies)
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
inputs = processor(images=[image], return_tensors="pt")
Reproduce: BEFORE fix (outputs change on every load)
POLICIES = ["dangerous", "sexual", "violence"]
for i in range(3):
model = ShieldGemma2ForImageClassification.from_pretrained(
model_id, device_map="auto", torch_dtype=torch.bfloat16
).eval()
inp = {k: v.to(device=model.device, dtype=torch.bfloat16) if v.is_floating_point()
else v.to(device=model.device) for k, v in inputs.items()}
with torch.inference_mode():
scores = model(**inp)
# probabilities[:, 0] = "Yes" (violates/UNSAFE), [:, 1] = "No" (SAFE)
print(f"Load {i+1}:", {p: f"safe={scores.probabilities[j,1].item():.4f}"
for j, p in enumerate(POLICIES)})
del model; gc.collect(); torch.cuda.empty_cache()
Output (BEFORE fix)
Notice that lm_head.weight has a different sum on each load, and the safe/unsafe probabilities vary wildly:
MODEL LOAD 1/3
lm_head.weight sum=568.720520 norm=518.187866
[bee ] dangerous -> Safe=0.000000 Unsafe=1.000000
[bee ] sexual -> Safe=0.000032 Unsafe=1.000000
[bee ] violence -> Safe=0.000025 Unsafe=1.000000
[food ] dangerous -> Safe=0.039551 Unsafe=0.960938
[food ] sexual -> Safe=0.003769 Unsafe=0.996094
[food ] violence -> Safe=0.115723 Unsafe=0.882812
[cat ] dangerous -> Safe=0.000067 Unsafe=1.000000
[cat ] sexual -> Safe=0.003021 Unsafe=0.996094
[cat ] violence -> Safe=0.000265 Unsafe=1.000000
[woman-sexy ] dangerous -> Safe=0.003326 Unsafe=0.996094
[woman-sexy ] sexual -> Safe=0.216797 Unsafe=0.785156
[woman-sexy ] violence -> Safe=0.275391 Unsafe=0.726562
MODEL LOAD 2/3
lm_head.weight sum=359.003418 norm=518.184021
[bee ] dangerous -> Safe=0.000315 Unsafe=1.000000
[bee ] sexual -> Safe=0.003891 Unsafe=0.996094
[bee ] violence -> Safe=0.000062 Unsafe=1.000000
[food ] dangerous -> Safe=0.341797 Unsafe=0.660156
[food ] sexual -> Safe=0.156250 Unsafe=0.843750
[food ] violence -> Safe=0.016602 Unsafe=0.984375
[cat ] dangerous -> Safe=0.015442 Unsafe=0.984375
[cat ] sexual -> Safe=0.061035 Unsafe=0.937500
[cat ] violence -> Safe=0.000012 Unsafe=1.000000
[woman-sexy ] dangerous -> Safe=0.005920 Unsafe=0.992188
[woman-sexy ] sexual -> Safe=0.225586 Unsafe=0.773438
[woman-sexy ] violence -> Safe=0.220703 Unsafe=0.777344
MODEL LOAD 3/3
lm_head.weight sum=-883.367554 norm=518.172913
[bee ] dangerous -> Safe=1.000000 Unsafe=0.000001
[bee ] sexual -> Safe=1.000000 Unsafe=0.000018
[bee ] violence -> Safe=1.000000 Unsafe=0.000002
[food ] dangerous -> Safe=1.000000 Unsafe=0.000000
[food ] sexual -> Safe=0.996094 Unsafe=0.003479
[food ] violence -> Safe=0.855469 Unsafe=0.146484
[cat ] dangerous -> Safe=1.000000 Unsafe=0.001724
[cat ] sexual -> Safe=0.976562 Unsafe=0.025146
[cat ] violence -> Safe=0.816406 Unsafe=0.182617
[woman-sexy ] dangerous -> Safe=1.000000 Unsafe=0.000017
[woman-sexy ] sexual -> Safe=1.000000 Unsafe=0.000033
[woman-sexy ] violence -> Safe=1.000000 Unsafe=0.000210
CONFIRMED: lm_head.weight CHANGES across loads (randomly initialised). This is the root cause of non-reproducible outputs.
AFTER fix (outputs are identical across loads)
for i in range(3):
model = ShieldGemma2ForImageClassification.from_pretrained(
model_id, device_map="auto", torch_dtype=torch.bfloat16
).eval()
# THE FIX -- tie lm_head.weight to embed_tokens.weight
model.model.lm_head.weight = model.model.get_input_embeddings().weight
inp = {k: v.to(device=model.device, dtype=torch.bfloat16) if v.is_floating_point()
else v.to(device=model.device) for k, v in inputs.items()}
with torch.inference_mode():
scores = model(**inp)
print(f"Load {i+1}:", {p: f"safe={scores.probabilities[j,1].item():.4f}"
for j, p in enumerate(POLICIES)})
del model; gc.collect(); torch.cuda.empty_cache()
Output (AFTER fix)
All 3 loads produce identical results. lm_head is now correctly tied to embed_tokens:
MODEL LOAD 1/3 (with tie fix)
lm_head tied to embed_tokens: True | lm_head sum: -105549.4219
[bee ] dangerous -> Safe=1.000000 Unsafe=0.000000
[bee ] sexual -> Safe=1.000000 Unsafe=0.000000
[bee ] violence -> Safe=1.000000 Unsafe=0.000000
[food ] dangerous -> Safe=1.000000 Unsafe=0.000005
[food ] sexual -> Safe=1.000000 Unsafe=0.000028
[food ] violence -> Safe=1.000000 Unsafe=0.000000
[cat ] dangerous -> Safe=1.000000 Unsafe=0.000000
[cat ] sexual -> Safe=0.996094 Unsafe=0.004608
[cat ] violence -> Safe=1.000000 Unsafe=0.000000
[woman-sexy ] dangerous -> Safe=1.000000 Unsafe=0.000109
[woman-sexy ] sexual -> Safe=0.000085 Unsafe=1.000000
[woman-sexy ] violence -> Safe=1.000000 Unsafe=0.000001
MODEL LOAD 2/3 (with tie fix)
lm_head tied to embed_tokens: True | lm_head sum: -105549.4219
[identical results]
MODEL LOAD 3/3 (with tie fix)
lm_head tied to embed_tokens: True | lm_head sum: -105549.4219
[identical results]
Variance summary (AFTER fix)
image policy safe_std safe_range unsafe_std unsafe_range
bee dangerous 0.000000 0.000000 0.000000 0.000000
bee sexual 0.000000 0.000000 0.000000 0.000000
bee violence 0.000000 0.000000 0.000000 0.000000
cat dangerous 0.000000 0.000000 0.000000 0.000000
cat sexual 0.000000 0.000000 0.000000 0.000000
cat violence 0.000000 0.000000 0.000000 0.000000
food dangerous 0.000000 0.000000 0.000000 0.000000
food sexual 0.000000 0.000000 0.000000 0.000000
food violence 0.000000 0.000000 0.000000 0.000000
woman-sexy dangerous 0.000000 0.000000 0.000000 0.000000
woman-sexy sexual 0.000000 0.000000 0.000000 0.000000
woman-sexy violence 0.000000 0.000000 0.000000 0.000000
SUCCESS: max range across loads = 0.00000000. Outputs are now perfectly reproducible across model loads.
Helper function (copy-paste into your code)
def load_shieldgemma2(model_id="google/shieldgemma-2-4b-it", **kwargs):
"""Load ShieldGemma-2 with the lm_head weight tying fix for reproducible outputs."""
model = ShieldGemma2ForImageClassification.from_pretrained(model_id, **kwargs).eval()
model.model.lm_head.weight = model.model.get_input_embeddings().weight
return model
# Usage:
model = load_shieldgemma2(device_map="auto", torch_dtype=torch.bfloat16)
thanks, was super helpful for me too!