Source code for tropt.optimizer.utils.scheduler

from __future__ import annotations
import math
from abc import ABC, abstractmethod

"""
Schedulers for the n_flip parameter used in optimizers to control the number
of token positions flipped during each optimization step.
"""

[docs] class NFlipScheduler(ABC):
[docs] @abstractmethod def get_n_flip(self, step: int) -> int: """Returns the n_flip value for the given step (0-indexed).""" pass
[docs] class ConstantScheduler(NFlipScheduler): """ A scheduler that always returns the same n_flip value. """ def __init__(self, n_flip: int): self.n_flip = n_flip
[docs] def get_n_flip(self, step: int) -> int: return self.n_flip
[docs] class LinearScheduler(NFlipScheduler): """ A scheduler that linearly decreases n_flip from an initial value to 1 over the course of optimization steps, starting from a specified step. """ def __init__(self, initial_n_flip: int, total_steps: int, decline_start: int | float): self.initial_n_flip = initial_n_flip self.total_steps = total_steps if isinstance(decline_start, float): self.decline_start_step = int(total_steps * decline_start) else: self.decline_start_step = int(decline_start)
[docs] def get_n_flip(self, step: int) -> int: if step < self.decline_start_step: return self.initial_n_flip steps_remaining = self.total_steps - step decline_duration = self.total_steps - self.decline_start_step if decline_duration <= 0: return 1 ratio = steps_remaining / decline_duration return max(1, math.ceil(self.initial_n_flip * ratio))