Class: Candle::Reranker
- Inherits:
-
Object
- Object
- Candle::Reranker
- Defined in:
- lib/candle/reranker.rb
Constant Summary collapse
- DEFAULT_MODEL_PATH =
Default model path for cross-encoder/ms-marco-MiniLM-L-12-v2
"cross-encoder/ms-marco-MiniLM-L-12-v2"
Class Method Summary collapse
-
.new(model_path: DEFAULT_MODEL_PATH, cuda: false) ⇒ Object
Constructor for creating a new Reranker with optional parameters.
Instance Method Summary collapse
-
#rerank(query, documents, pooling_method: "pooler", apply_sigmoid: true) ⇒ Object
Returns the embedding for a string using the specified pooling method.
Class Method Details
.new(model_path: DEFAULT_MODEL_PATH, cuda: false) ⇒ Object
Constructor for creating a new Reranker with optional parameters
9 10 11 12 13 14 15 |
# File 'lib/candle/reranker.rb', line 9 def self.new(model_path: DEFAULT_MODEL_PATH, cuda: false) if cuda _create_cuda(model_path) else _create(model_path) end end |
Instance Method Details
#rerank(query, documents, pooling_method: "pooler", apply_sigmoid: true) ⇒ Object
Returns the embedding for a string using the specified pooling method.
22 23 24 25 26 |
# File 'lib/candle/reranker.rb', line 22 def rerank(query, documents, pooling_method: "pooler", apply_sigmoid: true) (query, documents, pooling_method, apply_sigmoid).collect { |doc, score, doc_id| { doc_id: doc_id, score: score, text: doc } } end |