Class: Candle::EmbeddingModel

Inherits:
Object
  • Object
show all
Defined in:
lib/candle/embedding_model.rb

Constant Summary collapse

DEFAULT_MODEL_PATH =

Default model path for Jina BERT embedding model

"jinaai/jina-embeddings-v2-base-en"
DEFAULT_TOKENIZER_PATH =

Default tokenizer path that works well with the default model

"sentence-transformers/all-MiniLM-L6-v2"
DEFAULT_EMBEDDING_MODEL_TYPE =

Default embedding model type

"jina_bert"

Class Method Summary collapse

Instance Method Summary collapse

Class Method Details

.new(model_path: DEFAULT_MODEL_PATH, tokenizer_path: DEFAULT_TOKENIZER_PATH, device: nil, model_type: DEFAULT_EMBEDDING_MODEL_TYPE, embedding_size: nil) ⇒ Object

Constructor for creating a new EmbeddingModel with optional parameters

Parameters:

  • model_path (String, nil) (defaults to: DEFAULT_MODEL_PATH)

    The path to the model on Hugging Face

  • tokenizer_path (String, nil) (defaults to: DEFAULT_TOKENIZER_PATH)

    The path to the tokenizer on Hugging Face

  • device (Candle::Device, nil) (defaults to: nil)

    The device to use for computation (nil = CPU)

  • model_type (String, nil) (defaults to: DEFAULT_EMBEDDING_MODEL_TYPE)

    The type of embedding model to use

  • embedding_size (Integer, nil) (defaults to: nil)

    Override for the embedding size (optional)



18
19
20
21
22
23
24
# File 'lib/candle/embedding_model.rb', line 18

def self.new(model_path: DEFAULT_MODEL_PATH,
  tokenizer_path: DEFAULT_TOKENIZER_PATH,
  device: nil,
  model_type: DEFAULT_EMBEDDING_MODEL_TYPE,
  embedding_size: nil)
  _create(model_path, tokenizer_path, device, model_type, embedding_size)
end

Instance Method Details

#embedding(str, pooling_method: "pooled_normalized") ⇒ Object

Returns the embedding for a string using the specified pooling method.

Parameters:

  • str (String)

    The input text

  • pooling_method (String) (defaults to: "pooled_normalized")

    Pooling method: “pooled”, “pooled_normalized”, or “cls”. Default: “pooled_normalized”



28
29
30
# File 'lib/candle/embedding_model.rb', line 28

def embedding(str, pooling_method: "pooled_normalized")
  _embedding(str, pooling_method)
end