Source code for qlinks.basis.solvers.brute_force
from __future__ import annotations
from dataclasses import dataclass
from itertools import product
from typing import Sequence
import numpy as np
from qlinks.basis.basis import Basis
from qlinks.constraints import Constraint, SectorCondition, all_satisfied
from qlinks.variables import VariableLayout
[docs]
@dataclass(frozen=True, slots=True)
class BruteForceBasisSolver:
"""
Exhaustive product-space basis solver.
This is simple and useful for tests, but it scales as
prod_i dim(local_space_i)
so it should only be used for small systems.
"""
sort: bool = False
[docs]
def solve(
self,
layout: VariableLayout,
constraints: Sequence[Constraint] = (),
sectors: Sequence[SectorCondition] = (),
*,
max_states: int | None = None,
) -> Basis:
if max_states is not None and max_states < 0:
raise ValueError("max_states must be non-negative or None.")
if max_states == 0:
return Basis.empty(layout)
domains = [layout.local_space(i).values.tolist() for i in range(layout.n_variables)]
states: list[np.ndarray] = []
for values in product(*domains):
config = np.asarray(values, dtype=np.int64)
if all_satisfied(config, constraints=constraints, sectors=sectors):
states.append(config.copy())
if max_states is not None and len(states) >= max_states:
break
if len(states) == 0:
return Basis.empty(layout)
return Basis.from_states(layout, np.vstack(states), sort=self.sort)