Rescorla-Wagner with Forgetting Tutorial
This tutorial extends the basic Rescorla-Wagner model by adding a forgetting mechanism for not-chosen actions. You’ll learn how to:
- Implement forgetting mechanisms in SPICE
- Work with multiple cognitive mechanisms simultaneously
- Understand how SPICE discovers interaction effects
Prerequisites
Before starting this tutorial, make sure you have:
- Completed the Basic Rescorla-Wagner Tutorial
- SPICE installed with all dependencies
- Understanding of basic reinforcement learning concepts
The Forgetting Mechanism
In real-world learning scenarios, humans tend to forget information about options they haven’t chosen recently. The forgetting mechanism models this by:
- Gradually decreasing the value of non-chosen actions
- Maintaining separate learning rates for chosen and non-chosen actions
- Allowing for dynamic adjustment of forgetting rates
Tutorial Contents
- Setting up the environment with forgetting
- Creating a Q-learning agent with forgetting
- Training SPICE with multiple mechanisms
- Analyzing the discovered equations
- Implementing custom forgetting mechanisms
Interactive Version
This is the static web version of the tutorial. For an interactive version:
- Go to the SPICE repository
- Navigate to
tutorials/2_rescorla-wagner_forgetting.ipynb
- Run the notebook in Jupyter
Full Tutorial
View or download the complete notebook
Step-by-Step Guide
1. Setup and Imports
import numpy as np
import torch
from spice.resources.bandits import BanditsDrift, AgentQ, create_dataset
# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
2. Create Environment and Agent
Now we’ll create an agent with forgetting:
# Set up the environment
n_actions = 2
sigma = 0.2
environment = BanditsDrift(sigma=sigma, n_actions=n_actions)
# Set up the agent with forgetting
agent = AgentQ(
n_actions=n_actions,
alpha_reward=0.3, # Learning rate for rewards
forget_rate=0.2, # Rate of forgetting for non-chosen actions
)
# Generate dataset
n_trials = 200
n_sessions = 256
dataset, _, _ = create_dataset(
agent=agent,
environment=environment,
n_trials=n_trials,
n_sessions=n_sessions,
)
3. Using the Precoded Forgetting RNN
SPICE provides a precoded RNN that includes forgetting mechanisms:
from spice.precoded import ForgettingRNN, FORGETTING_CONFIG
from spice.estimator import SpiceEstimator
# Create and train SPICE model
spice_estimator = SpiceEstimator(
rnn_class=ForgettingRNN,
spice_config=FORGETTING_CONFIG,
hidden_size=8,
learning_rate=5e-3,
epochs=16,
verbose=True
)
spice_estimator.fit(dataset.xs, dataset.ys)
4. Analyzing the Results
Extract and examine the learned features:
features = spice_estimator.spice_agent.get_spice_features()
for id, feat in features.items():
print(f"\nAgent {id}:")
for model_name, (feat_names, coeffs) in feat.items():
print(f" {model_name}:")
for name, coeff in zip(feat_names, coeffs):
print(f" {name}: {coeff}")
5. Custom Forgetting Mechanisms
You can also implement your own forgetting mechanism:
from spice.estimator import SpiceConfig
CUSTOM_FORGETTING_CONFIG = SpiceConfig(
library_setup={
'x_value_reward': ['c_reward'],
'x_value_forget': ['c_action'],
},
filter_setup={
'x_value_reward': ['c_action', 1, True],
'x_value_forget': ['c_action', 0, True],
},
control_parameters=['c_action', 'c_reward'],
rnn_modules=['x_value_reward', 'x_value_forget']
)
Understanding the Results
When analyzing the results, look for:
- Forgetting Rate: The coefficient that determines how quickly non-chosen values decay
- Interaction Effects: How forgetting interacts with reward learning
- Value Updates: Different update rules for chosen vs non-chosen actions
Common Patterns
The model typically discovers:
- Faster learning rates for chosen actions
- Gradual decay for non-chosen actions
- Balance between exploration and exploitation
Next Steps
After completing this tutorial, you can:
- Experiment with different forgetting rates
- Implement more complex forgetting mechanisms
- Move on to Working with Hardcoded Equations
Common Issues and Solutions
- Unstable Learning: Try reducing the learning rate or increasing batch size
- Poor Forgetting: Adjust the forgetting rate or increase training data
- Convergence Issues: Increase the number of epochs or adjust optimizer parameters