File size: 2,863 Bytes
76003b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""
Example usage of the MLX ResNet34 speaker embedding model.
"""

import mlx.core as mx
import numpy as np
from resnet_embedding import load_resnet34_embedding


def main():
    print("="*70)
    print(" MLX ResNet34 Speaker Embedding - Example Usage")
    print("="*70)

    # Load model
    print("\n[1] Loading model...")
    model = load_resnet34_embedding("weights.npz")
    print("✓ Model loaded successfully!")

    # Create example mel spectrogram input
    print("\n[2] Creating example mel spectrogram input...")
    # In practice, you would extract this from real audio
    # Shape: (batch_size, time_frames, freq_bins)
    batch_size = 2
    time_frames = 150  # ~1.5 seconds at 10ms hop
    freq_bins = 80     # mel filterbanks

    mel_spec = mx.array(np.random.randn(batch_size, time_frames, freq_bins).astype(np.float32))
    print(f"  Input shape: {mel_spec.shape}")

    # Extract embeddings
    print("\n[3] Extracting speaker embeddings...")
    embeddings = model(mel_spec)
    print(f"  Output shape: {embeddings.shape}")
    print(f"  Embedding norms: {mx.linalg.norm(embeddings, axis=1)}")

    # Compute speaker similarity
    print("\n[4] Computing speaker similarity...")
    emb1 = embeddings[0]
    emb2 = embeddings[1]

    # Cosine similarity
    similarity = mx.sum(emb1 * emb2) / (
        mx.linalg.norm(emb1) * mx.linalg.norm(emb2)
    )

    print(f"  Cosine similarity between speaker 1 and 2: {float(similarity):.4f}")
    print(f"  Interpretation:")
    print(f"    > 0.9  = Likely same speaker")
    print(f"    0.5-0.9 = Uncertain")
    print(f"    < 0.5  = Likely different speakers")

    # Batch processing example
    print("\n[5] Batch processing multiple utterances...")
    num_utterances = 5
    utterances = mx.array(np.random.randn(num_utterances, time_frames, freq_bins).astype(np.float32))

    batch_embeddings = model(utterances)
    print(f"  Processed {num_utterances} utterances")
    print(f"  Embeddings shape: {batch_embeddings.shape}")

    # Compute pairwise similarities
    print("\n[6] Computing pairwise similarity matrix...")
    # Normalize embeddings for faster cosine similarity
    emb_normalized = batch_embeddings / mx.linalg.norm(batch_embeddings, axis=1, keepdims=True)

    # Similarity matrix via matrix multiplication
    similarity_matrix = emb_normalized @ emb_normalized.T
    sim_np = np.array(similarity_matrix)

    print("\n  Similarity Matrix:")
    print("        ", end="")
    for i in range(num_utterances):
        print(f"Utt{i}  ", end="")
    print()

    for i in range(num_utterances):
        print(f"  Utt{i}:", end="")
        for j in range(num_utterances):
            print(f" {sim_np[i,j]:.3f} ", end="")
        print()

    print("\n" + "="*70)
    print(" ✓ Example completed successfully!")
    print("="*70)


if __name__ == "__main__":
    main()