ShieldGemma-2: Non-Reproducible Outputs Bug & Fix

#10
by hmeisheri - opened

Bug

ShieldGemma2ForImageClassification gives different outputs every time you load the model, even with the exact same image and settings.

Root Cause

The checkpoint has text_config.tie_word_embeddings = True, meaning lm_head.weight should be shared with embed_tokens.weight (and is NOT stored separately in the checkpoint). However, the outer wrapper class (ShieldGemma2ForImageClassification) does not propagate _tied_weights_keys, so lm_head.weight is randomly initialized on every load.

You can confirm this from the load report:

ShieldGemma2ForImageClassification LOAD REPORT from: google/shieldgemma-2-4b-it
Key                  | Status  |
---------------------+---------+-
model.lm_head.weight | MISSING |

Fix (one line)

After loading, manually tie lm_head.weight to embed_tokens.weight:

model.model.lm_head.weight = model.model.get_input_embeddings().weight

Affected: transformers >= 4.49 (when ShieldGemma2ForImageClassification was added)

Model: google/shieldgemma-2-4b-it


Setup

# pip install transformers torch pillow requests

import torch, gc
from transformers import AutoProcessor, ShieldGemma2ForImageClassification
from PIL import Image
import requests

model_id = "google/shieldgemma-2-4b-it"
processor = AutoProcessor.from_pretrained(model_id)

# Load a test image (a bee -- should be SAFE for all policies)
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
inputs = processor(images=[image], return_tensors="pt")

Reproduce: BEFORE fix (outputs change on every load)

POLICIES = ["dangerous", "sexual", "violence"]

for i in range(3):
    model = ShieldGemma2ForImageClassification.from_pretrained(
        model_id, device_map="auto", torch_dtype=torch.bfloat16
    ).eval()

    inp = {k: v.to(device=model.device, dtype=torch.bfloat16) if v.is_floating_point()
           else v.to(device=model.device) for k, v in inputs.items()}

    with torch.inference_mode():
        scores = model(**inp)

    # probabilities[:, 0] = "Yes" (violates/UNSAFE), [:, 1] = "No" (SAFE)
    print(f"Load {i+1}:", {p: f"safe={scores.probabilities[j,1].item():.4f}"
          for j, p in enumerate(POLICIES)})

    del model; gc.collect(); torch.cuda.empty_cache()

Output (BEFORE fix)

Notice that lm_head.weight has a different sum on each load, and the safe/unsafe probabilities vary wildly:

MODEL LOAD 1/3
  lm_head.weight  sum=568.720520   norm=518.187866
  [bee         ] dangerous  -> Safe=0.000000  Unsafe=1.000000
  [bee         ] sexual     -> Safe=0.000032  Unsafe=1.000000
  [bee         ] violence   -> Safe=0.000025  Unsafe=1.000000
  [food        ] dangerous  -> Safe=0.039551  Unsafe=0.960938
  [food        ] sexual     -> Safe=0.003769  Unsafe=0.996094
  [food        ] violence   -> Safe=0.115723  Unsafe=0.882812
  [cat         ] dangerous  -> Safe=0.000067  Unsafe=1.000000
  [cat         ] sexual     -> Safe=0.003021  Unsafe=0.996094
  [cat         ] violence   -> Safe=0.000265  Unsafe=1.000000
  [woman-sexy  ] dangerous  -> Safe=0.003326  Unsafe=0.996094
  [woman-sexy  ] sexual     -> Safe=0.216797  Unsafe=0.785156
  [woman-sexy  ] violence   -> Safe=0.275391  Unsafe=0.726562

MODEL LOAD 2/3
  lm_head.weight  sum=359.003418   norm=518.184021
  [bee         ] dangerous  -> Safe=0.000315  Unsafe=1.000000
  [bee         ] sexual     -> Safe=0.003891  Unsafe=0.996094
  [bee         ] violence   -> Safe=0.000062  Unsafe=1.000000
  [food        ] dangerous  -> Safe=0.341797  Unsafe=0.660156
  [food        ] sexual     -> Safe=0.156250  Unsafe=0.843750
  [food        ] violence   -> Safe=0.016602  Unsafe=0.984375
  [cat         ] dangerous  -> Safe=0.015442  Unsafe=0.984375
  [cat         ] sexual     -> Safe=0.061035  Unsafe=0.937500
  [cat         ] violence   -> Safe=0.000012  Unsafe=1.000000
  [woman-sexy  ] dangerous  -> Safe=0.005920  Unsafe=0.992188
  [woman-sexy  ] sexual     -> Safe=0.225586  Unsafe=0.773438
  [woman-sexy  ] violence   -> Safe=0.220703  Unsafe=0.777344

MODEL LOAD 3/3
  lm_head.weight  sum=-883.367554  norm=518.172913
  [bee         ] dangerous  -> Safe=1.000000  Unsafe=0.000001
  [bee         ] sexual     -> Safe=1.000000  Unsafe=0.000018
  [bee         ] violence   -> Safe=1.000000  Unsafe=0.000002
  [food        ] dangerous  -> Safe=1.000000  Unsafe=0.000000
  [food        ] sexual     -> Safe=0.996094  Unsafe=0.003479
  [food        ] violence   -> Safe=0.855469  Unsafe=0.146484
  [cat         ] dangerous  -> Safe=1.000000  Unsafe=0.001724
  [cat         ] sexual     -> Safe=0.976562  Unsafe=0.025146
  [cat         ] violence   -> Safe=0.816406  Unsafe=0.182617
  [woman-sexy  ] dangerous  -> Safe=1.000000  Unsafe=0.000017
  [woman-sexy  ] sexual     -> Safe=1.000000  Unsafe=0.000033
  [woman-sexy  ] violence   -> Safe=1.000000  Unsafe=0.000210

CONFIRMED: lm_head.weight CHANGES across loads (randomly initialised). This is the root cause of non-reproducible outputs.


AFTER fix (outputs are identical across loads)

for i in range(3):
    model = ShieldGemma2ForImageClassification.from_pretrained(
        model_id, device_map="auto", torch_dtype=torch.bfloat16
    ).eval()

    # THE FIX -- tie lm_head.weight to embed_tokens.weight
    model.model.lm_head.weight = model.model.get_input_embeddings().weight

    inp = {k: v.to(device=model.device, dtype=torch.bfloat16) if v.is_floating_point()
           else v.to(device=model.device) for k, v in inputs.items()}

    with torch.inference_mode():
        scores = model(**inp)

    print(f"Load {i+1}:", {p: f"safe={scores.probabilities[j,1].item():.4f}"
          for j, p in enumerate(POLICIES)})

    del model; gc.collect(); torch.cuda.empty_cache()

Output (AFTER fix)

All 3 loads produce identical results. lm_head is now correctly tied to embed_tokens:

MODEL LOAD 1/3  (with tie fix)
  lm_head tied to embed_tokens: True  |  lm_head sum: -105549.4219
  [bee         ] dangerous  -> Safe=1.000000  Unsafe=0.000000
  [bee         ] sexual     -> Safe=1.000000  Unsafe=0.000000
  [bee         ] violence   -> Safe=1.000000  Unsafe=0.000000
  [food        ] dangerous  -> Safe=1.000000  Unsafe=0.000005
  [food        ] sexual     -> Safe=1.000000  Unsafe=0.000028
  [food        ] violence   -> Safe=1.000000  Unsafe=0.000000
  [cat         ] dangerous  -> Safe=1.000000  Unsafe=0.000000
  [cat         ] sexual     -> Safe=0.996094  Unsafe=0.004608
  [cat         ] violence   -> Safe=1.000000  Unsafe=0.000000
  [woman-sexy  ] dangerous  -> Safe=1.000000  Unsafe=0.000109
  [woman-sexy  ] sexual     -> Safe=0.000085  Unsafe=1.000000
  [woman-sexy  ] violence   -> Safe=1.000000  Unsafe=0.000001

MODEL LOAD 2/3  (with tie fix)
  lm_head tied to embed_tokens: True  |  lm_head sum: -105549.4219
  [identical results]

MODEL LOAD 3/3  (with tie fix)
  lm_head tied to embed_tokens: True  |  lm_head sum: -105549.4219
  [identical results]

Variance summary (AFTER fix)

     image    policy  safe_std  safe_range  unsafe_std  unsafe_range
       bee dangerous  0.000000    0.000000    0.000000      0.000000
       bee    sexual  0.000000    0.000000    0.000000      0.000000
       bee  violence  0.000000    0.000000    0.000000      0.000000
       cat dangerous  0.000000    0.000000    0.000000      0.000000
       cat    sexual  0.000000    0.000000    0.000000      0.000000
       cat  violence  0.000000    0.000000    0.000000      0.000000
      food dangerous  0.000000    0.000000    0.000000      0.000000
      food    sexual  0.000000    0.000000    0.000000      0.000000
      food  violence  0.000000    0.000000    0.000000      0.000000
woman-sexy dangerous  0.000000    0.000000    0.000000      0.000000
woman-sexy    sexual  0.000000    0.000000    0.000000      0.000000
woman-sexy  violence  0.000000    0.000000    0.000000      0.000000

SUCCESS: max range across loads = 0.00000000. Outputs are now perfectly reproducible across model loads.


Helper function (copy-paste into your code)

def load_shieldgemma2(model_id="google/shieldgemma-2-4b-it", **kwargs):
    """Load ShieldGemma-2 with the lm_head weight tying fix for reproducible outputs."""
    model = ShieldGemma2ForImageClassification.from_pretrained(model_id, **kwargs).eval()
    model.model.lm_head.weight = model.model.get_input_embeddings().weight
    return model

# Usage:
model = load_shieldgemma2(device_map="auto", torch_dtype=torch.bfloat16)

Have created a PR on Transformers Repo: https://github.com/huggingface/transformers/pull/44358

@lkv @pannaga10 can you please review

thanks, was super helpful for me too!

Sign up or log in to comment