diff --git a/src/solver/core/kktsolvers/direct/quasidef/ldlsolvers/config.rs b/src/solver/core/kktsolvers/direct/quasidef/ldlsolvers/config.rs index f5140416..5c5d9d76 100644 --- a/src/solver/core/kktsolvers/direct/quasidef/ldlsolvers/config.rs +++ b/src/solver/core/kktsolvers/direct/quasidef/ldlsolvers/config.rs @@ -13,65 +13,24 @@ use crate::{ }, }; -// The Julia version implements the mapping from user setting -// to LDL solver implementation using dynamic dispatch on -// ::Val{:qdldl} types of arguments. There is no equivalent -// in Rust, so we get these big match statements. - type LDLConstructor = fn(&CscMatrix, &[i8], &CoreSettings, Option>) -> BoxedDirectLDLSolver; -// Some solvers only support 64 bit variants, which presents -// a problem since most of the solver code is generic over FloatT -// and trait specialization is not avaiable in Rust yet. Hence -// this janky trait - pub trait LDLConfiguration: FloatT { fn get_ldlsolver_config( settings: &CoreSettings, - ) -> (MatrixTriangle, LDLConstructor) { - // default is to use the generic form, ignoring - // types that support f64 only - Self::get_ldlsolver_config_default(settings) - } - - // The default configurator for generic T - fn get_ldlsolver_config_default( - settings: &CoreSettings, - ) -> (MatrixTriangle, LDLConstructor) { - let ldlptr: LDLConstructor; - let kktshape: MatrixTriangle; - let case = settings.direct_solve_method.as_str(); - - match case { - "auto" => { - kktshape = AutoDirectLDLSolver::::required_matrix_shape(); - ldlptr = |M, D, S, P| AutoDirectLDLSolver::new(M, D, S, P); - } - "qdldl" => { - kktshape = QDLDLDirectLDLSolver::::required_matrix_shape(); - ldlptr = |M, D, S, P| Box::new(QDLDLDirectLDLSolver::new(M, D, S, P)); - } - #[cfg(feature = "faer-sparse")] - "faer" => { - kktshape = FaerDirectLDLSolver::::required_matrix_shape(); - ldlptr = |M, D, S, P| Box::new(FaerDirectLDLSolver::new(M, D, S, P)); - } - _ => { - panic!("Unrecognized LDL solver type: \"{}\"", case); - } - } - (kktshape, ldlptr) - } + ) -> (MatrixTriangle, LDLConstructor); } -// This cursed section of code exists because trait specialisation -// does not yet exist in rust. We want get_ldlsolver_config to -// construct pardiso solvers only when FloatT = f64, and to return -// an error for pardiso options for all other FloatT. +// The Julia version implements the mapping from user settings to LDL solver +// implementation using dynamic dispatch on ::Val{:qdldl} types of arguments. +// There is no equivalent in Rust, so we get these big match statemen + +// The implementation first handles the f64-specific solvers (MKL, Panua Pardiso). +// Since trait specialization doesn't exist in Rust, we use runtime TypeId +// checks within a blanket impl to conditionally enable these solvers only for f64. impl LDLConfiguration for T { - // fn get_ldlsolver_config(settings: &CoreSettings) -> (MatrixTriangle, LDLConstructor) { let ldlptr: LDLConstructor; let kktshape: MatrixTriangle; @@ -115,8 +74,23 @@ impl LDLConfiguration for T { ); } } - _ => (kktshape, ldlptr) = Self::get_ldlsolver_config_default(settings), - } + "auto" => { + kktshape = AutoDirectLDLSolver::::required_matrix_shape(); + ldlptr = |M, D, S, P| AutoDirectLDLSolver::new(M, D, S, P); + } + "qdldl" => { + kktshape = QDLDLDirectLDLSolver::::required_matrix_shape(); + ldlptr = |M, D, S, P| Box::new(QDLDLDirectLDLSolver::new(M, D, S, P)); + } + #[cfg(feature = "faer-sparse")] + "faer" => { + kktshape = FaerDirectLDLSolver::::required_matrix_shape(); + ldlptr = |M, D, S, P| Box::new(FaerDirectLDLSolver::new(M, D, S, P)); + } + _ => { + panic!("Unrecognized LDL solver type: \"{}\"", case); + } + }; (kktshape, ldlptr) } }