jaepil commited on
Commit
dc5983c
·
verified ·
1 Parent(s): 37ef779

Upload tokenization_cognica_poe.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. tokenization_cognica_poe.py +212 -0
tokenization_cognica_poe.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Cognica-PoE tokenizer: HuggingFace `PreTrainedTokenizer` wrapper around the
2
+ nanochat tiktoken BPE. Loaded by `AutoTokenizer.from_pretrained(...,
3
+ trust_remote_code=True)`.
4
+
5
+ The underlying encoding is pickled as `tokenizer.pkl` (a `tiktoken.Encoding`
6
+ object). Special tokens are assigned to their tiktoken ids so the HF tokenize-
7
+ around-specials flow and raw tiktoken encoding produce identical id sequences.
8
+ """
9
+
10
+ import os
11
+ import pickle
12
+ from typing import Dict, List, Optional, Tuple
13
+
14
+ import tiktoken
15
+ from transformers import PreTrainedTokenizer
16
+ from transformers.tokenization_utils import AddedToken
17
+
18
+
19
+ SPECIAL_TOKENS = [
20
+ "<|bos|>",
21
+ "<|user_start|>",
22
+ "<|user_end|>",
23
+ "<|assistant_start|>",
24
+ "<|assistant_end|>",
25
+ "<|python_start|>",
26
+ "<|python_end|>",
27
+ "<|output_start|>",
28
+ "<|output_end|>",
29
+ ]
30
+
31
+
32
+ class CognicaPoETokenizer(PreTrainedTokenizer):
33
+ """BPE tokenizer backed by a pickled `tiktoken.Encoding` (nanochat format)."""
34
+
35
+ vocab_files_names = {"vocab_file": "tokenizer.pkl"}
36
+ model_input_names = ["input_ids", "attention_mask"]
37
+
38
+ def __init__(
39
+ self,
40
+ vocab_file: str,
41
+ bos_token: str = "<|bos|>",
42
+ eos_token: str = "<|bos|>",
43
+ pad_token: Optional[str] = None,
44
+ unk_token: Optional[str] = None,
45
+ **kwargs,
46
+ ):
47
+ if not os.path.exists(vocab_file):
48
+ raise ValueError(
49
+ f"tokenizer.pkl not found at {vocab_file}. "
50
+ "Make sure it was downloaded from the model repo."
51
+ )
52
+ with open(vocab_file, "rb") as f:
53
+ enc = pickle.load(f)
54
+ if not isinstance(enc, tiktoken.Encoding):
55
+ raise TypeError(
56
+ f"Expected tiktoken.Encoding in {vocab_file}, got {type(enc).__name__}"
57
+ )
58
+ self.enc = enc
59
+ self._vocab_file = vocab_file
60
+
61
+ # Respect a pre-built added_tokens_decoder from tokenizer_config.json;
62
+ # otherwise synthesize one from the tiktoken special-tokens set.
63
+ added_decoder = kwargs.pop("added_tokens_decoder", None)
64
+ if not added_decoder:
65
+ added_decoder = {}
66
+ for tok in SPECIAL_TOKENS:
67
+ try:
68
+ tid = enc.encode_single_token(tok)
69
+ except (KeyError, ValueError):
70
+ continue
71
+ added_decoder[tid] = AddedToken(
72
+ tok,
73
+ lstrip=False,
74
+ rstrip=False,
75
+ single_word=False,
76
+ normalized=False,
77
+ special=True,
78
+ )
79
+
80
+ super().__init__(
81
+ bos_token=bos_token,
82
+ eos_token=eos_token,
83
+ pad_token=pad_token,
84
+ unk_token=unk_token,
85
+ added_tokens_decoder=added_decoder,
86
+ **kwargs,
87
+ )
88
+
89
+ @property
90
+ def vocab_size(self) -> int:
91
+ return self.enc.n_vocab
92
+
93
+ def get_vocab(self) -> Dict[str, int]:
94
+ vocab: Dict[str, int] = {}
95
+ for tid in range(self.enc.n_vocab):
96
+ try:
97
+ raw = self.enc.decode_single_token_bytes(tid)
98
+ token = raw.decode("utf-8", errors="replace")
99
+ except Exception:
100
+ token = f"<id_{tid}>"
101
+ vocab[token] = tid
102
+ for tok in SPECIAL_TOKENS:
103
+ try:
104
+ vocab[tok] = self.enc.encode_single_token(tok)
105
+ except (KeyError, ValueError):
106
+ pass
107
+ vocab.update(self.added_tokens_encoder)
108
+ return vocab
109
+
110
+ def _tokenize(self, text: str, **kwargs) -> List[str]:
111
+ # Base class splits around special tokens and calls _tokenize on the
112
+ # non-special chunks. We use tiktoken's ordinary encoder (which does
113
+ # not recognize specials), and return ids as strings.
114
+ ids = self.enc.encode_ordinary(text)
115
+ return [str(i) for i in ids]
116
+
117
+ def _convert_token_to_id(self, token: str) -> int:
118
+ try:
119
+ return int(token)
120
+ except ValueError:
121
+ try:
122
+ return self.enc.encode_single_token(token)
123
+ except (KeyError, ValueError) as e:
124
+ raise ValueError(f"Unknown token: {token!r}") from e
125
+
126
+ def _convert_id_to_token(self, index: int) -> str:
127
+ try:
128
+ raw = self.enc.decode_single_token_bytes(index)
129
+ return raw.decode("utf-8", errors="replace")
130
+ except Exception:
131
+ return str(index)
132
+
133
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
134
+ # The token list here is a mix of: (a) stringified integer ids from
135
+ # `_tokenize`, (b) UTF-8 text chunks from `_convert_id_to_token`, and
136
+ # (c) raw special-token literals from the added-tokens splitter. We
137
+ # can't disambiguate (a) from (b) by `int(tok)` alone — a token whose
138
+ # decoded UTF-8 text happens to be numeric (e.g. "17") would otherwise
139
+ # be mis-cast to the *id* 17 and decode to byte 0x11. Resolve each
140
+ # entry carefully and fall back to a literal UTF-8 re-encode.
141
+ ids: List[int] = []
142
+ for tok in tokens:
143
+ if tok in self.added_tokens_encoder:
144
+ ids.append(self.added_tokens_encoder[tok])
145
+ continue
146
+ try:
147
+ tid = int(tok)
148
+ except ValueError:
149
+ tid = None
150
+ if tid is not None and 0 <= tid < self.enc.n_vocab:
151
+ raw = self.enc.decode_single_token_bytes(tid)
152
+ if raw.decode("utf-8", errors="replace") == tok:
153
+ ids.append(tid)
154
+ continue
155
+ # Treat as a literal text fragment from `_convert_id_to_token`.
156
+ ids.extend(self.enc.encode_ordinary(tok))
157
+ return self.enc.decode(ids)
158
+
159
+ def _decode(
160
+ self,
161
+ token_ids,
162
+ skip_special_tokens: bool = False,
163
+ clean_up_tokenization_spaces: Optional[bool] = None,
164
+ spaces_between_special_tokens: bool = True,
165
+ **kwargs,
166
+ ) -> str:
167
+ # Bypass HF's token-string round-trip: decoding ids through
168
+ # `_convert_id_to_token` -> `convert_tokens_to_string` loses byte
169
+ # boundaries for multi-byte UTF-8 tokens and is fragile around
170
+ # numeric-looking tokens. Go directly through tiktoken.
171
+ if isinstance(token_ids, int):
172
+ token_ids = [token_ids]
173
+ elif hasattr(token_ids, "tolist"):
174
+ token_ids = token_ids.tolist()
175
+ token_ids = [int(t) for t in token_ids]
176
+ special_ids = set(self.all_special_ids)
177
+ if skip_special_tokens:
178
+ return self.enc.decode([t for t in token_ids if t not in special_ids])
179
+ if not special_ids.intersection(token_ids):
180
+ return self.enc.decode(token_ids)
181
+ # Emit specials literally (between non-special byte runs).
182
+ out_parts: List[str] = []
183
+ run: List[int] = []
184
+ id_to_special = {
185
+ self.added_tokens_encoder[t]: t for t in self.added_tokens_encoder
186
+ }
187
+ for tid in token_ids:
188
+ if tid in special_ids:
189
+ if run:
190
+ out_parts.append(self.enc.decode(run))
191
+ run = []
192
+ out_parts.append(id_to_special.get(tid, f"<|id_{tid}|>"))
193
+ else:
194
+ run.append(tid)
195
+ if run:
196
+ out_parts.append(self.enc.decode(run))
197
+ return "".join(out_parts)
198
+
199
+ def save_vocabulary(
200
+ self,
201
+ save_directory: str,
202
+ filename_prefix: Optional[str] = None,
203
+ ) -> Tuple[str, ...]:
204
+ os.makedirs(save_directory, exist_ok=True)
205
+ prefix = f"{filename_prefix}-" if filename_prefix else ""
206
+ out = os.path.join(save_directory, f"{prefix}tokenizer.pkl")
207
+ with open(out, "wb") as f:
208
+ pickle.dump(self.enc, f)
209
+ return (out,)
210
+
211
+
212
+ __all__ = ["CognicaPoETokenizer"]