import numpy as np
from collections import Counter
from typing import Optional, List, Sequence
from dataclasses import dataclass
import matplotlib.pyplot as plt
###########################################################################################
## MCMC Chain and States ##
###########################################################################################
@dataclass
[docs]
class MCMCState:
"""
Represents a single step in an MCMC trajectory.
Stores the proposed configuration, whether it was accepted by the
Metropolis rule, its energy, and the position of the step in the chain.
"""
# TODO: update this
# class MCMCState:
# def __init__(self, bits):
# # Store the "source of truth" as a compact array
# self._bits = np.asanyarray(bits, dtype=np.int8)
# @property
# def bits(self) -> np.ndarray:
# return self._bits
# @property
# def spinstring(self) -> str:
# # Map 0 -> +1, 1 -> -1 (or vice versa) and join
# spins = np.where(self._bits == 0, '+', '-')
# return "".join(spins)
# @property
# def bitstring(self) -> str:
# return "".join(self._bits.astype(str))
# @property
# def spins(self) -> np.ndarray:
# return 1 - 2 * self._bits
@dataclass(init=True)
[docs]
class MCMCChain:
"""
Container for the sequence of states produced during an MCMC run.
This class records all proposed states, tracks accepted configurations,
and provides helper methods for extracting trajectories, energies, and
empirical distributions from the Markov chain.
"""
def __init__(self, states: Optional[List[MCMCState]] = None, name: Optional[str] = "MCMC"):
if states is None:
self._states: List[MCMCState] = []
self._current_state: MCMCState = None
self._states_accepted: List[MCMCState] = []
self.markov_chain: List[str] = []
else:
self._states = states
self._current_state: MCMCState = next((s for s in self._states[::-1] if s.accepted), None)
self._states_accepted: List[MCMCState] = [state for state in states if state.accepted]
self.markov_chain: List[str] = self.get_list_markov_chain()
[docs]
def add_state(self, state: MCMCState):
if state.accepted:
self._current_state = state
self._states_accepted.append(state)
self.markov_chain.append(self._current_state.bitstring)
self._states.append(state)
@property
[docs]
def states(self):
return self._states
[docs]
def get_accepted_energies(self):
self.accepted_energies = []
self.accepted_positions = []
for state in self._states_accepted:
self.accepted_energies.append(state.energy)
self.accepted_positions.append(state.position)
self.accepted_positions = np.array(self.accepted_positions)
self.accepted_energies = np.array(self.accepted_energies)
return self.accepted_energies, self.accepted_positions
[docs]
def get_current_energy_array(self):
# returns the array of current energy across the entire number of hops
# ie returns the last accepted energy
# Useful for plotting etc
current_energy_array = []
for state in self._states:
if state.accepted:
current_energy_array.append(state.energy)
else:
current_energy_array.append(current_energy_array[-1])
return np.array(current_energy_array)
[docs]
def get_pos_array(self):
# returns the array of current pos across the entire number of hops
# Useful for plotting etc
pos_array = []
for state in self._states:
pos_array.append(state.position)
return np.array(pos_array)
[docs]
def get_current_state_array(self):
# returns the array of current state across the entire number of hops
# ie returns the last accepted state
# Useful for plotting etc
current_state_array = []
for state in self._states:
if state.accepted:
current_state_array.append(state.bitstring)
else:
current_state_array.append(current_state_array[-1])
return np.array(current_state_array)
[docs]
def get_all_energies(self):
self.energies = []
for state in self._states:
self.energies.append(state.energy)
return self.energies
@property
[docs]
def current_state(self):
return self._current_state
@property
[docs]
def accepted_states(self) -> List[str]:
return [state.bitstring for state in self._states_accepted]
[docs]
def get_list_markov_chain(self) -> List[str]:
markov_chain_in_state = [self.states[0].bitstring]
for i in range(1, len(self.states)):
mcmc_state = self.states[i].bitstring
accepted = self.states[i].accepted
if accepted:
markov_chain_in_state.append(mcmc_state)
else:
markov_chain_in_state.append(markov_chain_in_state[i - 1])
self.markov_chain = markov_chain_in_state
return self.markov_chain
[docs]
def get_accepted_dict(self, normalize: bool = False, until_index: int = -1):
if until_index != -1:
accepted_states = self.markov_chain[:until_index]
else:
accepted_states = self.markov_chain
if normalize:
length = len(accepted_states)
accepted_dict = Counter({s: count / length for s, count in Counter(accepted_states).items()})
else:
accepted_dict = Counter(accepted_states)
return accepted_dict
[docs]
def plot_chains(chains: list[MCMCChain], color: str, label: str, plot_individual_chains: bool = True):
for chain in chains:
energies = chain.get_current_energy_array()
pos = chain.get_pos_array()
if plot_individual_chains:
plt.plot(pos, energies, color=color, alpha=0.1)
avg_energy = sum(chain.get_current_energy_array() for chain in chains) / len(chains)
plt.plot(pos, avg_energy, color=color, label=f"Average {label}")
[docs]
def get_random_state(num_spins: int) -> str:
"""
Generate a random state for a given number of spins.
Parameters:
num_spins (int): The number of spins in the system.
Returns:
str: A bitstring representing the random state.
"""
next_state = np.random.randint(0, 2, size=num_spins, dtype=np.int8)
# s_prime = f"{next_state:0{num_spins}b}"
s_prime = "".join(str(bit) for bit in next_state)
return s_prime
[docs]
def get_all_possible_states(num_spins: int) -> list:
"""
Returns all possible binary strings of length n=num_spins
Paremeters:
num_spins: n length of the bitstring
Returns:
list: A list of all possible binary strings of length num_spins.
"""
num_possible_states = 2 ** (num_spins)
possible_states = [f"{k:0{num_spins}b}" for k in range(0, num_possible_states)]
return possible_states
[docs]
def magnetization_of_state(bitstring: str) -> float:
"""
Parmeters:
bitstring: for eg: '010'
Returns:
float: Magnetization for the given bitstring
"""
if type(bitstring) is not str:
raise TypeError("bitstring must be a string in magnetization_of_state")
array = np.array(list(bitstring))
num_times_one = np.count_nonzero(array == "1")
num_times_zero = len(array) - num_times_one
magnetization = num_times_one - num_times_zero
n_spins = len(array)
return magnetization / n_spins
[docs]
def dict_magnetization_of_all_states(list_all_possible_states: list) -> dict:
"""
Returns magnetization for all unique states
Parameters:
list_all_possible_states
Returns:
dict: A dictionary mapping each state to its magnetization value.
"""
list_mag_vals = [magnetization_of_state(state) for state in list_all_possible_states]
dict_magnetization = dict(zip(list_all_possible_states, list_mag_vals))
return dict_magnetization
[docs]
def hamming_dist(str1, str2):
i = 0
count = 0
while i < len(str1):
if str1[i] != str2[i]:
count += 1
i += 1
return count
# ###########################################################################################
# ======================= Coarse Graining helper functions: ===============================
# ###########################################################################################
@dataclass
[docs]
class CoarseGrainingConfig:
[docs]
subgroups: list[list[int]]
[docs]
subgroup_probs: list[float]
[docs]
def validate_subgroups(subgroups, subgroup_probs, n_spins):
"""
Validate coarse-graining subgroups.
Requirements:
- subgroups is a non-empty list of non-empty sequences
- each element is an int in [0, n-1]
- each subgroup has no duplicate indices
- every spin from 0 to n-1 appears in at least one subgroup
Raises
------
ValueError/TypeError
"""
if not isinstance(subgroups, list) or len(subgroups) == 0:
raise ValueError("subgroups must be a non-empty list")
for g in subgroups:
if not isinstance(g, Sequence) or len(g) == 0:
raise ValueError("each subgroup must be a non-empty list")
if not all(isinstance(i, int) for i in g):
raise TypeError("subgroup indices must be integers")
if not all(0 <= i < n_spins for i in g):
raise ValueError("subgroup index out of bounds")
if len(set(g)) != len(g):
raise ValueError("duplicate indices inside a subgroup")
covered = set(i for g in subgroups for i in g)
if covered != set(range(n_spins)):
raise ValueError("subgroups must cover all spins exactly once or at least once")
if subgroup_probs is not None:
if len(subgroup_probs) != len(subgroups):
raise ValueError("subgroup_probs must match subgroups length")
if any(p < 0 for p in subgroup_probs):
raise ValueError("subgroup_probs must be non-negative")
if not np.isclose(sum(subgroup_probs), 1.0):
raise ValueError("subgroup_probs must sum to 1")