Skip to content

Commit 7dba1d0

Browse files
joeindancixx
authored andcommitted
refactor: split tokenize into _tokenize and tokenize to respect MRO (#566)
1 parent 39b0d7f commit 7dba1d0

22 files changed

Lines changed: 90 additions & 144 deletions

fastembed/late_interaction/colbert.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,16 @@ def _preprocess_onnx_input(
8080
)
8181
return onnx_input
8282

83-
def tokenize(self, texts: list[str], is_doc: bool = True, **kwargs: Any) -> list[Encoding]: # type: ignore[override]
83+
def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]:
84+
return self._tokenize(documents, **kwargs)
85+
86+
def _tokenize(
87+
self, documents: list[str], is_doc: bool = True, **kwargs: Any
88+
) -> list[Encoding]:
8489
return (
85-
self._tokenize_documents(documents=texts)
90+
self._tokenize_documents(documents=documents)
8691
if is_doc
87-
else self._tokenize_query(query=next(iter(texts)))
92+
else self._tokenize_query(query=next(iter(documents)))
8893
)
8994

9095
def _tokenize_query(self, query: str) -> list[Encoding]:

fastembed/late_interaction/late_interaction_embedding_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(
2121
self._local_files_only = kwargs.pop("local_files_only", False)
2222
self._embedding_size: Optional[int] = None
2323

24-
def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]:
24+
def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]:
2525
raise NotImplementedError()
2626

2727
def embed(

fastembed/late_interaction/late_interaction_text_embedding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,18 +116,18 @@ def get_embedding_size(cls, model_name: str) -> int:
116116
)
117117
return embedding_size
118118

119-
def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]:
119+
def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]:
120120
"""
121121
Tokenize input texts using the model's tokenizer.
122122
123123
Args:
124-
texts: List of strings to tokenize
124+
documents: List of strings to tokenize
125125
**kwargs: Additional arguments passed to the tokenizer
126126
127127
Returns:
128128
List of tokenizer Encodings
129129
"""
130-
return self.model.tokenize(texts, **kwargs)
130+
return self.model.tokenize(documents, **kwargs)
131131

132132
def embed(
133133
self,

fastembed/late_interaction/token_embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
]
2626

2727

28-
class TokenEmbeddingsModel(OnnxTextEmbedding, LateInteractionTextEmbeddingBase): # type: ignore[misc]
28+
class TokenEmbeddingsModel(OnnxTextEmbedding, LateInteractionTextEmbeddingBase):
2929
@classmethod
3030
def _list_supported_models(cls) -> list[DenseModelDescription]:
3131
"""Lists the supported models.

fastembed/late_interaction_multimodal/colpali.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,12 @@ def _post_process_onnx_text_output(
160160
"""
161161
return output.model_output
162162

163-
def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]: # type: ignore[override]
163+
def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]:
164+
return self._tokenize(documents, **kwargs)
165+
166+
def _tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]:
164167
texts_query: list[str] = []
165-
for query in texts:
168+
for query in documents:
166169
query = self.BOS_TOKEN + self.QUERY_PREFIX + query + self.PAD_TOKEN * 10
167170
query += "\n"
168171

fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,18 +119,18 @@ def get_embedding_size(cls, model_name: str) -> int:
119119
)
120120
return embedding_size
121121

122-
def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]:
122+
def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]:
123123
"""
124124
Tokenize input texts using the model's tokenizer.
125125
126126
Args:
127-
texts: List of strings to tokenize
127+
documents: List of strings to tokenize
128128
**kwargs: Additional arguments passed to the tokenizer
129129
130130
Returns:
131131
List of tokenizer Encodings
132132
"""
133-
return self.model.tokenize(texts, **kwargs)
133+
return self.model.tokenize(documents, **kwargs)
134134

135135
def embed_text(
136136
self,

fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(
2222
self._local_files_only = kwargs.pop("local_files_only", False)
2323
self._embedding_size: Optional[int] = None
2424

25-
def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]:
25+
def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]:
2626
raise NotImplementedError()
2727

2828
def embed_text(

fastembed/late_interaction_multimodal/onnx_multimodal_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,17 +80,17 @@ def _load_onnx_model(
8080
def load_onnx_model(self) -> None:
8181
raise NotImplementedError("Subclasses must implement this method")
8282

83-
def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]:
83+
def _tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]:
8484
if self.tokenizer is None:
8585
raise RuntimeError("Tokenizer not initialized")
86-
return self.tokenizer.encode_batch(texts, **kwargs) # type: ignore[union-attr]
86+
return self.tokenizer.encode_batch(documents, **kwargs) # type: ignore[union-attr]
8787

8888
def onnx_embed_text(
8989
self,
9090
documents: list[str],
9191
**kwargs: Any,
9292
) -> OnnxOutputContext:
93-
encoded = self.tokenize(documents, **kwargs)
93+
encoded = self._tokenize(documents, **kwargs)
9494
input_ids = np.array([e.ids for e in encoded])
9595
attention_mask = np.array([e.attention_mask for e in encoded]) # type: ignore[union-attr]
9696
input_names = {node.name for node in self.model.get_inputs()} # type: ignore[union-attr]

fastembed/rerank/cross_encoder/onnx_text_model.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@ def _load_onnx_model(
4646
assert self.tokenizer is not None
4747

4848
def tokenize(self, pairs: list[tuple[str, str]], **kwargs: Any) -> list[Encoding]:
49-
if self.tokenizer is None:
50-
raise RuntimeError("Tokenizer not initialized")
5149
return self.tokenizer.encode_batch(pairs, **kwargs) # type: ignore[union-attr]
5250

5351
def _build_onnx_input(self, tokenized_input: list[Encoding]) -> dict[str, NumpyArray]:

fastembed/sparse/bm25.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -137,25 +137,25 @@ def __init__(
137137

138138
self.tokenizer = SimpleTokenizer
139139

140-
def tokenize(self, texts: list[str], **kwargs: Any) -> list[Encoding]:
140+
def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]:
141141
"""Tokenize texts using SimpleTokenizer.
142142
143143
Returns a list of simple Encoding-like objects with token strings.
144144
Note: BM25 uses a simple word tokenizer, not a learned tokenizer.
145145
"""
146146
result = []
147-
for text in texts:
148-
tokens = self.tokenizer.tokenize(text)
149147

150-
# Create a simple object that mimics Encoding interface
151-
class SimpleEncoding:
152-
def __init__(self, tokens: list[str]):
153-
self.tokens = tokens
154-
self.ids = tokens # For BM25, tokens are the IDs
155-
self.attention_mask = [1] * len(tokens)
148+
class SimpleEncoding:
149+
def __init__(self, tokens: list[str]):
150+
self.tokens = tokens
151+
self.ids = tokens # For BM25, tokens are the IDs
152+
self.attention_mask = [1] * len(tokens)
153+
154+
for document in documents:
155+
tokens = self.tokenizer.tokenize(document)
156+
result.append(SimpleEncoding(tokens))
156157

157-
result.append(SimpleEncoding(tokens)) # type: ignore[arg-type]
158-
return result # type: ignore[return-value]
158+
return result
159159

160160
@classmethod
161161
def _list_supported_models(cls) -> list[SparseModelDescription]:

0 commit comments

Comments
 (0)