A super-simplified, single-file PyTorch replica of the Continuous Thought Machine (CTM) from Sakana AI.
Modern AI often ignores the timing and synchronization in biological brains for efficiency. CTM bridges that gap, using neural dynamics as the core of computation. This simplified version distills the essence of the original repo into one Python file (ctm.py), training on MNIST to classify digits while demonstrating key CTM concepts like neuron synchronization and temporal thinking.
Inspired by the original work: Continuous Thought Machines and Sakana AI.
-
Clone the repo:
git clone https://github.com/xandykati98/SimpleCTM.git cd SimpleCTM -
Install dependencies:
pip install -r requirements.txt
(Requires PyTorch, Torchvision – check
requirements.txtfor details)
Run the training script:
python ctm.pyThis will:
- Download MNIST
- Train the SimplifiedCTM model for 10 epochs
- Print progress and accuracy
- Save the model to
ctm_model.pth
Customize hyperparameters in main() – like number of neurons, max ticks, epochs.
Example output:
Using device: cpu
=== Model Information ===
Total trainable parameters: 276,052
...
Epoch 1/10 completed - Average Loss: 2.3026, Accuracy: 11.24%
...
Model saved to ctm_model.pth
- Image Encoding: Flattens and reduces MNIST images to a 4D vector.
- Internal Ticks: Loops over time steps, updating pre/post activations.
- Synapse Model: Connects neurons with image input.
- Neuron-Level Models: Each neuron processes its activation history.
- Synchronization: Computes weighted dot products between neuron pairs with learnable decay.
- Attention: Synchronizations query the encoded image.
- Prediction: Reads combined features for classification.
- Early Stopping: Stops thinking when confident (>80%) or max ticks reached.
For full details, dive into ctm.py or the original Sakana AI CTM page.
Built with ❤️ by xandykati98
