Class: Candle::EmbeddingModel

Inherits:
Object
  • Object
show all
Defined in:
lib/candle/embedding_model.rb

Constant Summary collapse

DEFAULT_MODEL_PATH =

Default model path for Jina BERT embedding model

"jinaai/jina-embeddings-v2-base-en"
DEFAULT_TOKENIZER_PATH =

Default tokenizer path that works well with the default model

"sentence-transformers/all-MiniLM-L6-v2"
DEFAULT_EMBEDDING_MODEL_TYPE =

Default embedding model type

"jina_bert"

Class Method Summary collapse

Instance Method Summary collapse

Class Method Details

.from_pretrained(model_id = DEFAULT_MODEL_PATH, device: Candle::Device.best, tokenizer: nil, model_type: nil, embedding_size: nil) ⇒ EmbeddingModel

Load a pre-trained embedding model from HuggingFace

Parameters:

  • model_id (String) (defaults to: DEFAULT_MODEL_PATH)

    HuggingFace model ID (defaults to jinaai/jina-embeddings-v2-base-en)

  • device (Candle::Device) (defaults to: Candle::Device.best)

    The device to use for computation (defaults to best available)

  • tokenizer (String, nil) (defaults to: nil)

    The tokenizer to use (defaults to using the model’s tokenizer)

  • model_type (String, nil) (defaults to: nil)

    The type of embedding model (auto-detected if nil)

  • embedding_size (Integer, nil) (defaults to: nil)

    Override for the embedding size (optional)

Returns:



19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# File 'lib/candle/embedding_model.rb', line 19

def self.from_pretrained(model_id = DEFAULT_MODEL_PATH, device: Candle::Device.best, tokenizer: nil, model_type: nil, embedding_size: nil)
  # Auto-detect model type based on model_id if not provided
  if model_type.nil?
    model_type = case model_id.downcase
    when /jina/
      "jina_bert"
    when /distilbert/
      "distilbert"
    when /minilm/
      "minilm"
    else
      "standard_bert"
    end
  end
  
  # Use model_id as tokenizer if not specified (usually what you want)
  tokenizer_id = tokenizer || model_id
  
  _create(model_id, tokenizer_id, device, model_type, embedding_size)
end

.new(model_path: DEFAULT_MODEL_PATH, tokenizer_path: DEFAULT_TOKENIZER_PATH, device: Candle::Device.best, model_type: DEFAULT_EMBEDDING_MODEL_TYPE, embedding_size: nil) ⇒ Object

Deprecated.

Use from_pretrained instead

Constructor for creating a new EmbeddingModel with optional parameters

Parameters:

  • model_path (String, nil) (defaults to: DEFAULT_MODEL_PATH)

    The path to the model on Hugging Face

  • tokenizer_path (String, nil) (defaults to: DEFAULT_TOKENIZER_PATH)

    The path to the tokenizer on Hugging Face

  • device (Candle::Device, Candle::Device.cpu) (defaults to: Candle::Device.best)

    The device to use for computation

  • model_type (String, nil) (defaults to: DEFAULT_EMBEDDING_MODEL_TYPE)

    The type of embedding model to use

  • embedding_size (Integer, nil) (defaults to: nil)

    Override for the embedding size (optional)



47
48
49
50
51
52
53
54
# File 'lib/candle/embedding_model.rb', line 47

def self.new(model_path: DEFAULT_MODEL_PATH,
  tokenizer_path: DEFAULT_TOKENIZER_PATH,
  device: Candle::Device.best,
  model_type: DEFAULT_EMBEDDING_MODEL_TYPE,
  embedding_size: nil)
  $stderr.puts "[DEPRECATION] `EmbeddingModel.new` is deprecated. Please use `EmbeddingModel.from_pretrained` instead."
  _create(model_path, tokenizer_path, device, model_type, embedding_size)
end

Instance Method Details

#embedding(str, pooling_method: "pooled_normalized") ⇒ Object

Returns the embedding for a string using the specified pooling method.

Parameters:

  • str (String)

    The input text

  • pooling_method (String) (defaults to: "pooled_normalized")

    Pooling method: “pooled”, “pooled_normalized”, or “cls”. Default: “pooled_normalized”



58
59
60
# File 'lib/candle/embedding_model.rb', line 58

def embedding(str, pooling_method: "pooled_normalized")
  _embedding(str, pooling_method)
end

#inspectObject

Improved inspect method



63
64
65
66
67
68
69
70
71
72
73
# File 'lib/candle/embedding_model.rb', line 63

def inspect
  opts = options rescue {}
  
  parts = ["#<Candle::EmbeddingModel"]
  parts << "model=#{opts["model_id"] || "unknown"}"
  parts << "type=#{opts["model_type"]}" if opts["model_type"]
  parts << "device=#{opts["device"] || "unknown"}"
  parts << "size=#{opts["embedding_size"]}" if opts["embedding_size"]
  
  parts.join(" ") + ">"
end