Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 195 additions & 0 deletions Examples/qspace_lanczos/compute_spectral_Xpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
"""
Q-Space Lanczos: phonon spectral function at the X point of SnTe
================================================================

This example computes the anharmonic phonon spectral function for SnTe
at the X point (zone boundary) using the Q-space Lanczos algorithm.

The Q-space Lanczos exploits Bloch momentum conservation to drastically
reduce the size of the two-phonon sector, giving a speedup proportional
to the number of unit cells in the supercell (N_cell = 8 for this 2x2x2
supercell).

Requirements:
- Julia with SparseArrays package
- spglib
- The SnTe ensemble from tests/test_julia/data/

Usage:
python compute_spectral_Xpoint.py

# Or with MPI parallelism:
mpirun -np 4 python compute_spectral_Xpoint.py
"""
from __future__ import print_function

import numpy as np
import os, sys

import cellconstructor as CC
import cellconstructor.Phonons
import cellconstructor.Units

import sscha, sscha.Ensemble
import tdscha.QSpaceLanczos as QL

from tdscha.Parallel import pprint as print

# ========================
# Parameters
# ========================
# Path to the SnTe ensemble data (2x2x2 supercell, 2 atoms/unit cell)
DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'..', '..', 'tests', 'test_julia', 'data')
NQIRR = 3 # Number of irreducible q-points in the dynamical matrix
TEMPERATURE = 250 # Temperature in Kelvin
N_STEPS = 50 # Number of Lanczos steps (increase for production)
SAVE_DIR = "output" # Directory for checkpoints and results


def main():
# ========================
# 1. Load ensemble
# ========================
dyn = CC.Phonons.Phonons(os.path.join(DATA_DIR, "dyn_gen_pop1_"), NQIRR)
ens = sscha.Ensemble.Ensemble(dyn, TEMPERATURE)
ens.load_bin(DATA_DIR, 1)

# ========================
# 2. Create Q-space Lanczos
# ========================
qlanc = QL.QSpaceLanczos(ens)
qlanc.ignore_v3 = False # Include 3-phonon interactions
qlanc.ignore_v4 = False # Include 4-phonon interactions
qlanc.init(use_symmetries=True)

# ========================
# 3. Inspect q-points and pick the X point
# ========================
print("Available q-points:")
for iq, q in enumerate(qlanc.q_points):
freqs = qlanc.w_q[:, iq] * CC.Units.RY_TO_CM
print(" iq={}: q = ({:8.5f}, {:8.5f}, {:8.5f}) "
"freqs = {} cm-1".format(iq, q[0], q[1], q[2],
np.array2string(freqs, precision=1, separator=', ')))

# The zone-boundary X point in a 2x2x2 FCC supercell is at iq=5,6,7
# (equivalent by cubic symmetry). We pick iq=5.
iq_pert = 5
print()
print("Selected q-point: iq={}".format(iq_pert))
print(" q = {}".format(qlanc.q_points[iq_pert]))
print()

# ========================
# 4. Run Lanczos for each band at the X point
# ========================
n_bands = qlanc.n_bands # 6 bands for SnTe (3 * 2 atoms)

for band in range(n_bands):
freq = qlanc.w_q[band, iq_pert] * CC.Units.RY_TO_CM

# Skip acoustic modes (zero frequency at Gamma only; at X all modes
# have finite frequency, but we check anyway for safety)
if qlanc.w_q[band, iq_pert] < 1e-6:
print("Skipping acoustic band {} (freq = {:.2f} cm-1)".format(
band, freq))
continue

print("=" * 50)
print("Band {}: {:.2f} cm-1".format(band, freq))
print("=" * 50)

qlanc.prepare_mode_q(iq_pert, band)
qlanc.run_FT(N_STEPS, save_each=10, save_dir=SAVE_DIR,
prefix="Xpoint_band{}".format(band), verbose=True)
qlanc.save_status(os.path.join(SAVE_DIR,
"Xpoint_band{}_final.npz".format(band)))

# ========================
# 5. Plot the total spectral function at X
# ========================
print()
print("=" * 50)
print("Computing total spectral function at X point")
print("=" * 50)

# Frequency grid (cm-1)
w_cm = np.linspace(0, 200, 500)
w_ry = w_cm / CC.Units.RY_TO_CM
smearing = 3.0 / CC.Units.RY_TO_CM # 3 cm-1 broadening

total_spectral = np.zeros_like(w_cm)

for band in range(n_bands):
if qlanc.w_q[band, iq_pert] < 1e-6:
continue

result_file = os.path.join(SAVE_DIR,
"Xpoint_band{}_final.npz".format(band))
if not os.path.exists(result_file):
print(" Band {} result not found, skipping".format(band))
continue

# Load and compute Green function
from tdscha.DynamicalLanczos import Lanczos
lanc_tmp = Lanczos()
lanc_tmp.load_status(result_file)

gf = lanc_tmp.get_green_function_continued_fraction(
w_ry, smearing=smearing, use_terminator=True)
spectral = -np.imag(gf)
total_spectral += spectral

# Print the renormalized frequency from G(0)
gf0 = lanc_tmp.get_green_function_continued_fraction(
np.array([0.0]), smearing=smearing, use_terminator=True)
w2 = 1.0 / np.real(gf0[0])
w_ren = np.sign(w2) * np.sqrt(np.abs(w2)) * CC.Units.RY_TO_CM
print(" Band {}: harmonic = {:.2f} cm-1, "
"renormalized = {:.2f} cm-1".format(
band,
qlanc.w_q[band, iq_pert] * CC.Units.RY_TO_CM,
w_ren))

# Save the spectrum to a text file
output_file = os.path.join(SAVE_DIR, "spectral_Xpoint.dat")
np.savetxt(output_file,
np.column_stack([w_cm, total_spectral]),
header="omega(cm-1) spectral_function(arb.units)",
fmt="%.6f")
print()
print("Spectral function saved to {}".format(output_file))

# Plot if matplotlib available
try:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(w_cm, total_spectral, 'b-', linewidth=1.5)

# Mark harmonic frequencies
for band in range(n_bands):
freq = qlanc.w_q[band, iq_pert] * CC.Units.RY_TO_CM
if freq > 1e-3:
ax.axvline(freq, color='r', linestyle='--', alpha=0.5,
linewidth=0.8)

ax.set_xlabel("Frequency (cm$^{-1}$)")
ax.set_ylabel("Spectral function (arb. units)")
ax.set_title("SnTe phonon spectral function at X point (T = {} K)".format(
TEMPERATURE))
ax.set_xlim(0, 200)
ax.legend(["Anharmonic (TD-SCHA)", "Harmonic frequencies"],
loc="upper right")
fig.tight_layout()
fig.savefig(os.path.join(SAVE_DIR, "spectral_Xpoint.pdf"))
print("Plot saved to {}/spectral_Xpoint.pdf".format(SAVE_DIR))
except ImportError:
print("matplotlib not available, skipping plot")


if __name__ == "__main__":
main()
23 changes: 17 additions & 6 deletions Modules/DynamicalLanczos.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def __init__(self, ensemble = None, mode = None, unwrap_symmetries = False, sele
self.gamma_only = False
self.trans_projector = None # (n_modes, n_modes) matrix P = (1/n_cells) Σ_R T_R^mode
self.trans_operators = None # list of (n_modes, n_modes) T_R^mode matrices
self.trans_cart_perms = None # list of Cartesian permutation index arrays for fast projection

# Set to True if we want to use the Wigner equations
self.use_wigner = use_wigner
Expand Down Expand Up @@ -764,15 +765,21 @@ def prepare_symmetrization(self, no_sym = False, verbose = True, symmetries = No
len(pg_symmetries), n_total_syms, n_total_syms // max(len(pg_symmetries), 1)))

# Build translation operators in mode space: T_R^mode = pols^T @ P_R @ pols
# Also store Cartesian permutation index arrays for fast projection
self.trans_operators = []
self.trans_cart_perms = []
nat_sc = super_structure.N_atoms
for t_sym in translations:
irt = CC.symmetries.GetIRT(super_structure, t_sym)
# Build Cartesian permutation matrix P_R (3*nat_sc x 3*nat_sc)
P_R = np.zeros((3 * nat_sc, 3 * nat_sc), dtype=np.double)
# Build inverse permutation index array for fast O(n^2) projection
inv_perm = np.zeros(3 * nat_sc, dtype=np.intp)
for i_at in range(nat_sc):
j_at = irt[i_at]
P_R[3*j_at:3*j_at+3, 3*i_at:3*i_at+3] = np.eye(3)
inv_perm[3*j_at:3*j_at+3] = [3*i_at, 3*i_at+1, 3*i_at+2]
self.trans_cart_perms.append(inv_perm)
# Project into mode space
T_mode = self.pols.T @ P_R @ self.pols # (n_modes, n_modes)
self.trans_operators.append(T_mode)
Expand Down Expand Up @@ -3171,12 +3178,16 @@ def get_combined_proc(start_end):
# Project f_pert_av: P_trans @ f
f_pert_av = self.trans_projector @ f_pert_av

# Project d2v_pert_av: (1/n_cells) Σ_R T_R @ d2v @ T_R^T
n_cells = len(self.trans_operators)
d2v_proj = np.zeros_like(d2v_pert_av)
for T_R in self.trans_operators:
d2v_proj += T_R @ d2v_pert_av @ T_R.T
d2v_pert_av = d2v_proj / n_cells
# Project d2v_pert_av using Cartesian-space permutations (O(n^2) per translation)
# Math: (1/N) Σ_R T_R @ d2v @ T_R^T = pols^T @ [(1/N) Σ_R P_R @ (pols @ d2v @ pols^T) @ P_R^T] @ pols
# where P_R @ M @ P_R^T is just row/column permutation (fancy indexing)
n_cells = len(self.trans_cart_perms)
d2v_cart = self.pols @ d2v_pert_av @ self.pols.T
d2v_cart_avg = np.zeros_like(d2v_cart)
for inv_perm in self.trans_cart_perms:
d2v_cart_avg += d2v_cart[np.ix_(inv_perm, inv_perm)]
d2v_cart_avg /= n_cells
d2v_pert_av = self.pols.T @ d2v_cart_avg @ self.pols
_t_trans_end = time.time()
if self.verbose:
print("Time for translational projection: {:.6f} s".format(_t_trans_end - _t_trans_start))
Expand Down
Loading
Loading