diff --git a/Cargo.lock b/Cargo.lock index 72f33c295..017a6afa6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -48,6 +48,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "alphanumeric-sort" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81149050d254e2b758c80dcf55949e5c45482e0c9cb3670b1c4b48eb51791f8e" + [[package]] name = "approx" version = "0.5.1" @@ -253,6 +259,7 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" name = "circ" version = "0.1.0" dependencies = [ + "alphanumeric-sort", "approx", "bellman", "bincode", @@ -377,7 +384,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f86c5252b7b745adc0bc8a724171018dd2fc1e63629f7f8ec2f28a66b0d6ed7" dependencies = [ "coin_cbc_sys", - "lazy_static", ] [[package]] @@ -779,9 +785,8 @@ dependencies = [ [[package]] name = "good_lp" -version = "1.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f3e07a962fcc2fdcf91ad2828974e8e77d7f8e9f4dacab3f45cfa5e2d609794" +version = "1.3.2" +source = "git+https://github.com/Clive2312/good_lp.git#11f60cb1087c0bfc7e2db36cf520db53ee775aef" dependencies = [ "coin_cbc", "fnv", diff --git a/Cargo.toml b/Cargo.toml index f5d9f578e..fc4dcc9f0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,7 @@ bellman = { git = "https://github.com/alex-ozdemir/bellman.git", branch = "mirag rug-polynomial = { version = "0.2.5", optional = true } ff = { version = "0.12", optional = true } fxhash = "0.2" -good_lp = { version = "1.1", features = ["lp-solvers", "coin_cbc"], default-features = false, optional = true } +good_lp = {git = "https://github.com/Clive2312/good_lp.git", features = ["lp-solvers", "coin_cbc"], optional = true } group = { version = "0.12", optional = true } lp-solvers = { version = "0.0.4", optional = true } serde_json = "1.0" @@ -50,6 +50,7 @@ curve25519-dalek = {version = "3.2.0", features = ["serde"], optional = true} paste = "1.0" im = "15" once_cell = "1" +alphanumeric-sort = "1.5.1" [dev-dependencies] quickcheck = "1" diff --git a/examples/C/mpc/benchmarks/kmeans/2pc_kmeans_.c b/examples/C/mpc/benchmarks/kmeans/2pc_kmeans_.c new file mode 100644 index 000000000..7a5183e42 --- /dev/null +++ b/examples/C/mpc/benchmarks/kmeans/2pc_kmeans_.c @@ -0,0 +1,205 @@ +#define D 2 // Dimension (fix) +#define NA 100 // Number of data points from Party A +#define NB 100 // Number of data points from Party B +#define NC 5 // Number of clusters +#define PRECISION 4 + +#define LEN (NA+NB) +#define LEN_OUTER 10 +#define LEN_INNER (LEN/LEN_OUTER) + + +typedef int coord_t; + +struct input_a{ + int dataA[D*NA]; +}; + +struct input_b { + int dataB[D*NA]; +}; + +typedef struct +{ + coord_t cluster[D*NC]; +} Output; + + +int dist2(int x1, int y1, int x2, int y2) { + return (x1-x2) * (x1-x2) + (y1 - y2) * (y1 - y2); +} + +// Computes minimum in a tree based fashion and associated with aux element +int min_with_aux(int *data, int *aux, int len, int stride) { + // if(stride > len) { + // return aux[0]; + // } else { + // for(int i = 0; i + stride < len; i+=stride<<1) { + // if(data[i+stride] < data[i]) { + // data[i] = data[i+stride]; + // aux[i] = aux[i+stride]; + // }std::cout + // } + // return min_with_aux(data, aux, len, stride<<1); + // } + int min = data[0]; + int res = 0; + for(int i = 1; i < NC; i++){ + if(data[i] < min) { + min = data[i]; + res = i; + } + } + return res; +} + + +#define ADD2(X,A) A[X] + A[X+1] +#define ADD4(X,A) ADD2(X,A) + ADD2(X+2,A) +#define ADD8(X,A) ADD4(X,A) + ADD4(X+4,A) +#define ADD10(X,A) ADD8(X,A) + ADD2(X+8,A) + +/** + * Iteration loop unrolled and depth minimized by computing minimum over tree structure + */ +void iteration_unrolled_inner_depth(int *data_inner, int *cluster, int *OUTPUT_cluster, int *OUTPUT_count) { + int i,c; + int dist[NC]; + int pos[NC]; + int bestMap_inner[LEN_INNER]; + + for(c = 0; c < NC; c++) { + OUTPUT_cluster[c*D] = 0; + OUTPUT_cluster[c*D+1] = 0; + OUTPUT_count[c] = 0; + } + + // Compute nearest clusters for Data item i + for(i = 0; i < LEN_INNER; i++) { + int dx = data_inner[i*D]; + int dy = data_inner[i*D+1]; + for(c = 0; c < NC; c++) { + pos[c]=c; + dist[c] = dist2(cluster[D*c], cluster[D*c+1], dx, dy); + } + bestMap_inner[i] = min_with_aux(dist, pos, NC, 1); + int cc = bestMap_inner[i]; + OUTPUT_cluster[cc*D] += data_inner[i*D]; + OUTPUT_cluster[cc*D+1] += data_inner[i*D+1]; + OUTPUT_count[cc]++; + } +} + +/** + * Iteration unrolled outer loop + */ +void iteration_unrolled_outer(int *data, int *cluster, int *OUTPUT_cluster) { + // int j, c; + int j,c; + // int count[NC]; + int count[NC]; + + // Set Outer result + for(c = 0; c < NC; c++) { + OUTPUT_cluster[c*D] = 0; + OUTPUT_cluster[c*D+1] = 0; + count[c] = 0; + } + + // TODO: loop_clusterD1 -- 2d arrays + int loop_clusterD1[NC][LEN_OUTER]; + int loop_clusterD2[NC][LEN_OUTER]; + // int loop_count[NC][LEN_OUTER]; + int loop_count[NC][LEN_OUTER]; + + + // Compute decomposition + for(j = 0; j < LEN_OUTER; j++) { + // Copy data, fasthack for scalability + int data_offset = j*LEN_INNER*D; + int data_inner[LEN_INNER*D]; + + // memcpy(data_inner, data+data_offset, LEN_INNER*D*sizeof(int)); + for (int i = 0; i < LEN_INNER * D; i++) + { + data_inner[i] = data[i + data_offset]; + } + + int cluster_inner[NC*D]; + // int count_inner[NC]; + int count_inner[NC]; + + iteration_unrolled_inner_depth(data_inner, cluster, cluster_inner, count_inner); + + // Depth: num_cluster Addition + for(c = 0; c < NC; c++) { + loop_clusterD1[c][j] = cluster_inner[c*D]; + loop_clusterD2[c][j] = cluster_inner[c*D+1]; + loop_count[c][j] = count_inner[c]; + } + } + + for(c = 0; c < NC; c++) { + OUTPUT_cluster[c*D] = ADD10(0,loop_clusterD1[c]); + OUTPUT_cluster[c*D+1] = ADD10(0,loop_clusterD2[c]); + count[c] = ADD10(0, loop_count[c]); + } + + // Recompute cluster Pos + // Compute mean + for(c = 0; c < NC; c++) { + if(count[c] > 0) { + OUTPUT_cluster[c*D] /= count[c]; + OUTPUT_cluster[c*D+1] /= count[c]; + } + } +} + +void kmeans(int *data, int *OUTPUT_res) { + // int c, p; + int c, p; + int cluster[NC*D]; + + // Assign random start cluster from data + for(c = 0; c < NC; c++) { + cluster[c*D] = data[((c+3)%LEN)*D]; + cluster[c*D+1] = data[((c+3)%LEN)*D+1]; + } + + for (p = 0; p < PRECISION; p++) { + int new_cluster[NC*D]; + iteration_unrolled_outer(data, cluster, new_cluster); + // iteration(data, cluster, new_cluster, len, num_cluster); + + // We need to copy inputs to outputs + for( c = 0; c < NC*D; c++) { + cluster[c] = new_cluster[c]; + } + } + + for(c = 0; c < NC; c++) { + OUTPUT_res[c*D] = cluster[c*D]; + OUTPUT_res[c*D+1] = cluster[c*D+1]; + } +} + + +Output main(__attribute__((private(0))) int a[200], __attribute__((private(1))) int b[200]) +{ + // init data + int data[LEN * D]; + for (int i = 0; i < D * NA; i++) + { + data[i] = a[i]; + } + int offset = D * NA; + for (int i = 0; i < D * NB; i++) + { + data[i + offset] = b[i]; + } + + Output output; + kmeans(data, output.cluster); + + return output; +} \ No newline at end of file diff --git a/examples/C/mpc/benchmarks/kmeans/2pc_kmeans_og.c b/examples/C/mpc/benchmarks/kmeans/2pc_kmeans_og.c index 92d6ba58c..d749d4f51 100644 --- a/examples/C/mpc/benchmarks/kmeans/2pc_kmeans_og.c +++ b/examples/C/mpc/benchmarks/kmeans/2pc_kmeans_og.c @@ -196,8 +196,8 @@ int main(__attribute__((private(0))) int a[20], __attribute__((private(1))) int { data[i + offset] = b[i]; } - - struct output output; + struct output output; + kmeans(data, output.cluster); int sum = 0; diff --git a/examples/circ.rs b/examples/circ.rs index 9a24a9095..d84aa81b2 100644 --- a/examples/circ.rs +++ b/examples/circ.rs @@ -112,6 +112,16 @@ enum Backend { cost_model: String, #[arg(long, default_value = "lp", name = "selection_scheme")] selection_scheme: String, + #[arg(long, default_value = "4000", name = "partition_size")] + partition_size: usize, + #[arg(long, default_value = "4", name = "mutation_level")] + mutation_level: usize, + #[arg(long, default_value = "1", name = "mutation_step_size")] + mutation_step_size: usize, + #[arg(long, default_value = "1", name = "hyper")] + hyper: usize, + #[arg(long, default_value = "3", name = "imbalance")] + imbalance: usize, }, } @@ -238,17 +248,17 @@ fn main() { opt( cs, vec![ - Opt::ScalarizeVars, + // Opt::ScalarizeVars, Opt::Flatten, Opt::Sha, Opt::ConstantFold(Box::new(ignore.clone())), Opt::Flatten, // Function calls return tuples - Opt::Tuple, - Opt::Obliv, + // Opt::Tuple, + // Opt::Obliv, // The obliv elim pass produces more tuples, that must be eliminated - Opt::Tuple, - Opt::LinearScan, + // Opt::Tuple, + // Opt::LinearScan, // The linear scan pass produces more tuples, that must be eliminated Opt::Tuple, Opt::ConstantFold(Box::new(ignore)), @@ -364,6 +374,11 @@ fn main() { Backend::Mpc { cost_model, selection_scheme, + partition_size, + mutation_level, + mutation_step_size, + hyper, + imbalance, } => { println!("Converting to aby"); let lang_str = match language { @@ -373,7 +388,18 @@ fn main() { }; println!("Cost model: {cost_model}"); println!("Selection scheme: {selection_scheme}"); - to_aby(cs, &path_buf, &lang_str, &cost_model, &selection_scheme); + to_aby( + cs, + &path_buf, + &lang_str, + &cost_model, + &selection_scheme, + &partition_size, + &mutation_level, + &mutation_step_size, + &hyper, + &imbalance, + ); } #[cfg(not(feature = "aby"))] Backend::Mpc { .. } => { diff --git a/examples/opa_bench.rs b/examples/opa_bench.rs index ffe07b5b3..b5fa26dd2 100644 --- a/examples/opa_bench.rs +++ b/examples/opa_bench.rs @@ -29,6 +29,6 @@ fn main() { outputs: vec![term![Op::Eq; t, v]], ..Default::default() }; - let _assignment = ilp::assign(&cs, "hycc"); + let _assignment = ilp::assign(&cs.to_cs(), "hycc"); //dbg!(&assignment); } diff --git a/scripts/aby_tests/c_test_aby.py b/scripts/aby_tests/c_test_aby.py index 2cbfd1345..398cc0b88 100755 --- a/scripts/aby_tests/c_test_aby.py +++ b/scripts/aby_tests/c_test_aby.py @@ -16,10 +16,10 @@ ite_tests + \ shift_tests + \ div_tests + \ - mod_tests + \ - struct_tests + \ - ptr_tests + \ - c_misc_tests + mod_tests + # struct_tests + \ + # ptr_tests + \ + # c_misc_tests # array_tests + \ # c_array_tests + \ # matrix_tests + \ diff --git a/scripts/build_kahypar.zsh b/scripts/build_kahypar.zsh index 99717ec44..c38273c8f 100755 --- a/scripts/build_kahypar.zsh +++ b/scripts/build_kahypar.zsh @@ -2,7 +2,7 @@ if [[ ! -z ${KAHYPAR_SOURCE} ]]; then cd ${KAHYPAR_SOURCE} - mkdir build && cd build + mkdir -p build && cd build cmake .. -DCMAKE_BUILD_TYPE=RELEASE make else diff --git a/scripts/build_mpc_c_test.zsh b/scripts/build_mpc_c_test.zsh index 2d79987f9..3a1d67cf3 100755 --- a/scripts/build_mpc_c_test.zsh +++ b/scripts/build_mpc_c_test.zsh @@ -8,6 +8,8 @@ disable -r time BIN=./target/release/examples/circ export CARGO_MANIFEST_DIR=$(pwd) +export KAHIP_SOURCE=../KaHIP/ +export KAHYPAR_SOURCE=../kahypar/ case "$OSTYPE" in darwin*) @@ -18,65 +20,82 @@ case "$OSTYPE" in ;; esac -function mpc_test { +function mpc_test_glp { + parties=$1 + cpath=$2 + RUST_BACKTRACE=1 measure_time $BIN --parties $parties $cpath mpc --cost-model "hycc" --selection-scheme "glp" +} + +function mpc_test_tlp { parties=$1 cpath=$2 - RUST_BACKTRACE=1 measure_time $BIN --parties $parties $cpath mpc --cost-model "hycc" + RUST_BACKTRACE=1 measure_time $BIN --parties $parties $cpath mpc --cost-model "hycc" --selection-scheme "tlp" } -function mpc_test_2 { +function mpc_test_css { parties=$1 cpath=$2 - RUST_BACKTRACE=1 measure_time $BIN --parties $parties $cpath mpc --cost-model "hycc" --selection-scheme "a+b" + RUST_BACKTRACE=1 measure_time $BIN --parties $parties $cpath mpc --cost-model "hycc" --selection-scheme "css" +} + + +# function mpc_test_2 { +# parties=$1 +# cpath=$2 +# RUST_BACKTRACE=1 measure_time $BIN --parties $parties $cpath mpc --cost-model "hycc" --selection-scheme "a+b" +# } + +function mpc_test { + mpc_test_tlp $1 $2 } # mpc_test_2 2 ./examples/C/mpc/playground.c -# build mpc arithmetic tests -mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_add.c -mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_sub.c -mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_mult.c -mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_mult_add_pub.c -mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_mod.c -mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_add_unsigned.c +# # build mpc arithmetic tests +# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_add.c +# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_sub.c +# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_mult.c +# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_mult_add_pub.c +# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_mod.c +# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_add_unsigned.c -mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_equals.c -mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_greater_than.c -mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_greater_equals.c -mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_less_than.c -mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_less_equals.c +# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_equals.c +# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_greater_than.c +# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_greater_equals.c +# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_less_than.c +# mpc_test 2 ./examples/C/mpc/unit_tests/arithmetic_tests/2pc_int_less_equals.c -# # build nary arithmetic tests -mpc_test 2 ./examples/C/mpc/unit_tests/nary_arithmetic_tests/2pc_nary_arithmetic_add.c +# # # build nary arithmetic tests +# mpc_test 2 ./examples/C/mpc/unit_tests/nary_arithmetic_tests/2pc_nary_arithmetic_add.c -# # build bitwise tests -mpc_test 2 ./examples/C/mpc/unit_tests/bitwise_tests/2pc_bitwise_and.c -mpc_test 2 ./examples/C/mpc/unit_tests/bitwise_tests/2pc_bitwise_or.c -mpc_test 2 ./examples/C/mpc/unit_tests/bitwise_tests/2pc_bitwise_xor.c +# # # build bitwise tests +# mpc_test 2 ./examples/C/mpc/unit_tests/bitwise_tests/2pc_bitwise_and.c +# mpc_test 2 ./examples/C/mpc/unit_tests/bitwise_tests/2pc_bitwise_or.c +# mpc_test 2 ./examples/C/mpc/unit_tests/bitwise_tests/2pc_bitwise_xor.c -# # build boolean tests -mpc_test 2 ./examples/C/mpc/unit_tests/boolean_tests/2pc_boolean_and.c -mpc_test 2 ./examples/C/mpc/unit_tests/boolean_tests/2pc_boolean_or.c -mpc_test 2 ./examples/C/mpc/unit_tests/boolean_tests/2pc_boolean_equals.c +# # # build boolean tests +# mpc_test 2 ./examples/C/mpc/unit_tests/boolean_tests/2pc_boolean_and.c +# mpc_test 2 ./examples/C/mpc/unit_tests/boolean_tests/2pc_boolean_or.c +# mpc_test 2 ./examples/C/mpc/unit_tests/boolean_tests/2pc_boolean_equals.c -# # build nary boolean tests -mpc_test 2 ./examples/C/mpc/unit_tests/nary_boolean_tests/2pc_nary_boolean_and.c +# # # build nary boolean tests +# mpc_test 2 ./examples/C/mpc/unit_tests/nary_boolean_tests/2pc_nary_boolean_and.c -# # build const tests -mpc_test 2 ./examples/C/mpc/unit_tests/const_tests/2pc_const_arith.c -mpc_test 2 ./examples/C/mpc/unit_tests/const_tests/2pc_const_bool.c +# # # build const tests +# mpc_test 2 ./examples/C/mpc/unit_tests/const_tests/2pc_const_arith.c +# mpc_test 2 ./examples/C/mpc/unit_tests/const_tests/2pc_const_bool.c -# build if statement tests -mpc_test 2 ./examples/C/mpc/unit_tests/ite_tests/2pc_ite_ret_bool.c -mpc_test 2 ./examples/C/mpc/unit_tests/ite_tests/2pc_ite_ret_int.c -mpc_test 2 ./examples/C/mpc/unit_tests/ite_tests/2pc_ite_only_if.c +# # build if statement tests +# mpc_test 2 ./examples/C/mpc/unit_tests/ite_tests/2pc_ite_ret_bool.c +# mpc_test 2 ./examples/C/mpc/unit_tests/ite_tests/2pc_ite_ret_int.c +# mpc_test 2 ./examples/C/mpc/unit_tests/ite_tests/2pc_ite_only_if.c -# build shift tests -mpc_test 2 ./examples/C/mpc/unit_tests/shift_tests/2pc_lhs.c -mpc_test 2 ./examples/C/mpc/unit_tests/shift_tests/2pc_rhs.c +# # build shift tests +# mpc_test 2 ./examples/C/mpc/unit_tests/shift_tests/2pc_lhs.c +# mpc_test 2 ./examples/C/mpc/unit_tests/shift_tests/2pc_rhs.c -# build div tests -mpc_test 2 ./examples/C/mpc/unit_tests/div_tests/2pc_div.c +# # build div tests +# mpc_test 2 ./examples/C/mpc/unit_tests/div_tests/2pc_div.c # # build array tests # mpc_test 2 ./examples/C/mpc/unit_tests/array_tests/2pc_array_sum.c @@ -90,13 +109,13 @@ mpc_test 2 ./examples/C/mpc/unit_tests/div_tests/2pc_div.c # mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array_3.c # mpc_test 2 ./examples/C/mpc/unit_tests/c_array_tests/2pc_array_sum_c.c -# build function tests -mpc_test 2 ./examples/C/mpc/unit_tests/function_tests/2pc_function_add.c -mpc_test 2 ./examples/C/mpc/unit_tests/function_tests/function_arg_order.c +# # build function tests +# mpc_test 2 ./examples/C/mpc/unit_tests/function_tests/2pc_function_add.c +# mpc_test 2 ./examples/C/mpc/unit_tests/function_tests/function_arg_order.c -# build struct tests -mpc_test 2 ./examples/C/mpc/unit_tests/struct_tests/2pc_struct_add.c +# # build struct tests +# mpc_test 2 ./examples/C/mpc/unit_tests/struct_tests/2pc_struct_add.c # mpc_test 2 ./examples/C/mpc/unit_tests/struct_tests/2pc_struct_array_add.c # # build matrix tests @@ -104,29 +123,30 @@ mpc_test 2 ./examples/C/mpc/unit_tests/struct_tests/2pc_struct_add.c # mpc_test 2 ./examples/C/mpc/unit_tests/matrix_tests/2pc_matrix_assign_add.c # mpc_test 2 ./examples/C/mpc/unit_tests/matrix_tests/2pc_matrix_ptr_add.c -# build ptr tests -mpc_test 2 ./examples/C/mpc/unit_tests/ptr_tests/2pc_ptr_add.c +# # build ptr tests +# mpc_test 2 ./examples/C/mpc/unit_tests/ptr_tests/2pc_ptr_add.c -# build misc tests -mpc_test 2 ./examples/C/mpc/unit_tests/misc_tests/2pc_millionaires.c -mpc_test 2 ./examples/C/mpc/unit_tests/misc_tests/2pc_multi_var.c +# # build misc tests +# mpc_test 2 ./examples/C/mpc/unit_tests/misc_tests/2pc_millionaires.c +# mpc_test 2 ./examples/C/mpc/unit_tests/misc_tests/2pc_multi_var.c # build hycc benchmarks # mpc_test 2 ./examples/C/mpc/benchmarks/biomatch/2pc_biomatch.c # mpc_test 2 ./examples/C/mpc/benchmarks/kmeans/2pc_kmeans.c +mpc_test 2 ./examples/C/mpc/benchmarks/kmeans/2pc_kmeans_.c # mpc_test 2 ./examples/C/mpc/benchmarks/kmeans/2pc_kmeans_og.c -# mpc_test_2 2 ./examples/C/mpc/benchmarks/db/db_join.c -# mpc_test_2 2 ./examples/C/mpc/benchmarks/gauss/2pc_gauss.c -# mpc_test_2 2 ./examples/C/mpc/benchmarks/mnist/2pc_mnist.c - -# # ilp benchmarks -# # mpc_test 2 ./examples/C/mpc/ilp_benchmarks/2pc_ilp_bench_1.c -# # mpc_test 2 ./examples/C/mpc/ilp_benchmarks/2pc_ilp_bench_2.c -# # mpc_test 2 ./examples/C/mpc/ilp_benchmarks/2pc_ilp_bench_3.c -# # mpc_test 2 ./examples/C/mpc/ilp_benchmarks/2pc_ilp_bench_4.c -# # mpc_test 2 ./examples/C/mpc/ilp_benchmarks/2pc_ilp_bench_5.c -# # mpc_test 2 ./examples/C/mpc/ilp_benchmarks/2pc_ilp_bench_6.c -# # mpc_test 2 ./examples/C/mpc/ilp_benchmarks/2pc_ilp_bench_7.c -# # mpc_test 2 ./examples/C/mpc/ilp_benchmarks/2pc_ilp_bench_8.c -# # mpc_test 2 ./examples/C/mpc/ilp_benchmarks/2pc_ilp_bench_9.c -# # mpc_test 2 ./examples/C/mpc/ilp_benchmarks/2pc_ilp_bench.c +# mpc_test 2 ./examples/C/mpc/benchmarks/db/db_join.c +# mpc_test 2 ./examples/C/mpc/benchmarks/gauss/2pc_gauss.c +# mpc_test 2 ./examples/C/mpc/benchmarks/mnist/2pc_mnist.c + +# ilp benchmarks +# mpc_test 2 ./examples/C/mpc/ilp_benchmarks/2pc_ilp_bench_1.c +# mpc_test 2 ./examples/C/mpc/ilp_benchmarks/2pc_ilp_bench_2.c +# mpc_test 2 ./examples/C/mpc/ilp_benchmarks/2pc_ilp_bench_3.c +# mpc_test 2 ./examples/C/mpc/ilp_benchmarks/2pc_ilp_bench_4.c +# mpc_test 2 ./examples/C/mpc/ilp_benchmarks/2pc_ilp_bench_5.c +# mpc_test 2 ./examples/C/mpc/ilp_benchmarks/2pc_ilp_bench_6.c +# mpc_test 2 ./examples/C/mpc/ilp_benchmarks/2pc_ilp_bench_7.c +# mpc_test 2 ./examples/C/mpc/ilp_benchmarks/2pc_ilp_bench_8.c +# mpc_test 2 ./examples/C/mpc/ilp_benchmarks/2pc_ilp_bench_9.c +# mpc_test 2 ./examples/C/mpc/ilp_benchmarks/2pc_ilp_bench.c diff --git a/src/circify/mem.rs b/src/circify/mem.rs index 8d14b0315..278ac3250 100644 --- a/src/circify/mem.rs +++ b/src/circify/mem.rs @@ -116,6 +116,15 @@ impl MemManager { alloc.cur_term = val; } + /// Get the stored term in the allocation `id` + pub fn term(&self, id: AllocId) -> Term { + self.allocs + .get(&id) + .expect("Missing allocation") + .cur_term + .clone() + } + /// Is `offset` in bounds for the allocation `id`? pub fn in_bounds(&self, id: AllocId, offset: Term) -> Term { let alloc = self.allocs.get(&id).expect("Missing allocation"); diff --git a/src/front/c/mod.rs b/src/front/c/mod.rs index 84806a7a0..4c4645604 100644 --- a/src/front/c/mod.rs +++ b/src/front/c/mod.rs @@ -187,7 +187,7 @@ impl CGen { TypeSpecifier::Void => None, TypeSpecifier::Int => Some(Ty::Int(true, 32)), TypeSpecifier::Unsigned => Some(Ty::Int(false, 32)), - TypeSpecifier::Long => Some(Ty::Int(true, 32)), // TODO: not 32 bits + TypeSpecifier::Long => Some(Ty::Int(true, 64)), // TODO: not 32 bits TypeSpecifier::Bool => Some(Ty::Bool), TypeSpecifier::TypedefName(td) => { let name = &td.node.name; @@ -494,10 +494,10 @@ impl CGen { } } - fn gen_lval(&mut self, expr: &Expression) -> CLoc { - match &expr { + fn gen_lval(&mut self, expr: Node) -> CLoc { + match expr.node { Expression::Identifier(_) => { - let base_name = name_from_expr(expr); + let base_name = name_from_expr(&expr.node); CLoc::Var(Loc::local(base_name)) } Expression::BinaryOperator(ref node) => { @@ -505,22 +505,45 @@ impl CGen { match bin_op.operator.node { BinaryOperator::Index => { // get location - let loc = self.gen_lval(&bin_op.lhs.node); + let loc = self.gen_lval(*bin_op.lhs.clone()); - // get offset - let index = self.gen_index(expr); + // get base offset + let index = self.gen_index(&expr.node); let offset = self.index_offset(&index); - let idx = cterm(CTermData::Int(true, 32, offset)); + let idx = cterm(CTermData::Int(true, 32, offset.clone())); - if let Expression::BinaryOperator(_) = bin_op.lhs.node { - // Matrix case - let base = self.base_loc(loc); - CLoc::Idx(Box::new(base), idx) + if let Expression::BinaryOperator(op) = &bin_op.lhs.node { + if let BinaryOperator::Index = &op.node.operator.node { + // Matrix case + let base = self.base_loc(loc); + CLoc::Idx(Box::new(base), idx) + } else { + // Ptr Arithmetic case + let idx = if let CLoc::Idx(_, o) = &loc { + let new_offset = term![BV_ADD; o.term.term(self.circ.borrow().cir_ctx()), offset]; + cterm(CTermData::Int(true, 32, new_offset)) + } else { + idx + }; + let base = self.base_loc(loc); + CLoc::Idx(Box::new(base), idx) + } } else { CLoc::Idx(Box::new(loc), idx) } } - _ => unimplemented!("Invalid left hand value"), + BinaryOperator::Plus => { + // get location + let loc = self.gen_lval(*bin_op.lhs.clone()); + + // get offset + let offset = self.gen_expr(&bin_op.rhs.node); + + CLoc::Idx(Box::new(loc), offset) + } + _ => { + unimplemented!("Invalid left hand value") + } } } Expression::Member(node) => { @@ -528,13 +551,10 @@ impl CGen { operator: _operator, expression, identifier, - } = &node.node; + } = node.node; let base_name = name_from_expr(&expression.node); - let field_name = &identifier.node.name; - CLoc::Member( - Box::new(CLoc::Var(Loc::local(base_name))), - field_name.to_string(), - ) + let field_name = identifier.node.name; + CLoc::Member(Box::new(CLoc::Var(Loc::local(base_name))), field_name) } _ => unimplemented!("Invalid left hand value"), } @@ -732,15 +752,17 @@ impl CGen { let bin_op = &node.node; match bin_op.operator.node { BinaryOperator::Assign => { - let loc = self.gen_lval(&bin_op.lhs.node); + let loc = self.gen_lval(*bin_op.lhs.clone()); let val = self.gen_expr(&bin_op.rhs.node); self.gen_assign(loc, val) } - BinaryOperator::AssignPlus | BinaryOperator::AssignDivide => { + BinaryOperator::AssignPlus + | BinaryOperator::AssignDivide + | BinaryOperator::AssignMultiply => { let f = self.get_bin_op(&bin_op.operator.node); let i = self.gen_expr(&bin_op.lhs.node); let rhs = self.gen_expr(&bin_op.rhs.node); - let loc = self.gen_lval(&bin_op.lhs.node); + let loc = self.gen_lval(*bin_op.lhs.clone()); let val = f(i, rhs).unwrap(); self.gen_assign(loc, val) } @@ -755,7 +777,6 @@ impl CGen { let diff = sizes.len() - index.indices.len(); let new_sizes: Vec = sizes.clone().into_iter().take(diff).collect(); - let new_ty = Ty::Array(*size, new_sizes, Box::new(*t.clone())); Ok(cterm(CTermData::StackPtr(new_ty, offset, id))) @@ -802,7 +823,7 @@ impl CGen { let f = self.get_u_op(&u_op.operator.node); let i = self.gen_expr(&u_op.operand.node); let one = cterm(CTermData::Int(true, 32, bv_lit(1, 32))); - let loc = self.gen_lval(&u_op.operand.node); + let loc = self.gen_lval(*u_op.operand.clone()); let val = f(i, one).unwrap(); self.gen_assign(loc, val) } @@ -854,42 +875,49 @@ impl CGen { .get(&fname) .unwrap_or_else(|| panic!("No function '{}'", fname)) .clone(); - let ret_ty = f.ret_ty.clone(); - // Typecheck parameters and arguments - let arg_sorts = args - .iter() - .map(|e| e.term.type_().sort()) - .collect::>(); - let param_sorts = f.params.iter().map(|e| e.ty.sort()).collect::>(); - assert!(arg_sorts == param_sorts); - - // Create return type - // All function return types are Tuples in order to handle - // references and pointers - let mut ret_sorts: Vec = Vec::new(); - ret_sorts.push(ret_ty.unwrap().sort()); - for arg_sort in &arg_sorts { - if let Sort::Array(..) = arg_sort { - ret_sorts.push(arg_sort.clone()); - } - } - let ret_sort = Sort::Tuple(ret_sorts.into()); + let ret_sort = match &f.ret_ty { + None => None, + Some(ty) => Some(ty.sort()), + }; // Create ordered list of arguments based on argument names let metadata = self.circ_metadata(); - let arg_names = metadata.ordered_input_names(); - let mut args_map: FxHashMap = FxHashMap::default(); - for (name, arg) in arg_names.iter().zip(args.iter()) { - args_map.insert( - name.to_string(), - arg.term.term(self.circ.borrow().cir_ctx()), - ); + let mut args_map: FxHashMap = FxHashMap::default(); + let mut name_idx = 0; + for (arg, param) in args.iter().zip(f.params.iter()) { + args_map.insert(param.name.clone(), arg.clone()); + } + + // Create ordered call term + let mut ordered_arg_names: Vec = args_map.keys().cloned().collect(); + ordered_arg_names.sort(); + let ordered_args = ordered_arg_names + .iter() + .map(|name| args_map.get(name).expect("Argument not found: {}").clone()) + .collect::>(); + let ordered_arg_terms = ordered_args + .iter() + .map(|arg| arg.term.term(self.circ.borrow().cir_ctx()).clone()) + .collect::>(); + let ordered_arg_sorts = + ordered_arg_terms.iter().map(check).collect::>(); + let mut ret_sorts: Vec = Vec::new(); + if let Some(sort) = ret_sort { + ret_sorts.push(sort); } + for sort in &ordered_arg_sorts { + if let Sort::Array(..) = sort { + ret_sorts.push(sort.clone()); + } + } + let ret_sort = Sort::Tuple(ret_sorts.clone().into()); - let call_term = self - .circ_metadata() - .ordered_call_term(fname, args_map, ret_sort); + // Create call term + let call_term = term( + Op::Call(fname, ordered_arg_sorts.clone(), ret_sort), + ordered_arg_terms, + ); // Add function to queue if !self.function_cache.contains(call_term.op()) { @@ -898,15 +926,20 @@ impl CGen { } // Rewiring - for (i, arg_sort) in arg_sorts.iter().enumerate() { - if let Sort::Array(..) = arg_sort { - if let CTermData::Array(_, id) = args[i].term { + let mut rewire_idx = match f.ret_ty { + None => 0, + Some(_) => 1, + }; + for (i, sort) in ordered_arg_sorts.iter().enumerate() { + if let Sort::Array(..) = sort { + if let CTermData::Array(_, id) = ordered_args[i].term { self.circ_replace( id.unwrap(), - term![Op::Field(i); call_term.clone()], + term![Op::Field(rewire_idx); call_term.clone()], ); + rewire_idx += 1; } else { - unimplemented!("This should only be handling ptrs to arrays"); + panic!("This should only be handling ptrs to arrays"); } } } @@ -1123,9 +1156,7 @@ impl CGen { match ret { Some(expr) => { let ret = self.gen_expr(&expr.node); - let ret_ty = self.ret_ty_take(); - let new_ret = cast(ret_ty, ret); - let ret_res = self.circ_return_(Some(new_ret)); + let ret_res = self.circ_return_(Some(ret)); self.unwrap(ret_res); } None => { @@ -1176,7 +1207,7 @@ impl CGen { self.circ_enter_fn(f.name.to_owned(), f.ret_ty.clone()); for p in f.params.iter() { - let r = self.circ_declare_input(p.name.clone(), &p.ty, p.vis, None, false); + let r = self.circ_declare_input(p.name.clone(), &p.ty, p.vis, None, true); self.unwrap(r); } @@ -1258,21 +1289,19 @@ impl CGen { // Keep track of the names of arguments that are references let mut ret_names: Vec = Vec::new(); - // define input parameters + // Define input parameters assert!(arg_sorts.len() == f.params.len()); - for (i, param) in f.params.iter().enumerate() { + for (arg_sort, param) in arg_sorts.iter().zip(f.params.iter()) { let p_name = ¶m.name; - let p_sort = param.ty.sort(); - assert!(p_sort == arg_sorts[i]); let p_ty = match ¶m.ty { Ty::Ptr(_, t) => { - if let Sort::Array(_, _, len) = p_sort { - let dims = vec![len]; + if let Sort::Array(_, _, len) = arg_sort { + let dims: Vec = vec![*len]; // Add reference ret_names.push(p_name.clone()); - Ty::Array(len, dims, t.clone()) + Ty::Array(*len, dims, t.clone()) } else { - panic!("Ptr type does not match with Array sort: {}", p_sort) + panic!("Ptr type does not match with Array sort: {}", arg_sort) } } _ => param.ty.clone(), @@ -1286,7 +1315,7 @@ impl CGen { if let Some(returns) = self.circ_exit_fn_call(&ret_names) { let ret_terms = returns .into_iter() - .flat_map(|x| x.unwrap_term().term.terms(self.circ.borrow().cir_ctx())) + .map(|x| x.unwrap_term().term.term(self.circ.borrow().cir_ctx())) .collect::>(); let ret_term = term(Op::Tuple, ret_terms); assert!(check(&ret_term) == *rets); diff --git a/src/front/c/term.rs b/src/front/c/term.rs index e1904c93d..b2b095f39 100644 --- a/src/front/c/term.rs +++ b/src/front/c/term.rs @@ -30,35 +30,54 @@ impl CTermData { /// Get all IR terms inside this value, as a list. pub fn terms(&self, ctx: &CirCtx) -> Vec { let mut output: Vec = Vec::new(); - fn terms_tail(term_: &CTermData, output: &mut Vec, inner_ctx: &CirCtx) { + fn terms_tail(term_: &CTermData, output: &mut Vec, ctx: &CirCtx) { match term_ { CTermData::Bool(t) => output.push(t.clone()), CTermData::Int(_, _, t) => output.push(t.clone()), CTermData::Array(t, a) => { let alloc_id = a.unwrap_or_else(|| panic!("Unknown AllocID: {:#?}", a)); - if let Ty::Array(l, _, _) = t { - for i in 0..*l { - let offset = bv_lit(i, 32); - let idx_term = inner_ctx.mem.borrow_mut().load(alloc_id, offset); - output.push(idx_term); - } - } + if let Ty::Array(_, _, _) = t { + output.push(ctx.mem.borrow_mut().term(alloc_id)); + } else { + panic!("CTermData::Array does not hold Array type"); + }; } CTermData::StackPtr(t, _o, a) => { let alloc_id = a.unwrap_or_else(|| panic!("Unknown AllocID: {:#?}", a)); - if let Ty::Array(l, _, _) = t { - for i in 0..*l { - let offset = bv_lit(i, 32); - let idx_term = inner_ctx.mem.borrow_mut().load(alloc_id, offset); - output.push(idx_term); + match t { + Ty::Array(l, _, ty) => { + let mut base = term![Op::Const(Value::Array(Array::default( + Sort::BitVector(32), + &(*ty).sort(), + l.clone() + )))]; + for i in 0..*l { + let offset = bv_lit(i, 32); + let idx_term = ctx.mem.borrow_mut().load(alloc_id, offset); + base = term![Op::Store; base, bv_lit(i, 32), idx_term]; + } + output.push(base); } - } else { - panic!("Unsupported type for stack pointer: {:#?}", t); + Ty::Int(_, s) => { + let size = ctx.mem.borrow_mut().get_size(alloc_id); + let mut base = term![Op::Const(Value::Array(Array::default( + Sort::BitVector(32), + &Sort::BitVector(*s), + size.clone() + )))]; + for i in 0..size { + let offset = bv_lit(i, 32); + let idx_term = ctx.mem.borrow_mut().load(alloc_id, offset); + base = term![Op::Store; base, bv_lit(i, 32), idx_term]; + } + output.push(base); + } + _ => panic!("Unsupported type for stack pointer: {:#?}", t), } } CTermData::Struct(_, fs) => { for (_name, ct) in fs.fields() { - let mut ts = ct.term.terms(inner_ctx); + let mut ts = ct.term.terms(ctx); output.append(&mut ts); } } @@ -609,15 +628,17 @@ impl Embeddable for Ct { ), udef: bool_lit(false), }, - Ty::Array(n, _, ty) => { + Ty::Array(n, _, inner_ty) => { assert!(precompute.is_none()); let v: Vec = (0..*n) - .map(|i| self.declare_input(ctx, ty, idx_name(&name, i), visibility, None)) + .map(|i| { + self.declare_input(ctx, inner_ty, idx_name(&name, i), visibility, None) + }) .collect(); let mut mem = ctx.mem.borrow_mut(); - let id = mem.zero_allocate(*n, 32, ty.num_bits()); + let id = mem.zero_allocate(*n, 32, inner_ty.num_bits()); let arr = Self::T { - term: CTermData::Array(*ty.clone(), Some(id)), + term: CTermData::Array(ty.clone(), Some(id)), udef: bool_lit(false), }; for (i, t) in v.iter().enumerate() { diff --git a/src/front/c/types.rs b/src/front/c/types.rs index a806a31b9..aa9df07f4 100644 --- a/src/front/c/types.rs +++ b/src/front/c/types.rs @@ -88,7 +88,7 @@ impl Ty { Self::Struct(_name, fs) => { Sort::Tuple(fs.fields().map(|(_f_name, f_ty)| f_ty.sort()).collect()) } - Self::Ptr(_, _) => panic!("Ptrs don't have a CirC sort"), + Self::Ptr(_, _) => panic!("Cannot infer CirC sort"), } } diff --git a/src/ir/term/mod.rs b/src/ir/term/mod.rs index 7a37369b7..62fa73968 100644 --- a/src/ir/term/mod.rs +++ b/src/ir/term/mod.rs @@ -2002,34 +2002,6 @@ impl ComputationMetadata { pub fn remove_var(&mut self, name: &str) { self.vars.remove(name); } - - /// Create a call term, given the input arguments in sorted order by argument names. - /// - /// ## Arguments - /// - /// * `name`: function name - /// * `args`: map of argument name (String) to argument term (Term) - /// * `ret_sort`: return sort of the function - /// - /// ## Returns - /// - /// A call term with the input arguments in sorted order by argument names. - /// - pub fn ordered_call_term( - &self, - name: String, - args: FxHashMap, - ret_sort: Sort, - ) -> Term { - let ordered_arg_names = self.ordered_input_names(); - let ordered_args = ordered_arg_names - .iter() - .map(|name| args.get(name).expect("Argument not found: {}").clone()) - .collect::>(); - let ordered_sorts = ordered_args.iter().map(check).collect::>(); - - term(Op::Call(name, ordered_sorts, ret_sort), ordered_args) - } } /// A structured collection of variables that indicates the round structure: e.g., orderings, @@ -2250,6 +2222,16 @@ impl Computation { terms.into_iter() } + /// convert Computation to ComputationSubgraph + pub fn to_cs(&self) -> ComputationSubgraph { + let mut cs = ComputationSubgraph::new(); + for t in self.terms_postorder() { + cs.insert_node(&t); + } + cs.insert_edges(); + cs + } + /// Evaluate the precompute, then this computation. pub fn eval_all(&self, values: &FxHashMap) -> Vec { let mut values = values.clone(); @@ -2298,6 +2280,73 @@ impl Computations { } } +/// A graph representation of a Computation +#[derive(Clone)] +pub struct ComputationSubgraph { + /// List of terms in subgraph + pub nodes: TermSet, + /// Adjacency list of edges in subgraph + pub edges: TermMap, + /// Output leaf nodes + pub outs: TermSet, + /// Input leaf nodes + pub ins: TermSet, +} + +impl Default for ComputationSubgraph { + fn default() -> Self { + Self::new() + } +} + +impl ComputationSubgraph { + /// default constructor + pub fn new() -> Self { + Self { + nodes: TermSet::default(), + edges: TermMap::default(), + outs: TermSet::default(), + ins: TermSet::default(), + } + } + + /// Insert nodes into ComputationSubgraph + pub fn insert_node(&mut self, node: &Term) { + if !self.nodes.contains(node) { + self.nodes.insert(node.clone()); + } + } + + /// Insert edges based on nodes in the subgraph + pub fn insert_edges(&mut self) { + let mut defs: FxHashSet = FxHashSet::default(); + for t in self.nodes.iter() { + self.edges.insert(t.clone(), TermSet::default()); + let mut flag = true; + for c in t.cs().iter() { + if self.nodes.contains(c) { + self.edges.get_mut(t).unwrap().insert(c.clone()); + defs.insert(c.clone()); + flag = false; + } + } + if flag { + self.ins.insert(t.clone()); + } + } + + // Find the leaf node in each subgraph + // TODO: defs.difference(&_uses) ? + for t in self.nodes.iter() { + if !defs.contains(t) { + self.outs.insert(t.clone()); + } + } + // println!("LOG: Input nodes of partition: {}", self.ins.len()); + // println!("LOG: Output nodes of partition: {}", self.outs.len()); + } +} + /// Compute a (deterministic) prime-field challenge. pub fn pf_challenge(name: &str, field: &FieldT) -> FieldV { use rand::SeedableRng; diff --git a/src/lib.rs b/src/lib.rs index 161fc29c4..a4dce16df 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,7 +3,7 @@ //! A compiler infrastructure for compiling programs to circuits #![warn(missing_docs)] -#![deny(warnings)] +// #![deny(warnings)] #[macro_use] pub mod ir; diff --git a/src/target/aby/assignment/def_uses.rs b/src/target/aby/assignment/def_uses.rs new file mode 100644 index 000000000..0f5c1c53d --- /dev/null +++ b/src/target/aby/assignment/def_uses.rs @@ -0,0 +1,873 @@ +use rug::Integer; + +use fxhash::FxHashMap; +use fxhash::FxHashSet; + +use crate::ir::term::*; + +use std::cmp; +use std::collections::HashMap; +use std::time::Instant; + +/// A post order iterater that skip the const index of select/store +pub struct PostOrderIterV3 { + // (cs stacked, term) + stack: Vec<(bool, Term)>, + visited: TermSet, +} + +impl PostOrderIterV3 { + /// Make an iterator over the descendents of `root`. + pub fn new(roots: Vec) -> Self { + Self { + stack: roots.into_iter().map(|t| (false, t)).collect(), + visited: TermSet::default(), + } + } +} + +impl std::iter::Iterator for PostOrderIterV3 { + type Item = Term; + fn next(&mut self) -> Option { + while let Some((children_pushed, t)) = self.stack.last() { + if self.visited.contains(t) { + self.stack.pop(); + } else if !children_pushed { + if let Op::Select = t.op() { + if let Op::Const(Value::BitVector(_)) = &t.cs()[1].op() { + self.stack.last_mut().unwrap().0 = true; + let last = self.stack.last().unwrap().1.clone(); + self.stack.push((false, last.cs()[0].clone())); + continue; + } + } else if let Op::Store = t.op() { + if let Op::Const(Value::BitVector(_)) = &t.cs()[1].op() { + self.stack.last_mut().unwrap().0 = true; + let last = self.stack.last().unwrap().1.clone(); + self.stack.push((false, last.cs()[0].clone())); + self.stack.push((false, last.cs()[2].clone())); + continue; + } + } + self.stack.last_mut().unwrap().0 = true; + let last = self.stack.last().unwrap().1.clone(); + self.stack + .extend(last.cs().iter().map(|c| (false, c.clone()))); + } else { + break; + } + } + self.stack.pop().map(|(_, t)| { + self.visited.insert(t.clone()); + t + }) + } +} + +fn get_sort_len(s: &Sort) -> usize { + let mut len = 0; + len += match s { + Sort::Bool => 1, + Sort::BitVector(_) => 1, + Sort::Array(_, _, n) => *n, + Sort::Tuple(sorts) => { + let mut inner_len = 0; + for inner_s in sorts.iter() { + inner_len += get_sort_len(inner_s); + } + inner_len + } + _ => panic!("Sort is not supported: {:#?}", s), + }; + len +} + +#[derive(Clone)] +/// A structure that maps the actual terms inside of array and tuple +pub struct DefUsesSubGraph { + /// List of terms in subgraph + pub nodes: TermSet, + /// Adjacency list of edges in subgraph + pub edges: TermMap, + /// Output leaf nodes + pub outs: TermSet, + /// Input leaf nodes + pub ins: TermSet, + /// For ILP + pub def_use: FxHashSet<(Term, Term)>, + pub def_uses: FxHashMap>, +} + +impl DefUsesSubGraph { + /// default constructor + pub fn new() -> Self { + Self { + nodes: TermSet::default(), + edges: TermMap::default(), + outs: TermSet::default(), + ins: TermSet::default(), + def_use: FxHashSet::default(), + def_uses: FxHashMap::default(), + } + } + + /// Insert nodes into DefUseSubGraph + pub fn insert_node(&mut self, node: &Term) { + if !self.nodes.contains(node) { + self.nodes.insert(node.clone()); + self.def_uses.insert(node.clone(), Vec::new()); + } + } + + /// Insert edges based on nodes in the subgraph + pub fn insert_edges(&mut self, dug: &DefUsesGraph) { + let mut defs: FxHashSet = FxHashSet::default(); + for t in self.nodes.iter() { + self.edges.insert(t.clone(), TermSet::default()); + let mut flag = true; + for c in dug.use_defs.get(t).unwrap().iter() { + if self.nodes.contains(c) { + self.edges.get_mut(t).unwrap().insert(c.clone()); + self.def_use.insert((c.clone(), t.clone())); + defs.insert(c.clone()); + flag = false; + } + } + if flag { + self.ins.insert(t.clone()); + } + } + + for t in self.nodes.iter() { + if !defs.contains(t) { + self.outs.insert(t.clone()); + } + } + + for (d, u) in self.def_use.iter() { + self.def_uses + .entry(d.clone()) + .or_insert_with(Vec::new) + .push(u.clone()); + } + } +} + +/// Extend current dug to outer n level +pub fn extend_dusg(dusg: &DefUsesSubGraph, dug: &DefUsesGraph) -> DefUsesSubGraph { + let mut new_g: DefUsesSubGraph = DefUsesSubGraph::new(); + new_g.nodes = dusg.nodes.clone(); + for t in dusg.ins.iter() { + for d in dug.use_defs.get(t).unwrap().iter() { + new_g.insert_node(d); + } + } + for t in dusg.outs.iter() { + for u in dug.def_uses.get(t).unwrap().iter() { + new_g.insert_node(u); + } + } + new_g.insert_edges(dug); + new_g +} + +#[derive(Clone)] +/// Def Use Graph for a computation +pub struct DefUsesGraph { + pub def_use: FxHashSet<(Term, Term)>, + pub def_uses: FxHashMap>, + pub use_defs: FxHashMap>, + pub good_terms: TermSet, + + const_terms: TermSet, + // call_args: TermMap>>, + // call_rets: TermMap>>, + call_args_terms: TermMap>>, + call_rets_terms: TermMap>>, + ret_good_terms: Vec, + self_ins: Vec>, + self_outs: Vec>, + call_rets_to_term: HashMap<(Term, usize, usize, usize), Term>, + n_ref: TermMap, + is_main: bool, + var_used: TermMap, + depth_bool: usize, + depth_arith: usize, + num_bool: usize, + num_mul: usize, + num_calls: usize, +} + +impl DefUsesGraph { + pub fn new(c: &Computation) -> Self { + let mut now = Instant::now(); + let mut dug = Self { + // term_to_terms_idx: TermMap::default(), + // term_to_terms: TermMap::default(), + def_use: FxHashSet::default(), + def_uses: FxHashMap::default(), + use_defs: FxHashMap::default(), + const_terms: TermSet::default(), + good_terms: TermSet::default(), + // call_args: TermMap::default(), + // call_rets: TermMap::default(), + call_args_terms: TermMap::default(), + call_rets_terms: TermMap::default(), + ret_good_terms: Vec::new(), + self_ins: Vec::new(), + self_outs: Vec::new(), + call_rets_to_term: HashMap::new(), + n_ref: TermMap::default(), + is_main: true, + var_used: TermMap::default(), + depth_bool: 0, + depth_arith: 0, + num_bool: 0, + num_mul: 0, + num_calls: 1, + }; + println!("Entering Def Use Graph:"); + dug.construct_def_use_general(c, false, &HashMap::new()); + dug.construct_mapping(); + println!("Time: Def Use Graph: {:?}", now.elapsed()); + println!("DefUseGraph depth bool: {:?}", dug.depth_bool); + println!("DefUseGraph depth mul: {:?}", dug.depth_arith); + println!("DefUseGraph num_bool: {:?}", dug.num_bool); + println!("DefUseGraph num_mul: {:?}", dug.num_mul); + dug + } + + pub fn for_call_site( + c: &Computation, + dugs: &HashMap, + fname: &String, + ) -> Self { + let mut now = Instant::now(); + let mut dug = Self { + // term_to_terms_idx: TermMap::default(), + // term_to_terms: TermMap::default(), + def_use: FxHashSet::default(), + def_uses: FxHashMap::default(), + use_defs: FxHashMap::default(), + const_terms: TermSet::default(), + good_terms: TermSet::default(), + // call_args: TermMap::default(), + // call_rets: TermMap::default(), + call_args_terms: TermMap::default(), + call_rets_terms: TermMap::default(), + ret_good_terms: Vec::new(), + self_ins: Vec::new(), + self_outs: Vec::new(), + call_rets_to_term: HashMap::new(), + n_ref: TermMap::default(), + is_main: fname == "main", + var_used: TermMap::default(), + depth_bool: 0, + depth_arith: 0, + num_bool: 0, + num_mul: 0, + num_calls: 1, + }; + dug.construct_def_use_general(c, true, dugs); + // moved this after insert context + println!("Time: Def Use Graph: {:?}", now.elapsed()); + now = Instant::now(); + dug.construct_mapping(); + println!("Time: Def Use Graph mapping: {:?}", now.elapsed()); + println!("DefUseGraph depth bool: {:?}", dug.depth_bool); + println!("DefUseGraph depth mul: {:?}", dug.depth_arith); + println!("DefUseGraph num_bool: {:?}", dug.num_bool); + println!("DefUseGraph num_mul: {:?}", dug.num_mul); + dug + } + + pub fn set_num_calls(&mut self, cnt: &usize) { + self.num_calls = *cnt; + } + pub fn get_k(&self) -> FxHashMap { + let mut k_map: FxHashMap = FxHashMap::default(); + let max_k: f64 = 1.0; + k_map.insert( + "a".to_string(), + max_k.min((self.depth_arith as f64) / ((self.num_mul * self.num_calls) as f64)), + ); + k_map.insert( + "b".to_string(), + max_k.min((self.depth_bool as f64) / ((self.num_bool * self.num_calls) as f64)), + ); + println!("num_calls: {}", self.num_calls); + println!("k_map: {:?}", k_map); + k_map + } + + // Cnt # of refs for each term + fn construct_n_ref(&mut self, c: &Computation) { + for t in PostOrderIterV3::new(c.outputs.clone()) { + for arg in t.cs().iter() { + *self.n_ref.entry(arg.clone()).or_insert(0) += 1; + } + } + for out in c.outputs.iter() { + *self.n_ref.entry(out.clone()).or_insert(0) += 1; + } + } + + fn get_and_de_ref( + &mut self, + term_to_terms: &mut TermMap>, + t: &Term, + ) -> Vec<(Term, usize, usize, usize)> { + let cnt = self.n_ref.get_mut(t).unwrap(); + *cnt -= 1; + if *cnt == 0 { + term_to_terms.remove(t).unwrap() + } else { + term_to_terms.get(t).unwrap().clone() + } + } + + fn construct_def_use_general( + &mut self, + c: &Computation, + css_flag: bool, + dugs: &HashMap, + ) { + self.construct_n_ref(c); + // A mapping from a term to real terms in the circuit + // An array term -> A vector of children terms + let mut term_to_terms: TermMap> = TermMap::default(); + for t in PostOrderIterV3::new(c.outputs().clone()) { + match &t.op() { + Op::Var(..) => { + term_to_terms.insert(t.clone(), vec![(t.clone(), 0, 0, 0)]); + if self.is_main { + self.add_term(&t); + } + } + Op::Const(Value::BitVector(_)) => { + term_to_terms.insert(t.clone(), vec![(t.clone(), 0, 0, 0)]); + self.const_terms.insert(t.clone()); + self.add_term(&t); + } + Op::Const(Value::Tuple(tup)) => { + let mut terms: Vec<(Term, usize, usize, usize)> = Vec::new(); + for val in tup.iter() { + terms.push((leaf_term(Op::Const(val.clone())), 0, 0, 0)); + self.const_terms.insert(leaf_term(Op::Const(val.clone()))); + self.add_term(&leaf_term(Op::Const(val.clone()))); + } + term_to_terms.insert(t.clone(), terms); + } + Op::Tuple => { + let mut terms: Vec<(Term, usize, usize, usize)> = Vec::new(); + for c in t.cs().iter() { + terms.extend(self.get_and_de_ref(&mut term_to_terms, &c)); + } + term_to_terms.insert(t.clone(), terms); + } + Op::Field(i) => { + let tuple_terms = self.get_and_de_ref(&mut term_to_terms, &t.cs()[0]); + let tuple_sort = check(&t.cs()[0]); + let (offset, len) = match tuple_sort { + Sort::Tuple(t) => { + assert!(*i < t.len()); + let mut offset = 0; + for j in 0..*i { + offset += get_sort_len(&t[j]); + } + let len = get_sort_len(&t[*i]); + (offset, len) + } + _ => panic!("Field op on non-tuple"), + }; + // get ret slice + let field_terms = &tuple_terms[offset..offset + len]; + term_to_terms.insert(t.clone(), field_terms.to_vec()); + } + Op::Update(i) => { + let mut tuple_terms = self.get_and_de_ref(&mut term_to_terms, &t.cs()[0]); + let value_terms = self.get_and_de_ref(&mut term_to_terms, &t.cs()[1]); + tuple_terms[*i] = value_terms[0].clone(); + term_to_terms.insert(t.clone(), tuple_terms); + } + Op::Const(Value::Array(arr)) => { + let mut terms: Vec<(Term, usize, usize, usize)> = Vec::new(); + let sort = check(&t); + if let Sort::Array(_, _, n) = sort { + let n = n as i32; + for i in 0..n { + let idx = Value::BitVector(BitVector::new(Integer::from(i), 32)); + let v = match arr.map.get(&idx) { + Some(c) => c, + None => &*arr.default, + }; + terms.push((leaf_term(Op::Const(v.clone())), 0, 0, 0)); + self.const_terms.insert(leaf_term(Op::Const(v.clone()))); + self.add_term(&leaf_term(Op::Const(v.clone()))); + } + } else { + todo!("Const array sort not array????") + } + term_to_terms.insert(t.clone(), terms); + } + Op::Store => { + let mut array_terms = self.get_and_de_ref(&mut term_to_terms, &t.cs()[0]); + let value_terms = self.get_and_de_ref(&mut term_to_terms, &t.cs()[2]); + if let Op::Const(Value::BitVector(bv)) = &t.cs()[1].op() { + // constant indexing + let i = bv.uint().to_usize().unwrap().clone(); + // println!("Store the {} value on a {} size array.",idx , array_terms.len()); + array_terms[i] = value_terms[0].clone(); + term_to_terms.insert(t.clone(), array_terms); + } else { + self.get_and_de_ref(&mut term_to_terms, &t.cs()[1]); + for i in 0..array_terms.len() { + self.add_def_use(&array_terms[i].0, &t); + array_terms[i] = + (t.clone(), 0, array_terms[i].2 + 1, array_terms[i].3); + } + self.def_use.insert((value_terms[0].0.clone(), t.clone())); + term_to_terms.insert(t.clone(), array_terms); + } + } + Op::Select => { + let array_terms = self.get_and_de_ref(&mut term_to_terms, &t.cs()[0]); + if let Op::Const(Value::BitVector(bv)) = &t.cs()[1].op() { + // constant indexing + let i = bv.uint().to_usize().unwrap().clone(); + term_to_terms.insert(t.clone(), vec![array_terms[i].clone()]); + } else { + self.get_and_de_ref(&mut term_to_terms, &t.cs()[1]); + let mut depth_bool: usize = 0; + let mut depth_arith: usize = 0; + for idx in 0..array_terms.len() { + self.add_def_use(&array_terms[idx].0, &t); + depth_bool = cmp::max(depth_bool, array_terms[idx].2); + depth_arith = cmp::max(depth_arith, array_terms[idx].3) + } + term_to_terms + .insert(t.clone(), vec![(t.clone(), 0, depth_bool + 1, depth_arith)]); + } + } + Op::Call(callee, _, ret_sort) => { + // Use call term itself as the placeholder + // Call term will be ignore by the ilp solver later + let mut ret_terms: Vec<(Term, usize, usize, usize)> = Vec::new(); + let mut num_rets: usize = 0; + // ret_sort must be a tuple + if let Sort::Tuple(sorts) = ret_sort { + num_rets = sorts.iter().map(|ret| get_sort_len(ret)).sum(); + } else { + panic!("call term return sort is not a tuple.") + } + // TODO: Add comment + // let mut args: Vec> = Vec::new(); + // let mut rets: Vec> = Vec::new(); + let mut args_t: Vec> = Vec::new(); + let mut rets_t: Vec> = Vec::new(); + + let mut depth_bool: usize = 0; + let mut depth_arith: usize = 0; + + if css_flag{ + // css mode + // insert callee's context into caller's context inside the call + let context_args = dugs.get(callee).unwrap().self_ins.clone(); + let context_rets = dugs.get(callee).unwrap().self_outs.clone(); + + // args -> call's in + let mut arg_id = 0; + for arg in t.cs().clone().iter() { + // Inlining callee's use + let arg_terms = self.get_and_de_ref(&mut term_to_terms, arg); + for tu in arg_terms.iter() { + // Ignore terms ret from a call + // No need to assign when a function call takes in another function's return as argument + if !self.call_rets_to_term.contains_key(tu) { + let uses = context_args.get(arg_id).unwrap(); + for u in uses.iter() { + self.add_def_use(&tu.0, &u); + } + } + depth_bool = cmp::max(depth_bool, tu.2); + depth_arith = cmp::max(depth_arith, tu.3); + } + // TODO: Is this correct? Need to think carefully + arg_id += 1; + + // Safe call site + // let mut arg_set: FxHashSet = FxHashSet::default(); + let mut arg_vec: Vec = Vec::new(); + for aarg in arg_terms.iter() { + // arg_set.insert(get_op_id(&aarg.0.op())); + arg_vec.push(aarg.0.clone()); + } + args_t.push(arg_vec); + // args.push(arg_set); + } + + let mut idx = 0; + ret_terms = context_rets + .into_iter() + .flatten() + .map(|ret| { + // self.add_term(&ret); + let tu = (ret, idx, depth_bool, depth_arith); + idx += 1; + self.call_rets_to_term.insert(tu.clone(), t.clone()); + // rets.push(FxHashSet::default()); + rets_t.push(Vec::new()); + tu + }) + .collect(); + + } else{ + // non css mode + for c in t.cs().iter() { + let arg_terms = self.get_and_de_ref(&mut term_to_terms, c); + // let mut arg_set: FxHashSet = FxHashSet::default(); + let mut arg_term: Vec = Vec::new(); + for arg in arg_terms.iter() { + // arg_set.insert(get_op_id(&arg.0.op())); + arg_term.push(arg.0.clone()); + depth_bool = cmp::max(depth_bool, arg.2); + depth_arith = cmp::max(depth_arith, arg.3); + } + args_t.push(arg_term); + // args.push(arg_set); + } + for idx in 0..num_rets { + // rets.push(FxHashSet::default()); + ret_terms.push((t.clone(), idx, depth_bool, depth_arith)); + rets_t.push(Vec::new()); + } + } + assert_eq!(num_rets, ret_terms.len()); + + term_to_terms.insert(t.clone(), ret_terms); + // self.call_args.insert(t.clone(), args); + // self.call_rets.insert(t.clone(), rets); + self.call_args_terms.insert(t.clone(), args_t); + self.call_rets_terms.insert(t.clone(), rets_t); + } + _ => { + if matches!(t.op(), Op::Ite) && matches!(t.cs()[1].op(), Op::Store){ + // assert_eq!(t.cs[2].op, Op::Store); + let cond_terms = self.get_and_de_ref(&mut term_to_terms, &t.cs()[0]); + assert_eq!(cond_terms.len(), 1); + self.def_use.insert((cond_terms[0].0.clone(), t.clone())); + // true branch + let mut t_terms = self.get_and_de_ref(&mut term_to_terms, &t.cs()[1]); + // false branch + let f_terms = self.get_and_de_ref(&mut term_to_terms, &t.cs()[2]); + assert_eq!(t_terms.len(), f_terms.len()); + for i in 0..t_terms.len() { + self.add_def_use(&t_terms[i].0, &t); + self.add_def_use(&f_terms[i].0, &t); + t_terms[i] = (t.clone(), 0, t_terms[i].2 + 1, t_terms[i].3 + 1); + } + term_to_terms.insert(t.clone(), t_terms); + } else { + let mut depth_bool: usize = 0; + let mut depth_arith: usize = 0; + // TODO: add css mode here + + if css_flag { + for c in t.cs().iter() { + let terms = self.get_and_de_ref(&mut term_to_terms, c); + assert_eq!(terms.len(), 1); + if let Some(call_t) = self.call_rets_to_term.get(&terms[0]) { + // insert op to ret set + // let rets = self.call_rets.get_mut(&call_t).unwrap(); + // rets.get_mut(terms[0].1).unwrap().insert(get_op_id(&t.op())); + + // insert term to ret terms + let rets_t = self.call_rets_terms.get_mut(&call_t).unwrap(); + rets_t.get_mut(terms[0].1).unwrap().push(t.clone()); + + self.add_def_use(&terms[0].0, &t); + } else { + self.add_def_use(&terms[0].0, &t); + } + depth_bool = cmp::max(depth_bool, terms[0].2); + depth_arith = cmp::max(depth_arith, terms[0].3); + } + } else{ + for c in t.cs().iter() { + if let Op::Call(..) = c.op() { + continue; + } else { + let terms = self.get_and_de_ref(&mut term_to_terms, c); + assert_eq!(terms.len(), 1); + if let Op::Call(..) = terms[0].0.op() { + // insert op to ret set + // let rets = self.call_rets.get_mut(&terms[0].0).unwrap(); + // rets.get_mut(terms[0].1).unwrap().insert(get_op_id(&t.op())); + // insert term to ret terms + let rets_t = self.call_rets_terms.get_mut(&terms[0].0).unwrap(); + rets_t.get_mut(terms[0].1).unwrap().push(t.clone()); + } else { + self.def_use.insert((terms[0].0.clone(), t.clone())); + self.add_def_use(&terms[0].0, &t); + } + depth_bool = cmp::max(depth_bool, terms[0].2); + depth_arith = cmp::max(depth_arith, terms[0].3); + } + } + } + match &t.clone().op() { + Op::BvNaryOp(BvNaryOp::Mul) => { + term_to_terms.insert( + t.clone(), + vec![(t.clone(), 0, depth_bool + 1, depth_arith + 1)], + ); + } + _ => { + term_to_terms.insert( + t.clone(), + vec![(t.clone(), 0, depth_bool + 1, depth_arith)], + ); + } + } + } + self.add_term(&t); + } + } + } + + for out in c.outputs().iter() { + let out_terms = self.get_and_de_ref(&mut term_to_terms, out); + let mut out_v: Vec = Vec::new(); + for (t, _, depth_bool, depth_arith) in out_terms.iter() { + // v.push(t.clone()); + self.ret_good_terms.push(t.clone()); + out_v.push(t.clone()); + self.depth_bool = cmp::max(self.depth_bool, *depth_bool); + self.depth_arith = cmp::max(self.depth_arith, *depth_arith); + } + self.self_outs.push(out_v); + } + + // for (k, _) in self.term_to_terms.iter(){ + // println!("Left over ts: {} {}", k.op, self.n_ref.get(k).unwrap()); + // } + // todo!("TEsting") + + println!("Def Use Graph # of terms: {}", self.good_terms.len()); + println!("Def Use Graph # of edges: {}", self.def_use.len()); + } + + fn construct_mapping(&mut self) { + for (def, _use) in self.def_use.iter() { + if self.def_uses.contains_key(def) { + self.def_uses.get_mut(def).unwrap().insert(_use.clone()); + } else { + let mut uses: FxHashSet = FxHashSet::default(); + uses.insert(_use.clone()); + self.def_uses.insert(def.clone(), uses); + } + if self.use_defs.contains_key(_use) { + self.use_defs.get_mut(_use).unwrap().insert(def.clone()); + } else { + let mut defs: FxHashSet = FxHashSet::default(); + defs.insert(def.clone()); + self.def_uses.insert(_use.clone(), defs); + } + } + } + + pub fn gen_in_out(&mut self, c: &Computation) { + for n in c.metadata.ordered_input_names() { + // n is already a ssa name here + let s = c.metadata.input_sort(&n).clone(); + let t = leaf_term(Op::Var(n.to_string(), s)); + if let Some(uses) = self.def_uses.get(&t) { + self.self_ins.push(uses.clone()); + } else { + // This argument is not being used at all! + self.self_ins.push(FxHashSet::default()); + } + } + } + + /// Out put the call site from this function's computation + pub fn get_call_site( + &mut self, + ) -> Vec<(Term, Vec>, Vec>)> { + let mut call_sites: Vec<(Term, Vec>, Vec>)> = Vec::new(); + + for (call_term, arg_terms) in self.call_args_terms.iter() { + let ret_terms = self.call_rets_terms.get(call_term).unwrap(); + call_sites.push((call_term.clone(), arg_terms.clone(), ret_terms.clone())); + } + call_sites + } + + /// insert the caller's context + pub fn insert_context( + &mut self, + arg_values: &Vec>, + rets: &Vec>, + caller_dug: &DefUsesGraph, + callee: &Computation, + extra_level: usize, + ) { + let mut _input_set: TermSet = TermSet::default(); + let mut _output_set: TermSet = TermSet::default(); + todo!(); + // insert def of args + // for (n, v) in arg_names.into_iter().zip(arg_values) { + // let ssa_names = callee.metadata.input_ssa_name_from_nice_name(n); + // for (sname, index) in ssa_names.iter() { + // let s = callee.metadata.input_sort(&sname).clone(); + // // println!("Def: {}, Use: {}", v.get(*index).unwrap(), leaf_term(Op::Var(sname.clone(), s.clone()))); + // let def_t = v.get(*index).unwrap().clone(); + // let use_t = leaf_term(Op::Var(sname.to_string(), s)); + // if let Op::Call(..) = def_t.op { + // continue; + // } + // if let Op::Var(..) = def_t.op { + // continue; + // } + // if let Op::Const(_) = def_t.op { + // continue; + // } + // // if !self.good_terms.contains(&use_t) { + // // // println!("FIX: {}", use_t.op); + // // // This is because the function doesn't use this arg + // // //todo!("Fix this..."); + // // continue; + // // } + // if let Some(actual_used) = self.var_used.clone().get(&use_t) { + // for actual_t in actual_used.iter() { + // self.add_term(&def_t); + // self.def_use.insert((def_t.clone(), actual_t.clone())); + // input_set.insert(def_t.clone()); + // } + // } + // } + // } + + // // insert use of rets + // let outs = self.ret_good_terms.clone(); + + // assert_eq!(outs.len(), rets.len()); + // for (d, uses) in outs.into_iter().zip(rets) { + // for u in uses.iter() { + // // TODO: so weird + // self.add_term(&d); + // self.add_term(u); + // self.def_use.insert((d.clone(), u.clone())); + // } + // } + + // kind of mutation? + // for i in 1..extra_level { + // // insert def of def + // for def in input_set.clone().iter() { + // let def_defs = caller_dug.def_uses.get(def).unwrap(); + // for def_def in def_defs.iter() { + // self.add_term(def_def); + // self.def_use.insert((def_def.clone(), def.clone())); + // input_set.insert(def_def.clone()); + // } + // } + + // // insert use of use + // for _use in output_set.clone().iter() { + // let use_uses = caller_dug.def_uses.get(_use).unwrap(); + // for use_use in use_uses.iter() { + // self.add_term(use_use); + // self.def_use.insert((_use.clone(), use_use.clone())); + // input_set.insert(use_use.clone()); + // } + // } + // } + + // self.construct_mapping(); + } + + fn add_term(&mut self, t: &Term) { + if !self.good_terms.contains(t) { + self.good_terms.insert(t.clone()); + let defs: FxHashSet = FxHashSet::default(); + let uses: FxHashSet = FxHashSet::default(); + self.def_uses.insert(t.clone(), uses); + self.use_defs.insert(t.clone(), defs); + match &t.op() { + Op::Ite => self.num_bool += 1, + Op::BvNaryOp(o) => { + match o { + BvNaryOp::Xor => { + self.num_bool += 0; + } + BvNaryOp::Or => { + self.num_bool += 1; + } + BvNaryOp::And => { + self.num_bool += 1; + } + BvNaryOp::Add => { + // self.num_bool += 1; + } + BvNaryOp::Mul => { + // self.num_bool += 1; + self.num_mul += 1; + } + } + } + Op::BvBinOp(o) => { + match o { + BvBinOp::Sub => { + // self.num_bool += 1; + } + BvBinOp::Udiv => { + self.num_bool += 1; + } + BvBinOp::Urem => { + self.num_bool += 1; + } + _ => {} + } + } + Op::Eq => { + self.num_bool += 1; + } + Op::BvBinPred(_) => { + self.num_bool += 1; + } + Op::Store | Op::Select => { + if let Sort::Array(_, _, length) = check(&t.cs()[0]) { + self.num_bool += length; + } + // self.num_bool += 1; + } + _ => {} + } + } + } + + fn add_def_use(&mut self, d: &Term, u: &Term) { + if self.is_main { + self.add_term(d); + self.add_term(u); + self.def_use.insert((d.clone(), u.clone())); + } else { + if let Op::Var(..) = d.op() { + self.add_term(u); + if self.var_used.contains_key(d) { + self.var_used.get_mut(d).unwrap().insert(u.clone()); + } else { + let mut var_used_set: TermSet = TermSet::default(); + var_used_set.insert(u.clone()); + self.var_used.insert(d.clone(), var_used_set); + } + return; + } else { + self.add_term(d); + self.add_term(u); + self.def_use.insert((d.clone(), u.clone())); + } + } + } +} diff --git a/src/target/aby/assignment/ilp.rs b/src/target/aby/assignment/ilp.rs index 0e81eb191..4b2cf302b 100644 --- a/src/target/aby/assignment/ilp.rs +++ b/src/target/aby/assignment/ilp.rs @@ -33,18 +33,24 @@ use fxhash::{FxHashMap, FxHashSet}; use super::{ShareType, SharingMap, SHARE_TYPES}; use crate::ir::term::*; -use crate::target::aby::assignment::CostModel; +use crate::target::aby::assignment::CostModel; use crate::target::ilp::{Expression, Ilp, Variable}; use good_lp::variable; +use crate::target::aby::assignment::def_uses::*; + +use std::collections::HashMap; use std::env::var; /// Uses an ILP to assign... -pub fn assign(c: &Computation, cm: &str) -> SharingMap { +pub fn assign(c: &ComputationSubgraph, cm: &str) -> SharingMap { let base_dir = match cm { "opa" => "opa", "hycc" => "hycc", + "empirical" => "empirical", + "empirical_wan" => "empirical_wan", + "synth" => "synthetic", _ => panic!("Unknown cost model type: {}", cm), }; let p = format!( @@ -52,19 +58,351 @@ pub fn assign(c: &Computation, cm: &str) -> SharingMap { var("CARGO_MANIFEST_DIR").expect("Could not find env var CARGO_MANIFEST_DIR"), base_dir ); - let costs = CostModel::from_opa_cost_file(&p); + let costs = CostModel::from_opa_cost_file(&p, FxHashMap::default()); build_ilp(c, &costs) } -fn build_ilp(c: &Computation, costs: &CostModel) -> SharingMap { +/// Uses an ILP to assign and abandon the outer assignments +pub fn assign_mut(c: &ComputationSubgraph, cm: &str, co: &ComputationSubgraph) -> SharingMap { + let base_dir = match cm { + "opa" => "opa", + "hycc" => "hycc", + "empirical" => "empirical", + "empirical_wan" => "empirical_wan", + "synth" => "synthetic", + _ => panic!("Unknown cost model type: {}", cm), + }; + let p = format!( + "{}/third_party/{}/adapted_costs.json", + var("CARGO_MANIFEST_DIR").expect("Could not find env var CARGO_MANIFEST_DIR"), + base_dir + ); + let costs = CostModel::from_opa_cost_file(&p, FxHashMap::default()); + let mut smap = TermMap::default(); + let mut cnt = 1; + while smap.len() == 0 { + // A hack for empty result during multi-threading + // Simply retry until get a non-empty result + if cnt > 5 { + panic!("MT BUG: Dead loop.") + } + smap = build_ilp(c, &costs); + cnt = cnt + 1; + } + let mut trunc_smap = TermMap::default(); + for node in co.nodes.clone() { + let share = smap.get_mut(&node).unwrap(); + trunc_smap.insert(node, *share); + } + trunc_smap +} + +/// Uses an ILP to assign and abandon the outer assignments +pub fn assign_mut_smart( + dusg: &DefUsesSubGraph, + cm: &str, + dusg_ref: &TermSet, + k_map: &FxHashMap, +) -> SharingMap { + let base_dir = match cm { + "opa" => "opa", + "hycc" => "hycc", + "empirical" => "empirical", + "empirical_wan" => "empirical_wan", + "synth" => "synthetic", + _ => panic!("Unknown cost model type: {}", cm), + }; + let p = format!( + "{}/third_party/{}/adapted_costs.json", + var("CARGO_MANIFEST_DIR").expect("Could not find env var CARGO_MANIFEST_DIR"), + base_dir + ); + let costs = CostModel::from_opa_cost_file(&p, k_map.clone()); + let mut smap = TermMap::default(); + let mut cnt = 1; + while smap.len() == 0 { + // A hack for empty result during multi-threading + // Simply retry until get a non-empty result + if cnt > 5 { + panic!("MT BUG: Dead loop.") + } + smap = build_smart_ilp(dusg.nodes.clone(), &dusg.def_use, &costs); + cnt = cnt + 1; + } + let mut trunc_smap = TermMap::default(); + for node in dusg_ref.iter() { + let share = smap.get_mut(&node).unwrap(); + trunc_smap.insert(node.clone(), *share); + } + trunc_smap +} + +/// Uses an ILP to assign... +pub fn smart_global_assign( + terms: &TermSet, + def_uses: &FxHashSet<(Term, Term)>, + k_map: &FxHashMap, + cm: &str, +) -> SharingMap { + let base_dir = match cm { + "opa" => "opa", + "hycc" => "hycc", + "empirical" => "empirical", + "empirical_wan" => "empirical_wan", + "synth" => "synthetic", + _ => panic!("Unknown cost model type: {}", cm), + }; + let p = format!( + "{}/third_party/{}/adapted_costs.json", + var("CARGO_MANIFEST_DIR").expect("Could not find env var CARGO_MANIFEST_DIR"), + base_dir + ); + let costs = CostModel::from_opa_cost_file(&p, k_map.clone()); + build_smart_ilp(terms.clone(), def_uses, &costs) +} + +fn build_smart_ilp( + term_set: TermSet, + def_uses: &FxHashSet<(Term, Term)>, + costs: &CostModel, +) -> SharingMap { + let terms: FxHashMap = term_set + .into_iter() + .enumerate() + .map(|(i, t)| (t, i)) + .collect(); + let mut term_vars: FxHashMap<(Term, ShareType), (Variable, f64, String)> = FxHashMap::default(); + let mut conv_vars: FxHashMap<(Term, ShareType, ShareType), (Variable, f64)> = + FxHashMap::default(); + let mut ilp = Ilp::new(); + + // build variables for all term assignments + for (t, i) in terms.iter() { + let mut vars = vec![]; + // println!("op: {}",&t.op); + match &t.op() { + Op::Var(..) | Op::Const(_) => { + for ty in &SHARE_TYPES { + let name = format!("t_{}_{}", i, ty.char()); + // println!("name: {:?}, op: {:?}", name, t.op); + let v = ilp.new_variable(variable().binary(), name.clone()); + if *ty == ShareType::Arithmetic { + term_vars.insert((t.clone(), *ty), (v, 0.1, name)); + } else if *ty == ShareType::Boolean { + term_vars.insert((t.clone(), *ty), (v, 0.12, name)); + } else { + term_vars.insert((t.clone(), *ty), (v, 0.11, name)); + } + // term_vars.insert((t.clone(), *ty), (v, 0.0, name)); + vars.push(v); + } + } + Op::Select | Op::Store => { + if let Sort::Array(_, _, length) = check(&t.cs()[0]) { + if let Some(costs) = costs.ops.get(&t.op().to_string()) { + for (ty, cost) in costs { + let name = format!("t_{}_{}", i, ty.char()); + // println!("name: {:?}, op: {:?}", name, t.op); + let v = ilp.new_variable(variable().binary(), name.clone()); + term_vars.insert((t.clone(), *ty), (v, *cost * (length as f64), name)); + vars.push(v); + } + } else { + panic!("No cost for op {}", &t.op()) + } + } else { + panic!("Not array sort {}", &t.cs()[1].op()) + } + } + // fix the select and store here for array size + _ => { + if let Some(costs) = costs.ops.get(&t.op().to_string()) { + for (ty, cost) in costs { + let name = format!("t_{}_{}", i, ty.char()); + // println!("name: {:?}, op: {:?}", name, t.op); + let v = ilp.new_variable(variable().binary(), name.clone()); + term_vars.insert((t.clone(), *ty), (v, *cost, name)); + vars.push(v); + } + } else { + panic!("No cost for op {}", &t.op()) + } + } + } + // Sum of assignments is at least 1. + ilp.new_constraint( + vars.into_iter() + .fold((0.0).into(), |acc: Expression, v| acc + v) + >> 1.0, + ); + } + + // build variables for all conversions assignments + for (def, use_) in def_uses { + // println!("def op: {}", def.op); + // println!("use op: {}", use_.op); + let def_i = terms.get(def).unwrap(); + for from_ty in &SHARE_TYPES { + for to_ty in &SHARE_TYPES { + // if def can be from_ty, and use can be to_ty + if term_vars.contains_key(&(def.clone(), *from_ty)) + && term_vars.contains_key(&(use_.clone(), *to_ty)) + && from_ty != to_ty + { + let v = ilp.new_variable( + variable().binary(), + format!("c_{}_{}2{}", def_i, from_ty.char(), to_ty.char()), + ); + conv_vars.insert( + (def.clone(), *from_ty, *to_ty), + (v, *costs.conversions.get(&(*from_ty, *to_ty)).unwrap()), + ); + } + } + } + } + + let def_uses_map: FxHashMap> = { + let mut t = FxHashMap::default(); + for (d, u) in def_uses { + t.entry(d.clone()).or_insert_with(Vec::new).push(u.clone()); + } + t + }; + + for (def, uses) in def_uses_map.iter() { + for use_ in uses { + for from_ty in &SHARE_TYPES { + for to_ty in &SHARE_TYPES { + let ilp_version = true; + if ilp_version { + conv_vars.get(&(def.clone(), *from_ty, *to_ty)).map(|c| { + term_vars.get(&(def.clone(), *from_ty)).map(|t_from| { + // c[term i from pi to pi'] >= t[term j with pi'] + t[term i with pi] - 1 + term_vars.get(&(use_.clone(), *to_ty)).map(|t_to| { + ilp.new_constraint(c.0 >> (t_from.0 + t_to.0 - 1.0)) + }) + }) + }); + } else { + // hardcoding here + // a2b > y2b + // y2a > b2a + // a2y > b2y + if *from_ty == ShareType::Arithmetic && *to_ty == ShareType::Boolean { + let cheap_ty = ShareType::Yao; + conv_vars.get(&(def.clone(), *from_ty, *to_ty)).map(|c| { + term_vars.get(&(def.clone(), *from_ty)).map(|d_from| { + term_vars.get(&(def.clone(), *to_ty)).map(|d_to| { + term_vars.get(&(def.clone(), cheap_ty)).map(|d_ch| { + term_vars.get(&(use_.clone(), *to_ty)).map(|u_to| { + ilp.new_constraint( + c.0 >> (d_from.0 + u_to.0 + - 1.0 + - d_to.0 + - d_ch.0), + ) + }) + }) + }) + }) + }); + } else if *from_ty == ShareType::Yao && *to_ty == ShareType::Arithmetic { + let cheap_ty = ShareType::Boolean; + conv_vars.get(&(def.clone(), *from_ty, *to_ty)).map(|c| { + term_vars.get(&(def.clone(), *from_ty)).map(|d_from| { + term_vars.get(&(def.clone(), *to_ty)).map(|d_to| { + term_vars.get(&(def.clone(), cheap_ty)).map(|d_ch| { + term_vars.get(&(use_.clone(), *to_ty)).map(|u_to| { + ilp.new_constraint( + c.0 >> (d_from.0 + u_to.0 + - 1.0 + - d_to.0 + - d_ch.0), + ) + }) + }) + }) + }) + }); + } else if *from_ty == ShareType::Arithmetic && *to_ty == ShareType::Yao { + let cheap_ty = ShareType::Boolean; + conv_vars.get(&(def.clone(), *from_ty, *to_ty)).map(|c| { + term_vars.get(&(def.clone(), *from_ty)).map(|d_from| { + term_vars.get(&(def.clone(), *to_ty)).map(|d_to| { + term_vars.get(&(def.clone(), cheap_ty)).map(|d_ch| { + term_vars.get(&(use_.clone(), *to_ty)).map(|u_to| { + ilp.new_constraint( + c.0 >> (d_from.0 + u_to.0 + - 1.0 + - d_to.0 + - d_ch.0), + ) + }) + }) + }) + }) + }); + } else { + conv_vars.get(&(def.clone(), *from_ty, *to_ty)).map(|c| { + term_vars.get(&(def.clone(), *from_ty)).map(|d_from| { + term_vars.get(&(def.clone(), *to_ty)).map(|d_to| { + term_vars.get(&(use_.clone(), *to_ty)).map(|u_to| { + ilp.new_constraint( + c.0 >> (d_from.0 + u_to.0 - 1.0 - d_to.0), + ) + }) + }) + }) + }); + } + } + } + } + } + } + + ilp.maximize( + -conv_vars + .values() + .map(|(a, b)| (a, b)) + .chain(term_vars.values().map(|(a, b, _)| (a, b))) + .fold(0.0.into(), |acc: Expression, (v, cost)| acc + *v * *cost), + ); + + let (_opt, solution) = ilp.default_solve().unwrap(); + + // println!("{:?}", solution); + + let mut assignment = TermMap::default(); + for ((term, ty), (_, _, var_name)) in &term_vars { + if solution.get(var_name).unwrap() == &1.0 { + // if assignment.contains_key(term){ + // println!("Duplicate assignment found: {}!!!", term.op); + // } + // if *ty == ShareType::Boolean{ + // println!("Boolean assignment found: {}!!!", term.op); + // if let Some(uses) = def_uses_map.get(term){ + // for t in uses.iter(){ + // println!("uses: {}!!!", t.op); + // } + // } + // } + assignment.insert(term.clone(), *ty); + } else if solution.get(var_name).unwrap() == &0.5 { + println!("Half op: {:?}", term.op()); + } + } + assignment +} + +fn build_ilp(c: &ComputationSubgraph, costs: &CostModel) -> SharingMap { let mut terms: TermSet = TermSet::default(); let mut def_uses: FxHashSet<(Term, Term)> = FxHashSet::default(); - for o in &c.outputs { - for t in PostOrderIter::new(o.clone()) { - terms.insert(t.clone()); - for c in t.cs() { - def_uses.insert((c.clone(), t.clone())); - } + for (node, edges) in c.edges.clone() { + terms.insert(node.clone()); + for c in edges.iter() { + def_uses.insert((c.clone(), node.clone())); } } let terms: FxHashMap = @@ -77,13 +415,9 @@ fn build_ilp(c: &Computation, costs: &CostModel) -> SharingMap { // build variables for all term assignments for (t, i) in terms.iter() { let mut vars = vec![]; + // println!("op: {}",&t.op); match &t.op() { - Op::Var(..) - | Op::Const(_) - | Op::Call(..) - | Op::Field(_) - | Op::Update(..) - | Op::Tuple => { + Op::Var(..) | Op::Const(_) => { for ty in &SHARE_TYPES { let name = format!("t_{}_{}", i, ty.char()); let v = ilp.new_variable(variable().binary(), name.clone()); @@ -91,11 +425,9 @@ fn build_ilp(c: &Computation, costs: &CostModel) -> SharingMap { vars.push(v); } } - Op::Select | Op::Store => { - panic!("Requires def-use-graph, tests should not have secret indices.") - } _ => { - if let Some(costs) = costs.ops.get(t.op()) { + println!("op: {}", t.op()); + if let Some(costs) = costs.ops.get(&t.op().to_string()) { for (ty, cost) in costs { let name = format!("t_{}_{}", i, ty.char()); let v = ilp.new_variable(variable().binary(), name.clone()); @@ -182,8 +514,800 @@ fn build_ilp(c: &Computation, costs: &CostModel) -> SharingMap { assignment } +/// Use a ILP to find a optimal combination of mutation assignments +pub fn comb_selection( + mut_maps: &HashMap>, + cs_map: &HashMap, + cm: &str, +) -> HashMap { + let base_dir = match cm { + "opa" => "opa", + "hycc" => "hycc", + "empirical" => "empirical", + "empirical_wan" => "empirical_wan", + "synth" => "synthetic", + _ => panic!("Unknown cost model type: {}", cm), + }; + let p = format!( + "{}/third_party/{}/adapted_costs.json", + var("CARGO_MANIFEST_DIR").expect("Could not find env var CARGO_MANIFEST_DIR"), + base_dir + ); + let costs = CostModel::from_opa_cost_file(&p, FxHashMap::default()); + build_comb_ilp(mut_maps, cs_map, &costs) +} + +/** + * Combination algo: + * + * Notations: + * - P ^ i: partitions + * - l^i_j: assignment j for P^i + * - C^i_j: inner cost (node and inner edges) of P^i with l^i_j + * - X^i_j: Assign l^i_j to P^i + * - K_{p, p'}: conversion cost from p to p' + * - v^i_{k,p}: (edge) vertex k in partition P^i with assignment p + * - B^i_{k, p}: Set of indices j such that l^i_j assign p to k + * - E_{u, w, p, p'}: cross partition edges that assign u, w with p, p' repsectively + * + * Constraints: + * - [ for any i, \sum_j X^i_j >= 1 ]: Each partition must have a assignment + * - [ for any i,k, \sum_p v^i_{k, p} >= 1 ]: Each node must have a assignment + * - [ E_{u, w, p, p'} >= v{u, p} + v{w, p'} - 1 ]: Edge is added if u and v is assigned with p and p' resp. + * + * Object: + * - min[\sum_{i, j} C^i_j X^i_j + \sum E{u,w,p,p'}K_{p, p'}] + * + */ + +fn build_comb_ilp( + mut_maps: &HashMap>, + cs_map: &HashMap, + costs: &CostModel, +) -> HashMap { + // global vars + let mut ilp = Ilp::new(); + + let mut x_vars: FxHashMap<(usize, usize), (Variable, f64, String)> = FxHashMap::default(); + let mut v_vars: FxHashMap<(Term, ShareType), (Variable, f64, String)> = FxHashMap::default(); + let mut e_vars: FxHashMap<(Term, ShareType, ShareType), (Variable, f64)> = FxHashMap::default(); + + let mut b_set: FxHashMap<(usize, Term, ShareType), Vec> = FxHashMap::default(); + + // build variables for selection in each partition X^i_j + for (pid, smaps) in mut_maps.iter() { + let mut vars = vec![]; + for (mid, maps) in smaps.iter() { + let name = format!("X_{}_{}", pid, mid); + let v = ilp.new_variable(variable().binary(), name.clone()); + // TODO: update this buggy function + let map_cost = calculate_cost(&maps, costs); + x_vars.insert((*pid, *mid), (v, map_cost, name)); + vars.push(v); + } + // Sum of assignment selection is at least 1 + ilp.new_constraint( + vars.into_iter() + .fold((0.0).into(), |acc: Expression, v| acc + v) + >> 1.0, + ); + } + + // build b set + // build variables for each edge node v^i_{k,p} + let mut edge_terms: FxHashSet<(usize, usize, Term)> = FxHashSet::default(); + let mut def_uses: FxHashSet<(Term, Term, usize, usize)> = FxHashSet::default(); + for (pid, cs) in cs_map.iter() { + let mut index: usize = 1; + for t in cs.ins.clone() { + edge_terms.insert((*pid, index, t.clone())); + index = index + 1; + // get all cross partition egdes: + for outer_node in t.cs().iter() { + def_uses.insert((outer_node.clone(), t.clone(), *pid, index)); + } + } + for t in cs.outs.clone() { + edge_terms.insert((*pid, index, t.clone())); + index = index + 1; + } + } + + for (pid, i, t) in edge_terms.iter() { + let mut vars = vec![]; + match &t.op() { + Op::Var(..) | Op::Const(_) => { + for ty in &SHARE_TYPES { + let name = format!("t_{}_{}_{}", pid, i, ty.char()); + let v = ilp.new_variable(variable().binary(), name.clone()); + v_vars.insert((t.clone(), *ty), (v, 0.0, name)); + vars.push(v); + // TODO: add constraints for B here + let mut x_vec = vec![]; + for (mid, maps) in mut_maps.get(pid).unwrap().iter() { + // buggy? + let a_ty = maps.get(t).unwrap(); + if ty == a_ty { + if !b_set.contains_key(&(*pid, t.clone(), *ty)) { + b_set.insert((*pid, t.clone(), *ty), Vec::new()); + } + b_set + .get_mut(&(*pid, t.clone(), *ty)) + .unwrap() + .push(x_vars.get(&(*pid, *mid)).unwrap().0); + x_vec.push(x_vars.get(&(*pid, *mid)).unwrap().0); + } + } + + ilp.new_constraint( + x_vec + .into_iter() + .fold((0.0).into(), |acc: Expression, v| acc + v) + >> v, + ); + } + } + _ => { + if let Some(costs) = costs.ops.get(&t.op().to_string()) { + for (ty, cost) in costs { + let name = format!("t_{}_{}_{}", pid, i, ty.char()); + let v = ilp.new_variable(variable().binary(), name.clone()); + v_vars.insert((t.clone(), *ty), (v, *cost, name)); + vars.push(v); + // TODO: add constraints for B here + let mut x_vec = vec![]; + for (mid, maps) in mut_maps.get(pid).unwrap().iter() { + // buggy? + let a_ty = maps.get(t).unwrap(); + if ty == a_ty { + if !b_set.contains_key(&(*pid, t.clone(), *ty)) { + b_set.insert((*pid, t.clone(), *ty), Vec::new()); + } + b_set + .get_mut(&(*pid, t.clone(), *ty)) + .unwrap() + .push(x_vars.get(&(*pid, *mid)).unwrap().0); + x_vec.push(x_vars.get(&(*pid, *mid)).unwrap().0); + } + } + ilp.new_constraint( + x_vec + .into_iter() + .fold((0.0).into(), |acc: Expression, v| acc + v) + >> v, + ); + } + } else { + panic!("No cost for op {}", &t.op()) + } + } + } + // Sum of assignments is at least 1. + ilp.new_constraint( + vars.into_iter() + .fold((0.0).into(), |acc: Expression, v| acc + v) + >> 1.0, + ); + } + + // build variables for conversion + for (def, use_, pid, idx) in &def_uses { + for from_ty in &SHARE_TYPES { + for to_ty in &SHARE_TYPES { + // if def can be from_ty, and use can be to_ty + if v_vars.contains_key(&(def.clone(), *from_ty)) + && v_vars.contains_key(&(use_.clone(), *to_ty)) + && from_ty != to_ty + { + let v = ilp.new_variable( + variable().binary(), + format!("c_{}_{}_{}2{}", pid, idx, from_ty.char(), to_ty.char()), + ); + e_vars.insert( + (def.clone(), *from_ty, *to_ty), + (v, *costs.conversions.get(&(*from_ty, *to_ty)).unwrap()), + ); + } + } + } + } + + let def_uses: FxHashMap> = { + let mut t = FxHashMap::default(); + for (d, u, _, _) in def_uses { + t.entry(d).or_insert_with(Vec::new).push(u); + } + t + }; + + for (def, uses) in def_uses { + for use_ in uses { + for from_ty in &SHARE_TYPES { + for to_ty in &SHARE_TYPES { + e_vars.get(&(def.clone(), *from_ty, *to_ty)).map(|c| { + v_vars.get(&(def.clone(), *from_ty)).map(|t_from| { + // c[term i from pi to pi'] >= t[term j with pi'] + t[term i with pi] - 1 + v_vars + .get(&(use_.clone(), *to_ty)) + .map(|t_to| ilp.new_constraint(c.0 >> (t_from.0 + t_to.0 - 1.0))) + }) + }); + } + } + } + } + + ilp.maximize( + -e_vars + .values() + .map(|(a, b)| (a, b)) + .chain(x_vars.values().map(|(a, b, _)| (a, b))) + .fold(0.0.into(), |acc: Expression, (v, cost)| acc + *v * *cost), + ); + + let (_opt, solution) = ilp.default_solve().unwrap(); + + let mut local_assignments: HashMap = HashMap::new(); + + for (pid, smaps) in mut_maps.iter() { + for (mid, _maps) in smaps.iter() { + let name = format!("X_{}_{}", pid, mid); + if solution.get(&name).unwrap() == &1.0 { + let map = mut_maps.get(pid).unwrap().get(mid).unwrap().clone(); + local_assignments.insert(*pid, map); + } + } + } + + local_assignments +} + +/// Use a ILP to find a optimal combination of mutation assignments +pub fn comb_selection_smart( + dug: &DefUsesGraph, + mut_maps: &HashMap>, + dusg_map: &HashMap, + cm: &str, +) -> HashMap { + let base_dir = match cm { + "opa" => "opa", + "hycc" => "hycc", + "empirical" => "empirical", + "empirical_wan" => "empirical_wan", + "synth" => "synthetic", + _ => panic!("Unknown cost model type: {}", cm), + }; + let p = format!( + "{}/third_party/{}/adapted_costs.json", + var("CARGO_MANIFEST_DIR").expect("Could not find env var CARGO_MANIFEST_DIR"), + base_dir + ); + let costs = CostModel::from_opa_cost_file(&p, dug.get_k()); + build_comb_ilp_smart(mut_maps, dug, dusg_map, &costs) +} + +/** + * Combination algo: + * + * Notations: + * - P ^ i: partitions + * - l^i_j: assignment j for P^i + * - C^i_j: inner cost (node and inner edges) of P^i with l^i_j + * - X^i_j: Assign l^i_j to P^i + * - K_{p, p'}: conversion cost from p to p' + * - v^i_{k,p}: (edge) vertex k in partition P^i with assignment p + * - B^i_{k, p}: Set of indices j such that l^i_j assign p to k + * - E_{u, w, p, p'}: cross partition edges that assign u, w with p, p' repsectively + * + * Constraints: + * - [ for any i, \sum_j X^i_j >= 1 ]: Each partition must have a assignment + * - [ for any i,k, \sum_p v^i_{k, p} >= 1 ]: Each node must have a assignment + * - [ E_{u, w, p, p'} >= v{u, p} + v{w, p'} - 1 ]: Edge is added if u and v is assigned with p and p' resp. + * + * Object: + * - min[\sum_{i, j} C^i_j X^i_j + \sum E{u,w,p,p'}K_{p, p'}] + * + */ + +fn build_comb_ilp_smart( + mut_maps: &HashMap>, + dug: &DefUsesGraph, + dusg_map: &HashMap, + costs: &CostModel, +) -> HashMap { + // global vars + let mut ilp = Ilp::new(); + + let mut x_vars: FxHashMap<(usize, usize), (Variable, f64, String)> = FxHashMap::default(); + let mut v_vars: FxHashMap<(Term, ShareType), (Variable, f64, String)> = FxHashMap::default(); + let mut e_vars: FxHashMap<(Term, ShareType, ShareType), (Variable, f64)> = FxHashMap::default(); + + let mut b_set: FxHashMap<(usize, Term, ShareType), Vec> = FxHashMap::default(); + + // build variables for selection in each partition X^i_j + for (pid, smaps) in mut_maps.iter() { + let mut vars = vec![]; + let du = dusg_map.get(pid).unwrap(); + for (mid, maps) in smaps.iter() { + let name = format!("X_{}_{}", pid, mid); + let v = ilp.new_variable(variable().binary(), name.clone()); + // TODO: update this buggy function + let map_cost = calculate_cost_smart(&maps, costs, du); + x_vars.insert((*pid, *mid), (v, map_cost, name)); + vars.push(v); + } + // Sum of assignment selection is at least 1 + ilp.new_constraint( + vars.into_iter() + .fold((0.0).into(), |acc: Expression, v| acc + v) + >> 1.0, + ); + } + + // + // build b set + // build variables for each edge node v^i_{k,p} + let mut edge_terms: FxHashSet<(usize, usize, Term)> = FxHashSet::default(); + let mut def_uses: FxHashSet<(Term, Term, usize, usize)> = FxHashSet::default(); + for (pid, du) in dusg_map.iter() { + let mut index: usize = 1; + for t in du.ins.clone() { + edge_terms.insert((*pid, index, t.clone())); + index = index + 1; + // get all cross partition egdes: + for outer_node in dug.use_defs.get(&t).unwrap().iter() { + def_uses.insert((outer_node.clone(), t.clone(), *pid, index)); + } + } + for t in du.outs.clone() { + edge_terms.insert((*pid, index, t.clone())); + index = index + 1; + } + } + + for (pid, i, t) in edge_terms.iter() { + let mut vars = vec![]; + match &t.op() { + Op::Var(..) | Op::Const(_) => { + for ty in &SHARE_TYPES { + let name = format!("t_{}_{}_{}", pid, i, ty.char()); + let v = ilp.new_variable(variable().binary(), name.clone()); + v_vars.insert((t.clone(), *ty), (v, 0.0, name)); + vars.push(v); + // TODO: add constraints for B here + let mut x_vec = vec![]; + for (mid, maps) in mut_maps.get(pid).unwrap().iter() { + // buggy? + let a_ty = maps.get(t).unwrap(); + if ty == a_ty { + if !b_set.contains_key(&(*pid, t.clone(), *ty)) { + b_set.insert((*pid, t.clone(), *ty), Vec::new()); + } + b_set + .get_mut(&(*pid, t.clone(), *ty)) + .unwrap() + .push(x_vars.get(&(*pid, *mid)).unwrap().0); + x_vec.push(x_vars.get(&(*pid, *mid)).unwrap().0); + } + } + + ilp.new_constraint( + x_vec + .into_iter() + .fold((0.0).into(), |acc: Expression, v| acc + v) + >> v, + ); + } + } + Op::Select | Op::Store => { + if let Sort::Array(_, _, length) = check(&t.cs()[0]) { + if let Some(costs) = costs.ops.get(&t.op().to_string()) { + for (ty, cost) in costs { + let name = format!("t_{}_{}_{}", pid, i, ty.char()); + let v = ilp.new_variable(variable().binary(), name.clone()); + v_vars.insert((t.clone(), *ty), (v, *cost * (length as f64), name)); + vars.push(v); + // TODO: add constraints for B here + let mut x_vec = vec![]; + for (mid, maps) in mut_maps.get(pid).unwrap().iter() { + // buggy? + let a_ty = maps.get(t).unwrap(); + if ty == a_ty { + if !b_set.contains_key(&(*pid, t.clone(), *ty)) { + b_set.insert((*pid, t.clone(), *ty), Vec::new()); + } + b_set + .get_mut(&(*pid, t.clone(), *ty)) + .unwrap() + .push(x_vars.get(&(*pid, *mid)).unwrap().0); + x_vec.push(x_vars.get(&(*pid, *mid)).unwrap().0); + } + } + ilp.new_constraint( + x_vec + .into_iter() + .fold((0.0).into(), |acc: Expression, v| acc + v) + >> v, + ); + } + } else { + panic!("No cost for op {}", &t.op()) + } + } else { + panic!("Not array sort {}", &t.cs()[1].op()) + } + } + _ => { + if let Some(costs) = costs.ops.get(&t.op().to_string()) { + for (ty, cost) in costs { + let name = format!("t_{}_{}_{}", pid, i, ty.char()); + let v = ilp.new_variable(variable().binary(), name.clone()); + v_vars.insert((t.clone(), *ty), (v, *cost, name)); + vars.push(v); + // TODO: add constraints for B here + let mut x_vec = vec![]; + for (mid, maps) in mut_maps.get(pid).unwrap().iter() { + // buggy? + let a_ty = maps.get(t).unwrap(); + if ty == a_ty { + if !b_set.contains_key(&(*pid, t.clone(), *ty)) { + b_set.insert((*pid, t.clone(), *ty), Vec::new()); + } + b_set + .get_mut(&(*pid, t.clone(), *ty)) + .unwrap() + .push(x_vars.get(&(*pid, *mid)).unwrap().0); + x_vec.push(x_vars.get(&(*pid, *mid)).unwrap().0); + } + } + ilp.new_constraint( + x_vec + .into_iter() + .fold((0.0).into(), |acc: Expression, v| acc + v) + >> v, + ); + } + } else { + panic!("No cost for op {}", &t.op()) + } + } + } + // Sum of assignments is at least 1. + ilp.new_constraint( + vars.into_iter() + .fold((0.0).into(), |acc: Expression, v| acc + v) + >> 1.0, + ); + } + + // build variables for conversion + for (def, use_, pid, idx) in &def_uses { + for from_ty in &SHARE_TYPES { + for to_ty in &SHARE_TYPES { + // if def can be from_ty, and use can be to_ty + if v_vars.contains_key(&(def.clone(), *from_ty)) + && v_vars.contains_key(&(use_.clone(), *to_ty)) + && from_ty != to_ty + { + let v = ilp.new_variable( + variable().binary(), + format!("c_{}_{}_{}2{}", pid, idx, from_ty.char(), to_ty.char()), + ); + e_vars.insert( + (def.clone(), *from_ty, *to_ty), + (v, *costs.conversions.get(&(*from_ty, *to_ty)).unwrap()), + ); + } + } + } + } + + let def_uses: FxHashMap> = { + let mut t = FxHashMap::default(); + for (d, u, _, _) in def_uses { + t.entry(d).or_insert_with(Vec::new).push(u); + } + t + }; + + for (def, uses) in def_uses { + for use_ in uses { + for from_ty in &SHARE_TYPES { + for to_ty in &SHARE_TYPES { + let ilp_version = true; + if ilp_version { + e_vars.get(&(def.clone(), *from_ty, *to_ty)).map(|c| { + v_vars.get(&(def.clone(), *from_ty)).map(|t_from| { + // c[term i from pi to pi'] >= t[term j with pi'] + t[term i with pi] - 1 + v_vars.get(&(use_.clone(), *to_ty)).map(|t_to| { + ilp.new_constraint(c.0 >> (t_from.0 + t_to.0 - 1.0)) + }) + }) + }); + } else { + // hardcoding here + // a2b > y2b + // y2a > b2a + // a2y > b2y + if *from_ty == ShareType::Arithmetic && *to_ty == ShareType::Boolean { + let cheap_ty = ShareType::Yao; + e_vars.get(&(def.clone(), *from_ty, *to_ty)).map(|c| { + v_vars.get(&(def.clone(), *from_ty)).map(|d_from| { + v_vars.get(&(def.clone(), *to_ty)).map(|d_to| { + v_vars.get(&(def.clone(), cheap_ty)).map(|d_ch| { + v_vars.get(&(use_.clone(), *to_ty)).map(|u_to| { + ilp.new_constraint( + c.0 >> (d_from.0 + u_to.0 + - 1.0 + - d_to.0 + - d_ch.0), + ) + }) + }) + }) + }) + }); + } else if *from_ty == ShareType::Yao && *to_ty == ShareType::Arithmetic { + let cheap_ty = ShareType::Boolean; + e_vars.get(&(def.clone(), *from_ty, *to_ty)).map(|c| { + v_vars.get(&(def.clone(), *from_ty)).map(|d_from| { + v_vars.get(&(def.clone(), *to_ty)).map(|d_to| { + v_vars.get(&(def.clone(), cheap_ty)).map(|d_ch| { + v_vars.get(&(use_.clone(), *to_ty)).map(|u_to| { + ilp.new_constraint( + c.0 >> (d_from.0 + u_to.0 + - 1.0 + - d_to.0 + - d_ch.0), + ) + }) + }) + }) + }) + }); + } else if *from_ty == ShareType::Arithmetic && *to_ty == ShareType::Yao { + let cheap_ty = ShareType::Boolean; + e_vars.get(&(def.clone(), *from_ty, *to_ty)).map(|c| { + v_vars.get(&(def.clone(), *from_ty)).map(|d_from| { + v_vars.get(&(def.clone(), *to_ty)).map(|d_to| { + v_vars.get(&(def.clone(), cheap_ty)).map(|d_ch| { + v_vars.get(&(use_.clone(), *to_ty)).map(|u_to| { + ilp.new_constraint( + c.0 >> (d_from.0 + u_to.0 + - 1.0 + - d_to.0 + - d_ch.0), + ) + }) + }) + }) + }) + }); + } else { + e_vars.get(&(def.clone(), *from_ty, *to_ty)).map(|c| { + v_vars.get(&(def.clone(), *from_ty)).map(|d_from| { + v_vars.get(&(def.clone(), *to_ty)).map(|d_to| { + v_vars.get(&(use_.clone(), *to_ty)).map(|u_to| { + ilp.new_constraint( + c.0 >> (d_from.0 + u_to.0 - 1.0 - d_to.0), + ) + }) + }) + }) + }); + } + } + } + } + } + } + + ilp.maximize( + -e_vars + .values() + .map(|(a, b)| (a, b)) + .chain(x_vars.values().map(|(a, b, _)| (a, b))) + .fold(0.0.into(), |acc: Expression, (v, cost)| acc + *v * *cost), + ); + + let (_opt, solution) = ilp.default_solve().unwrap(); + + let mut local_assignments: HashMap = HashMap::new(); + + for (pid, smaps) in mut_maps.iter() { + for (mid, _maps) in smaps.iter() { + let name = format!("X_{}_{}", pid, mid); + if solution.get(&name).unwrap() == &1.0 { + let map = mut_maps.get(pid).unwrap().get(mid).unwrap().clone(); + local_assignments.insert(*pid, map); + } + } + } + + local_assignments +} + +/// Calculate the cost of a global assignment +pub fn calculate_cost_smart(smap: &SharingMap, costs: &CostModel, dusg: &DefUsesSubGraph) -> f64 { + let mut cost: f64 = 0.0; + let mut conv_cost: HashMap<(Term, ShareType), f64> = HashMap::new(); + for (t, to_ty) in smap { + match &t.op() { + Op::Var(..) + | Op::Const(_) + | Op::BvConcat + | Op::BvExtract(..) + | Op::BoolToBv + | Op::BvBit(_) => { + cost = cost + 0.0; + } + Op::Select | Op::Store => { + if let Sort::Array(_, _, length) = check(&t.cs()[0]) { + cost = cost + + (length as f64) + * costs + .ops + .get(&t.op().to_string()) + .unwrap() + .get(to_ty) + .unwrap(); + } else { + panic!("Not array sort {}", &t.cs()[1].op()) + } + } + _ => { + // println!("op: {}", t.op); + cost = cost + + costs + .ops + .get(&t.op().to_string()) + .unwrap() + .get(to_ty) + .unwrap(); + } + } + for arg_t in dusg.def_uses.get(t).unwrap().iter() { + if smap.contains_key(&arg_t) { + let from_ty = smap.get(&arg_t).unwrap(); + if from_ty != to_ty { + // todo fix the calculation heres + let c = costs.conversions.get(&(*to_ty, *from_ty)).unwrap(); + conv_cost.insert((t.clone(), *from_ty), *c); + } + } + } + } + cost = cost + conv_cost.values().fold(0.0, |acc, &x| acc + x); + cost +} + +/// Calculate the cost of a global assignment +pub fn calculate_cost_smart_dug(smap: &SharingMap, cm: &str, dug: &DefUsesGraph) -> f64 { + let base_dir = match cm { + "opa" => "opa", + "hycc" => "hycc", + "empirical" => "empirical", + "empirical_wan" => "empirical_wan", + "synth" => "synthetic", + _ => panic!("Unknown cost model type: {}", cm), + }; + let p = format!( + "{}/third_party/{}/adapted_costs.json", + var("CARGO_MANIFEST_DIR").expect("Could not find env var CARGO_MANIFEST_DIR"), + base_dir + ); + let costs = CostModel::from_opa_cost_file(&p, dug.get_k()); + let mut cost: f64 = 0.0; + let mut conv_cost: HashMap<(Term, ShareType), f64> = HashMap::new(); + for (t, to_ty) in smap { + match &t.op() { + Op::Var(..) + | Op::Const(_) + // | Op::BvConcat + // | Op::BvExtract(..) + // | Op::BoolToBv + // | Op::BvBit(_) + => { + cost = cost + 0.0; + let mut new_to_ty = to_ty.clone(); + if *to_ty == ShareType::Yao { + new_to_ty = ShareType::Boolean; + } + for arg_t in dug.def_uses.get(t).unwrap().iter() { + if smap.contains_key(&arg_t) { + let from_ty = smap.get(&arg_t).unwrap(); + if *from_ty != new_to_ty { + // todo fix the calculation heres + // println!("conversion from {:?} to {:?}", *to_ty, *from_ty); + // println!("def: {:?} use {:?}", t.op, arg_t.op); + let c = costs.conversions.get(&(new_to_ty, *from_ty)).unwrap(); + conv_cost.insert((t.clone(), *from_ty), *c); + } + } + } + } + Op::Select | Op::Store => { + if let Sort::Array(_, _, length) = check(&t.cs()[0]){ + cost = cost + (length as f64) * costs.ops.get(&t.op().to_string()).unwrap().get(to_ty).unwrap(); + for arg_t in dug.def_uses.get(t).unwrap().iter() { + if smap.contains_key(&arg_t) { + let from_ty = smap.get(&arg_t).unwrap(); + if from_ty != to_ty { + // todo fix the calculation heres + // println!("conversion from {:?} to {:?}", *to_ty, *from_ty); + // println!("def: {:?} use {:?}", t.op, arg_t.op); + let c = costs.conversions.get(&(*to_ty, *from_ty)).unwrap(); + conv_cost.insert((t.clone(), *from_ty), *c); + } + } + } + } else{ + panic!("Not array sort {}", &t.cs()[1].op()) + } + } + _ => { + // println!("op: {}", t.op); + cost = cost + costs.ops.get(&t.op().to_string()).unwrap().get(to_ty).unwrap(); + for arg_t in dug.def_uses.get(t).unwrap().iter() { + if smap.contains_key(&arg_t) { + let from_ty = smap.get(&arg_t).unwrap(); + if from_ty != to_ty { + // todo fix the calculation heres + // println!("conversion from {:?} to {:?}", *to_ty, *from_ty); + // println!("def: {:?} use {:?}", t.op, arg_t.op); + let c = costs.conversions.get(&(*to_ty, *from_ty)).unwrap(); + conv_cost.insert((t.clone(), *from_ty), *c); + } + } + } + } + } + } + cost = cost + conv_cost.values().fold(0.0, |acc, &x| acc + x); + cost +} + +/// Calculate the cost of a global assignment +pub fn calculate_cost(smap: &SharingMap, costs: &CostModel) -> f64 { + let mut cost: f64 = 0.0; + let mut conv_cost: HashMap<(Term, ShareType), f64> = HashMap::new(); + for (t, to_ty) in smap { + match &t.op() { + Op::Var(..) + | Op::Const(_) + | Op::BvConcat + | Op::BvExtract(..) + | Op::BoolToBv + | Op::BvBit(_) => { + cost = cost + 0.0; + } + _ => { + // println!("op: {}", t.op); + cost = cost + + costs + .ops + .get(&t.op().to_string()) + .unwrap() + .get(to_ty) + .unwrap(); + } + } + for arg_t in t.cs().iter() { + if smap.contains_key(&arg_t) { + let from_ty = smap.get(&arg_t).unwrap(); + if from_ty != to_ty { + let c = costs.conversions.get(&(*to_ty, *from_ty)).unwrap(); + conv_cost.insert((arg_t.clone(), *to_ty), *c); + } + } + } + } + cost = cost + conv_cost.values().fold(0.0, |acc, &x| acc + x); + cost +} + #[cfg(test)] -mod tests { +mod test { use super::*; #[test] @@ -192,16 +1316,20 @@ mod tests { "{}/third_party/opa/adapted_costs.json", var("CARGO_MANIFEST_DIR").expect("Could not find env var CARGO_MANIFEST_DIR") ); - let c = CostModel::from_opa_cost_file(&p); + let c = CostModel::from_opa_cost_file(&p, FxHashMap::default()); // random checks from the file... assert_eq!( &1127.0, - c.ops.get(&BV_MUL).unwrap().get(&ShareType::Yao).unwrap() + c.ops + .get(&BV_MUL.to_string()) + .unwrap() + .get(&ShareType::Yao) + .unwrap() ); assert_eq!( &1731.0, c.ops - .get(&BV_MUL) + .get(&BV_MUL.to_string()) .unwrap() .get(&ShareType::Boolean) .unwrap() @@ -209,7 +1337,7 @@ mod tests { assert_eq!( &7.0, c.ops - .get(&BV_XOR) + .get(&BV_XOR.to_string()) .unwrap() .get(&ShareType::Boolean) .unwrap() @@ -222,15 +1350,18 @@ mod tests { "{}/third_party/opa/adapted_costs.json", var("CARGO_MANIFEST_DIR").expect("Could not find env var CARGO_MANIFEST_DIR") ); - let costs = CostModel::from_opa_cost_file(&p); + let costs = CostModel::from_opa_cost_file(&p, FxHashMap::default()); let cs = Computation { outputs: vec![term![BV_MUL; leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32))), leaf_term(Op::Var("b".to_owned(), Sort::BitVector(32))) ]], - ..Default::default() + metadata: ComputationMetadata::default(), + precomputes: Default::default(), + persistent_arrays: Default::default(), }; - let _assignment = build_ilp(&cs, &costs); + let cg = cs.to_cs(); + let _assignment = build_ilp(&cg, &costs); } #[test] @@ -239,7 +1370,7 @@ mod tests { "{}/third_party/opa/adapted_costs.json", var("CARGO_MANIFEST_DIR").expect("Could not find env var CARGO_MANIFEST_DIR") ); - let costs = CostModel::from_opa_cost_file(&p); + let costs = CostModel::from_opa_cost_file(&p, FxHashMap::default()); let cs = Computation { outputs: vec![term![Op::Eq; term![BV_MUL; @@ -266,9 +1397,12 @@ mod tests { ], leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32))) ]], - ..Default::default() + metadata: ComputationMetadata::default(), + precomputes: Default::default(), + persistent_arrays: Default::default(), }; - let assignment = build_ilp(&cs, &costs); + let cg = cs.to_cs(); + let assignment = build_ilp(&cg, &costs); // Big enough to do the math with arith assert_eq!( &ShareType::Arithmetic, @@ -284,7 +1418,7 @@ mod tests { "{}/third_party/opa/adapted_costs.json", var("CARGO_MANIFEST_DIR").expect("Could not find env var CARGO_MANIFEST_DIR") ); - let costs = CostModel::from_opa_cost_file(&p); + let costs = CostModel::from_opa_cost_file(&p, FxHashMap::default()); let cs = Computation { outputs: vec![term![Op::Eq; term![BV_MUL; @@ -302,9 +1436,12 @@ mod tests { ], leaf_term(Op::Var("a".to_owned(), Sort::BitVector(32))) ]], - ..Default::default() + metadata: ComputationMetadata::default(), + precomputes: Default::default(), + persistent_arrays: Default::default(), }; - let assignment = build_ilp(&cs, &costs); + let cg = cs.to_cs(); + let assignment = build_ilp(&cg, &costs); // All yao assert_eq!( &ShareType::Yao, diff --git a/src/target/aby/assignment/ilp_opa.rs b/src/target/aby/assignment/ilp_opa.rs new file mode 100644 index 000000000..f9a4edb33 --- /dev/null +++ b/src/target/aby/assignment/ilp_opa.rs @@ -0,0 +1,175 @@ +//! ILP-based sharing assignment +//! +//! Based on ["Efficient MPC via Program Analysis: A Framework for Efficient Optimal +//! Mixing"](https://dl.acm.org/doi/pdf/10.1145/3319535.3339818) by Ishaq, Muhammad and Milanova, +//! Ana L. and Zikas, Vassilis. + +use fxhash::{FxHashMap, FxHashSet}; + +use super::{ShareType, SharingMap}; +use crate::ir::term::*; + +use crate::target::aby::assignment::CostModel; +use crate::target::ilp::{Expression, Ilp, Variable}; +use good_lp::variable; + +use std::env::var; + +/// Uses an ILP to assign... +pub fn opa_smart_global_assign( + terms: &TermSet, + def_uses: &FxHashSet<(Term, Term)>, + k_map: &FxHashMap, + cm: &str, + share_types: &[ShareType; 2], +) -> SharingMap { + let base_dir = match cm { + "opa" => "opa", + "hycc" => "hycc", + "empirical" => "empirical", + "empirical_wan" => "empirical_wan", + "synth" => "synthetic", + _ => panic!("Unknown cost model type: {}", cm), + }; + let p = format!( + "{}/third_party/{}/adapted_costs.json", + var("CARGO_MANIFEST_DIR").expect("Could not find env var CARGO_MANIFEST_DIR"), + base_dir + ); + let costs = CostModel::from_opa_cost_file(&p, k_map.clone()); + build_smart_ilp(terms.clone(), def_uses, &costs, share_types) +} + +fn build_smart_ilp( + term_set: TermSet, + def_uses: &FxHashSet<(Term, Term)>, + costs: &CostModel, + share_types: &[ShareType; 2], +) -> SharingMap { + let terms: FxHashMap = term_set + .into_iter() + .enumerate() + .map(|(i, t)| (t, i)) + .collect(); + let mut term_vars: FxHashMap<(Term, ShareType), (Variable, f64, String)> = FxHashMap::default(); + let mut conv_vars: FxHashMap<(Term, ShareType, ShareType), (Variable, f64)> = + FxHashMap::default(); + let mut ilp = Ilp::new(); + + // build variables for all term assignments + for (t, i) in terms.iter() { + let mut vars = vec![]; + // println!("op: {}",&t.op); + match &t.op() { + Op::Var(..) | Op::Const(_) => { + for ty in share_types { + let name = format!("t_{}_{}", i, ty.char()); + let v = ilp.new_variable(variable().min(0), name.clone()); + term_vars.insert((t.clone(), *ty), (v, 0.0, name)); + vars.push(v); + } + } + // fix the select and store here for array size + _ => { + if let Some(costs) = costs.ops.get(&t.op().to_string()) { + for ty in share_types { + if let Some(cost) = costs.get(ty) { + let name = format!("t_{}_{}", i, ty.char()); + let v = ilp.new_variable(variable().min(0), name.clone()); + term_vars.insert((t.clone(), *ty), (v, *cost, name)); + vars.push(v); + } + } + } else { + panic!("No cost for op {}", &t.op()) + } + } + } + // Sum of assignments is at least 1. + ilp.new_constraint( + vars.into_iter() + .fold((0.0).into(), |acc: Expression, v| acc + v) + >> 1.0, + ); + } + + // build variables for all conversions assignments + for (def, use_) in def_uses { + let def_i = terms.get(def).unwrap(); + for from_ty in share_types { + for to_ty in share_types { + // if def can be from_ty, and use can be to_ty + if term_vars.contains_key(&(def.clone(), *from_ty)) + && term_vars.contains_key(&(use_.clone(), *to_ty)) + && from_ty != to_ty + { + let v = ilp.new_variable( + variable().min(0).max(1), + format!("c_{}_{}2{}", def_i, from_ty.char(), to_ty.char()), + ); + // println!("t: {:?} {:?}: {:?} cost: {:?}", def, v, format!("c_{}_{}2{}", def_i, from_ty.char(), to_ty.char()), *costs.conversions.get(&(*from_ty, *to_ty)).unwrap()); + conv_vars.insert( + (def.clone(), *from_ty, *to_ty), + (v, *costs.conversions.get(&(*from_ty, *to_ty)).unwrap()), + ); + } + } + } + } + + let def_uses_map: FxHashMap> = { + let mut t = FxHashMap::default(); + for (d, u) in def_uses { + t.entry(d.clone()).or_insert_with(Vec::new).push(u.clone()); + } + t + }; + + for (def, uses) in def_uses_map.iter() { + for use_ in uses { + for from_ty in share_types { + for to_ty in share_types { + // OPA formulation: + // conv_from_2_to >= def_from - use_from + // This is not correct + // conv_vars.get(&(def.clone(), *from_ty, *to_ty)).map(|c| { + // term_vars.get(&(def.clone(), *from_ty)).map(|t_from| { + // term_vars.get(&(use_.clone(), *from_ty)).map(|t_to| { + // println!("{:?} >> {:?} - {:?}", c.0, def, use_); + // println!("{:?} >> {:?} - {:?}", c.0, t_from.2, t_to.2); + // ilp.new_constraint(c.0 >> (t_from.0 - t_to.0)) + // }) + // }) + // }); + conv_vars.get(&(def.clone(), *from_ty, *to_ty)).map(|c| { + term_vars.get(&(def.clone(), *from_ty)).map(|t_from| { + // c[term i from pi to pi'] >= t[term j with pi'] + t[term i with pi] - 1 + term_vars + .get(&(use_.clone(), *to_ty)) + .map(|t_to| ilp.new_constraint(c.0 >> (t_from.0 + t_to.0 - 1.0))) + }) + }); + } + } + } + } + + ilp.maximize( + -conv_vars + .values() + .map(|(a, b)| (a, b)) + .chain(term_vars.values().map(|(a, b, _)| (a, b))) + .fold(0.0.into(), |acc: Expression, (v, cost)| acc + *v * *cost), + ); + + let (_opt, solution) = ilp.default_solve().unwrap(); + + let mut assignment = TermMap::default(); + for ((term, ty), (_, _, var_name)) in &term_vars { + if solution.get(var_name).unwrap() == &1.0 { + assignment.insert(term.clone(), *ty); + } + } + // println!("{:?}", solution); + assignment +} diff --git a/src/target/aby/assignment/mod.rs b/src/target/aby/assignment/mod.rs index 109f146ec..3e276f504 100644 --- a/src/target/aby/assignment/mod.rs +++ b/src/target/aby/assignment/mod.rs @@ -4,8 +4,10 @@ use fxhash::FxHashMap; use serde_json::Value; use std::{env::var, fs::File, path::Path}; +pub mod def_uses; #[cfg(feature = "lp")] pub mod ilp; +pub mod ilp_opa; /// The sharing scheme used for an operation #[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] @@ -16,6 +18,8 @@ pub enum ShareType { Boolean, /// Yao sharing (one party holds `k_a`, `k_b`, other knows the `{k_a, k_b} <-> {0, 1}` mapping) Yao, + /// None type, reserved for terms without circuit representation + None, } /// List of share types. @@ -28,6 +32,7 @@ impl ShareType { ShareType::Arithmetic => 'a', ShareType::Boolean => 'b', ShareType::Yao => 'y', + ShareType::None => 'n', } } } @@ -40,15 +45,34 @@ pub type SharingMap = TermMap; pub struct CostModel { #[allow(dead_code)] /// Conversion costs: maps (from, to) pairs to cost - conversions: FxHashMap<(ShareType, ShareType), f64>, + pub conversions: FxHashMap<(ShareType, ShareType), f64>, /// Operator costs: maps (op, type) to cost - ops: FxHashMap>, + pub ops: FxHashMap>, + + /// Zero costs + pub zero: FxHashMap, } impl CostModel { + /// Cost model constructor + pub fn new( + conversions: FxHashMap<(ShareType, ShareType), f64>, + ops: FxHashMap>, + ) -> CostModel { + let mut zero: FxHashMap = FxHashMap::default(); + zero.insert(ShareType::Arithmetic, 0.0); + zero.insert(ShareType::Boolean, 0.0); + zero.insert(ShareType::Yao, 0.0); + CostModel { + conversions, + ops, + zero, + } + } + /// Create a cost model from an OPA json file, like [this](https://github.com/ishaq/OPA/blob/d613c15ff715fa62c03e37b673548f94c16bfe0d/solver/sample-costs.json) - pub fn from_opa_cost_file(p: &impl AsRef) -> CostModel { + pub fn from_opa_cost_file(p: &impl AsRef, k: FxHashMap) -> CostModel { use ShareType::*; let get_cost_opt = |share_name: &str, obj: &serde_json::map::Map| -> Option { @@ -72,6 +96,18 @@ impl CostModel { ) .unwrap() }; + let get_depth = + |share_type: &str, obj: &serde_json::map::Map| -> Option { + let o = obj + .get("depth") + .unwrap_or_else(|| panic!("Missing {} in {:#?}", "depth", obj)); + Some( + o.get(share_type) + .unwrap_or_else(|| panic!("Missing {} entry in {:#?}", share_type, o)) + .as_f64() + .expect("not a number"), + ) + }; let mut conversions = FxHashMap::default(); let mut ops = FxHashMap::default(); let f = File::open(p).expect("Missing file"); @@ -85,55 +121,32 @@ impl CostModel { conversions.insert((Yao, Arithmetic), get_cost("y2a", costs)); conversions.insert((Arithmetic, Yao), get_cost("a2y", costs)); - let ops_from_name = |name: &str| { - match name { - // assume comparisions are unsigned - "ge" => vec![BV_UGE], - "le" => vec![BV_ULE], - "gt" => vec![BV_UGT], - "lt" => vec![BV_ULT], - // assume n-ary ops apply to BVs - "add" => vec![BV_ADD], - "mul" => vec![BV_MUL], - "and" => vec![BV_AND], - "or" => vec![BV_OR], - "xor" => vec![BV_XOR], - // assume eq applies to BVs - "eq" => vec![Op::Eq], - "shl" => vec![BV_SHL], - // assume shr is logical, not arithmetic - "shr" => vec![BV_LSHR], - "sub" => vec![BV_SUB], - "mux" => vec![ITE], - "ne" => vec![Op::Not, Op::Eq], - "div" => vec![BV_UDIV], - "rem" => vec![BV_UREM], - // added to pass test case - "&&" => vec![AND], - "||" => vec![OR], - _ => panic!("Unknown operator name: {}", name), - } - }; for (op_name, cost) in costs { // HACK: assumes the presence of 2 partitions names into conversion and otherwise. - if !op_name.contains('2') { - for op in ops_from_name(op_name) { - for (share_type, share_name) in &[(Arithmetic, "a"), (Boolean, "b"), (Yao, "y")] - { - if let Some(c) = get_cost_opt(share_name, cost.as_object().unwrap()) { - ops.entry(op.clone()) - .or_insert_with(FxHashMap::default) - .insert(*share_type, c); + if !op_name.contains('2') && !op_name.contains("depth") { + for (share_type, share_name) in &[(Arithmetic, "a"), (Boolean, "b"), (Yao, "y")] { + if let Some(c) = get_cost_opt(share_name, cost.as_object().unwrap()) { + let mut cost_depth: f64 = 0.0; + if *share_type != Yao { + if let Some(d) = get_depth(share_name, cost.as_object().unwrap()) { + cost_depth += k.get(share_name.clone()).unwrap_or_else(|| &1.0) + * d + * get_depth(share_name, costs).unwrap(); + } } + ops.entry(op_name.clone()) + .or_insert_with(FxHashMap::default) + .insert(*share_type, c + cost_depth); + // println!("Insert cost model:{}, {}, {}", op_name, share_name, c + cost_depth); } } } } - CostModel { conversions, ops } + CostModel::new(conversions, ops) } } -fn get_cost_model(cm: &str) -> CostModel { +pub fn get_cost_model(cm: &str) -> CostModel { let base_dir = match cm { "opa" => "opa", "hycc" => "hycc", @@ -144,7 +157,7 @@ fn get_cost_model(cm: &str) -> CostModel { var("CARGO_MANIFEST_DIR").expect("Could not find env var CARGO_MANIFEST_DIR"), base_dir ); - CostModel::from_opa_cost_file(&p) + CostModel::from_opa_cost_file(&p, FxHashMap::default()) } /// Assigns boolean sharing to all terms @@ -174,7 +187,7 @@ pub fn assign_arithmetic_and_boolean(c: &Computation, cm: &str) -> SharingMap { PostOrderIter::new(output.clone()).map(|term| { ( term.clone(), - if let Some(costs) = cost_model.ops.get(term.op()) { + if let Some(costs) = cost_model.ops.get(&term.op().to_string()) { let mut min_ty: ShareType = ShareType::Boolean; let mut min_cost: f64 = costs[&min_ty]; for ty in &[ShareType::Arithmetic] { @@ -204,7 +217,7 @@ pub fn assign_arithmetic_and_yao(c: &Computation, cm: &str) -> SharingMap { PostOrderIter::new(output.clone()).map(|term| { ( term.clone(), - if let Some(costs) = cost_model.ops.get(term.op()) { + if let Some(costs) = cost_model.ops.get(&term.op().to_string()) { let mut min_ty: ShareType = ShareType::Yao; let mut min_cost: f64 = costs[&min_ty]; for ty in &[ShareType::Arithmetic] { @@ -234,7 +247,7 @@ pub fn assign_greedy(c: &Computation, cm: &str) -> SharingMap { PostOrderIter::new(output.clone()).map(|term| { ( term.clone(), - if let Some(costs) = cost_model.ops.get(term.op()) { + if let Some(costs) = cost_model.ops.get(&term.op().to_string()) { let mut min_ty: ShareType = ShareType::Yao; let mut min_cost: f64 = costs[&min_ty]; for ty in &[ShareType::Arithmetic, ShareType::Boolean] { diff --git a/src/target/aby/call_site_similarity.rs b/src/target/aby/call_site_similarity.rs new file mode 100644 index 000000000..69b968d13 --- /dev/null +++ b/src/target/aby/call_site_similarity.rs @@ -0,0 +1,338 @@ +//! Call Site Similarity + +use crate::ir::term::*; +use crate::target::aby::assignment::def_uses::*; + +use std::collections::HashMap; +use std::collections::HashSet; + +#[derive(Clone)] +/// A structure that stores the context and all the call terms in one call site +struct CallSite { + // Context's fname + pub calls: Vec, +} + +impl CallSite { + pub fn new( + t: &Term, + ) -> Self { + Self { + calls: vec![t.clone()], + } + } +} + +pub struct CallSiteSimilarity { + comps: Computations, + dugs: HashMap, + visited: HashSet, + call_sites: HashMap<(String, Vec, Vec), CallSite>, + callee_caller: HashSet<(String, String)>, + func_to_cs: HashMap>, + dup_per_func: HashMap, + call_cnt: HashMap, + ml: usize, +} + +impl CallSiteSimilarity { + pub fn new(comps: &Computations, ml: &usize) -> Self { + let mut css = Self { + comps: comps.clone(), + dugs: HashMap::new(), + visited: HashSet::new(), + call_sites: HashMap::new(), + callee_caller: HashSet::new(), + func_to_cs: HashMap::new(), + dup_per_func: HashMap::new(), + call_cnt: HashMap::new(), + ml: ml.clone(), + }; + css + } + + fn traverse(&mut self, fname: &String) { + *self.call_cnt.entry(fname.clone()).or_insert(0) += 1; + let c = self.comps.get(fname).clone(); + for t in c.terms_postorder() { + if let Op::Call(callee, ..) = &t.op() { + self.traverse(&callee); + } + } + if !self.visited.contains(fname) { + println!("Building dug for {}", fname); + let mut dug = DefUsesGraph::for_call_site(&c, &self.dugs, fname); + dug.gen_in_out(&c); + let cs: Vec<(Term, Vec>, Vec>)> = dug.get_call_site(); + for (t, args_t, rets_t) in cs.iter() { + if let Op::Call(callee, _, _) = t.op() { + // convert term to op id + let key: (String, Vec, Vec) = + (callee.clone(), to_key(args_t), to_key(rets_t)); + if self.call_sites.contains_key(&key) { + self.call_sites.get_mut(&key).unwrap().calls.push(t.clone()); + } else { + // Use the first context + if let Op::Call(_, _, _) = &t.op() { + let cs = CallSite::new(t); + self.call_sites.insert(key, cs); + } + } + // recording callee-caller + self.callee_caller.insert((callee.clone(), fname.clone())); + } + } + self.dugs.insert(fname.clone(), dug); + self.dup_per_func.insert(fname.clone(), 0); + self.func_to_cs.insert(fname.clone(), HashMap::new()); + self.visited.insert(fname.clone()); + } + } + + pub fn call_site_similarity_smart(&mut self) -> (Computations, HashMap) { + let main = "main".to_string(); + self.traverse(&main); + + // Functions that have more than one call site + let mut duplicated_f: HashSet = HashSet::new(); + // Functions that need to be rewrote for calling to duplicated f + // If a callee is duplicated, the caller need to be rewrote + let mut rewriting_f: HashSet = HashSet::new(); + let mut call_map: TermMap = TermMap::default(); + + // Generating duplicate set + for (key, cs) in self.call_sites.iter() { + let call_id: usize = self.dup_per_func.get(&key.0).unwrap().clone(); + + if call_id > 0 { + // indicate this function need to be rewrote + duplicated_f.insert(key.0.clone()); + } + + for t in cs.calls.iter() { + call_map.insert(t.clone(), call_id); + } + self.dup_per_func.insert(key.0.clone(), call_id + 1); + let id_to_cs = self.func_to_cs.get_mut(&key.0).unwrap(); + id_to_cs.insert(call_id, cs.clone()); + } + + // Generating rewriting set + for (callee, caller) in self.callee_caller.iter() { + if duplicated_f.contains(callee) { + rewriting_f.insert(caller.clone()); + } + } + + remap( + &self.comps, + &rewriting_f, + &duplicated_f, + &call_map, + &self.call_cnt, + &self.func_to_cs, + self.ml, + ) + } +} + +/// Rewriting the call term to new call +fn rewrite_call(c: &mut Computation, call_map: &TermMap, duplicate_set: &HashSet) { + let mut cache = TermMap::::default(); + let mut children_added = TermSet::default(); + let mut stack = Vec::new(); + stack.extend(c.outputs.iter().cloned()); + while let Some(top) = stack.pop() { + if !cache.contains_key(&top) { + // was it missing? + if children_added.insert(top.clone()) { + stack.push(top.clone()); + stack.extend(top.cs().iter().filter(|c| !cache.contains_key(c)).cloned()); + } else { + let get_children = || -> Vec { + top.cs() + .iter() + .map(|c| cache.get(c).unwrap()) + .cloned() + .collect() + }; + let new_t_op: Op = match &top.op() { + Op::Call(name, arg_sorts, ret_sorts) => { + let mut new_t = top.op().clone(); + if duplicate_set.contains(name) { + if let Some(cid) = call_map.get(&top) { + let new_n = format_dup_call(name, cid); + // let mut new_arg_names: Vec = Vec::new(); + // todo!(); + // for an in arg_names.iter() { + // new_arg_names.push(an.replace(name, &new_n)); + // } + // TODO: maybe wrong + new_t = Op::Call( + new_n, + arg_sorts.clone(), + ret_sorts.clone(), + ); + } + } + new_t + } + _ => top.op().clone(), + }; + let new_t = term(new_t_op, get_children()); + cache.insert(top.clone(), new_t); + } + } + } + c.outputs = c + .outputs + .iter() + .map(|o| cache.get(o).unwrap().clone()) + .collect(); +} + +/// Rewriting the var term to new name +fn rewrite_var(c: &mut Computation, fname: &String, cid: &usize) { + let mut cache = TermMap::::default(); + let mut children_added = TermSet::default(); + let mut stack = Vec::new(); + stack.extend(c.outputs.iter().cloned()); + while let Some(top) = stack.pop() { + if !cache.contains_key(&top) { + // was it missing? + if children_added.insert(top.clone()) { + stack.push(top.clone()); + stack.extend(top.cs().iter().filter(|c| !cache.contains_key(c)).cloned()); + } else { + let get_children = || -> Vec { + top.cs() + .iter() + .map(|c| cache.get(c).unwrap()) + .cloned() + .collect() + }; + let new_t_op: Op = match &top.op() { + Op::Var(name, sort) => { + let new_call_n = format_dup_call(fname, cid); + let new_var_n = name.replace(fname, &new_call_n); + Op::Var(new_var_n.clone(), sort.clone()) + } + _ => top.op().clone(), + }; + let new_t = term(new_t_op, get_children()); + cache.insert(top.clone(), new_t); + } + } + } + c.outputs = c + .outputs + .iter() + .map(|o| cache.get(o).unwrap().clone()) + .collect(); +} + +fn traverse(comps: &Computations, fname: &String, dugs: &mut HashMap) { + if !dugs.contains_key(fname) { + let c = comps.get(fname).clone(); + for t in c.terms_postorder() { + if let Op::Call(callee, ..) = &t.op() { + traverse(comps, &callee, dugs); + } + } + let mut dug = DefUsesGraph::for_call_site(&c, dugs, fname); + dug.gen_in_out(&c); + dugs.insert(fname.clone(), dug); + } +} + +fn remap( + comps: &Computations, + rewriting_set: &HashSet, + duplicate_set: &HashSet, + call_map: &TermMap, + call_cnt: &HashMap, + func_to_cs: &HashMap>, + ml: usize, +) -> (Computations, HashMap) { + let mut n_comps = Computations::new(); + let mut n_dugs: HashMap = HashMap::new(); + let mut context_map: HashMap = HashMap::new(); + let mut css_call_cnt: HashMap = HashMap::new(); + for (fname, comp) in comps.comps.iter() { + let mut ncomp: Computation = comp.clone(); + let id_to_cs = func_to_cs.get(fname).unwrap(); + + if rewriting_set.contains(fname) { + rewrite_call(&mut ncomp, call_map, duplicate_set); + } + + if duplicate_set.contains(fname) { + for (cid, cs) in id_to_cs.iter() { + let new_n: String = format_dup_call(fname, cid); + let mut dup_comp: Computation = Computation { + outputs: ncomp.outputs().clone(), + metadata: ncomp.metadata.clone(), + precomputes: ncomp.precomputes.clone(), + persistent_arrays: Default::default(), + }; + rewrite_var(&mut dup_comp, fname, cid); + n_comps.comps.insert(new_n.clone(), dup_comp); + context_map.insert(new_n.clone(), cs.clone()); + css_call_cnt.insert(new_n, call_cnt.get(fname).unwrap().clone()); + } + } else { + if let Some(cs) = id_to_cs.get(&0) { + context_map.insert(fname.clone(), cs.clone()); + css_call_cnt.insert(fname.clone(), call_cnt.get(fname).unwrap().clone()); + } + n_comps.comps.insert(fname.clone(), ncomp); + } + } + let main = "main".to_string(); + traverse(&n_comps, &main, &mut n_dugs); + + for (fname, cs) in context_map.iter() { + let mut dug = n_dugs.get_mut(fname).unwrap(); + let comp = n_comps.get(fname); + dug.set_num_calls(css_call_cnt.get(fname).unwrap()); + // TODO: enable this + // dug.insert_context(&cs.args, &cs.rets, &cs.caller_dug, comp, ml); + } + + (n_comps, n_dugs) +} + +fn format_dup_call(fname: &String, cid: &usize) -> String { + format!("{}_circ_v_{}", fname, cid).clone() +} + +fn to_key(vterms: &Vec>) -> Vec { + let mut key: Vec = Vec::new(); + for terms in vterms{ + let mut v: Vec = Vec::new(); + for t in terms{ + v.push(get_op_id(t.op())); + } + v.sort(); + key.extend(v); + } + key +} + +fn get_op_id(op: &Op) -> usize { + match op { + Op::Var(..) => 1, + Op::Const(_) => 2, + Op::Eq => 3, + Op::Ite => 4, + Op::Not => 5, + Op::BoolNaryOp(o) => 6, + Op::BvBinPred(o) => 7, + Op::BvNaryOp(o) => 8, + Op::BvBinOp(o) => 9, + Op::Select => 10, + Op::Store => 11, + Op::Call(..) => 12, + _ => todo!("What op?"), + } +} \ No newline at end of file diff --git a/src/target/aby/graph/mod.rs b/src/target/aby/graph/mod.rs new file mode 100644 index 000000000..46ac4f4a7 --- /dev/null +++ b/src/target/aby/graph/mod.rs @@ -0,0 +1,5 @@ +//! Graph partitioning backend + +pub mod tp; +pub mod trans; +pub mod utils; diff --git a/src/target/aby/graph/tp.rs b/src/target/aby/graph/tp.rs new file mode 100644 index 000000000..685a7fffb --- /dev/null +++ b/src/target/aby/graph/tp.rs @@ -0,0 +1,168 @@ +//! Multi-level Partitioning Implementation +//! +//! + +use crate::ir::opt::link::link_one; +use crate::ir::term::*; + +use crate::target::aby::assignment::def_uses::*; +use crate::target::aby::graph::utils::graph_utils::*; +use crate::target::aby::graph::utils::part::*; + +use std::collections::HashMap; + +pub struct TrivialPartition { + partitioner: Partitioner, + gwriter: GraphWriter, + comps: Computations, + comp_history: HashMap, +} + +impl TrivialPartition { + pub fn new( + comps: &Computations, + time_limit: usize, + imbalance: usize, + hyper_mode: bool, + ) -> Self { + let mut tp = Self { + partitioner: Partitioner::new(time_limit, imbalance, hyper_mode), + gwriter: GraphWriter::new(hyper_mode), + comps: comps.clone(), + comp_history: HashMap::new(), + }; + // for fname in fs.computations.keys() { + // tp.traverse(fname); + // } + tp + } + + /// traverse the comp and combine + fn traverse(&mut self, fname: &String) { + if !self.comp_history.contains_key(fname) { + let mut c = self.comps.get(fname).clone(); + let mut cnt = 0; + for t in c.terms_postorder() { + if let Op::Call(callee, ..) = &t.op() { + self.traverse(&callee); + } + } + self.merge(&mut c); + self.comp_history.insert(fname.into(), c.clone()); + } + } + + fn merge(&mut self, computation: &mut Computation) { + let mut cache = TermMap::::default(); + let mut children_added = TermSet::default(); + let mut stack = Vec::new(); + stack.extend(computation.outputs.iter().cloned()); + while let Some(top) = stack.pop() { + if !cache.contains_key(&top) { + // was it missing? + if children_added.insert(top.clone()) { + stack.push(top.clone()); + stack.extend(top.cs().iter().filter(|c| !cache.contains_key(c)).cloned()); + } else { + let get_children = || -> Vec { + top.cs() + .iter() + .map(|c| cache.get(c).unwrap()) + .cloned() + .collect() + }; + let new_t_opt = self.visit(computation, &top, get_children); + let new_t = new_t_opt.unwrap_or_else(|| term(top.op().clone(), get_children())); + cache.insert(top.clone(), new_t); + } + } + } + computation.outputs = computation + .outputs + .iter() + .map(|o| cache.get(o).unwrap().clone()) + .collect(); + } + + fn visit Vec>( + &mut self, + _computation: &mut Computation, + orig: &Term, + rewritten_children: F, + ) -> Option { + if let Op::Call(fn_name, _, _) = &orig.op() { + // println!("Rewritten children: {:?}", rewritten_children()); + let callee = self + .comp_history + .get(fn_name) + .expect("missing inlined callee"); + let term = link_one(callee, rewritten_children()); + Some(term) + } else { + None + } + } + + pub fn inline_all( + &mut self, + fname: &String, + fnames: Vec<&String>, + ) -> (Computation, DefUsesGraph) { + for fname in fnames { + self.traverse(fname); + } + let c = self.comp_history.get(fname).unwrap().clone(); + let dug = DefUsesGraph::new(&c); + (c, dug) + } + + pub fn run( + &mut self, + fname: &String, + path: &String, + ps: usize, + ) -> (Computation, DefUsesGraph, TermMap, usize) { + let mut part_map = TermMap::default(); + self.traverse(fname); + let c = self.comp_history.get(fname).unwrap(); + let dug = DefUsesGraph::new(&c); + let num_parts = dug.good_terms.len() / ps + 1; + println!("LOG: Number of Partitions: {}", num_parts); + if num_parts > 1 { + let t_map = self.gwriter.build_from_dug(&dug); + self.gwriter.write(path); + let partition = self.partitioner.do_partition(path, &num_parts); + for (t, tid) in t_map.iter() { + part_map.insert(t.clone(), *partition.get(tid).unwrap()); + } + } + ( + self.comp_history.get(fname).unwrap().clone(), + dug, + part_map, + num_parts, + ) + } + + pub fn run_from_dug( + &mut self, + fname: &String, + dug: &DefUsesGraph, + path: &String, + ps: usize, + ) -> (TermMap, usize) { + let mut part_map = TermMap::default(); + let c = self.comps.get(fname); + let num_parts = dug.good_terms.len() / ps + 1; + println!("LOG: Number of Partitions: {}", num_parts); + if num_parts > 1 { + let t_map = self.gwriter.build_from_dug(&dug); + self.gwriter.write(path); + let partition = self.partitioner.do_partition(path, &num_parts); + for (t, tid) in t_map.iter() { + part_map.insert(t.clone(), *partition.get(tid).unwrap()); + } + } + (part_map, num_parts) + } +} diff --git a/src/target/aby/graph/trans.rs b/src/target/aby/graph/trans.rs new file mode 100644 index 000000000..b94801b74 --- /dev/null +++ b/src/target/aby/graph/trans.rs @@ -0,0 +1,521 @@ +use crate::ir::term::*; + +use crate::target::aby::assignment::def_uses::*; +use crate::target::aby::assignment::get_cost_model; +use crate::target::aby::assignment::ilp::*; +use crate::target::aby::assignment::ilp_opa::opa_smart_global_assign; +use crate::target::aby::assignment::ShareType; +use crate::target::aby::assignment::SharingMap; +use crate::target::aby::graph::tp::TrivialPartition; +use crate::target::aby::graph::utils::mutation::*; + +use std::collections::HashMap; +use std::path::Path; +use std::time::Duration; +use std::time::Instant; + +use std::fs; + +// Get file path to write Chaco graph to +fn get_graph_path(path: &Path, lang: &str, hyper_mode: bool) -> String { + let filename = Path::new(&path.iter().last().unwrap().to_os_string()) + .file_stem() + .unwrap() + .to_os_string() + .into_string() + .unwrap(); + let name = format!("{}_{}", filename, lang); + let mut path = format!("scripts/aby_tests/tests/{}.graph", name); + if hyper_mode { + path = format!("scripts/aby_tests/tests/{}_hyper.graph", name); + } + if Path::new(&path).exists() { + fs::remove_file(&path).expect("Failed to remove old graph file"); + } + path +} + +/// inline all function into main +pub fn partition_with_mut( + comps: &Computations, + cm: &str, + path: &Path, + lang: &str, + ps: &usize, + hyper_mode: bool, + ml: &usize, + mss: &usize, + imbalance: &usize, +) -> (Computations, HashMap) { + let mut now = Instant::now(); + let mut tp = TrivialPartition::new(comps, 0, imbalance.clone(), hyper_mode); + let main = "main"; + let graph_path = get_graph_path(path, lang, hyper_mode); + let (c, d, partition, num_parts) = tp.run(&main.to_string(), &graph_path, *ps); + + println!("Time: Partition: {:?}", now.elapsed()); + now = Instant::now(); + + // Construct ComputationSubgraphs + let mut tmp_css: HashMap = HashMap::new(); + let mut css: HashMap = HashMap::new(); + + for part_id in 0..num_parts { + tmp_css.insert(part_id, ComputationSubgraph::new()); + } + + for (t, part_id) in partition.iter() { + if let Some(subgraph) = tmp_css.get_mut(&part_id) { + subgraph.insert_node(t); + } else { + panic!("Subgraph not found for index: {}", num_parts); + } + } + + for (part_id, mut cs) in tmp_css.into_iter() { + cs.insert_edges(); + css.insert(part_id, cs.clone()); + } + println!("Time: To Subgraph: {:?}", now.elapsed()); + + now = Instant::now(); + let assignment = get_share_map_with_mutation(&c, cm, &css, &partition, ml, mss); + println!("Time: ILP : {:?}", now.elapsed()); + let mut s_map: HashMap = HashMap::new(); + s_map.insert(main.to_string(), assignment); + let mut comps = Computations::new(); + comps.comps.insert(main.to_string(), c); + (comps, s_map) +} + +#[cfg(feature = "lp")] +/// inline all function into main +pub fn css_partition_with_mut_smart( + comps: &Computations, + dugs: &HashMap, + cm: &str, + path: &Path, + lang: &str, + ps: &usize, + hyper_mode: bool, + ml: &usize, + mss: &usize, + imbalance: &usize, +) -> HashMap { + let mut s_map: HashMap = HashMap::new(); + + let mut part_duration: Duration = Duration::ZERO; + let mut ilp_duration: Duration = Duration::ZERO; + + for (fname, comp) in comps.comps.iter() { + let mut now = Instant::now(); + println!("Partitioning: {}", fname); + let mut tp = TrivialPartition::new(comps, 0, imbalance.clone(), hyper_mode); + let graph_path = get_graph_path(path, lang, hyper_mode); + let d = dugs.get(fname).unwrap(); + let (partition, num_parts) = tp.run_from_dug(fname, d, &graph_path, *ps); + + part_duration += now.elapsed(); + + let mut assignment: SharingMap; + if num_parts == 1 { + // No need to partition + now = Instant::now(); + assignment = smart_global_assign(&d.good_terms, &d.def_use, &d.get_k(), cm); + ilp_duration += now.elapsed(); + } else { + // Construct DefUsesSubGraph + now = Instant::now(); + let mut tmp_dusg: HashMap = HashMap::new(); + let mut dusg: HashMap = HashMap::new(); + + for part_id in 0..num_parts { + tmp_dusg.insert(part_id, DefUsesSubGraph::new()); + } + + for t in d.good_terms.iter() { + let part_id = partition.get(t).unwrap(); + if let Some(du) = tmp_dusg.get_mut(&part_id) { + du.insert_node(t); + } else { + panic!("Subgraph not found for index: {}", num_parts); + } + } + + for (part_id, mut du) in tmp_dusg.into_iter() { + du.insert_edges(&d); + dusg.insert(part_id, du.clone()); + } + assignment = get_share_map_with_mutation_smart(&d, cm, &dusg, &partition, ml, mss); + + ilp_duration += now.elapsed(); + } + s_map.insert(fname.clone(), assignment); + } + + println!("LOG: Partition time: {:?}", part_duration); + println!("LOG: ILP time: {:?}", ilp_duration); + + s_map +} + +#[cfg(feature = "lp")] +/// inline all function into main +pub fn partition_with_mut_smart( + comps: &Computations, + cm: &str, + path: &Path, + lang: &str, + ps: &usize, + hyper_mode: bool, + ml: &usize, + mss: &usize, + imbalance: &usize, +) -> (Computations, HashMap) { + let mut now = Instant::now(); + let mut tp = TrivialPartition::new(comps, 0, imbalance.clone(), hyper_mode); + let main = "main"; + let graph_path = get_graph_path(path, lang, hyper_mode); + let (c, d, partition, num_parts) = tp.run(&main.to_string(), &graph_path, *ps); + + println!("LOG: Partition time: {:?}", now.elapsed()); + + let assignment: SharingMap; + if num_parts == 1 { + // No need to partition + now = Instant::now(); + assignment = smart_global_assign(&d.good_terms, &d.def_use, &d.get_k(), cm); + println!("LOG: ILP time: {:?}", now.elapsed()); + } else { + // Construct DefUsesSubGraph + now = Instant::now(); + let mut tmp_dusg: HashMap = HashMap::new(); + let mut dusg: HashMap = HashMap::new(); + + for part_id in 0..num_parts { + tmp_dusg.insert(part_id, DefUsesSubGraph::new()); + } + + for t in d.good_terms.iter() { + // println!("op: {}", t.op); + let part_id = partition.get(t).unwrap(); + if let Some(du) = tmp_dusg.get_mut(&part_id) { + du.insert_node(t); + } else { + panic!("Subgraph not found for index: {}", num_parts); + } + } + + println!("Finish inserting terms"); + + for (part_id, mut du) in tmp_dusg.into_iter() { + du.insert_edges(&d); + dusg.insert(part_id, du.clone()); + } + + println!("Finish inserting edges"); + + assignment = get_share_map_with_mutation_smart(&d, cm, &dusg, &partition, ml, mss); + println!("LOG: ILP time: {:?}", now.elapsed()); + } + + println!( + "Calculate cost: {}", + calculate_cost_smart_dug(&assignment, cm, &d) + ); + + let mut s_map: HashMap = HashMap::new(); + s_map.insert(main.to_string(), assignment); + let mut comps = Computations::new(); + comps.comps.insert(main.to_string(), c); + (comps, s_map) +} + +#[cfg(feature = "lp")] +/// inline all function into main +pub fn inline_all_and_assign_glp( + comps: &Computations, + cm: &str, +) -> (Computations, HashMap) { + let mut tp = TrivialPartition::new(comps, 0, 0, false); + let main = "main"; + let fnames = comps.comps.keys().collect::>(); + let (c, dug) = tp.inline_all(&main.to_string(), fnames); + + // println!("Terms after inline all."); + // for t in c.terms_postorder() { + // println!("t: {}", t); + // } + + let cs = c.to_cs(); + let assignment = assign(&cs, cm); + let mut s_map: HashMap = HashMap::new(); + s_map.insert(main.to_string(), assignment); + let mut comps = Computations::new(); + comps.comps.insert(main.to_string(), c); + (comps, s_map) +} + +#[cfg(feature = "lp")] +/// inline all function into main +pub fn inline_all_and_assign_smart_glp( + comps: &Computations, + cm: &str, +) -> (Computations, HashMap) { + let mut now = Instant::now(); + let mut tp = TrivialPartition::new(comps, 0, 0, false); + let main = "main"; + let fnames = comps.comps.keys().collect::>(); + let (c, dug) = tp.inline_all(&main.to_string(), fnames); + + let k_map = dug.get_k(); + + println!( + "Time: Inline and construction def uses: {:?}", + now.elapsed() + ); + + now = Instant::now(); + let assignment = smart_global_assign(&dug.good_terms, &dug.def_use, &k_map, cm); + println!( + "Calculate cost: {}", + calculate_cost_smart_dug(&assignment, cm, &dug) + ); + + println!("LOG: ILP time: {:?}", now.elapsed()); + + let mut s_map: HashMap = HashMap::new(); + s_map.insert(main.to_string(), assignment); + let mut comps = Computations::new(); + comps.comps.insert(main.to_string(), c); + (comps, s_map) +} + +#[cfg(feature = "lp")] +/// inline all function into main +pub fn inline_all_and_assign_y( + comps: &Computations, + cm: &str, +) -> (Computations, HashMap) { + let mut now = Instant::now(); + let mut tp = TrivialPartition::new(comps, 0, 0, false); + let main = "main"; + let fnames = comps.comps.keys().collect::>(); + let (c, dug) = tp.inline_all(&main.to_string(), fnames); + + println!( + "Time: Inline and construction def uses: {:?}", + now.elapsed() + ); + + now = Instant::now(); + let assignment: SharingMap = dug + .good_terms + .iter() + .map(|term| (term.clone(), ShareType::Yao)) + .collect(); + println!( + "Calculate cost: {}", + calculate_cost_smart_dug(&assignment, cm, &dug) + ); + + println!("LOG: ILP time: {:?}", now.elapsed()); + + let mut s_map: HashMap = HashMap::new(); + s_map.insert(main.to_string(), assignment); + let mut comps = Computations::new(); + comps.comps.insert(main.to_string(), c); + (comps, s_map) +} + +#[cfg(feature = "lp")] +/// inline all function into main +pub fn inline_all_and_assign_b( + comps: &Computations, + cm: &str, +) -> (Computations, HashMap) { + let mut now = Instant::now(); + let mut tp = TrivialPartition::new(comps, 0, 0, false); + let main = "main"; + let fnames = comps.comps.keys().collect::>(); + let (c, dug) = tp.inline_all(&main.to_string(), fnames); + + println!( + "Time: Inline and construction def uses: {:?}", + now.elapsed() + ); + + now = Instant::now(); + let assignment: SharingMap = dug + .good_terms + .iter() + .map(|term| (term.clone(), ShareType::Boolean)) + .collect(); + println!( + "Calculate cost: {}", + calculate_cost_smart_dug(&assignment, cm, &dug) + ); + + println!("LOG: ILP time: {:?}", now.elapsed()); + + let mut s_map: HashMap = HashMap::new(); + s_map.insert(main.to_string(), assignment); + let mut comps = Computations::new(); + comps.comps.insert(main.to_string(), c); + (comps, s_map) +} + +#[cfg(feature = "lp")] +/// inline all function into main +pub fn inline_all_and_assign_a_y( + comps: &Computations, + cm: &str, +) -> (Computations, HashMap) { + let mut now = Instant::now(); + let mut tp = TrivialPartition::new(comps, 0, 0, false); + let main = "main"; + let fnames = comps.comps.keys().collect::>(); + let (c, dug) = tp.inline_all(&main.to_string(), fnames); + + println!( + "Time: Inline and construction def uses: {:?}", + now.elapsed() + ); + + now = Instant::now(); + let cost_model = get_cost_model(cm); + let assignment: SharingMap = dug + .good_terms + .iter() + .map(|term| { + ( + term.clone(), + if let Some(costs) = cost_model.ops.get(&term.op().to_string()) { + match &term.op() { + Op::Select | Op::Store => ShareType::Yao, + _ => { + let mut min_ty: ShareType = ShareType::Yao; + let mut min_cost: f64 = costs[&min_ty]; + for ty in &[ShareType::Arithmetic] { + if let Some(c) = costs.get(ty) { + if *c < min_cost { + min_ty = *ty; + min_cost = *c; + } + } + } + min_ty + } + } + } else { + ShareType::Yao + }, + ) + }) + .collect(); + println!( + "Calculate cost: {}", + calculate_cost_smart_dug(&assignment, cm, &dug) + ); + + println!("LOG: ILP time: {:?}", now.elapsed()); + + let mut s_map: HashMap = HashMap::new(); + s_map.insert(main.to_string(), assignment); + let mut comps = Computations::new(); + comps.comps.insert(main.to_string(), c); + (comps, s_map) +} + +#[cfg(feature = "lp")] +/// inline all function into main +pub fn inline_all_and_assign_a_b( + comps: &Computations, + cm: &str, +) -> (Computations, HashMap) { + let mut now = Instant::now(); + let mut tp = TrivialPartition::new(comps, 0, 0, false); + let main = "main"; + let fnames = comps.comps.keys().collect::>(); + let (c, dug) = tp.inline_all(&main.to_string(), fnames); + + println!( + "Time: Inline and construction def uses: {:?}", + now.elapsed() + ); + + now = Instant::now(); + let cost_model = get_cost_model(cm); + let assignment: SharingMap = dug + .good_terms + .iter() + .map(|term| { + ( + term.clone(), + if let Some(costs) = cost_model.ops.get(&term.op().to_string()) { + let mut min_ty: ShareType = ShareType::Boolean; + let mut min_cost: f64 = costs[&min_ty]; + for ty in &[ShareType::Arithmetic] { + if let Some(c) = costs.get(ty) { + if *c < min_cost { + min_ty = *ty; + min_cost = *c; + } + } + } + min_ty + } else { + ShareType::Boolean + }, + ) + }) + .collect(); + println!( + "Calculate cost: {}", + calculate_cost_smart_dug(&assignment, cm, &dug) + ); + + println!("LOG: ILP time: {:?}", now.elapsed()); + + let mut s_map: HashMap = HashMap::new(); + s_map.insert(main.to_string(), assignment); + let mut new_comps = Computations::new(); + new_comps.comps.insert(main.to_string(), c); + (new_comps, s_map) +} + +#[cfg(feature = "lp")] +/// inline all function into main +pub fn inline_all_and_assign_opa( + comps: &Computations, + cm: &str, + share_types: &[ShareType; 2], +) -> (Computations, HashMap) { + let mut now = Instant::now(); + let mut tp = TrivialPartition::new(comps, 0, 0, false); + let main = "main"; + let fnames = comps.comps.keys().collect::>(); + let (c, dug) = tp.inline_all(&main.to_string(), fnames); + let k_map = dug.get_k(); + + println!( + "Time: Inline and construction def uses: {:?}", + now.elapsed() + ); + + now = Instant::now(); + let assignment = + opa_smart_global_assign(&dug.good_terms, &dug.def_use, &k_map, cm, share_types); + println!( + "Calculate cost: {}", + calculate_cost_smart_dug(&assignment, cm, &dug) + ); + + println!("LOG: ILP time: {:?}", now.elapsed()); + + let mut s_map: HashMap = HashMap::new(); + s_map.insert(main.to_string(), assignment); + let mut comps = Computations::new(); + comps.comps.insert(main.to_string(), c); + (comps, s_map) +} diff --git a/src/target/aby/graph/utils/graph_utils.rs b/src/target/aby/graph/utils/graph_utils.rs new file mode 100644 index 000000000..672f4a07b --- /dev/null +++ b/src/target/aby/graph/utils/graph_utils.rs @@ -0,0 +1,277 @@ +//! Translation from IR to Chaco file input format +//! This input format can be found in [Jostle User Guide](https://chriswalshaw.co.uk/jostle/jostle-exe.pdf) + +use crate::ir::term::*; +use crate::target::aby::assignment::def_uses::*; + +use std::collections::HashMap; +use std::fs; +use std::io::prelude::*; +use std::path::Path; + +#[derive(Clone, PartialEq, Eq, Hash)] +struct HyperEdge { + idx: usize, +} + +#[derive(Clone)] +struct Edges { + vec: Vec, +} + +impl Edges { + fn add(&mut self, item: T) -> bool { + if !self.vec.contains(&item) { + self.vec.push(item); + return true; + } + false + } +} + +fn coarse_map_get(cm: &HashMap>, t: &Term, level: usize) -> usize { + let v = cm.get(t).unwrap(); + *(v.get(level).unwrap_or_else(|| v.last().unwrap())) +} + +/// +pub struct GraphWriter { + num_nodes: usize, + num_edges: usize, + num_hyper_edges: usize, + term_to_id: HashMap, + edges: HashMap>, + hyper_edges: HashMap>, + node_to_hyper_edge: HashMap, + hyper_mode: bool, +} + +impl GraphWriter { + pub fn new(hyper_mode: bool) -> Self { + let gw = Self { + num_nodes: 0, + num_edges: 0, + num_hyper_edges: 0, + term_to_id: HashMap::new(), + edges: HashMap::new(), + hyper_edges: HashMap::new(), + node_to_hyper_edge: HashMap::new(), + hyper_mode: hyper_mode, + }; + gw + } + + pub fn build( + &mut self, + cs: &Computation, + coarsen_map: &HashMap>, + level: usize, + num_nodes: usize, + ) { + self.num_nodes = num_nodes; + for t in cs.terms_postorder() { + match &t.op() { + Op::Ite + | Op::Not + | Op::Eq + | Op::Store + | Op::Select + | Op::Tuple + | Op::Field(_) + | Op::BvBinOp(_) + | Op::BvNaryOp(_) + | Op::BvBinPred(_) + | Op::BoolNaryOp(_) => { + let t_id = coarse_map_get(coarsen_map, &t, level); + for cs in t.cs().iter() { + let cs_id = coarse_map_get(coarsen_map, &cs, level); + if cs_id != t_id { + if self.hyper_mode { + self.insert_hyper_edge(&cs_id, &t_id); + } else { + self.insert_edge(&cs_id, &t_id); + self.insert_edge(&t_id, &cs_id); + } + } + } + } + _ => unimplemented!("Haven't implemented conversion of {:#?}, {:#?}", t, t.op()), + } + } + } + + pub fn build_from_dug(&mut self, dug: &DefUsesGraph) -> HashMap { + for t in dug.good_terms.iter() { + match &t.op() { + Op::Var(_, _) | Op::Const(_) => { + self.get_tid_or_assign(&t); + } + Op::Ite + | Op::Not + | Op::Eq + | Op::Store + | Op::Select + | Op::Tuple + | Op::Field(_) + | Op::BvBinOp(_) + | Op::BvNaryOp(_) + | Op::BvBinPred(_) + | Op::BoolNaryOp(_) => { + let t_id = self.get_tid_or_assign(&t); + for def in dug.use_defs.get(t).unwrap().iter() { + let def_id = self.get_tid_or_assign(&def); + if def_id != t_id { + if self.hyper_mode { + self.insert_hyper_edge(&def_id, &t_id); + } else { + self.insert_edge(&def_id, &t_id); + self.insert_edge(&t_id, &def_id); + } + } + } + } + _ => unimplemented!("Haven't implemented conversion of {:#?}", t.op()), + } + } + self.term_to_id.clone() + } + + fn get_tid_or_assign(&mut self, t: &Term) -> usize { + if self.term_to_id.contains_key(t) { + return *(self.term_to_id.get(t).unwrap()); + } else { + self.num_nodes += 1; + self.term_to_id.insert(t.clone(), self.num_nodes); + return self.num_nodes; + } + } + + pub fn write(&mut self, path: &String) { + if self.hyper_mode { + self.write_hyper_graph(path); + } else { + self.write_graph(path); + } + } + + // Insert edge into PartitionGraph + fn insert_edge(&mut self, from: &usize, to: &usize) { + if !self.edges.contains_key(&from) { + self.edges.insert(from.clone(), Edges { vec: Vec::new() }); + } + let added = self.edges.get_mut(&from).unwrap().add(*to); + if added { + self.num_edges += 1; + } + } + + // Insert hyper edge into PartitionGraph + fn insert_hyper_edge(&mut self, from: &usize, to: &usize) { + // Assume each node will only have one output + // TODO: fix this? + if !self.node_to_hyper_edge.contains_key(from) { + self.num_hyper_edges += 1; + + let new_hyper_edge = HyperEdge { + idx: self.num_hyper_edges, + }; + self.node_to_hyper_edge + .insert(from.clone(), new_hyper_edge.clone()); + + self.hyper_edges + .insert(new_hyper_edge.clone(), Edges { vec: Vec::new() }); + // Add from node itself + self.hyper_edges + .get_mut(&new_hyper_edge) + .unwrap() + .add(*from); + self.hyper_edges.get_mut(&new_hyper_edge).unwrap().add(*to); + } else { + let hyper_edge = self.node_to_hyper_edge.get(&from).unwrap(); + self.hyper_edges.get_mut(&hyper_edge).unwrap().add(*to); + } + } + + // Write Chaco graph to file + fn write_graph(&mut self, path: &String) { + if !Path::new(path).exists() { + fs::File::create(path).expect("Failed to create graph file"); + } + let mut file = fs::OpenOptions::new() + .write(true) + .append(true) + .open(path) + .expect("Failed to open graph file"); + + // write number of nodes and edges + file.write_all(format!("{} {}\n", self.num_nodes, self.num_edges / 2).as_bytes()) + .expect("Failed to write to graph file"); + + // for Nodes 1..N, write their neighbors + for i in 0..(self.num_nodes) { + let id = i + 1; + match self.edges.get(&id) { + Some(edges) => { + let line = edges + .vec + .clone() + .into_iter() + .map(|nid| nid.to_string()) + .collect::>() + .join(" "); + file.write_all(line.as_bytes()) + .expect("Failed to write to graph file"); + } + None => { + let line = ""; + file.write_all(line.as_bytes()) + .expect("Failed to write to graph file"); + } + } + file.write_all("\n".as_bytes()) + .expect("Failed to write to graph file"); + } + } + + // Write Chaco graph to file + fn write_hyper_graph(&mut self, path: &String) { + if !Path::new(path).exists() { + println!("hyper_graph_path: {}", path); + fs::File::create(path).expect("Failed to create hyper graph file"); + } + let mut file = fs::OpenOptions::new() + .write(true) + .append(true) + .open(path) + .expect("Failed to open hyper graph file"); + + // write number of nodes and edges + file.write_all(format!("{} {}\n", self.num_hyper_edges, self.num_nodes).as_bytes()) + .expect("Failed to write to hyper graph file"); + + // for Nodes 1..N, write their neighbors + for i in 0..(self.num_hyper_edges) { + let hyper_edge = HyperEdge { idx: i + 1 }; + match self.hyper_edges.get(&hyper_edge) { + Some(nodes) => { + let line = nodes + .vec + .clone() + .into_iter() + .map(|nid| nid.to_string()) + .collect::>() + .join(" "); + file.write_all(line.as_bytes()) + .expect("Failed to write to graph file"); + } + None => { + let line = ""; + file.write_all(line.as_bytes()) + .expect("Failed to write to graph file"); + } + } + file.write_all("\n".as_bytes()) + .expect("Failed to write to graph file"); + } + } +} diff --git a/src/target/aby/graph/utils/mod.rs b/src/target/aby/graph/utils/mod.rs new file mode 100644 index 000000000..c1f5b25df --- /dev/null +++ b/src/target/aby/graph/utils/mod.rs @@ -0,0 +1,3 @@ +pub mod graph_utils; +pub mod mutation; +pub mod part; diff --git a/src/target/aby/graph/utils/mutation.rs b/src/target/aby/graph/utils/mutation.rs new file mode 100644 index 000000000..04d475caf --- /dev/null +++ b/src/target/aby/graph/utils/mutation.rs @@ -0,0 +1,218 @@ +//! Translation from IR to Chaco file input format +//! This input format can be found in [Jostle User Guide](https://chriswalshaw.co.uk/jostle/jostle-exe.pdf) +//! +//! + +use crate::ir::term::*; + +use crate::target::aby::assignment::ilp::assign_mut; +use crate::target::aby::assignment::ilp::assign_mut_smart; +use crate::target::aby::assignment::ilp::comb_selection; +use crate::target::aby::assignment::ilp::comb_selection_smart; + +use crate::target::aby::assignment::SharingMap; +use std::collections::HashMap; + +use crate::target::aby::assignment::def_uses::*; + +// use std::thread; + +fn get_outer_n(cs: &ComputationSubgraph, n: usize) -> ComputationSubgraph { + let mut last_cs = cs.clone(); + for _ in 0..n { + let mut mut_cs: ComputationSubgraph = ComputationSubgraph::new(); + for node in last_cs.nodes.clone() { + mut_cs.insert_node(&node); + } + for node in last_cs.ins.clone() { + for outer_node in node.cs().iter() { + mut_cs.insert_node(&outer_node) + } + } + mut_cs.insert_edges(); + last_cs = mut_cs; + } + last_cs +} + +/// Mutations with multi threading +fn mutate_partitions_mp_step( + cs: &HashMap, + cm: &str, + outer_level: usize, + step: usize, +) -> HashMap> { + // TODO: merge and stop + let mut mut_smaps: HashMap> = HashMap::new(); + + let mut mut_sets: HashMap<(usize, usize), (ComputationSubgraph, ComputationSubgraph)> = + HashMap::new(); + + for (i, c) in cs.iter() { + mut_smaps.insert(*i, HashMap::new()); + for j in 0..outer_level { + let outer_tmp = get_outer_n(c, j * step); + mut_sets.insert((*i, j), (outer_tmp.clone(), c.clone())); + } + } + + // let mut children = vec![]; + let _cm = cm.to_string(); + + // for ((i, j), (c, c_ref)) in mut_sets.iter() { + // let costm = _cm.clone(); + // let i = i.clone(); + // let j = j.clone(); + // let c = c.clone(); + // let c_ref = c_ref.clone(); + // children.push(thread::spawn(move || (i, j, assign_mut(&c, &costm, &c_ref)))); + // } + + // for child in children { + // let (i, j, smap) = child.join().unwrap(); + // mut_smaps.get_mut(&i).unwrap().insert(j, smap); + // } + + for ((i, j), (c, c_ref)) in mut_sets.iter() { + let costm = _cm.clone(); + mut_smaps + .get_mut(&i) + .unwrap() + .insert(*j, assign_mut(&c, &costm, &c_ref)); + } + + mut_smaps +} + +/// Mutations with multi threading +fn mutate_partitions_mp_step_smart( + dug: &DefUsesGraph, + dusg: &HashMap, + cm: &str, + outer_level: usize, + step: usize, +) -> HashMap> { + // TODO: merge and stop + let mut mut_smaps: HashMap> = HashMap::new(); + + let mut mut_sets: HashMap<(usize, usize), (DefUsesSubGraph, TermSet)> = HashMap::new(); + + for (i, du) in dusg.iter() { + mut_smaps.insert(*i, HashMap::new()); + mut_sets.insert((*i, 0), (du.clone(), du.nodes.clone())); + let mut old_du = du.clone(); + for j in 0..outer_level { + old_du = extend_dusg(&old_du, dug); + println!("Mutation {} for partition {}: {}", i, j, old_du.nodes.len()); + mut_sets.insert((*i, j), (old_du.clone(), du.nodes.clone())); + } + } + + // let mut children = vec![]; + let _cm = cm.to_string(); + let k_map = dug.get_k(); + + // for ((i, j), (du, du_ref)) in mut_sets.iter() { + // let costm = _cm.clone(); + // let i = i.clone(); + // let j = j.clone(); + // let du = du.clone(); + // let du_ref = du_ref.clone(); + // let k_map = k_map.clone(); + // children.push(thread::spawn(move || { + // (i, j, assign_mut_smart(&du, &costm, &du_ref, &k_map)) + // })); + // } + + // for child in children { + // let (i, j, smap) = child.join().unwrap(); + // mut_smaps.get_mut(&i).unwrap().insert(j, smap); + // } + + for ((i, j), (du, du_ref)) in mut_sets.iter() { + let costm = _cm.clone(); + mut_smaps + .get_mut(&i) + .unwrap() + .insert(*j, assign_mut_smart(&du, &costm, &du_ref, &k_map)); + } + + mut_smaps +} + +fn get_global_assignments( + cs: &Computation, + term_to_part: &TermMap, + local_smaps: &HashMap, +) -> SharingMap { + let mut global_smap: SharingMap = SharingMap::default(); + + let Computation { outputs, .. } = cs.clone(); + for term_ in &outputs { + for t in PostOrderIter::new(term_.clone()) { + // get term partition assignment + let part = term_to_part.get(&t).unwrap(); + + // get local assignment + let local_share = local_smaps.get(part).unwrap().get(&t).unwrap(); + + // TODO: mutate local assignments ilp + + // assign to global share + global_smap.insert(t.clone(), *local_share); + } + } + global_smap +} + +fn get_global_assignments_smart( + dug: &DefUsesGraph, + term_to_part: &TermMap, + local_smaps: &HashMap, +) -> SharingMap { + let mut global_smap: SharingMap = SharingMap::default(); + for t in dug.good_terms.iter() { + // get term partition assignment + let part = term_to_part.get(&t).unwrap(); + + // get local assignment + let local_share = local_smaps.get(part).unwrap().get(&t).unwrap(); + + // assign to global share + global_smap.insert(t.clone(), *local_share); + } + global_smap +} + +pub fn get_share_map_with_mutation( + cs: &Computation, + cm: &str, + partitions: &HashMap, + term_to_part: &TermMap, + mut_level: &usize, + mut_step_size: &usize, +) -> SharingMap { + let mutation_smaps = + mutate_partitions_mp_step(partitions, cm, mut_level.clone(), mut_step_size.clone()); + let selected_mut_maps = comb_selection(&mutation_smaps, &partitions, cm); + get_global_assignments(cs, term_to_part, &selected_mut_maps) +} + +pub fn get_share_map_with_mutation_smart( + dug: &DefUsesGraph, + cm: &str, + partitions: &HashMap, + term_to_part: &TermMap, + mut_level: &usize, + mut_step_size: &usize, +) -> SharingMap { + let mutation_smaps = mutate_partitions_mp_step_smart( + dug, + partitions, + cm, + mut_level.clone(), + mut_step_size.clone(), + ); + let selected_mut_maps = comb_selection_smart(dug, &mutation_smaps, &partitions, cm); + get_global_assignments_smart(dug, term_to_part, &selected_mut_maps) +} diff --git a/src/target/aby/graph/utils/part.rs b/src/target/aby/graph/utils/part.rs new file mode 100644 index 000000000..5deb83470 --- /dev/null +++ b/src/target/aby/graph/utils/part.rs @@ -0,0 +1,192 @@ +use std::collections::HashMap; +use std::env; +use std::fs::File; +use std::io::{self, BufRead}; +use std::path::Path; +use std::process::{Command, Stdio}; +// use std::time::Instant; + +pub struct Partitioner { + time_limit: usize, + imbalance: usize, + imbalance_f32: f32, + hyper_mode: bool, + kahip_source: String, + kahypar_source: String, +} + +impl Partitioner { + pub fn new(time_limit: usize, imbalance: usize, hyper_mode: bool) -> Self { + // TODO: Allow only kahip or kahypar + // Get KaHIP source directory + let kahip_source = match env::var("KAHIP_SOURCE") { + Ok(val) => val, + Err(e) => panic!("Missing env variable: KAHIP_SOURCE, {}", e), + }; + // Get kahypar source directory + let kahypar_source = match env::var("KAHYPAR_SOURCE") { + Ok(val) => val, + Err(e) => panic!("Missing env variable: KAHYPAR_SOURCE, {}", e), + }; + let mut graph = Self { + time_limit: time_limit, + imbalance: imbalance, + imbalance_f32: imbalance as f32 / 100.0, + hyper_mode: hyper_mode, + kahip_source, + kahypar_source, + }; + graph + } + + pub fn do_refinement( + &self, + graph_path: &String, + input_part_path: &String, + output_part_path: &String, + num_parts: &usize, + ) -> HashMap { + if self.hyper_mode { + let part_path = format!( + "{}.part{}.epsilon{}.seed-1.KaHyPar", + graph_path, + num_parts, + self.imbalance_f32.to_string() + ); + self.call_hyper_graph_refiner(graph_path, input_part_path, num_parts); + self.parse_partition(&part_path) + } else { + unimplemented!("Refinement using KaHIP not implemented. "); + } + } + + pub fn do_partition(&self, graph_path: &String, num_parts: &usize) -> HashMap { + if self.hyper_mode { + let part_path = format!( + "{}.part{}.epsilon{}.seed-1.KaHyPar", + graph_path, + num_parts, + self.imbalance_f32.to_string() + ); + self.call_hyper_graph_partitioner(graph_path, num_parts); + self.parse_partition(&part_path) + } else { + self.check_graph(graph_path); + let part_path = format!("{}.part", graph_path); + self.call_graph_partitioner(graph_path, &part_path, num_parts); + self.parse_partition(&part_path) + } + } + + // Read a file line by line + fn read_lines

(&self, filename: P) -> io::Result>> + where + P: AsRef, + { + let file = File::open(filename)?; + Ok(io::BufReader::new(file).lines()) + } + + fn parse_partition(&self, part_path: &String) -> HashMap { + let mut part_map = HashMap::new(); + if let Ok(lines) = self.read_lines(part_path) { + for line in lines.into_iter().enumerate() { + if let (i, Ok(part)) = line { + let part_num = part.parse::().unwrap(); + part_map.insert(i + 1, part_num); + } + } + } + part_map + } + + // Call hyper graph partitioning algorithm on input hyper graph + fn call_hyper_graph_partitioner(&self, graph_path: &String, num_parts: &usize) { + let output = Command::new(format!( + "{}/build/kahypar/application/KaHyPar", + self.kahypar_source + )) + .arg("-h") + .arg(graph_path) + .arg("-k") + .arg(num_parts.to_string()) //TODO: make this a function on the number of terms + .arg("-e") + .arg(self.imbalance_f32.to_string()) + .arg("--objective=cut") + .arg("--mode=direct") + .arg(format!( + "--preset={}/config/cut_kKaHyPar_sea20.ini", + self.kahypar_source + )) + .arg("--write-partition=true") + .stdout(Stdio::piped()) + .output() + .unwrap(); + let stdout = String::from_utf8(output.stdout).unwrap(); + // println!("stdout: {}", stdout); + } + + // Call graph partitioning algorithm on input graph + fn call_graph_partitioner(&self, graph_path: &String, part_path: &String, num_parts: &usize) { + let output = Command::new(format!("{}/deploy/kaffpa", self.kahip_source)) + .arg(graph_path) + .arg("--k") + .arg(num_parts.to_string()) //TODO: make this a function on the number of terms + .arg("--preconfiguration=fast") + .arg("--imbalance") + .arg(self.imbalance.to_string()) + .arg("--time_limit") + .arg(self.time_limit.to_string()) + .arg(format!("--output_filename={}", part_path)) + .stdout(Stdio::piped()) + .output() + .unwrap(); + let stdout = String::from_utf8(output.stdout).unwrap(); + assert!(stdout.contains(&format!("writing partition to {}", part_path))); + } + + // Call hyper graph partitioning algorithm on input hyper graph + fn call_hyper_graph_refiner( + &self, + graph_path: &String, + input_path: &String, + num_parts: &usize, + ) { + let input_part_arg = format!("--part-file={}", input_path); + let output = Command::new(format!( + "{}/build/kahypar/application/KaHyPar", + self.kahypar_source + )) + .arg("-h") + .arg(graph_path) + .arg("-k") + .arg(num_parts.to_string()) //TODO: make this a function on the number of terms + .arg("-e") + .arg(self.imbalance_f32.to_string()) + .arg("--objective=cut") + .arg("--mode=direct") + .arg(format!( + "--preset={}/config/cut_kKaHyPar_sea20.ini", + self.kahypar_source + )) + .arg(input_part_arg) + .arg("--vcycles=3") + .arg("--write-partition=true") + .stdout(Stdio::piped()) + .output() + .unwrap(); + let stdout = String::from_utf8(output.stdout).unwrap(); + // println!("stdout: {}", stdout); + } + + // Check if input graph is formatted correctly + fn check_graph(&self, graph_path: &String) { + let output = Command::new(format!("{}/deploy/graphchecker", self.kahip_source)) + .arg(graph_path) + .stdout(Stdio::piped()) + .output() + .unwrap(); + let stdout = String::from_utf8(output.stdout).unwrap(); + assert!(stdout.contains("The graph format seems correct.")); + } +} diff --git a/src/target/aby/mod.rs b/src/target/aby/mod.rs index 2c8197555..c83100309 100644 --- a/src/target/aby/mod.rs +++ b/src/target/aby/mod.rs @@ -1,4 +1,9 @@ //! ABY + +#[cfg(any(feature = "kahip", feature = "kahypar"))] +pub mod graph; + pub mod assignment; +pub mod call_site_similarity; pub mod trans; pub mod utils; diff --git a/src/target/aby/trans.rs b/src/target/aby/trans.rs index 660db14c8..820741b7b 100644 --- a/src/target/aby/trans.rs +++ b/src/target/aby/trans.rs @@ -13,9 +13,11 @@ use crate::target::aby::assignment::ilp::assign; use crate::target::aby::assignment::SharingMap; use crate::target::aby::utils::*; use std::collections::HashMap; +use std::collections::HashSet; use std::fs; use std::io; use std::path::Path; +use std::time::Instant; use super::assignment::assign_all_boolean; use super::assignment::assign_all_yao; @@ -24,11 +26,16 @@ use super::assignment::assign_arithmetic_and_yao; use super::assignment::assign_greedy; use super::assignment::ShareType; +#[cfg(feature = "lp")] +use crate::target::aby::graph::trans::*; + +use super::call_site_similarity::CallSiteSimilarity; + const PUBLIC: u8 = 2; const WRITE_SIZE: usize = 65536; struct ToABY<'a> { - cs: Computations, + fs: Computations, s_map: HashMap, path: &'a Path, lang: String, @@ -36,17 +43,18 @@ struct ToABY<'a> { // Input mapping inputs: Vec, // Term to share id - term_to_shares: TermMap, + term_to_shares: TermMap>, share_cnt: i32, // Cache - cache: HashMap<(Op, Vec), i32>, - // Const Cache - const_cache: HashMap>, + cache: HashMap<(Op, Vec), Vec>, // Outputs bytecode_input: Vec, bytecode_output: Vec, const_output: Vec, share_output: Vec, + term_share_output: Vec, + const_map: HashMap<(Integer, char), i32>, + written_const_set: HashSet, } impl Drop for ToABY<'_> { @@ -63,40 +71,32 @@ impl Drop for ToABY<'_> { } impl<'a> ToABY<'a> { - fn new( - cs: Computations, - s_map: HashMap, - path: &'a Path, - lang: &str, - ) -> Self { + fn new(fs: Computations, s_map: HashMap, path: &'a Path, lang: &str) -> Self { Self { - cs, + fs, s_map, path, lang: lang.to_string(), curr_comp: "".to_string(), inputs: Vec::new(), term_to_shares: TermMap::default(), - share_cnt: 0, + // 0 is used for not used terms + share_cnt: 1, cache: HashMap::new(), - const_cache: HashMap::new(), bytecode_input: Vec::new(), bytecode_output: Vec::new(), - const_output: Vec::new(), - share_output: Vec::new(), + const_output: vec!["2 1 0 32 0 CONS\n".to_string()], + share_output: vec!["0 a\n".to_string()], + term_share_output: Vec::new(), + const_map: HashMap::new(), + written_const_set: HashSet::new(), } } fn write_const_output(&mut self, flush: bool) { if flush || self.const_output.len() >= WRITE_SIZE { let const_output_path = get_path(self.path, &self.lang, "const", false); - let mut lines = self - .const_output - .clone() - .into_iter() - .collect::>(); - lines.dedup(); - write_lines(&const_output_path, &lines); + write_lines(&const_output_path, &self.const_output); self.const_output.clear(); } } @@ -117,14 +117,11 @@ impl<'a> ToABY<'a> { fn write_share_output(&mut self, flush: bool) { if flush || self.share_output.len() >= WRITE_SIZE { let share_output_path = get_path(self.path, &self.lang, "share_map", false); - let mut lines = self - .share_output - .clone() - .into_iter() - .collect::>(); - lines.dedup(); - write_lines(&share_output_path, &lines); + write_lines(&share_output_path, &self.share_output); self.share_output.clear(); + let term_share_output_path = get_path(self.path, &self.lang, "term_share_map", false); + write_lines(&term_share_output_path, &self.term_share_output); + self.term_share_output.clear(); } } @@ -137,103 +134,192 @@ impl<'a> ToABY<'a> { } fn get_md(&self) -> &ComputationMetadata { - &self.cs.comps.get(&self.curr_comp).unwrap().metadata + &self.fs.comps.get(&self.curr_comp).unwrap().metadata } fn get_term_share_type(&self, t: &Term) -> ShareType { let s_map = self.s_map.get(&self.curr_comp).unwrap(); - *s_map.get(t).unwrap() + if let Some(s) = s_map.get(&t) { + *s + } else { + if let Op::Const(_) = t.op(){ + ShareType::Arithmetic + } else{ + ShareType::None + } + + } } - fn insert_const(&mut self, t: &Term) { - if !self.const_cache.contains_key(t) { - let mut const_map: HashMap = HashMap::new(); + fn write_share(&mut self, t: &Term, s: i32) { + if !self.written_const_set.contains(&s) { + let share_type = self.get_term_share_type(t).char(); + let line = format!("{} {}\n", s, share_type); + self.share_output.push(line); + match t.op() { + Op::Var(..) | Op::Call(..) => {} + _ => { + let line2 = format!("{} {}\n", t.op(), share_type); + } + } + } + } - // a type - let s_a = self.share_cnt; - const_map.insert(ShareType::Arithmetic, s_a); - self.share_cnt += 1; + fn write_shares(&mut self, t: &Term, shares: &Vec) { + let share_type = self.get_term_share_type(t).char(); + for s in shares { + if !self.written_const_set.contains(s) { + let line = format!("{} {}\n", s, share_type); + self.share_output.push(line); + let line2 = format!("{} {}\n", t.op(), share_type); + self.term_share_output.push(line2); + } + } + } - // b type - let s_b = self.share_cnt; - const_map.insert(ShareType::Boolean, s_b); - self.share_cnt += 1; + // TODO: Rust ENTRY api on maps + fn get_share(&mut self, t: &Term) -> i32 { + match self.term_to_shares.get(t) { + Some(v) => { + assert!(v.len() == 1); + v[0] + } + None => { + match &t.op() { + Op::Const(Value::BitVector(b)) => { + let bi = b.as_sint(); + let share_type = self.get_term_share_type(t).char(); + let key = (bi, share_type); + if self.const_map.contains_key(&key) { + let s = self.const_map.get(&key).unwrap().clone(); + self.term_to_shares.insert(t.clone(), [s].to_vec()); + s + } else { + let s = self.share_cnt; + self.term_to_shares.insert(t.clone(), [s].to_vec()); + self.share_cnt += 1; + self.const_map.insert(key, s); + // Write share + self.write_share(t, s); + + s + } + } + _ => { + let s = self.share_cnt; + self.term_to_shares.insert(t.clone(), [s].to_vec()); + self.share_cnt += 1; - // y type - let s_y = self.share_cnt; - const_map.insert(ShareType::Yao, s_y); - self.share_cnt += 1; + // Write share + self.write_share(t, s); - self.const_cache.insert(t.clone(), const_map); + s + } + } + } } } - fn output_const_share(&mut self, t: &Term, to_share_type: ShareType) -> i32 { - if self.const_cache.contains_key(t) { - let output_share = *self - .const_cache - .get(t) - .unwrap() - .get(&to_share_type) - .unwrap(); - let op = "CONS"; + fn get_shares(&mut self, t: &Term) -> Vec { + match self.term_to_shares.get(t) { + Some(v) => v.clone(), + None => { + match &t.op() { + Op::Const(Value::Array(arr)) => { + let sort = check(t); + let num_shares = self.get_sort_len(&sort) as i32; + let mut shares: Vec = Vec::new(); + let share_type = self.get_term_share_type(t).char(); + for i in 0..num_shares { + let idx = Value::BitVector(BitVector::new(Integer::from(i), 32)); + let v = match arr.map.get(&idx) { + Some(c) => c, + + None => &*arr.default, + }; + + match v { + Value::BitVector(b) => { + let bi = b.as_sint(); + let key = (bi, share_type); + if self.const_map.contains_key(&key) { + let s = self.const_map.get(&key).unwrap().clone(); + shares.push(s); + } else { + let s = self.share_cnt; + self.share_cnt += 1; + self.const_map.insert(key, s); + shares.push(s); + } + } + _ => todo!(), + } + } + self.term_to_shares.insert(t.clone(), shares.clone()); - match &t.op() { - Op::Const(Value::BitVector(b)) => { - let value = b.as_sint(); - let bitlen = 32; - let line = format!("2 1 {value} {bitlen} {output_share} {op}\n"); - self.const_output.push(line); - } - Op::Const(Value::Bool(b)) => { - let value = *b as i32; - let bitlen = 1; - let line = format!("2 1 {value} {bitlen} {output_share} {op}\n"); - self.const_output.push(line); - } - _ => todo!(), - }; + // Write shares + self.write_shares(t, &shares); - // Add to share map - let line = format!("{} {}\n", output_share, to_share_type.char()); - self.share_output.push(line); + shares + } + Op::Const(Value::Tuple(tup)) => { + check(t); + let mut shares: Vec = Vec::new(); + let share_type = self.get_term_share_type(t).char(); + for val in tup.iter() { + match val { + Value::BitVector(b) => { + let bi = b.as_sint(); + let key = (bi, share_type); + if self.const_map.contains_key(&key) { + let s = self.const_map.get(&key).unwrap().clone(); + shares.push(s); + } else { + let s = self.share_cnt; + self.share_cnt += 1; + self.const_map.insert(key, s); + shares.push(s); + } + } + _ => todo!(), + } + } + self.term_to_shares.insert(t.clone(), shares.clone()); - output_share - } else { - panic!("const cache does not contain term: {}", t); - } - } + // Write shares + self.write_shares(t, &shares); - fn write_share(&mut self, t: &Term, s: i32) { - let share_type = self.get_term_share_type(t).char(); - let line = format!("{s} {share_type}\n"); - self.share_output.push(line); - } + shares + } + _ => { + let sort = check(t); + let num_shares = self.get_sort_len(&sort) as i32; - // TODO: Rust ENTRY api on maps - fn get_share(&mut self, t: &Term, to_share_type: ShareType) -> i32 { - if t.is_const() && check(t).is_scalar() { - self.output_const_share(t, to_share_type) - } else { - match self.term_to_shares.get(t) { - Some(v) => *v, - None => { - let s = self.share_cnt; - self.term_to_shares.insert(t.clone(), s); - self.share_cnt += 1; + let shares: Vec = (0..num_shares) + .map(|x| x + self.share_cnt) + .collect::>(); + self.term_to_shares.insert(t.clone(), shares.clone()); + + // Write shares + self.write_shares(t, &shares); - // Write share - self.write_share(t, s); + self.share_cnt += num_shares; - s + shares + } } } } } - // clippy doesn't like that self is only used in recursion - // allowing so this can remain an associated function - #[allow(clippy::only_used_in_recursion)] + fn rewirable(&self, s: &Sort) -> bool { + match s { + Sort::Array(..) => true, + Sort::Bool | Sort::BitVector(..) | Sort::Tuple(..) => false, + _ => todo!(), + } + } + fn get_sort_len(&mut self, s: &Sort) -> usize { let mut len = 0; len += match s { @@ -262,92 +348,91 @@ impl<'a> ToABY<'a> { fn embed_eq(&mut self, t: &Term) { let op = "EQ"; - let to_share_type = self.get_term_share_type(t); - let a = self.get_share(&t.cs()[0], to_share_type); - let b = self.get_share(&t.cs()[1], to_share_type); + let a = self.get_share(&t.cs()[0]); + let b = self.get_share(&t.cs()[1]); + let key = (t.op().clone(), vec![a, b]); - let s = self.get_share(t, to_share_type); - if let std::collections::hash_map::Entry::Vacant(e) = self.cache.entry(key.clone()) { - e.insert(s); - let line = format!("2 1 {a} {b} {s} {op}\n"); - self.bytecode_output.push(line); - } else { - let s = *self.cache.get(&key).unwrap(); + if self.cache.contains_key(&key) { + let s = self.cache.get(&key).unwrap().clone(); self.term_to_shares.insert(t.clone(), s); + } else { + let s = self.get_shares(t); + self.cache.insert(key, s.clone()); + let line = format!("2 1 {} {} {} {}\n", a, b, s[0], op); + self.bytecode_output.push(line); }; } fn embed_bool(&mut self, t: Term) { - let to_share_type = self.get_term_share_type(&t); + let s = self.get_share(&t); match &t.op() { Op::Var(name, Sort::Bool) => { let md = self.get_md(); if !self.inputs.contains(&t) && md.is_input(name) { let vis = self.unwrap_vis(name); - let s = self.get_share(&t, to_share_type); + let s = self.get_share(&t); let op = "IN"; if vis == PUBLIC { let bitlen = 1; - let line = format!("3 1 {name} {vis} {bitlen} {s} {op}\n"); + let line = format!("3 1 {} {} {} {} {}\n", name, vis, bitlen, s, op); self.bytecode_input.push(line); } else { - let line = format!("2 1 {name} {vis} {s} {op}\n"); + let line = format!("2 1 {} {} {} {}\n", name, vis, s, op); self.bytecode_input.push(line); } self.inputs.push(t.clone()); } } - Op::Const(_) => { - self.insert_const(&t); + Op::Const(Value::Bool(b)) => { + let op = "CONS"; + let line = format!("2 1 {} 1 {} {}\n", *b as i32, s, op); + self.const_output.push(line); } Op::Eq => { self.embed_eq(&t); } Op::Ite => { let op = "MUX"; - let to_share_type = self.get_term_share_type(&t); - let sel = self.get_share(&t.cs()[0], to_share_type); - let a = self.get_share(&t.cs()[1], to_share_type); - let b = self.get_share(&t.cs()[2], to_share_type); + let sel = self.get_share(&t.cs()[0]); + let a = self.get_share(&t.cs()[1]); + let b = self.get_share(&t.cs()[2]); let key = (t.op().clone(), vec![a, b]); - let s = self.get_share(&t, to_share_type); - if let std::collections::hash_map::Entry::Vacant(e) = self.cache.entry(key.clone()) - { - e.insert(s); - let line = format!("3 1 {sel} {a} {b} {s} {op}\n"); - self.bytecode_output.push(line); - } else { - let s = *self.cache.get(&key).unwrap(); + if self.cache.contains_key(&key) { + let s = self.cache.get(&key).unwrap().clone(); self.term_to_shares.insert(t.clone(), s); + } else { + let s = self.get_shares(&t); + self.cache.insert(key, s.clone()); + let line = format!("3 1 {} {} {} {} {}\n", sel, a, b, s[0], op); + self.bytecode_output.push(line); }; } Op::Not => { let op = "NOT"; - let a = self.get_share(&t.cs()[0], to_share_type); + let a = self.get_share(&t.cs()[0]); let key = (t.op().clone(), vec![a]); - let s = self.get_share(&t, to_share_type); - if let std::collections::hash_map::Entry::Vacant(e) = self.cache.entry(key.clone()) - { - e.insert(s); - let line = format!("1 1 {a} {s} {op}\n"); - self.bytecode_output.push(line); - } else { - let s = *self.cache.get(&key).unwrap(); + if self.cache.contains_key(&key) { + let s = self.cache.get(&key).unwrap().clone(); self.term_to_shares.insert(t.clone(), s); + } else { + let s = self.get_shares(&t); + self.cache.insert(key, s.clone()); + let line = format!("1 1 {} {} {}\n", a, s[0], op); + self.bytecode_output.push(line); }; } Op::BoolNaryOp(o) => { if t.cs().len() == 1 { // HACK: Conditionals might not contain two variables - // If t.cs() len is 1, just output that term + // If t.cs len is 1, just output that term // This is to bypass adding an AND gate with a single conditional term // Refer to pub fn condition() in src/circify/mod.rs - let a = self.get_share(&t.cs()[0], to_share_type); + let a = self.get_share(&t.cs()[0]); match o { - BoolNaryOp::And => self.term_to_shares.insert(t.clone(), a), + BoolNaryOp::And => self.term_to_shares.insert(t.clone(), vec![a]), _ => { unimplemented!("Single operand boolean operation"); } @@ -359,20 +444,18 @@ impl<'a> ToABY<'a> { BoolNaryOp::Xor => "XOR", }; - let a = self.get_share(&t.cs()[0], to_share_type); - let b = self.get_share(&t.cs()[1], to_share_type); + let a = self.get_share(&t.cs()[0]); + let b = self.get_share(&t.cs()[1]); let key = (t.op().clone(), vec![a, b]); - let s = self.get_share(&t, to_share_type); - if let std::collections::hash_map::Entry::Vacant(e) = - self.cache.entry(key.clone()) - { - e.insert(s); - let line = format!("2 1 {a} {b} {s} {op}\n"); - self.bytecode_output.push(line); - } else { - let s = *self.cache.get(&key).unwrap(); + if self.cache.contains_key(&key) { + let s = self.cache.get(&key).unwrap().clone(); self.term_to_shares.insert(t.clone(), s); + } else { + let s = self.get_shares(&t); + self.cache.insert(key, s.clone()); + let line = format!("2 1 {} {} {} {}\n", a, b, s[0], op); + self.bytecode_output.push(line); }; } } @@ -385,19 +468,18 @@ impl<'a> ToABY<'a> { _ => panic!("Non-field in bool BvBinPred: {}", o), }; - let a = self.get_share(&t.cs()[0], to_share_type); - let b = self.get_share(&t.cs()[1], to_share_type); + let a = self.get_share(&t.cs()[0]); + let b = self.get_share(&t.cs()[1]); let key = (t.op().clone(), vec![a, b]); - let s = self.get_share(&t, to_share_type); - if let std::collections::hash_map::Entry::Vacant(e) = self.cache.entry(key.clone()) - { - e.insert(s); - let line = format!("2 1 {a} {b} {s} {op}\n"); - self.bytecode_output.push(line); - } else { - let s = *self.cache.get(&key).unwrap(); + if self.cache.contains_key(&key) { + let s = self.cache.get(&key).unwrap().clone(); self.term_to_shares.insert(t.clone(), s); + } else { + let s = self.get_shares(&t); + self.cache.insert(key, s.clone()); + let line = format!("2 1 {} {} {} {}\n", a, b, s[0], op); + self.bytecode_output.push(line); }; } _ => panic!("Non-field in embed_bool: {}", t), @@ -405,46 +487,50 @@ impl<'a> ToABY<'a> { } fn embed_bv(&mut self, t: Term) { - let to_share_type = self.get_term_share_type(&t); match &t.op() { Op::Var(name, Sort::BitVector(_)) => { let md = self.get_md(); if !self.inputs.contains(&t) && md.is_input(name) { let vis = self.unwrap_vis(name); - let s = self.get_share(&t, to_share_type); + let s = self.get_share(&t); let op = "IN"; if vis == PUBLIC { let bitlen = 32; - let line = format!("3 1 {name} {vis} {bitlen} {s} {op}\n"); + let line = format!("3 1 {} {} {} {} {}\n", name, vis, bitlen, s, op); self.bytecode_input.push(line); } else { - let line = format!("2 1 {name} {vis} {s} {op}\n"); + let line = format!("2 1 {} {} {} {}\n", name, vis, s, op); self.bytecode_input.push(line); } self.inputs.push(t.clone()); } } - Op::Const(Value::BitVector(_)) => { - // create all three shares - self.insert_const(&t); + Op::Const(Value::BitVector(b)) => { + let s = self.get_share(&t); + if !self.written_const_set.contains(&s) { + self.written_const_set.insert(s); + let op = "CONS"; + let line = format!("2 1 {} 32 {} {}\n", b.as_sint(), s, op); + self.const_output.push(line); + } + // self.cache.insert(t.clone(), EmbeddedTerm::Bv); } Op::Ite => { let op = "MUX"; - let sel = self.get_share(&t.cs()[0], to_share_type); - let a = self.get_share(&t.cs()[1], to_share_type); - let b = self.get_share(&t.cs()[2], to_share_type); + let sel = self.get_share(&t.cs()[0]); + let a = self.get_share(&t.cs()[1]); + let b = self.get_share(&t.cs()[2]); let key = (t.op().clone(), vec![sel, a, b]); - let s = self.get_share(&t, to_share_type); - if let std::collections::hash_map::Entry::Vacant(e) = self.cache.entry(key.clone()) - { - e.insert(s); - let line = format!("3 1 {sel} {a} {b} {s} {op}\n"); - self.bytecode_output.push(line); - } else { - let s = *self.cache.get(&key).unwrap(); + if self.cache.contains_key(&key) { + let s = self.cache.get(&key).unwrap().clone(); self.term_to_shares.insert(t.clone(), s); + } else { + let s = self.get_shares(&t); + self.cache.insert(key, s.clone()); + let line = format!("3 1 {} {} {} {} {}\n", sel, a, b, s[0], op); + self.bytecode_output.push(line); }; } Op::BvNaryOp(o) => { @@ -455,19 +541,18 @@ impl<'a> ToABY<'a> { BvNaryOp::Add => "ADD", BvNaryOp::Mul => "MUL", }; - let a = self.get_share(&t.cs()[0], to_share_type); - let b = self.get_share(&t.cs()[1], to_share_type); + let a = self.get_share(&t.cs()[0]); + let b = self.get_share(&t.cs()[1]); let key = (t.op().clone(), vec![a, b]); - let s = self.get_share(&t, to_share_type); - if let std::collections::hash_map::Entry::Vacant(e) = self.cache.entry(key.clone()) - { - e.insert(s); - let line = format!("2 1 {a} {b} {s} {op}\n"); - self.bytecode_output.push(line); - } else { - let s = *self.cache.get(&key).unwrap(); + if self.cache.contains_key(&key) { + let s = self.cache.get(&key).unwrap().clone(); self.term_to_shares.insert(t.clone(), s); + } else { + let s = self.get_shares(&t); + self.cache.insert(key, s.clone()); + let line = format!("2 1 {} {} {} {}\n", a, b, s[0], op); + self.bytecode_output.push(line); }; } Op::BvBinOp(o) => { @@ -482,42 +567,36 @@ impl<'a> ToABY<'a> { match o { BvBinOp::Sub | BvBinOp::Udiv | BvBinOp::Urem => { - let a = self.get_share(&t.cs()[0], to_share_type); - let b = self.get_share(&t.cs()[1], to_share_type); + let a = self.get_share(&t.cs()[0]); + let b = self.get_share(&t.cs()[1]); let key = (t.op().clone(), vec![a, b]); - let s = self.get_share(&t, to_share_type); - if let std::collections::hash_map::Entry::Vacant(e) = - self.cache.entry(key.clone()) - { - e.insert(s); - let line = format!("2 1 {a} {b} {s} {op}\n"); - self.bytecode_output.push(line); - } else { - let s = *self.cache.get(&key).unwrap(); + if self.cache.contains_key(&key) { + let s = self.cache.get(&key).unwrap().clone(); self.term_to_shares.insert(t, s); + } else { + let s = self.get_shares(&t); + self.cache.insert(key, s.clone()); + let line = format!("2 1 {} {} {} {}\n", a, b, s[0], op); + self.bytecode_output.push(line); }; } BvBinOp::Shl | BvBinOp::Lshr => { - let a = self.get_share(&t.cs()[0], to_share_type); + let a = self.get_share(&t.cs()[0]); let const_shift_amount_term = fold(&t.cs()[1], &[]); let const_shift_amount = const_shift_amount_term.as_bv_opt().unwrap().uint(); - let key = ( - t.op().clone(), - vec![a, const_shift_amount.to_i32().unwrap()], - ); - let s = self.get_share(&t, to_share_type); - if let std::collections::hash_map::Entry::Vacant(e) = - self.cache.entry(key.clone()) - { - e.insert(s); - let line = format!("2 1 {a} {const_shift_amount} {s} {op}\n"); - self.bytecode_output.push(line); - } else { - let s = *self.cache.get(&key).unwrap(); + let key = (t.op().clone(), vec![a, const_shift_amount.to_i32().unwrap()]); + if self.cache.contains_key(&key) { + let s = self.cache.get(&key).unwrap().clone(); self.term_to_shares.insert(t, s); + } else { + let s = self.get_shares(&t); + self.cache.insert(key, s.clone()); + let line = + format!("2 1 {} {} {} {}\n", a, const_shift_amount, s[0], op); + self.bytecode_output.push(line); }; } _ => panic!("Binop not supported: {}", o), @@ -525,42 +604,51 @@ impl<'a> ToABY<'a> { } Op::Field(i) => { assert!(t.cs().len() == 1); - let tuple_share = self.get_share(&t.cs()[0], to_share_type); - let field_share = self.get_share(&t, to_share_type); - let op = "FIELD"; - let line = format!("2 1 {tuple_share} {i} {field_share} {op}\n"); - self.bytecode_output.push(line); - self.term_to_shares.insert(t.clone(), field_share); + let shares = self.get_shares(&t.cs()[0]); + assert!(*i < shares.len()); + self.term_to_shares.insert(t.clone(), vec![shares[*i]]); } Op::Select => { assert!(t.cs().len() == 2); - let select_share = self.get_share(&t, to_share_type); - let array_share = self.get_share(&t.cs()[0], to_share_type); - - let line = if let Op::Const(Value::BitVector(bv)) = &t.cs()[1].op() { - let op = "SELECT_CONS"; - let idx = bv.uint().to_usize().unwrap(); - let len = self.get_sort_len(&check(&t.cs()[0])); - assert!(idx < len, "{}", "idx: {idx}, len: {len}"); - format!("2 1 {array_share} {idx} {select_share} {op}\n") + let array_shares = self.get_shares(&t.cs()[0]); + + if let Op::Const(Value::BitVector(bv)) = &t.cs()[1].op() { + let idx = bv.uint().to_usize().unwrap().clone(); + assert!( + idx < array_shares.len(), + "idx: {}, shares: {}", + idx, + array_shares.len() + ); + + self.term_to_shares + .insert(t.clone(), vec![array_shares[idx]]); } else { let op = "SELECT"; - let idx_share = self.get_share(&t.cs()[1], to_share_type); - format!("2 1 {array_share} {idx_share} {select_share} {op}\n",) - }; - self.bytecode_output.push(line); - self.term_to_shares.insert(t.clone(), select_share); + let num_inputs = array_shares.len() + 1; + let index_share = self.get_share(&t.cs()[1]); + let output = self.get_share(&t); + let line = format!( + "{} 1 {} {} {} {}\n", + num_inputs, + self.shares_to_string(array_shares), + index_share, + output, + op + ); + self.bytecode_output.push(line); + self.term_to_shares.insert(t.clone(), vec![output]); + } } _ => panic!("Non-field in embed_bv: {:?}", t), } } fn embed_vector(&mut self, t: Term) { - let to_share_type = self.get_term_share_type(&t); match &t.op() { Op::Const(Value::Array(arr)) => { - let array_share = self.get_share(&t, to_share_type); let mut shares: Vec = Vec::new(); + for i in 0..arr.size { // TODO: sort of index might not be a 32-bit bitvector let idx = Value::BitVector(BitVector::new(Integer::from(i), 32)); @@ -571,105 +659,117 @@ impl<'a> ToABY<'a> { // TODO: sort of value might not be a 32-bit bitvector let v_term = leaf_term(Op::Const(v.clone())); - if self.const_cache.contains_key(&v_term) { + if self.term_to_shares.contains_key(&v_term) { // existing const - let s = self.get_share(&v_term, to_share_type); + let s = self.get_share(&v_term); shares.push(s); } else { // new const - self.insert_const(&v_term); - let s = self.get_share(&v_term, to_share_type); + let s_map = self.s_map.get(&self.curr_comp).unwrap(); + if !s_map.contains_key(&v_term) { + shares.push(0); + continue; + } + let s = self.get_share(&v_term); + match v { + Value::BitVector(b) => { + if !self.written_const_set.contains(&s) { + self.written_const_set.insert(s); + let op = "CONS"; + let line = format!("2 1 {} 32 {} {}\n", b.as_sint(), s, op); + self.const_output.push(line); + } + // self.cache.insert(t.clone(), EmbeddedTerm::Bv); + } + _ => todo!(), + } shares.push(s); } } - assert!(shares.len() == arr.size); - let op = "CONS_ARRAY"; - let line = format!( - "{} 1 {} {} {}\n", - arr.size, - self.shares_to_string(shares), - array_share, - op - ); - self.const_output.push(line); - self.term_to_shares.insert(t.clone(), array_share); + // println!("shares: {:?}", shares); + // println!("arr size: {}", arr.size); + assert!(shares.len() == arr.size); + self.term_to_shares.insert(t.clone(), shares); } Op::Const(Value::Tuple(tup)) => { - let tuple_share = self.get_share(&t, to_share_type); - let mut shares: Vec = Vec::new(); - for val in tup.iter() { + let shares = self.get_shares(&t); + assert!(shares.len() == tup.len()); + for (val, s) in tup.iter().zip(shares.iter()) { match val { Value::BitVector(b) => { - let v_term: Term = bv_lit(b.as_sint(), 32); - if self.const_cache.contains_key(&v_term) { - // existing const - let s = self.get_share(&v_term, to_share_type); - shares.push(s); - } else { - // new const - self.insert_const(&v_term); - let s = self.get_share(&v_term, to_share_type); - shares.push(s); + if !self.written_const_set.contains(s) { + self.written_const_set.insert(*s); + let op = "CONS"; + let line = format!("2 1 {} 32 {} {}\n", b.as_sint(), s, op); + self.const_output.push(line); } } _ => todo!(), } } - assert!(shares.len() == tup.len()); + } + Op::Ite => { + let op = "MUX"; + let shares = self.get_shares(&t); + + let sel = self.get_share(&t.cs()[0]); + let a = self.get_shares(&t.cs()[1]); + let b = self.get_shares(&t.cs()[2]); + + // assert scalar_term share lens are equivalent + assert!(shares.len() == a.len()); + assert!(shares.len() == b.len()); + + let num_inputs = 1 + shares.len() * 2; + let num_outputs = shares.len(); - let op = "CONS_TUPLE"; let line = format!( - "{} 1 {} {} {}\n", - tup.len(), - self.shares_to_string(shares.clone()), - tuple_share, + "{} {} {} {} {} {} {}\n", + num_inputs, + num_outputs, + sel, + self.shares_to_string(a), + self.shares_to_string(b), + self.shares_to_string(shares), op ); - self.const_output.push(line); - self.term_to_shares.insert(t.clone(), tuple_share); - } - Op::Ite => { - let op = "MUX"; - let mux_share = self.get_share(&t, to_share_type); - let sel = self.get_share(&t.cs()[0], to_share_type); - let a = self.get_share(&t.cs()[1], to_share_type); - let b = self.get_share(&t.cs()[2], to_share_type); - let line = format!("3 1 {sel} {a} {b} {mux_share} {op}\n"); self.bytecode_output.push(line); - self.term_to_shares.insert(t.clone(), mux_share); } Op::Store => { assert!(t.cs().len() == 3); - - let array_share = self.get_share(&t.cs()[0], to_share_type); - // let mut array_shares = self.get_shares(&t.cs()[0], to_share_type).clone(); - let value_share = self.get_share(&t.cs()[2], to_share_type); - let store_share = self.get_share(&t, to_share_type); - - let line = if let Op::Const(Value::BitVector(bv)) = &t.cs()[1].op() { - let op = "STORE_CONS"; - let idx = bv.uint().to_usize().unwrap(); - let len = self.get_sort_len(&check(&t.cs()[0])); - assert!(idx < len, "{}", "idx: {idx}, len: {len}"); - format!("3 1 {array_share} {idx} {value_share} {store_share} {op}\n",) + let mut array_shares = self.get_shares(&t.cs()[0]).clone(); + let value_share = self.get_share(&t.cs()[2]); + + if let Op::Const(Value::BitVector(bv)) = &t.cs()[1].op() { + // constant indexing + let idx = bv.uint().to_usize().unwrap().clone(); + array_shares[idx] = value_share; + self.term_to_shares.insert(t.clone(), array_shares.clone()); } else { let op = "STORE"; - let index_share = self.get_share(&t.cs()[1], to_share_type); - format!("3 1 {array_share} {index_share} {value_share} {store_share} {op}\n",) - }; - self.bytecode_output.push(line); - self.term_to_shares.insert(t.clone(), store_share); + let num_inputs = array_shares.len() + 2; + let outputs = self.get_shares(&t); + let num_outputs = outputs.len(); + let index_share = self.get_share(&t.cs()[1]); + let line = format!( + "{} {} {} {} {} {} {}\n", + num_inputs, + num_outputs, + self.shares_to_string(array_shares), + index_share, + value_share, + self.shares_to_string(outputs), + op + ); + + self.bytecode_output.push(line); + } } Op::Field(i) => { assert!(t.cs().len() == 1); - - // let shares = self.get_shares(&t.cs()[0], to_share_type); - let tuple_share = self.get_share(&t.cs()[0], to_share_type); - let field_share = self.get_share(&t, to_share_type); - - let op = "FIELD_VEC"; + let shares = self.get_shares(&t.cs()[0]); let tuple_sort = check(&t.cs()[0]); let (offset, len) = match tuple_sort { @@ -690,59 +790,69 @@ impl<'a> ToABY<'a> { _ => panic!("Field op on non-tuple"), }; - let line = format!("3 1 {tuple_share} {offset} {len} {field_share} {op}\n"); - self.bytecode_output.push(line); - self.term_to_shares.insert(t.clone(), field_share); + // get ret slice + let field_shares = &shares[offset..offset + len]; + + self.term_to_shares.insert(t.clone(), field_shares.to_vec()); } Op::Update(i) => { assert!(t.cs().len() == 2); + let mut tuple_shares = self.get_shares(&t.cs()[0]); + let value_share = self.get_share(&t.cs()[1]); - let tuple_share = self.get_share(&t.cs()[0], to_share_type); - let value_share = self.get_share(&t.cs()[1], to_share_type); - let update_share = self.get_share(&t, to_share_type); + // assert the index is in bounds + assert!(*i < tuple_shares.len()); - let op = "UPDATE"; - let line = format!("3 1 {tuple_share} {i} {value_share} {update_share} {op}\n",); - self.bytecode_output.push(line); - self.term_to_shares.insert(t.clone(), update_share); + // update shares in tuple + tuple_shares[*i] = value_share; + + // store shares + self.term_to_shares.insert(t.clone(), tuple_shares); } Op::Tuple => { - let tuple_share = self.get_share(&t, to_share_type); - let mut shares: Vec = Vec::new(); for c in t.cs().iter() { - shares.push(self.get_share(c, to_share_type)); + shares.append(&mut self.get_shares(c)); } - - let op = "TUPLE"; - let line = format!( - "{} 1 {} {} {}\n", - t.cs().len(), - self.shares_to_string(shares.clone()), - tuple_share, - op - ); - self.bytecode_output.push(line); - self.term_to_shares.insert(t.clone(), tuple_share); + self.term_to_shares.insert(t.clone(), shares); } - Op::Call(name, ..) => { - let call_share = self.get_share(&t, to_share_type); - let op = format!("CALL({name})"); - - let mut arg_shares: Vec = Vec::new(); + Op::Call(name, arg_sorts, ret_sort) => { + let shares = self.get_shares(&t); + let op = format!("CALL({})", name); + let num_args: usize = arg_sorts.iter().map(|ret| self.get_sort_len(ret)).sum(); + let num_rets: usize = self.get_sort_len(ret_sort); + let mut arg_shares: Vec = Vec::new(); for c in t.cs().iter() { - arg_shares.push(self.get_share(c, to_share_type)); + let sort = check(c); + if self.rewirable(&sort) { + arg_shares.extend(self.get_shares(c).iter().map(|&s| s.to_string())) + } else { + arg_shares.extend(self.get_shares(c).iter().map(|&s| s.to_string())) + } } + let mut ret_shares: Vec = Vec::new(); + let mut idx = 0; + + // TODO: Optimize this after correctness + let len = self.get_sort_len(ret_sort); + assert!(idx + len <= shares.len()); + if self.rewirable(ret_sort) { + ret_shares.extend(shares[idx..(idx + len)].iter().map(|&s| s.to_string())) + } else { + ret_shares.extend(shares[idx..(idx + len)].iter().map(|&s| s.to_string())) + } + idx += len; + let line = format!( - "{} 1 {} {} {}\n", - arg_shares.len(), - self.shares_to_string(arg_shares), - call_share, + "{} {} {} {} {}\n", + num_args, + num_rets, + arg_shares.join(" "), + ret_shares.join(" "), op ); self.bytecode_output.push(line); - self.term_to_shares.insert(t.clone(), call_share); } _ => { panic!("Non-field in embed_vector: {}", t.op()) @@ -775,7 +885,16 @@ impl<'a> ToABY<'a> { /// Given a term `t`, lower `t` to ABY Circuits fn lower(&mut self) { - let computations = self.cs.comps.clone(); + let now = Instant::now(); + + let computations = self.fs.comps.clone(); + + // for (name, c) in computations.iter() { + // println!("name: {}", name); + // for t in c.terms_postorder() { + // println!("t: {}", t); + // } + // } // create output files get_path(self.path, &self.lang, "const", true); @@ -791,31 +910,36 @@ impl<'a> ToABY<'a> { get_path( self.path, &self.lang, - &format!("{name}_bytecode_output"), + &format!("{}_bytecode_output", name), true, ); + println!("starting: {}, {}", name, comp.terms()); + for t in comp.outputs.iter() { self.embed(t.clone()); let op = "OUT"; - let to_share_type = self.get_term_share_type(t); - let share = self.get_share(t, to_share_type); - let line = format!("1 0 {share} {op}\n"); - outputs.push(line); + let shares = self.get_shares(&t); + + for s in shares { + let line = format!("1 0 {} {}\n", s, op); + outputs.push(line); + } } self.bytecode_output.append(&mut outputs); // reorder inputs let mut bytecode_input_map: HashMap = HashMap::new(); for line in &self.bytecode_input { - let key = line.split(' ').collect::>()[2]; + let key = line.split(" ").collect::>()[2]; bytecode_input_map.insert(key.to_string(), line.to_string()); } - - let inputs: Vec = comp + let input_order: Vec = comp .metadata - .ordered_input_names() + .ordered_input_names(); + + let inputs: Vec = input_order .iter() .map(|x| { if bytecode_input_map.contains_key(x) { @@ -830,14 +954,15 @@ impl<'a> ToABY<'a> { self.bytecode_input = inputs; // write input bytecode - let bytecode_path = get_path(self.path, &self.lang, &format!("{name}_bytecode"), true); + let bytecode_path = + get_path(self.path, &self.lang, &format!("{}_bytecode", name), true); write_lines(&bytecode_path, &self.bytecode_input); // write output bytecode let bytecode_output_path = get_path( self.path, &self.lang, - &format!("{name}_bytecode_output"), + &format!("{}_bytecode_output", name), false, ); write_lines(&bytecode_output_path, &self.bytecode_output); @@ -856,12 +981,10 @@ impl<'a> ToABY<'a> { io::copy(&mut bytecode_output, &mut bytecode).expect("Failed to merge bytecode files"); // delete output bytecode files - fs::remove_file(&bytecode_output_path).unwrap_or_else(|_| { - panic!( - "Failed to remove bytecode output: {}", - &bytecode_output_path - ) - }); + fs::remove_file(&bytecode_output_path).expect(&format!( + "Failed to remove bytecode output: {}", + &bytecode_output_path + )); //reset for next function self.bytecode_input.clear(); @@ -874,35 +997,83 @@ impl<'a> ToABY<'a> { // write remaining shares self.write_share_output(true); + println!("Time: Lower: {:?}", now.elapsed()); } } /// Convert this (IR) `ir` to ABY. -pub fn to_aby(cs: Computations, path: &Path, lang: &str, cm: &str, ss: &str) { - // Protocol Assignments - let mut s_map: HashMap = HashMap::new(); - - // TODO: change ILP to take in Functions instead of individual computations - for (name, comp) in cs.comps.iter() { - let assignments = match ss { - "b" => assign_all_boolean(comp, cm), - "y" => assign_all_yao(comp, cm), - "a+b" => assign_arithmetic_and_boolean(comp, cm), - "a+y" => assign_arithmetic_and_yao(comp, cm), - "greedy" => assign_greedy(comp, cm), - #[cfg(feature = "lp")] - "lp" => assign(comp, cm), - #[cfg(feature = "lp")] - "glp" => assign(comp, cm), - _ => { - panic!("Unsupported sharing scheme: {}", ss); +pub fn to_aby( + ir: Computations, + path: &Path, + lang: &str, + cm: &str, + ss: &str, + #[allow(unused_variables)] ps: &usize, + #[allow(unused_variables)] ml: &usize, + #[allow(unused_variables)] mss: &usize, + #[allow(unused_variables)] hyper: &usize, + #[allow(unused_variables)] imbalance: &usize, +) { + let now = Instant::now(); + match ss { + #[cfg(feature = "lp")] + "glp" => { + let (fs, s_map) = inline_all_and_assign_smart_glp(&ir, cm); + println!("LOG: Assignment time: {:?}", now.elapsed()); + let mut converter = ToABY::new(fs, s_map, path, lang); + converter.lower(); + } + #[cfg(feature = "lp")] + "tlp" => { + let (fs, s_map) = + partition_with_mut_smart(&ir, cm, path, lang, ps, *hyper == 1, ml, mss, imbalance); + println!("LOG: Assignment time: {:?}", now.elapsed()); + let mut converter = ToABY::new(fs, s_map, path, lang); + converter.lower(); + } + #[cfg(feature = "lp")] + "css" => { + let mut css = CallSiteSimilarity::new(&ir, &ml); + let (comps, dugs) = css.call_site_similarity_smart(); + let s_map = css_partition_with_mut_smart( + &comps, + &dugs, + cm, + path, + lang, + ps, + *hyper == 1, + ml, + mss, + imbalance, + ); + println!("LOG: Assignment time: {:?}", now.elapsed()); + let mut converter = ToABY::new(comps, s_map, path, lang); + converter.lower(); + } + _ => { + // Protocal Assignments + let mut s_map: HashMap = HashMap::new(); + for (name, comp) in ir.comps.iter() { + let assignments = match ss { + "b" => assign_all_boolean(&comp, cm), + "y" => assign_all_yao(&comp, cm), + "a+b" => assign_arithmetic_and_boolean(&comp, cm), + "a+y" => assign_arithmetic_and_yao(&comp, cm), + "greedy" => assign_greedy(&comp, cm), + #[cfg(feature = "lp")] + "lp" => assign(&comp.to_cs(), cm), + #[cfg(feature = "lp")] + "glp" => assign(&comp.to_cs(), cm), + _ => { + panic!("Unsupported sharing scheme: {}", ss); + } + }; + s_map.insert(name.to_string(), assignments); } - }; - #[cfg(feature = "bench")] - println!("LOG: Assignment {}: {:?}", name, now.elapsed()); - s_map.insert(name.to_string(), assignments); - } - - let mut converter = ToABY::new(cs, s_map, path, lang); - converter.lower(); + println!("LOG: Assignment time: {:?}", now.elapsed()); + let mut converter = ToABY::new(ir, s_map, path, lang); + converter.lower(); + } + }; } diff --git a/third_party/hycc/adapted_costs.json b/third_party/hycc/adapted_costs.json index 888922c5f..f652f0738 100644 --- a/third_party/hycc/adapted_costs.json +++ b/third_party/hycc/adapted_costs.json @@ -1,298 +1,305 @@ { - "&&": { + "depth": { + "a": 8.96816269e-2, + "b": 1.09058202e-1 + }, + "bvadd": { + "a": { + "32": 0.0 + }, "b": { - "32": 117 + "32": 0.1096 }, "y": { - "32": 32 + "32": 0.0577 + }, + "depth": { + "a": 0, + "b": 6 } }, - "||": { + "and": { "b": { - "32": 123 + "32": 0.0155 }, "y": { - "32": 40 + "32": 0.0167 + }, + "depth": { + "b": 1 } }, - "a2b": { - "1": 334.02, - "8": 327.064, - "16": 322.452, - "32": 335.764 - }, - "a2y": { - "1": 325.398, - "8": 318.621, - "16": 314.128, - "32": 327.097 - }, - "add": { - "a": { - "1": 214.449, - "8": 108.999, - "16": 202.045, - "32": 116.869 - }, + "bvand": { "b": { - "1": 1056.463, - "8": 1059.493, - "16": 1066.69, - "32": 1049.803 + "32": 0.0155 }, "y": { - "1": 316.939, - "8": 317.848, - "16": 320.007, - "32": 314.941 + "32": 0.0167 + }, + "depth": { + "b": 1 } }, - "and": { + "=": { "b": { - "1": 322.521, - "8": 320.293, - "16": 313.943, - "32": 315.589 + "32": 0.0187 }, "y": { - "1": 318.946, - "8": 326.198, - "16": 328.816, - "32": 314.308 + "32": 0.0246 + }, + "depth": { + "b": 4 } }, - "b2a": { - "1": 314.401, - "8": 328.788, - "16": 328.053, - "32": 326.316 + "bvuge": { + "b": { + "32": 0.0503 + }, + "y": { + "32": 0.0419 + }, + "depth": { + "b": 6 + } }, - "b2y": { - "1": 329.222, - "8": 318.638, - "16": 316.215, - "32": 319.556 + "bvugt": { + "b": { + "32": 0.0503 + }, + "y": { + "32": 0.0419 + }, + "depth": { + "b": 6 + } }, - "eq": { + "bvule": { "b": { - "1": 200.331, - "8": 429.922, - "16": 439.933, - "32": 529.283 + "32": 0.0503 }, "y": { - "1": 309.717, - "8": 323.07, - "16": 319.867, - "32": 316.492 + "32": 0.0419 + }, + "depth": { + "b": 6 } }, - "ge": { + "bvult": { "b": { - "1": 6987.454, - "8": 6832.078, - "16": 6794.562, - "32": 6743.487 + "32": 0.0503 }, "y": { - "1": 487.781, - "8": 476.934, - "16": 474.315, - "32": 470.749 + "32": 0.0419 + }, + "depth": { + "b": 6 } }, - "gt": { + "bvmul": { + "a": { + "32": 0.007 + }, "b": { - "1": 4658.303, - "8": 4554.719, - "16": 4529.708, - "32": 4495.658 + "32": 1.287 }, "y": { - "1": 325.187, - "8": 317.956, - "16": 316.21, - "32": 313.833 + "32": 2.0644 + }, + "depth": { + "a": 1, + "b": 32 } }, - "le": { + "ite": { "b": { - "1": 6987.454, - "8": 6832.078, - "16": 6794.562, - "32": 6743.487 + "32": 0.023 }, "y": { - "1": 487.781, - "8": 476.934, - "16": 474.315, - "32": 470.749 + "32": 0.0338 + }, + "depth": { + "b": 1 } }, - "lt": { + "ne": { "b": { - "1": 4658.303, - "8": 4554.719, - "16": 4529.708, - "32": 4495.658 + "32": 0.0187 }, "y": { - "1": 325.187, - "8": 317.956, - "16": 316.21, - "32": 313.833 + "32": 0.0246 + }, + "depth": { + "b": 5 } }, - "mul": { - "a": { - "1": 323.99, - "8": 425.06, - "16": 324.338, - "32": 416.112 + "or": { + "b": { + "32": 0.0155 + }, + "y": { + "32": 0.0167 }, + "depth": { + "b": 1 + } + }, + "bvor": { "b": { - "1": 318.94, - "8": 654.57, - "16": 766.692, - "32": 1044.67 + "32": 0.0155 }, "y": { - "1": 318.776, - "8": 323.827, - "16": 320.097, - "32": 410.183 + "32": 0.0167 + }, + "depth": { + "b": 1 } }, - "mux": { + "not": { "b": { - "1": 320.388, - "8": 322.502, - "16": 324.969, - "32": 316.438 + "32": 0.0037 }, "y": { - "1": 330.528, - "8": 318.433, - "16": 325.117, - "32": 319.27 + "32": 0.0084 + }, + "depth": { + "b": 0 } }, - "ne": { + "bvxor": { "b": { - "1": 200.331, - "8": 429.922, - "16": 439.933, - "32": 529.283 + "32": 0.0037 }, "y": { - "1": 309.717, - "8": 323.07, - "16": 319.867, - "32": 316.492 + "32": 0.0084 + }, + "depth": { + "b": 0 } }, - "or": { + "bvsub": { + "a": { + "32": 0.0 + }, "b": { - "1": 322.521, - "8": 320.293, - "16": 313.943, - "32": 315.589 + "32": 0.0331 }, "y": { - "1": 318.946, - "8": 326.198, - "16": 328.816, - "32": 314.308 + "32": 0.0574 + }, + "depth": { + "a": 0, + "b": 31 } }, - "shl": { + "bvudiv": { "b": { - "1": 981, - "8": 981, - "16": 981, - "32": 981 + "32": 1.034 }, "y": { - "1": 224, - "8": 224, - "16": 224, - "32": 224 + "32": 1.5924 + }, + "depth": { + "b": 652 } }, - "shr": { + "bvurem": { "b": { - "1": 1015, - "8": 1015, - "16": 1015, - "32": 1015 + "32": 1.0339 }, "y": { - "1": 224, - "8": 224, - "16": 224, - "32": 224 + "32": 1.5921 + }, + "depth": { + "b": 653 } }, - "sub": { - "a": { - "1": 214.449, - "8": 108.999, - "16": 202.045, - "32": 116.869 + "bvrem": { + "b": { + "32": 1.0339 }, + "y": { + "32": 1.5921 + }, + "depth": { + "b": 653 + } + }, + "bvshl": { "b": { - "1": 1056.463, - "8": 1059.493, - "16": 1066.69, - "32": 1049.803 + "32": 0.0024 }, "y": { - "1": 316.939, - "8": 317.848, - "16": 320.007, - "32": 314.941 + "32": 0.0052 + }, + "depth": { + "b": 0 } }, - "xor": { + "bvlshr": { "b": { - "1": 205.544, - "8": 120.112, - "16": 195.179, - "32": 125.154 + "32": 0.0025 }, "y": { - "1": 323.464, - "8": 314.425, - "16": 310.38, - "32": 322.615 + "32": 0.0056 + }, + "depth": { + "b": 0 } }, + "a2b": { + "32": 0.14496 + }, + "a2y": { + "32": 0.1306 + }, + "b2a": { + "32": 0.01225 + }, + "b2y": { + "32": 0.01922 + }, "y2a": { - "1": 491.107, - "8": 510.63, - "16": 499.249, - "32": 501.104 + "32": 0.00265 }, "y2b": { - "1": 312.428, - "8": 324.848, - "16": 317.608, - "32": 318.788 + "32": 0.008 }, - "div": { + "(field 0)": { + "a": { + "32": 0 + }, "b": { - "32": 10 + "32": 0 }, "y": { - "32": 100 + "32": 0 + }, + "depth": { + "a": 0, + "b": 0, + "y": 0 } }, - "rem": { + "select": { "b": { - "32": 10 + "32": 0.0417 }, "y": { - "32": 100 + "32": 0.0584 + }, + "depth": { + "b": 5 + } + }, + "store": { + "b": { + "32": 0.0417 + }, + "y": { + "32": 0.0584 + }, + "depth": { + "b": 5 } } } diff --git a/third_party/opa/adapted_costs.json b/third_party/opa/adapted_costs.json index 46d6f9834..f652f0738 100644 --- a/third_party/opa/adapted_costs.json +++ b/third_party/opa/adapted_costs.json @@ -1,181 +1,305 @@ { - "&&": { + "depth": { + "a": 8.96816269e-2, + "b": 1.09058202e-1 + }, + "bvadd": { + "a": { + "32": 0.0 + }, "b": { - "32": 117 + "32": 0.1096 }, "y": { - "32": 32 + "32": 0.0577 + }, + "depth": { + "a": 0, + "b": 6 } }, - "||": { + "and": { "b": { - "32": 123 + "32": 0.0155 }, "y": { - "32": 40 + "32": 0.0167 + }, + "depth": { + "b": 1 } }, - "a2b": { - "32": 2596.4 - }, - "a2y": { - "32": 2665.2 - }, - "add": { - "a": { - "32": 1 + "bvand": { + "b": { + "32": 0.0155 }, + "y": { + "32": 0.0167 + }, + "depth": { + "b": 1 + } + }, + "=": { "b": { - "32": 160 + "32": 0.0187 }, "y": { - "32": 48 + "32": 0.0246 + }, + "depth": { + "b": 4 } }, - "and": { + "bvuge": { "b": { - "32": 117 + "32": 0.0503 }, "y": { - "32": 32 + "32": 0.0419 + }, + "depth": { + "b": 6 } }, - "b2a": { - "32": 1868.3999999999999 + "bvugt": { + "b": { + "32": 0.0503 + }, + "y": { + "32": 0.0419 + }, + "depth": { + "b": 6 + } }, - "b2y": { - "32": 2293 + "bvule": { + "b": { + "32": 0.0503 + }, + "y": { + "32": 0.0419 + }, + "depth": { + "b": 6 + } }, - "eq": { + "bvult": { "b": { - "32": 489 + "32": 0.0503 }, "y": { - "32": 39 + "32": 0.0419 + }, + "depth": { + "b": 6 } }, - "ge": { + "bvmul": { + "a": { + "32": 0.007 + }, "b": { - "32": 733 + "32": 1.287 }, "y": { - "32": 60 + "32": 2.0644 + }, + "depth": { + "a": 1, + "b": 32 } }, - "gt": { + "ite": { "b": { - "32": 573 + "32": 0.023 }, "y": { - "32": 40 + "32": 0.0338 + }, + "depth": { + "b": 1 } }, - "le": { + "ne": { "b": { - "32": 618 + "32": 0.0187 }, "y": { - "32": 41 + "32": 0.0246 + }, + "depth": { + "b": 5 } }, - "lt": { + "or": { "b": { - "32": 739 + "32": 0.0155 }, "y": { - "32": 60 + "32": 0.0167 + }, + "depth": { + "b": 1 } }, - "mul": { + "bvor": { "b": { - "32": 1731 + "32": 0.0155 }, "y": { - "32": 1127 + "32": 0.0167 }, - "a": { - "32": 104 + "depth": { + "b": 1 } }, - "mux": { + "not": { "b": { - "32": 108 + "32": 0.0037 }, "y": { - "32": 37 + "32": 0.0084 + }, + "depth": { + "b": 0 } }, - "ne": { + "bvxor": { "b": { - "32": 484 + "32": 0.0037 }, "y": { - "32": 38 + "32": 0.0084 + }, + "depth": { + "b": 0 } }, - "or": { + "bvsub": { + "a": { + "32": 0.0 + }, "b": { - "32": 123 + "32": 0.0331 }, "y": { - "32": 40 + "32": 0.0574 + }, + "depth": { + "a": 0, + "b": 31 } }, - "shl": { + "bvudiv": { "b": { - "32": 981 + "32": 1.034 }, "y": { - "32": 224 + "32": 1.5924 + }, + "depth": { + "b": 652 } }, - "shr": { + "bvurem": { "b": { - "32": 1015 + "32": 1.0339 }, "y": { - "32": 224 + "32": 1.5921 + }, + "depth": { + "b": 653 } }, - "sub": { - "a": { - "32": 1 + "bvrem": { + "b": { + "32": 1.0339 }, + "y": { + "32": 1.5921 + }, + "depth": { + "b": 653 + } + }, + "bvshl": { "b": { - "32": 52 + "32": 0.0024 }, "y": { - "32": 49 + "32": 0.0052 + }, + "depth": { + "b": 0 } }, - "xor": { + "bvlshr": { "b": { - "32": 7 + "32": 0.0025 }, "y": { - "32": 23 + "32": 0.0056 + }, + "depth": { + "b": 0 } }, + "a2b": { + "32": 0.14496 + }, + "a2y": { + "32": 0.1306 + }, + "b2a": { + "32": 0.01225 + }, + "b2y": { + "32": 0.01922 + }, "y2a": { - "32": 3207 + "32": 0.00265 }, "y2b": { - "32": 2040.2 + "32": 0.008 }, - "div": { + "(field 0)": { + "a": { + "32": 0 + }, "b": { - "32": 10 + "32": 0 }, "y": { - "32": 100 + "32": 0 + }, + "depth": { + "a": 0, + "b": 0, + "y": 0 } }, - "rem": { + "select": { "b": { - "32": 10 + "32": 0.0417 }, "y": { - "32": 100 + "32": 0.0584 + }, + "depth": { + "b": 5 + } + }, + "store": { + "b": { + "32": 0.0417 + }, + "y": { + "32": 0.0584 + }, + "depth": { + "b": 5 } } } diff --git a/util.py b/util.py index 353c7d20d..ace74784a 100644 --- a/util.py +++ b/util.py @@ -15,13 +15,13 @@ def set_env(features): for f in features: - if f == 'aby': + if f == "aby": if not os.getenv("ABY_SOURCE"): os.environ["ABY_SOURCE"] = ABY_SOURCE - if f == 'kahip': + if f == "kahip": if not os.getenv("KAHIP_SOURCE"): os.environ["KAHIP_SOURCE"] = KAHIP_SOURCE - if f == 'kahypar': + if f == "kahypar": if not os.getenv("KAHYPAR_SOURCE"): os.environ["KAHYPAR_SOURCE"] = KAHYPAR_SOURCE