Class: Candle::LLM

Inherits:
Object
  • Object
show all
Defined in:
lib/candle/llm.rb

Constant Summary collapse

TOKENIZER_REGISTRY =

Tokenizer registry for automatic detection

{
  # Exact model matches
  "TheBloke/Mistral-7B-Instruct-v0.2-GGUF" => "mistralai/Mistral-7B-Instruct-v0.2",
  "TheBloke/Mistral-7B-v0.1-GGUF" => "mistralai/Mistral-7B-v0.1",
  "TheBloke/Llama-2-7B-Chat-GGUF" => "meta-llama/Llama-2-7b-chat-hf",
  "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF" => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
  
  # Qwen official GGUF models
  "Qwen/Qwen3-8B-GGUF" => "Qwen/Qwen3-8B",
  "Qwen/Qwen3-4B-GGUF" => "Qwen/Qwen3-4B",
  "Qwen/Qwen3-14B-GGUF" => "Qwen/Qwen3-14B",
  "Qwen/Qwen3-32B-GGUF" => "Qwen/Qwen3-32B",
  "Qwen/Qwen3-72B-GGUF" => "Qwen/Qwen3-72B",
  
  # Phi GGUF models
  "TheBloke/phi-2-GGUF" => "microsoft/phi-2",
  "microsoft/phi-4-gguf" => "microsoft/phi-4",
  "bartowski/Phi-3.5-mini-instruct-GGUF" => "microsoft/Phi-3.5-mini-instruct",
  
  # Pattern-based fallbacks (evaluated in order)
  :patterns => [
    # Mistral models
    [/mistral.*?7b.*?instruct.*?v0\.2/i, "mistralai/Mistral-7B-Instruct-v0.2"],
    [/mistral.*?7b.*?instruct.*?v0\.1/i, "mistralai/Mistral-7B-Instruct-v0.1"],
    [/mistral.*?7b/i, "mistralai/Mistral-7B-v0.1"],
    
    # Llama models
    [/llama.*?3.*?8b/i, "meta-llama/Meta-Llama-3-8B"],
    [/llama.*?3.*?70b/i, "meta-llama/Meta-Llama-3-70B"],
    [/llama.*?2.*?7b.*?chat/i, "meta-llama/Llama-2-7b-chat-hf"],
    [/llama.*?2.*?13b.*?chat/i, "meta-llama/Llama-2-13b-chat-hf"],
    [/llama.*?2.*?70b.*?chat/i, "meta-llama/Llama-2-70b-chat-hf"],
    [/tinyllama/i, "TinyLlama/TinyLlama-1.1B-Chat-v1.0"],
    
    # Gemma models
    [/gemma.*?2.*?9b/i, "google/gemma-2-9b"],
    [/gemma.*?2.*?2b/i, "google/gemma-2-2b"],
    [/gemma.*?7b/i, "google/gemma-7b"],
    [/gemma.*?2b/i, "google/gemma-2b"],
    
    # Qwen models
    [/qwen.*?3.*?72b/i, "Qwen/Qwen3-72B"],
    [/qwen.*?3.*?32b/i, "Qwen/Qwen3-32B"],
    [/qwen.*?3.*?14b/i, "Qwen/Qwen3-14B"],
    [/qwen.*?3.*?8b/i, "Qwen/Qwen3-8B"],
    [/qwen.*?3.*?4b/i, "Qwen/Qwen3-4B"],
    [/qwen.*?3.*?1\.8b/i, "Qwen/Qwen3-1.8B"],
    [/qwen.*?3.*?0\.5b/i, "Qwen/Qwen3-0.5B"],
    [/qwen.*?2\.5/i, "Qwen/Qwen2.5-0.5B"],
    [/qwen.*?2/i, "Qwen/Qwen2-1.5B"],
    [/qwen/i, "Qwen/Qwen-1_8B"],
    
    # Phi models (order matters - more specific patterns first)
    [/phi.*?3\.5.*?mini/i, "microsoft/Phi-3.5-mini-instruct"],
    [/phi.*?3.*?mini.*?4k/i, "microsoft/Phi-3-mini-4k-instruct"],
    [/phi.*?3.*?medium/i, "microsoft/Phi-3-medium-4k-instruct"],
    [/phi.*?3.*?small/i, "microsoft/Phi-3-small-8k-instruct"],
    [/phi.*?3.*?mini/i, "microsoft/Phi-3-mini-4k-instruct"],
    [/phi.*?3/i, "microsoft/Phi-3-mini-4k-instruct"],
    [/phi-4/i, "microsoft/phi-4"],
    [/phi.*?2/i, "microsoft/phi-2"],
    [/phi.*?1\.5/i, "microsoft/phi-1_5"],
    [/phi/i, "microsoft/phi-2"]
  ]
}

Class Method Summary collapse

Instance Method Summary collapse

Class Method Details

.from_pretrained(model_id, device: Candle::Device.cpu, gguf_file: nil, tokenizer: nil) ⇒ Object



196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
# File 'lib/candle/llm.rb', line 196

def self.from_pretrained(model_id, device: Candle::Device.cpu, gguf_file: nil, tokenizer: nil)
  model_str = if gguf_file
    "#{model_id}@#{gguf_file}"
  else
    model_id
  end
  
  # Handle GGUF models that need tokenizer
  if model_str.downcase.include?("gguf") && tokenizer.nil?
    # Try to load without tokenizer first
    begin
      _from_pretrained(model_str, device)
    rescue => e
      if e.message.include?("No tokenizer found")
        # Auto-detect tokenizer
        detected_tokenizer = guess_tokenizer(model_id)
        warn "No tokenizer found in GGUF repo. Using tokenizer from: #{detected_tokenizer}"
        model_str = "#{model_str}@@#{detected_tokenizer}"
        _from_pretrained(model_str, device)
      else
        raise e
      end
    end
  elsif tokenizer
    # User specified tokenizer
    model_str = "#{model_str}@@#{tokenizer}"
    _from_pretrained(model_str, device)
  else
    # Non-GGUF model or GGUF with embedded tokenizer
    _from_pretrained(model_str, device)
  end
end

.guess_tokenizer(model_id) ⇒ Object

Guess the tokenizer for a model



145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# File 'lib/candle/llm.rb', line 145

def self.guess_tokenizer(model_id)
  # Check exact matches first
  return TOKENIZER_REGISTRY[model_id] if TOKENIZER_REGISTRY[model_id]
  
  # Check patterns
  if patterns = TOKENIZER_REGISTRY[:patterns]
    patterns.each do |pattern, tokenizer|
      return tokenizer if model_id.match?(pattern)
    end
  end
  
  # Default: try removing common GGUF suffixes
  base_model = model_id.gsub(/-gguf|-q\d+_\w+$/i, "")
  base_model
end

.register_tokenizer(model_pattern, tokenizer_id) ⇒ Object

Allow users to register custom tokenizer mappings



133
134
135
136
137
138
139
140
141
142
# File 'lib/candle/llm.rb', line 133

def self.register_tokenizer(model_pattern, tokenizer_id)
  if model_pattern.is_a?(String)
    TOKENIZER_REGISTRY[model_pattern] = tokenizer_id
  elsif model_pattern.is_a?(Regexp)
    TOKENIZER_REGISTRY[:patterns] ||= []
    TOKENIZER_REGISTRY[:patterns].unshift([model_pattern, tokenizer_id])
  else
    raise ArgumentError, "model_pattern must be a String or Regexp"
  end
end

Instance Method Details

#chat(messages, **options) ⇒ Object

Simple chat interface for instruction models



162
163
164
165
# File 'lib/candle/llm.rb', line 162

def chat(messages, **options)
  prompt = apply_chat_template(messages)
  generate(prompt, **options)
end

#chat_stream(messages, **options, &block) ⇒ Object

Streaming chat interface



168
169
170
171
# File 'lib/candle/llm.rb', line 168

def chat_stream(messages, **options, &block)
  prompt = apply_chat_template(messages)
  generate_stream(prompt, **options, &block)
end

#constraint_from_regex(pattern) ⇒ Object

Create a structured constraint from a regex pattern



12
13
14
15
# File 'lib/candle/llm.rb', line 12

def constraint_from_regex(pattern)
  pattern_str = pattern.is_a?(Regexp) ? pattern.source : pattern.to_s
  StructuredConstraint.from_regex(pattern_str, tokenizer)
end

#constraint_from_schema(schema) ⇒ Object

Create a structured constraint from a JSON schema



6
7
8
9
# File 'lib/candle/llm.rb', line 6

def constraint_from_schema(schema)
  schema_str = schema.is_a?(String) ? schema : JSON.generate(schema)
  StructuredConstraint.from_schema(schema_str, tokenizer)
end

#generate(prompt, config: GenerationConfig.balanced, reset_cache: true) ⇒ Object



173
174
175
176
177
178
179
180
181
182
183
184
185
186
# File 'lib/candle/llm.rb', line 173

def generate(prompt, config: GenerationConfig.balanced, reset_cache: true)
  begin
    result = _generate(prompt, config)
    
    # If there's a constraint, clean up common EOS tokens that appear after the constrained content
    if config.constraint
      result = result.gsub(/(<\/s>|<\|endoftext\|>|<\|im_end\|>|<end>).*$/m, '').strip
    end
    
    result
  ensure
    clear_cache if reset_cache
  end
end

#generate_regex(prompt, pattern:, **options) ⇒ Object

Generate with regex constraint



18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# File 'lib/candle/llm.rb', line 18

def generate_regex(prompt, pattern:, **options)
  constraint = constraint_from_regex(pattern)
  
  # Add common EOS tokens as stop sequences for regex generation
  stop_sequences = options[:stop_sequences] || []
  stop_sequences += ["</s>", "<|endoftext|>", "<|im_end|>", "<end>", "\n"] unless options[:no_auto_stop]
  
  config_opts = options.merge(constraint: constraint, stop_sequences: stop_sequences)
  config = options[:config] || GenerationConfig.balanced(**config_opts)
  
  result = generate(prompt, config: config, reset_cache: options.fetch(:reset_cache, true))
  
  # Clean up any trailing EOS tokens
  result.gsub(/(<\/s>|<\|endoftext\|>|<\|im_end\|>|<end>).*$/m, '').strip
end

#generate_stream(prompt, config: GenerationConfig.balanced, reset_cache: true, &block) ⇒ Object



188
189
190
191
192
193
194
# File 'lib/candle/llm.rb', line 188

def generate_stream(prompt, config: GenerationConfig.balanced, reset_cache: true, &block)
  begin
    _generate_stream(prompt, config, &block)
  ensure
    clear_cache if reset_cache
  end
end

#generate_structured(prompt, schema:, **options) ⇒ Object

Generate and parse structured output from a JSON schema



35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# File 'lib/candle/llm.rb', line 35

def generate_structured(prompt, schema:, **options)
  constraint = constraint_from_schema(schema)
  config_opts = options.merge(constraint: constraint)
  config = options[:config] || GenerationConfig.balanced(**config_opts)
  
  result = generate(prompt, config: config, reset_cache: options.fetch(:reset_cache, true))
  
  # Clean up the result - remove common end-of-sequence tokens
  # that might appear after valid JSON
  cleaned_result = result.gsub(/(<\/s>|<\|endoftext\|>|<\|im_end\|>|<end>).*$/m, '')
  
  # Try to parse as JSON
  begin
    JSON.parse(cleaned_result)
  rescue JSON::ParserError => e
    # If cleaning didn't help, try to extract JSON from the result
    # Look for the first complete JSON object/array
    if match = cleaned_result.match(/(\{[^{}]*\}|\[[^\[\]]*\])/m)
      begin
        return JSON.parse(match[1])
      rescue JSON::ParserError
        # Fall through to warning
      end
    end
    
    # Return the raw string if parsing fails
    warn "Warning: Generated output is not valid JSON: #{e.message}" if options[:warn_on_parse_error]
    result
  end
end