BMP commited on
Commit
76003b0
·
verified ·
1 Parent(s): 97fc934

Upload example_usage.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. example_usage.py +89 -0
example_usage.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example usage of the MLX ResNet34 speaker embedding model.
3
+ """
4
+
5
+ import mlx.core as mx
6
+ import numpy as np
7
+ from resnet_embedding import load_resnet34_embedding
8
+
9
+
10
+ def main():
11
+ print("="*70)
12
+ print(" MLX ResNet34 Speaker Embedding - Example Usage")
13
+ print("="*70)
14
+
15
+ # Load model
16
+ print("\n[1] Loading model...")
17
+ model = load_resnet34_embedding("weights.npz")
18
+ print("✓ Model loaded successfully!")
19
+
20
+ # Create example mel spectrogram input
21
+ print("\n[2] Creating example mel spectrogram input...")
22
+ # In practice, you would extract this from real audio
23
+ # Shape: (batch_size, time_frames, freq_bins)
24
+ batch_size = 2
25
+ time_frames = 150 # ~1.5 seconds at 10ms hop
26
+ freq_bins = 80 # mel filterbanks
27
+
28
+ mel_spec = mx.array(np.random.randn(batch_size, time_frames, freq_bins).astype(np.float32))
29
+ print(f" Input shape: {mel_spec.shape}")
30
+
31
+ # Extract embeddings
32
+ print("\n[3] Extracting speaker embeddings...")
33
+ embeddings = model(mel_spec)
34
+ print(f" Output shape: {embeddings.shape}")
35
+ print(f" Embedding norms: {mx.linalg.norm(embeddings, axis=1)}")
36
+
37
+ # Compute speaker similarity
38
+ print("\n[4] Computing speaker similarity...")
39
+ emb1 = embeddings[0]
40
+ emb2 = embeddings[1]
41
+
42
+ # Cosine similarity
43
+ similarity = mx.sum(emb1 * emb2) / (
44
+ mx.linalg.norm(emb1) * mx.linalg.norm(emb2)
45
+ )
46
+
47
+ print(f" Cosine similarity between speaker 1 and 2: {float(similarity):.4f}")
48
+ print(f" Interpretation:")
49
+ print(f" > 0.9 = Likely same speaker")
50
+ print(f" 0.5-0.9 = Uncertain")
51
+ print(f" < 0.5 = Likely different speakers")
52
+
53
+ # Batch processing example
54
+ print("\n[5] Batch processing multiple utterances...")
55
+ num_utterances = 5
56
+ utterances = mx.array(np.random.randn(num_utterances, time_frames, freq_bins).astype(np.float32))
57
+
58
+ batch_embeddings = model(utterances)
59
+ print(f" Processed {num_utterances} utterances")
60
+ print(f" Embeddings shape: {batch_embeddings.shape}")
61
+
62
+ # Compute pairwise similarities
63
+ print("\n[6] Computing pairwise similarity matrix...")
64
+ # Normalize embeddings for faster cosine similarity
65
+ emb_normalized = batch_embeddings / mx.linalg.norm(batch_embeddings, axis=1, keepdims=True)
66
+
67
+ # Similarity matrix via matrix multiplication
68
+ similarity_matrix = emb_normalized @ emb_normalized.T
69
+ sim_np = np.array(similarity_matrix)
70
+
71
+ print("\n Similarity Matrix:")
72
+ print(" ", end="")
73
+ for i in range(num_utterances):
74
+ print(f"Utt{i} ", end="")
75
+ print()
76
+
77
+ for i in range(num_utterances):
78
+ print(f" Utt{i}:", end="")
79
+ for j in range(num_utterances):
80
+ print(f" {sim_np[i,j]:.3f} ", end="")
81
+ print()
82
+
83
+ print("\n" + "="*70)
84
+ print(" ✓ Example completed successfully!")
85
+ print("="*70)
86
+
87
+
88
+ if __name__ == "__main__":
89
+ main()