Source code for tnpy.finite_tdvp

# flake8: noqa
from enum import Enum

from scipy.integrate import solve_ivp
from tensornetwork import Node

from tnpy import logger
from tnpy.linalg import qr
from tnpy.matrix_product_state import Environment
from tnpy.operators import MatrixProductOperator


class Evolve(Enum):
    FORWARD = 1
    BACKWARD = -1


[docs]class FiniteTDVP:
[docs] def __init__(self, mpo, chi, init_method): NotImplemented self.center_matrices = {}
def _unit_solver(self, proceed, t_span, site): def forward(t, y): M = Node(y.reshape(self.mps_shape(site))) W = self.mpo[site] if site == 0: Renv = self.right_envs[site] Rnorm = self.right_norms[site] Renv[0] ^ M[2] Renv[1] ^ W[0] M[1] ^ W[1] Renv[2] ^ Rnorm[0] result = M @ W @ Renv @ Rnorm elif site == self.N - 1: Lenv = self.left_envs[site] Lnorm = self.left_norms[site] Lenv[0] ^ M[0] Lenv[1] ^ W[0] M[1] ^ W[1] Lenv[2] ^ Lnorm[0] result = Lenv @ Lnorm @ M @ W else: Lenv = self.left_envs[site] Lnorm = self.left_norms[site] Renv = self.right_envs[site] Rnorm = self.right_norms[site] Lenv[0] ^ M[0] Lenv[1] ^ W[0] M[1] ^ W[2] Renv[0] ^ M[2] Renv[1] ^ W[1] Lenv[2] ^ Lnorm[0] Renv[2] ^ Rnorm[0] result = Lenv @ Lnorm @ M @ W @ Renv @ Rnorm return -1j * result.tensor.reshape(y.shape) def backward(t, y): C = Node(y.reshape(self.center_matrices[site].tensor.shape)) Lenv = self.left_envs[site + 1] Lnorm = self.left_norms[site + 1] Renv = self.right_envs[site] Rnorm = self.right_norms[site] Lenv[0] ^ C[0] Renv[0] ^ C[1] Lenv[1] ^ Renv[1] Lenv[2] ^ Lnorm[0] Renv[2] ^ Rnorm[0] result = Lenv @ Lnorm @ C @ Renv @ Rnorm return 1j * result.tensor.reshape(y.shape) if proceed == Evolve.FORWARD: y0 = self._mps.nodes[site].tensor.reshape(-1).astype(complex) result = solve_ivp(forward, t_span, y0) elif proceed == Evolve.BACKWARD: y0 = self.center_matrices[site].tensor.reshape(-1).astype(complex) result = solve_ivp(backward, t_span, y0) return result.y[:, -1]
[docs] def sweep(self, iterator, t_span): direction = 1 if iterator[0] < iterator[-1] else -1 for site in iterator: theta = self._unit_solver(Evolve.FORWARD, t_span, site) if direction == 1: theta = theta.reshape(self.d * self.mps_shape(site)[0], -1) elif direction == -1: theta = theta.reshape(-1, self.d * self.mps_shape(site)[2]) q, r = qr(theta, cutoff=self.mps_shape(site)[1 + direction]) if direction == 1: self._mps.nodes[site] = Node(q.reshape(self.mps_shape(site))) self.center_matrices[site] = Node(r) self._update_left_env(site + 1) self._update_left_norm(site + 1) if site < self.N - 1: C = Node(self._unit_solver(Evolve.BACKWARD, t_span, site).reshape(r.shape)) Mp = self._mps.nodes[site + 1] C[1] ^ Mp[0] self._mps.nodes[site + 1] = C @ Mp elif direction == -1: self._mps.nodes[site] = Node(r.reshape(self.mps_shape(site))) self.center_matrices[site - 1] = Node(q) self._update_right_env(site - 1) self._update_right_norm(site - 1) if site > 0: C = Node(self._unit_solver(Evolve.BACKWARD, t_span, site - 1).reshape(q.shape)) Mp = self._mps.nodes[site - 1] Mp[2] ^ C[0] # @TODO: measure something here to check the status of mps print(site) print(self._mps.check_orthonormality("l", self.N - 1)) print(self._mps.check_orthonormality("r", 0)) print(self._mps.check_canonical()) # logging.info("Sweeping to site [{}/{}], norm = ".format(site+1, self.N)) return
[docs] def evolve(self, t_span): pass
#