diff --git a/kerngen/high_parser/parser.py b/kerngen/high_parser/parser.py index 32ad2583..0df064df 100644 --- a/kerngen/high_parser/parser.py +++ b/kerngen/high_parser/parser.py @@ -124,7 +124,9 @@ def _delegate(self, command_str: str, context_seen: list[Context], symbols_map): # Populate the polys map data = Data.from_string(rest) symbols_map[data.name] = Polys( - name=data.name, parts=data.parts, rns=context.max_rns + name=data.name, + parts=data.parts, + rns=context.current_rns, ) return data case _: diff --git a/kerngen/pisa_generators/basic.py b/kerngen/pisa_generators/basic.py index 34399413..08d21341 100644 --- a/kerngen/pisa_generators/basic.py +++ b/kerngen/pisa_generators/basic.py @@ -263,6 +263,8 @@ def get_pisa_op(num): for digit, op in get_pisa_op(self.input1.digits): input0_tmp = Polys.from_polys(self.input0) input0_tmp.name += "_" + ascii_letters[digit] + + # mul/mac for 0-current_rns ls.extend( op( self.context.label, @@ -273,7 +275,22 @@ def get_pisa_op(num): ) for part, q, unit in it.product( range(self.input1.start_parts, self.input1.parts), - range(self.input0.start_rns, self.input0.rns), + range(self.context.current_rns), + range(self.context.units), + ) + ) + # mul/mac for max_rns-krns terms + ls.extend( + op( + self.context.label, + self.output(part, q, unit), + input0_tmp(self.input0_fixed_part, q, unit), + self.input1(digit, part, q, unit), + q, + ) + for part, q, unit in it.product( + range(self.input1.start_parts, self.input1.parts), + range(self.context.max_rns, self.context.key_rns), range(self.context.units), ) ) @@ -296,11 +313,17 @@ def extract_last_part_polys(input0: Polys, rns: int) -> Tuple[Polys, Polys, Poly return input_last_part, last_coeff, upto_last_coeffs -def split_last_rns_polys(input0: Polys) -> Tuple[Polys, Polys]: +def split_last_rns_polys(input0: Polys, current_rns) -> Tuple[Polys, Polys]: """Split and extract last RNS of input0""" - return Polys.from_polys(input0, mode="last_rns"), Polys.from_polys( - input0, mode="drop_last_rns" - ) + if input0.rns <= current_rns: + return Polys.from_polys(input0, mode="last_rns"), Polys.from_polys( + input0, mode="drop_last_rns" + ) + + # do not include consumed rns + remaining = Polys.from_polys(input0) + remaining.rns = current_rns + return Polys.from_polys(input0, mode="last_rns"), remaining def duplicate_polys(input0: Polys, name: str) -> Polys: diff --git a/kerngen/pisa_generators/decomp.py b/kerngen/pisa_generators/decomp.py index f8262489..1149fda5 100644 --- a/kerngen/pisa_generators/decomp.py +++ b/kerngen/pisa_generators/decomp.py @@ -34,6 +34,7 @@ def to_pisa(self) -> list[PIsaOp]: ls: list[pisa_op] = [] for input_rns_index in range(self.input0.start_rns, self.input0.rns): + # muli for 0-current_rns ls.extend( pisa_op.Muli( self.context.label, @@ -44,13 +45,38 @@ def to_pisa(self) -> list[PIsaOp]: ) for part, pq, unit in it.product( range(self.input0.start_parts, self.input0.parts), - range(self.context.key_rns), + range(self.context.current_rns), range(self.context.units), ) ) + # muli for krns + ls.extend( + pisa_op.Muli( + self.context.label, + self.output(part, pq, unit), + rns_poly(part, input_rns_index, unit), + r2(part, pq, unit), + pq, + ) + for part, pq, unit in it.product( + range(self.input0.start_parts, self.input0.parts), + range(self.context.max_rns, self.context.key_rns), + range(self.context.units), + ) + ) + output_tmp = Polys.from_polys(self.output) output_tmp.name += "_" + ascii_letters[input_rns_index] - ls.extend(NTT(self.context, output_tmp, self.output).to_pisa()) + output_split = Polys.from_polys(self.output) + output_split.rns = self.context.current_rns + # ntt for 0-current_rns + ls.extend(NTT(self.context, output_tmp, output_split).to_pisa()) + + output_split = Polys.from_polys(self.output) + output_split.rns = self.context.key_rns + output_split.start_rns = self.context.max_rns + # ntt for krns + ls.extend(NTT(self.context, output_tmp, output_split).to_pisa()) return mixed_to_pisa_ops( INTT(self.context, rns_poly, self.input0), diff --git a/kerngen/pisa_generators/mod.py b/kerngen/pisa_generators/mod.py index 6d10e255..dabeecc6 100644 --- a/kerngen/pisa_generators/mod.py +++ b/kerngen/pisa_generators/mod.py @@ -33,13 +33,12 @@ class Mod(HighOp): context: KernelContext output: Polys input0: Polys - var_suffix: str = MOD_QLAST # default to qlast, use mod_q otherwise + var_suffix: str = MOD_QLAST # default to qlast, use mod_p otherwise def to_pisa(self) -> list[PIsaOp]: """Return the p-isa code to perform an mod switch down""" # Immediates last_q = self.input0.rns - 1 - self.input0.start_rns = (self.context.key_rns - 1) - self.context.current_rns it = Immediate(name="it" + self.var_suffix) t = Immediate(name="t", rns=last_q) @@ -48,7 +47,9 @@ def to_pisa(self) -> list[PIsaOp]: ) # Drop down input rns - input_last_rns, input_remaining_rns = split_last_rns_polys(self.input0) + input_last_rns, input_remaining_rns = split_last_rns_polys( + self.input0, self.context.current_rns + ) # Temp. temp_input_last_rns = duplicate_polys(input_last_rns, "y") diff --git a/kerngen/pisa_generators/relin.py b/kerngen/pisa_generators/relin.py index 0cae1398..9f390172 100644 --- a/kerngen/pisa_generators/relin.py +++ b/kerngen/pisa_generators/relin.py @@ -6,6 +6,7 @@ from high_parser.pisa_operations import PIsaOp, Comment from high_parser import KernelContext, HighOp, KeyPolys, Polys from .basic import Add, KeyMul, mixed_to_pisa_ops, extract_last_part_polys + from .mod import Mod from .decomp import DigitDecompExtend @@ -24,7 +25,6 @@ def to_pisa(self) -> list[PIsaOp]: supports number of digits equal to the RNS size""" self.output.parts = 2 self.input0.parts = 3 - relin_key = KeyPolys( "rlk", parts=2, rns=self.context.key_rns, digits=self.input0.rns ) @@ -40,7 +40,6 @@ def to_pisa(self) -> list[PIsaOp]: add_original = Polys.from_polys(mul_by_rlk_modded_down) add_original.name = self.input0.name - return mixed_to_pisa_ops( Comment("Start of relin kernel"), Comment("Digit decomposition and extend base from Q to PQ"), diff --git a/kerngen/pisa_generators/rescale.py b/kerngen/pisa_generators/rescale.py index 3bfe56b3..5b7441f3 100644 --- a/kerngen/pisa_generators/rescale.py +++ b/kerngen/pisa_generators/rescale.py @@ -28,21 +28,30 @@ class Rescale(HighOp): """Class representing mod down operation""" + MOD_QLAST = "_mod_qLast" context: KernelContext output: Polys input0: Polys + var_suffix: str = MOD_QLAST # default to qlast def to_pisa(self) -> list[PIsaOp]: """Return the p-isa code to perform an mod switch down""" + # Immediates last_q = self.input0.rns - 1 one, r2, iq = common_immediates(r2_rns=last_q, iq_rns=last_q) + one, r2, iq = common_immediates( + r2_rns=last_q, iq_rns=last_q, iq_suffix=self.var_suffix + ) + q_last_half = Polys("qLastHalf", 1, self.input0.rns) q_i_last_half = Polys("qiLastHalf", 1, rns=last_q) # split input - input_last_rns, input_remaining_rns = split_last_rns_polys(self.input0) + input_last_rns, input_remaining_rns = split_last_rns_polys( + self.input0, self.context.current_rns + ) # Create temp vars for input_last/remaining temp_input_last_rns = duplicate_polys(input_last_rns, "y")