Skip to content

Neural Linear Bandit

NeuralLinearBandit(network, buffer, n_embedding_size, min_samples_required_for_training=1024, selector=None, train_batch_size=32, lazy_uncertainty_update=False, lambda_=1.0, eps=0.01, weight_decay=0.0, learning_rate=0.001, learning_rate_decay=1.0, learning_rate_scheduler_step_size=1, early_stop_threshold=0.001, initial_train_steps=1024, contextualization_after_network=False, n_arms=None, warm_start=True)

Bases: LinearTSBandit[ActionInputType]

Lightning Module implementing a Neural Linear bandit.

A Neural Linear bandit model consists of a neural network that produces embeddings of the input data and a linear head that is trained on the embeddings. Since updating the neural network which encodes the inputs into embeddings is computationally expensive, the neural network is only updated once more than min_samples_required_for_training samples have been collected. Otherwise, only the linear head is updated.

References
ActionInputType

The type of the input data to the neural network. Can be a single tensor or a tuple of tensors.

Parameters:

Name Type Description Default
network Module

The neural network to be used to encode the input data into an embedding.

required
buffer AbstractBanditDataBuffer[ActionInputType, Any] | None

The buffer used for storing the data for continuously updating the neural network and storing the embeddings for the linear head.

required
n_embedding_size int

The size of the embedding produced by the neural network. Must be greater than 0. If contextualization_after_network is True, n_embedding_size is the size of the output of the network * n_arms (Using disjoint contextualization).

required
selector AbstractSelector | None

The selector used to choose the best action. Default is ArgMaxSelector (if None).

None
train_batch_size int

The batch size for the neural network update. Must be greater than 0.

32
min_samples_required_for_training int | None

The interval (in steps) at which the neural network is updated. None means the neural network is never updated. If not None, it must be greater than 0. Must Default is 1024.

1024
lazy_uncertainty_update bool

If True the precision matrix will not be updated during forward, but during the update step.

False
lambda_ float

The regularization parameter for the linear head. Must be greater than 0.

1.0
eps float

Small value to ensure invertibility of the precision matrix. Added to the diagonal. Must be greater than 0.

0.01
learning_rate float

The learning rate for the optimizer of the neural network. Passed to lr of torch.optim.Adam. Must be greater than 0.

0.001
weight_decay float

The regularization parameter for the neural network. Passed to weight_decay of torch.optim.Adam. Must be greater equal 0.

0.0
learning_rate_decay float

Multiplicative factor for learning rate decay. Passed to gamma of torch.optim.lr_scheduler.StepLR. Default is 1.0 (i.e. no decay). Must be greater than 0.

1.0
learning_rate_scheduler_step_size int

The step size for the learning rate decay. Passed to step_size of torch.optim.lr_scheduler.StepLR. Must be greater than 0. The learning rate scheduler is called every time the neural network is updated.

1
early_stop_threshold float | None

Loss threshold for early stopping. None to disable. Must be greater equal 0.

0.001
initial_train_steps int

Number of initial training steps (in samples). Defaults to 1024. Must be greater equal 0.

1024
contextualization_after_network bool

If True, the contextualization is applied after the network. Useful for situations where you want to use the model for retrieving an embedding then use this single embedding for multiple actions.

False
n_arms int | None

The number of arms to contextualize after the network. Only needed if contextualization_after_network is True. Else the number of arms is determined by the input data. Must be greater equal 0.

None
warm_start bool

If False the parameters of the network are reset in order to be retrained from scratch using network.reset_parameters() everytime a retraining of the network occurs. If True the network is trained from the current state.

True
Source code in src/calvera/bandits/neural_linear_bandit.py
def __init__(
    self,
    network: torch.nn.Module,
    buffer: AbstractBanditDataBuffer[ActionInputType, Any] | None,
    n_embedding_size: int,
    min_samples_required_for_training: int | None = 1024,
    selector: AbstractSelector | None = None,
    train_batch_size: int = 32,
    lazy_uncertainty_update: bool = False,
    lambda_: float = 1.0,
    eps: float = 1e-2,
    weight_decay: float = 0.0,
    learning_rate: float = 1e-3,
    learning_rate_decay: float = 1.0,
    learning_rate_scheduler_step_size: int = 1,
    early_stop_threshold: float | None = 1e-3,
    initial_train_steps: int = 1024,
    contextualization_after_network: bool = False,
    n_arms: int | None = None,
    warm_start: bool = True,
) -> None:
    """Initializes the NeuralLinearBanditModule.

    Args:
        network: The neural network to be used to encode the input data into an embedding.
        buffer: The buffer used for storing the data for continuously updating the neural network and
            storing the embeddings for the linear head.
        n_embedding_size: The size of the embedding produced by the neural network. Must be greater than 0.
            If `contextualization_after_network` is `True`, `n_embedding_size` is the size of the output of the
            network * n_arms (Using disjoint contextualization).
        selector: The selector used to choose the best action. Default is `ArgMaxSelector` (if None).
        train_batch_size: The batch size for the neural network update. Must be greater than 0.
        min_samples_required_for_training: The interval (in steps) at which the neural network is updated.
            None means the neural network is never updated. If not None, it must be greater than 0.
            Must Default is 1024.
        lazy_uncertainty_update: If `True` the precision matrix will not be updated during forward, but during the
            update step.
        lambda_: The regularization parameter for the linear head. Must be greater than 0.
        eps: Small value to ensure invertibility of the precision matrix. Added to the diagonal.
            Must be greater than 0.
        learning_rate: The learning rate for the optimizer of the neural network.
            Passed to `lr` of `torch.optim.Adam`.
            Must be greater than 0.
        weight_decay: The regularization parameter for the neural network.
            Passed to `weight_decay` of `torch.optim.Adam`.
            Must be greater equal 0.
        learning_rate_decay: Multiplicative factor for learning rate decay.
            Passed to `gamma` of `torch.optim.lr_scheduler.StepLR`.
            Default is 1.0 (i.e. no decay). Must be greater than 0.
        learning_rate_scheduler_step_size: The step size for the learning rate decay.
            Passed to `step_size` of `torch.optim.lr_scheduler.StepLR`.
            Must be greater than 0.
            The learning rate scheduler is called every time the neural network is updated.
        early_stop_threshold: Loss threshold for early stopping. None to disable.
            Must be greater equal 0.
        initial_train_steps: Number of initial training steps (in samples).
            Defaults to 1024. Must be greater equal 0.
        contextualization_after_network: If `True`, the contextualization is applied after the network. Useful for
            situations where you want to use the model for retrieving an embedding then use this single embedding
            for multiple actions.
        n_arms: The number of arms to contextualize after the network. Only needed if
            `contextualization_after_network` is `True`. Else the number of arms is determined by the input data.
            Must be greater equal 0.
        warm_start: If `False` the parameters of the network are reset in order to be retrained from scratch using
            `network.reset_parameters()` everytime a retraining of the network occurs. If `True` the network is
            trained from the current state.
    """
    assert n_embedding_size > 0, "The embedding size must be greater than 0."
    assert min_samples_required_for_training is None or min_samples_required_for_training > 0, (
        "The min_samples_required_for_training must be greater than 0."
        "Set it to None to never update the neural network."
    )
    assert lambda_ > 0, "The lambda_ must be greater than 0."
    assert eps > 0, "The eps must be greater than 0."
    assert weight_decay >= 0, "The weight_decay must be greater equal 0."
    assert learning_rate > 0, "The learning rate must be greater than 0."
    assert learning_rate_decay >= 0, "The learning rate decay must be greater equal 0."
    assert learning_rate_scheduler_step_size > 0, "The learning rate decay step size must be greater than 0."
    assert (
        early_stop_threshold is None or early_stop_threshold >= 0
    ), "Early stop threshold must be greater than or equal to 0."
    assert initial_train_steps >= 0, "Initial training steps must be greater than or equal to 0."

    assert (
        not contextualization_after_network or n_arms is not None
    ), "`n_arms` need to be provided when performing `contextualization_after_network`."
    n_linear_features = (
        n_embedding_size if not contextualization_after_network else cast(int, n_arms) * n_embedding_size
    )

    super().__init__(
        n_features=n_linear_features,
        selector=selector,
        buffer=buffer,
        train_batch_size=train_batch_size,
        eps=eps,
        lambda_=lambda_,
        lazy_uncertainty_update=lazy_uncertainty_update,
        clear_buffer_after_train=False,
    )

    self.save_hyperparameters(
        {
            "n_embedding_size": n_embedding_size,
            "min_samples_required_for_training": min_samples_required_for_training,
            "train_batch_size": train_batch_size,
            "weight_decay": weight_decay,
            "learning_rate": learning_rate,
            "learning_rate_decay": learning_rate_decay,
            "learning_rate_scheduler_step_size": learning_rate_scheduler_step_size,
            "early_stop_threshold": early_stop_threshold,
            "initial_train_steps": initial_train_steps,
            "contextualization_after_network": contextualization_after_network,
            "n_arms": n_arms,
            "warm_start": warm_start,
        }
    )

    if contextualization_after_network:
        n_embedding_size *= cast(int, n_arms)

    self.network = network.to(self.device)

    self.register_buffer(
        "contextualized_actions", torch.empty(0, device=self.device)
    )  # shape: (buffer_size, n_parts, n_network_input_size)
    self.register_buffer(
        "embedded_actions", torch.empty(0, device=self.device)
    )  # shape: (buffer_size, n_network_input_size)
    self.register_buffer("rewards", torch.empty(0, device=self.device))  # shape: (buffer_size,)

    # Disable Lightning's automatic optimization. Has to be kept in sync with should_train_network.
    self.automatic_optimization = False

    self.contextualizer: MultiClassContextualizer | None = None
    if self.hparams["contextualization_after_network"]:
        assert n_arms is not None, "The number of arms must be provided if contextualization_after_network is True."

        assert n_embedding_size % n_arms == 0, (
            "If `contextualization_after_network` is True, `n_embedding_size` is the size of the output of the "
            "network * n_arms (Using disjoint contextualization)."
            "Therefore, `n_embedding_size` must be divisible by `n_arms`."
        )

        assert isinstance(buffer, ListDataBuffer), (
            "Currently only the `ListDataBuffer` supports" "`contextualization_after_network`."
        )

        self.contextualizer = MultiClassContextualizer(n_arms=n_arms)

    # We use this network to train the encoder model. We mock a linear head with the final layer of the encoder,
    # hence the single output dimension.
    self._helper_network = HelperNetwork(
        self.network,
        n_embedding_size,
        self.contextualizer,
    ).to(self.device)

    self._helper_network_init = self._helper_network.state_dict().copy() if not self.hparams["warm_start"] else None

HelperNetwork(network, output_size, contextualizer=None)

Bases: Module

A helper network that is used to train the neural network of the NeuralLinearBandit.

It adds a linear head to the neural network which mocks the linear head of the NeuralLinearBandit, hence the single output dimension of the linear layer. This allows for training an embedding which is useful for the linear head of the NeuralLinearBandit.

Parameters:

Name Type Description Default
network Module

The neural network to be used to encode the input data into an embedding.

required
output_size int

The size of the output of the neural network.

required
contextualizer MultiClassContextualizer | None

If provided disjoint model contextualization will be applied to the embeddings.

None
Source code in src/calvera/bandits/neural_linear_bandit.py
def __init__(
    self, network: torch.nn.Module, output_size: int, contextualizer: MultiClassContextualizer | None = None
) -> None:
    """Initialize the HelperNetwork.

    Args:
        network: The neural network to be used to encode the input data into an embedding.
        output_size: The size of the output of the neural network.
        contextualizer: If provided disjoint model contextualization will be applied to the embeddings.
    """
    super().__init__()
    self.network = network
    self.linear_head = torch.nn.Linear(
        output_size, 1
    )  # mock linear head so we can learn an embedding that is useful for the linear head
    self.contextualizer = contextualizer