Class: Candle::LLM

Inherits:
Object
  • Object
show all
Defined in:
lib/candle/llm.rb

Constant Summary collapse

TOKENIZER_REGISTRY =

Tokenizer registry for automatic detection

{
  # Exact model matches
  "TheBloke/Mistral-7B-Instruct-v0.2-GGUF" => "mistralai/Mistral-7B-Instruct-v0.2",
  "TheBloke/Mistral-7B-v0.1-GGUF" => "mistralai/Mistral-7B-v0.1",
  "TheBloke/Llama-2-7B-Chat-GGUF" => "meta-llama/Llama-2-7b-chat-hf",
  "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF" => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
  
  # Qwen official GGUF models
  "Qwen/Qwen3-8B-GGUF" => "Qwen/Qwen3-8B",
  "Qwen/Qwen3-4B-GGUF" => "Qwen/Qwen3-4B",
  "Qwen/Qwen3-14B-GGUF" => "Qwen/Qwen3-14B",
  "Qwen/Qwen3-32B-GGUF" => "Qwen/Qwen3-32B",
  "Qwen/Qwen3-72B-GGUF" => "Qwen/Qwen3-72B",
  
  # Phi GGUF models
  "TheBloke/phi-2-GGUF" => "microsoft/phi-2",
  "microsoft/phi-4-gguf" => "microsoft/phi-4",
  "bartowski/Phi-3.5-mini-instruct-GGUF" => "microsoft/Phi-3.5-mini-instruct",
  
  # Pattern-based fallbacks (evaluated in order)
  :patterns => [
    # Mistral models
    [/mistral.*?7b.*?instruct.*?v0\.2/i, "mistralai/Mistral-7B-Instruct-v0.2"],
    [/mistral.*?7b.*?instruct.*?v0\.1/i, "mistralai/Mistral-7B-Instruct-v0.1"],
    [/mistral.*?7b/i, "mistralai/Mistral-7B-v0.1"],
    
    # Llama models
    [/llama.*?3.*?8b/i, "meta-llama/Meta-Llama-3-8B"],
    [/llama.*?3.*?70b/i, "meta-llama/Meta-Llama-3-70B"],
    [/llama.*?2.*?7b.*?chat/i, "meta-llama/Llama-2-7b-chat-hf"],
    [/llama.*?2.*?13b.*?chat/i, "meta-llama/Llama-2-13b-chat-hf"],
    [/llama.*?2.*?70b.*?chat/i, "meta-llama/Llama-2-70b-chat-hf"],
    [/tinyllama/i, "TinyLlama/TinyLlama-1.1B-Chat-v1.0"],
    
    # Gemma models
    [/gemma.*?2.*?9b/i, "google/gemma-2-9b"],
    [/gemma.*?2.*?2b/i, "google/gemma-2-2b"],
    [/gemma.*?7b/i, "google/gemma-7b"],
    [/gemma.*?2b/i, "google/gemma-2b"],
    
    # Qwen models
    [/qwen.*?3.*?72b/i, "Qwen/Qwen3-72B"],
    [/qwen.*?3.*?32b/i, "Qwen/Qwen3-32B"],
    [/qwen.*?3.*?14b/i, "Qwen/Qwen3-14B"],
    [/qwen.*?3.*?8b/i, "Qwen/Qwen3-8B"],
    [/qwen.*?3.*?4b/i, "Qwen/Qwen3-4B"],
    [/qwen.*?3.*?1\.8b/i, "Qwen/Qwen3-1.8B"],
    [/qwen.*?3.*?0\.5b/i, "Qwen/Qwen3-0.5B"],
    [/qwen.*?2\.5/i, "Qwen/Qwen2.5-0.5B"],
    [/qwen.*?2/i, "Qwen/Qwen2-1.5B"],
    [/qwen/i, "Qwen/Qwen-1_8B"],
    
    # Phi models (order matters - more specific patterns first)
    [/phi.*?3\.5.*?mini/i, "microsoft/Phi-3.5-mini-instruct"],
    [/phi.*?3.*?mini.*?4k/i, "microsoft/Phi-3-mini-4k-instruct"],
    [/phi.*?3.*?medium/i, "microsoft/Phi-3-medium-4k-instruct"],
    [/phi.*?3.*?small/i, "microsoft/Phi-3-small-8k-instruct"],
    [/phi.*?3.*?mini/i, "microsoft/Phi-3-mini-4k-instruct"],
    [/phi.*?3/i, "microsoft/Phi-3-mini-4k-instruct"],
    [/phi-4/i, "microsoft/phi-4"],
    [/phi.*?2/i, "microsoft/phi-2"],
    [/phi.*?1\.5/i, "microsoft/phi-1_5"],
    [/phi/i, "microsoft/phi-2"]
  ]
}

Class Method Summary collapse

Instance Method Summary collapse

Class Method Details

.from_pretrained(model_id, device: Candle::Device.best, gguf_file: nil, tokenizer: nil) ⇒ Object



248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
# File 'lib/candle/llm.rb', line 248

def self.from_pretrained(model_id, device: Candle::Device.best, gguf_file: nil, tokenizer: nil)
  model_str = if gguf_file
    "#{model_id}@#{gguf_file}"
  else
    model_id
  end
  
  # Handle GGUF models that need tokenizer
  if model_str.downcase.include?("gguf") && tokenizer.nil?
    # Try to load without tokenizer first
    begin
      _from_pretrained(model_str, device)
    rescue => e
      if e.message.include?("No tokenizer found")
        # Auto-detect tokenizer
        detected_tokenizer = guess_tokenizer(model_id)
        warn "No tokenizer found in GGUF repo. Using tokenizer from: #{detected_tokenizer}"
        model_str = "#{model_str}@@#{detected_tokenizer}"
        _from_pretrained(model_str, device)
      else
        raise e
      end
    end
  elsif tokenizer
    # User specified tokenizer
    model_str = "#{model_str}@@#{tokenizer}"
    _from_pretrained(model_str, device)
  else
    # Non-GGUF model or GGUF with embedded tokenizer
    _from_pretrained(model_str, device)
  end
end

.guess_tokenizer(model_id) ⇒ Object

Guess the tokenizer for a model



165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# File 'lib/candle/llm.rb', line 165

def self.guess_tokenizer(model_id)
  # Check exact matches first
  return TOKENIZER_REGISTRY[model_id] if TOKENIZER_REGISTRY[model_id]
  
  # Check patterns
  if patterns = TOKENIZER_REGISTRY[:patterns]
    patterns.each do |pattern, tokenizer|
      return tokenizer if model_id.match?(pattern)
    end
  end
  
  # Default: try removing common GGUF suffixes
  base_model = model_id.gsub(/-gguf|-q\d+_\w+$/i, "")
  base_model
end

.register_tokenizer(model_pattern, tokenizer_id) ⇒ Object

Allow users to register custom tokenizer mappings



153
154
155
156
157
158
159
160
161
162
# File 'lib/candle/llm.rb', line 153

def self.register_tokenizer(model_pattern, tokenizer_id)
  if model_pattern.is_a?(String)
    TOKENIZER_REGISTRY[model_pattern] = tokenizer_id
  elsif model_pattern.is_a?(Regexp)
    TOKENIZER_REGISTRY[:patterns] ||= []
    TOKENIZER_REGISTRY[:patterns].unshift([model_pattern, tokenizer_id])
  else
    raise ArgumentError, "model_pattern must be a String or Regexp"
  end
end

Instance Method Details

#cached_eos_tokenObject

Cache for EOS token to avoid repeated calls



6
7
8
9
10
11
12
# File 'lib/candle/llm.rb', line 6

def cached_eos_token
  @cached_eos_token ||= begin
    if respond_to?(:eos_token)
      eos_token rescue nil
    end
  end
end

#chat(messages, **options) ⇒ Object

Simple chat interface for instruction models



182
183
184
185
# File 'lib/candle/llm.rb', line 182

def chat(messages, **options)
  prompt = apply_chat_template(messages)
  generate(prompt, **options)
end

#chat_stream(messages, **options, &block) ⇒ Object

Streaming chat interface



188
189
190
191
# File 'lib/candle/llm.rb', line 188

def chat_stream(messages, **options, &block)
  prompt = apply_chat_template(messages)
  generate_stream(prompt, **options, &block)
end

#constraint_from_regex(pattern) ⇒ Object

Create a structured constraint from a regex pattern



41
42
43
44
# File 'lib/candle/llm.rb', line 41

def constraint_from_regex(pattern)
  pattern_str = pattern.is_a?(Regexp) ? pattern.source : pattern.to_s
  StructuredConstraint.from_regex(pattern_str, tokenizer)
end

#constraint_from_schema(schema) ⇒ Object

Create a structured constraint from a JSON schema



35
36
37
38
# File 'lib/candle/llm.rb', line 35

def constraint_from_schema(schema)
  schema_str = schema.is_a?(String) ? schema : JSON.generate(schema)
  StructuredConstraint.from_schema(schema_str, tokenizer)
end

#generate(prompt, config: GenerationConfig.balanced, reset_cache: true) ⇒ Object



232
233
234
235
236
237
238
# File 'lib/candle/llm.rb', line 232

def generate(prompt, config: GenerationConfig.balanced, reset_cache: true)
  begin
    _generate(prompt, config)
  ensure
    clear_cache if reset_cache
  end
end

#generate_regex(prompt, pattern:, stop_on_match: true, **options) ⇒ Object

Generate with regex constraint



47
48
49
50
51
52
53
54
55
56
57
58
59
# File 'lib/candle/llm.rb', line 47

def generate_regex(prompt, pattern:, stop_on_match: true, **options)
  constraint = constraint_from_regex(pattern)
  
  # Configure generation with early stopping by default
  config_opts = options.merge(
    constraint: constraint,
    stop_on_constraint_satisfaction: options.fetch(:stop_on_constraint_satisfaction, stop_on_match),
    stop_on_match: stop_on_match
  )
  config = options[:config] || GenerationConfig.balanced(**config_opts)
  
  generate(prompt, config: config, reset_cache: options.fetch(:reset_cache, true))
end

#generate_stream(prompt, config: GenerationConfig.balanced, reset_cache: true, &block) ⇒ Object



240
241
242
243
244
245
246
# File 'lib/candle/llm.rb', line 240

def generate_stream(prompt, config: GenerationConfig.balanced, reset_cache: true, &block)
  begin
    _generate_stream(prompt, config, &block)
  ensure
    clear_cache if reset_cache
  end
end

#generate_structured(prompt, schema:, **options) ⇒ Object

Generate and parse structured output from a JSON schema



62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# File 'lib/candle/llm.rb', line 62

def generate_structured(prompt, schema:, **options)
  constraint = constraint_from_schema(schema)
  
  # Configure generation with early stopping by default
  config_opts = options.merge(
    constraint: constraint,
    stop_on_constraint_satisfaction: options.fetch(:stop_on_constraint_satisfaction, true)
  )
  config = options[:config] || GenerationConfig.balanced(**config_opts)
  
  result = generate(prompt, config: config, reset_cache: options.fetch(:reset_cache, true))
  
  # Try to parse as JSON
  begin
    # First, try to extract JSON if there's content after stop tokens
    json_content = extract_json_content(result)
    JSON.parse(json_content)
  rescue JSON::ParserError => e
    # Return the raw string if parsing fails
    warn "Warning: Generated output is not valid JSON: #{e.message}" if options[:warn_on_parse_error]
    result
  end
end

#inspectObject

Inspect method for debugging and exploration



194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
# File 'lib/candle/llm.rb', line 194

def inspect
  opts = options rescue {}
  
  # Extract key information
  model_type = opts["model_type"] || "Unknown"
  device = opts["device"] || self.device.to_s rescue "unknown"
  
  # Build the inspect string
  parts = ["#<Candle::LLM"]
  
  # Add base model or model_id
  if opts["base_model"]
    parts << "model=#{opts["base_model"]}"
  elsif opts["model_id"]
    parts << "model=#{opts["model_id"]}"
  elsif respond_to?(:model_id)
    parts << "model=#{model_id}"
  end
  
  # Add GGUF file if present
  if opts["gguf_file"]
    parts << "gguf=#{opts["gguf_file"]}"
  end
  
  # Add device
  parts << "device=#{device}"
  
  # Add model type
  parts << "type=#{model_type}"
  
  # Add architecture for GGUF models
  if opts["architecture"]
    parts << "arch=#{opts["architecture"]}"
  end
  
  parts.join(" ") + ">"
end

#model_eos_tokensObject

Get model-specific EOS tokens



15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# File 'lib/candle/llm.rb', line 15

def model_eos_tokens
  @model_eos_tokens ||= begin
    tokens = []
    if model_eos = cached_eos_token
      tokens << model_eos
      # For Gemma, also include end_of_turn for chat scenarios and </s>
      # Even though </s> is technically an HTML tag in Gemma's vocabulary,
      # it seems to use it as a generation boundary in practice
      if model_name.downcase.include?("gemma")
        tokens << "<end_of_turn>"
        tokens << "</s>"
      end
    else
      # Fallback to common tokens only if model doesn't provide one
      tokens = ["</s>", "<|endoftext|>", "<|im_end|>", "<end>"]
    end
    tokens.uniq
  end
end