Class: Candle::Reranker

Inherits:
Object
  • Object
show all
Defined in:
lib/candle/reranker.rb

Constant Summary collapse

DEFAULT_MODEL_PATH =

Default model path for cross-encoder/ms-marco-MiniLM-L-12-v2

"cross-encoder/ms-marco-MiniLM-L-12-v2"

Class Method Summary collapse

Instance Method Summary collapse

Class Method Details

.from_pretrained(model_id = DEFAULT_MODEL_PATH, device: Candle::Device.best, max_length: 512) ⇒ Reranker

Load a pre-trained reranker model from HuggingFace

Parameters:

  • model_id (String) (defaults to: DEFAULT_MODEL_PATH)

    HuggingFace model ID (defaults to cross-encoder/ms-marco-MiniLM-L-12-v2)

  • device (Candle::Device) (defaults to: Candle::Device.best)

    The device to use for computation (defaults to best available)

  • max_length (Integer) (defaults to: 512)

    Maximum sequence length for truncation (defaults to 512)

Returns:

  • (Reranker)

    A new Reranker instance



11
12
13
# File 'lib/candle/reranker.rb', line 11

def self.from_pretrained(model_id = DEFAULT_MODEL_PATH, device: Candle::Device.best, max_length: 512)
  _create(model_id, device, max_length)
end

.new(model_path: DEFAULT_MODEL_PATH, device: Candle::Device.best, max_length: 512) ⇒ Object

Deprecated.

Use from_pretrained instead

Constructor for creating a new Reranker with optional parameters

Parameters:

  • model_path (String, nil) (defaults to: DEFAULT_MODEL_PATH)

    The path to the model on Hugging Face

  • device (Candle::Device, Candle::Device.cpu) (defaults to: Candle::Device.best)

    The device to use for computation

  • max_length (Integer) (defaults to: 512)

    Maximum sequence length for truncation (defaults to 512)



20
21
22
23
# File 'lib/candle/reranker.rb', line 20

def self.new(model_path: DEFAULT_MODEL_PATH, device: Candle::Device.best, max_length: 512)
  $stderr.puts "[DEPRECATION] `Reranker.new` is deprecated. Please use `Reranker.from_pretrained` instead."
  _create(model_path, device, max_length)
end

Instance Method Details

#inspectObject

Improved inspect method



37
38
39
40
41
42
43
# File 'lib/candle/reranker.rb', line 37

def inspect
  opts = options rescue {}
  parts = ["#<Candle::Reranker"]
  parts << "model=#{opts["model_id"] || "unknown"}"
  parts << "device=#{opts["device"] || "unknown"}"
  parts.join(" ") + ">"
end

#rerank(query, documents, pooling_method: "pooler", apply_sigmoid: true) ⇒ Object

Returns documents ranked by relevance using the specified pooling method.

Parameters:

  • query (String)

    The input text

  • documents (Array<String>)

    The list of documents to compare against

  • pooling_method (String) (defaults to: "pooler")

    Pooling method: “pooler”, “cls”, or “mean”. Default: “pooler”

  • apply_sigmoid (Boolean) (defaults to: true)

    Whether to apply sigmoid to the scores. Default: true



30
31
32
33
34
# File 'lib/candle/reranker.rb', line 30

def rerank(query, documents, pooling_method: "pooler", apply_sigmoid: true)
  rerank_with_options(query, documents, pooling_method, apply_sigmoid).collect { |doc, score, doc_id|
    { doc_id: doc_id, score: score, text: doc }
  }
end