Models API Reference¶
Python API reference for ChimeraLM model classes and components.
Overview¶
ChimeraLM's model architecture consists of:
- Factory class:
ChimeraLMfor loading pretrained models - Base module:
ClassificationLitfor PyTorch Lightning integration
ChimeraLM Factory¶
Usage¶
Load Pretrained Model¶
from chimeralm.models.lm import ChimeraLM
# Load from Hugging Face Hub (default)
model = ChimeraLM.from_pretrained("yangliz5/chimeralm")
# Load from local checkpoint
model = ChimeraLM.from_pretrained("/path/to/checkpoint.ckpt")
Create New Model¶
from chimeralm.models.lm import ChimeraLM
# Create new ChimeraLM instance
model = ChimeraLM.new(
model_name="hyena",
num_classes=2,
optimizer_config={"lr": 1e-4, "weight_decay": 0.01}
)
Methods¶
from_pretrained()¶
Load a pretrained ChimeraLM model.
Parameters:
model_name_or_path(str): Hugging Face model ID or local checkpoint path**kwargs: Additional arguments passed toClassificationLit
Returns:
ClassificationLit: Loaded model ready for inference or fine-tuning
Example:
# From Hugging Face Hub
model = ChimeraLM.from_pretrained("yangliz5/chimeralm")
# From local file
model = ChimeraLM.from_pretrained("logs/train/runs/2025-10-25/checkpoints/best.ckpt")
# With custom config
model = ChimeraLM.from_pretrained(
"yangliz5/chimeralm",
map_location="cpu" # Load on CPU
)
new()¶
Create a new ChimeraLM model instance.
Parameters:
model_name(str): Model architecture ("hyena","cnn","mamba")num_classes(int): Number of output classes (default: 2)optimizer_config(dict): Optimizer configuration**kwargs: Additional model-specific arguments
Returns:
ClassificationLit: New model instance
Example:
# HyenaDNA model
model = ChimeraLM.new(
model_name="hyena",
num_classes=2,
optimizer_config={"lr": 1e-4}
)
# CNN model
model = ChimeraLM.new(
model_name="cnn",
num_classes=2,
optimizer_config={"lr": 1e-3}
)
Training Loop¶
The module implements standard PyTorch Lightning hooks:
import lightning as L
# Create trainer
trainer = L.Trainer(max_epochs=50, accelerator="gpu", devices=1)
# Train model
trainer.fit(model, datamodule=data_module)
# Test model
trainer.test(model, datamodule=data_module)
Methods¶
forward()¶
Forward pass through the model.
Parameters:
x(torch.Tensor): Input sequence tensor (shape:[batch_size, seq_length])
Returns:
torch.Tensor: Logits (shape:[batch_size, num_classes])
Example:
import torch
# Input: batch of 16 sequences, each 1024 tokens
x = torch.randint(0, 5, (16, 1024))
# Forward pass
logits = model(x) # Shape: [16, 2]
# Get predictions
predictions = torch.argmax(logits, dim=-1) # Shape: [16]
training_step()¶
Training step for one batch.
Parameters:
batch(dict): Batch dictionary with"input_ids"and"labels"batch_idx(int): Batch index
Returns:
torch.Tensor: Loss value
validation_step()¶
Validation step for one batch.
Parameters:
batch(dict): Batch dictionarybatch_idx(int): Batch index
test_step()¶
Test step for one batch.
Parameters:
batch(dict): Batch dictionarybatch_idx(int): Batch index
Metrics¶
The module logs the following metrics:
Training:
train/loss: Cross-entropy losstrain/acc: Accuracy
Validation:
val/loss: Validation lossval/acc: Validation accuracyval/precision: Precision scoreval/recall: Recall scoreval/f1: F1 score
Test:
test/loss: Test losstest/acc: Test accuracytest/precision: Precision scoretest/recall: Recall scoretest/f1: F1 score
Callbacks¶
PredictionWriter¶
Custom Lightning callback for writing predictions to disk.
Usage:
from chimeralm.models.callbacks import PredictionWriter
import lightning as L
# Create callback
writer = PredictionWriter(
output_dir="predictions/",
write_interval="batch"
)
# Use with trainer
trainer = L.Trainer(callbacks=[writer])
trainer.predict(model, datamodule=data_module)
Complete Example: Inference¶
from chimeralm.models.lm import ChimeraLM
import torch
# 1. Load pretrained model
model = ChimeraLM.from_pretrained("yangliz5/chimeralm")
model.eval()
# 2. Prepare input (tokenized sequences)
# Assume we have tokenized DNA sequences
input_ids = torch.randint(0, 5, (4, 1024)) # 4 sequences, 1024 tokens each
# 3. Run inference
with torch.no_grad():
logits = model(input_ids)
predictions = torch.argmax(logits, dim=-1)
# 4. Interpret predictions
for i, pred in enumerate(predictions):
label = "Biological" if pred == 0 else "Chimeric"
print(f"Sequence {i}: {label}")
Output:
See Also¶
- CLI Commands - Command-line interface