Class: Candle::EmbeddingModel
- Inherits:
-
Object
- Object
- Candle::EmbeddingModel
- Defined in:
- lib/candle/embedding_model.rb
Constant Summary collapse
- DEFAULT_MODEL_PATH =
Default model path for Jina BERT embedding model
"jinaai/jina-embeddings-v2-base-en"
- DEFAULT_TOKENIZER_PATH =
Default tokenizer path that works well with the default model
"sentence-transformers/all-MiniLM-L6-v2"
- DEFAULT_EMBEDDING_MODEL_TYPE =
Default embedding model type
"jina_bert"
Class Method Summary collapse
-
.from_pretrained(model_id = DEFAULT_MODEL_PATH, device: Candle::Device.best, tokenizer: nil, model_type: nil, embedding_size: nil) ⇒ EmbeddingModel
Load a pre-trained embedding model from HuggingFace.
-
.new(model_path: DEFAULT_MODEL_PATH, tokenizer_path: DEFAULT_TOKENIZER_PATH, device: Candle::Device.best, model_type: DEFAULT_EMBEDDING_MODEL_TYPE, embedding_size: nil) ⇒ Object
deprecated
Deprecated.
Use EmbeddingModel.from_pretrained instead
Instance Method Summary collapse
-
#embedding(str, pooling_method: "pooled_normalized") ⇒ Object
Returns the embedding for a string using the specified pooling method.
-
#inspect ⇒ Object
Improved inspect method.
Class Method Details
.from_pretrained(model_id = DEFAULT_MODEL_PATH, device: Candle::Device.best, tokenizer: nil, model_type: nil, embedding_size: nil) ⇒ EmbeddingModel
Load a pre-trained embedding model from HuggingFace
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
# File 'lib/candle/embedding_model.rb', line 19 def self.from_pretrained(model_id = DEFAULT_MODEL_PATH, device: Candle::Device.best, tokenizer: nil, model_type: nil, embedding_size: nil) # Auto-detect model type based on model_id if not provided if model_type.nil? model_type = case model_id.downcase when /jina/ "jina_bert" when /distilbert/ "distilbert" when /minilm/ "minilm" else "standard_bert" end end # Use model_id as tokenizer if not specified (usually what you want) tokenizer_id = tokenizer || model_id _create(model_id, tokenizer_id, device, model_type, ) end |
.new(model_path: DEFAULT_MODEL_PATH, tokenizer_path: DEFAULT_TOKENIZER_PATH, device: Candle::Device.best, model_type: DEFAULT_EMBEDDING_MODEL_TYPE, embedding_size: nil) ⇒ Object
Deprecated.
Use from_pretrained instead
Constructor for creating a new EmbeddingModel with optional parameters
47 48 49 50 51 52 53 54 |
# File 'lib/candle/embedding_model.rb', line 47 def self.new(model_path: DEFAULT_MODEL_PATH, tokenizer_path: DEFAULT_TOKENIZER_PATH, device: Candle::Device.best, model_type: DEFAULT_EMBEDDING_MODEL_TYPE, embedding_size: nil) $stderr.puts "[DEPRECATION] `EmbeddingModel.new` is deprecated. Please use `EmbeddingModel.from_pretrained` instead." _create(model_path, tokenizer_path, device, model_type, ) end |
Instance Method Details
#embedding(str, pooling_method: "pooled_normalized") ⇒ Object
Returns the embedding for a string using the specified pooling method.
58 59 60 |
# File 'lib/candle/embedding_model.rb', line 58 def (str, pooling_method: "pooled_normalized") (str, pooling_method) end |
#inspect ⇒ Object
Improved inspect method
63 64 65 66 67 68 69 70 71 72 73 |
# File 'lib/candle/embedding_model.rb', line 63 def inspect opts = rescue {} parts = ["#<Candle::EmbeddingModel"] parts << "model=#{opts["model_id"] || "unknown"}" parts << "type=#{opts["model_type"]}" if opts["model_type"] parts << "device=#{opts["device"] || "unknown"}" parts << "size=#{opts["embedding_size"]}" if opts["embedding_size"] parts.join(" ") + ">" end |