Skip to content
4 changes: 3 additions & 1 deletion kerngen/high_parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ def _delegate(self, command_str: str, context_seen: list[Context], symbols_map):
# Populate the polys map
data = Data.from_string(rest)
symbols_map[data.name] = Polys(
name=data.name, parts=data.parts, rns=context.max_rns
name=data.name,
parts=data.parts,
rns=context.current_rns,
)
return data
case _:
Expand Down
33 changes: 28 additions & 5 deletions kerngen/pisa_generators/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ def get_pisa_op(num):
for digit, op in get_pisa_op(self.input1.digits):
input0_tmp = Polys.from_polys(self.input0)
input0_tmp.name += "_" + ascii_letters[digit]

# mul/mac for 0-current_rns
ls.extend(
op(
self.context.label,
Expand All @@ -273,7 +275,22 @@ def get_pisa_op(num):
)
for part, q, unit in it.product(
range(self.input1.start_parts, self.input1.parts),
range(self.input0.start_rns, self.input0.rns),
range(self.context.current_rns),
range(self.context.units),
)
)
# mul/mac for max_rns-krns terms
ls.extend(
op(
self.context.label,
self.output(part, q, unit),
input0_tmp(self.input0_fixed_part, q, unit),
self.input1(digit, part, q, unit),
q,
)
for part, q, unit in it.product(
range(self.input1.start_parts, self.input1.parts),
range(self.context.max_rns, self.context.key_rns),
range(self.context.units),
)
)
Expand All @@ -296,11 +313,17 @@ def extract_last_part_polys(input0: Polys, rns: int) -> Tuple[Polys, Polys, Poly
return input_last_part, last_coeff, upto_last_coeffs


def split_last_rns_polys(input0: Polys) -> Tuple[Polys, Polys]:
def split_last_rns_polys(input0: Polys, current_rns) -> Tuple[Polys, Polys]:
"""Split and extract last RNS of input0"""
return Polys.from_polys(input0, mode="last_rns"), Polys.from_polys(
input0, mode="drop_last_rns"
)
if input0.rns <= current_rns:
return Polys.from_polys(input0, mode="last_rns"), Polys.from_polys(
input0, mode="drop_last_rns"
)

# do not include consumed rns
remaining = Polys.from_polys(input0)
remaining.rns = current_rns
return Polys.from_polys(input0, mode="last_rns"), remaining


def duplicate_polys(input0: Polys, name: str) -> Polys:
Expand Down
30 changes: 28 additions & 2 deletions kerngen/pisa_generators/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def to_pisa(self) -> list[PIsaOp]:

ls: list[pisa_op] = []
for input_rns_index in range(self.input0.start_rns, self.input0.rns):
# muli for 0-current_rns
ls.extend(
pisa_op.Muli(
self.context.label,
Expand All @@ -44,13 +45,38 @@ def to_pisa(self) -> list[PIsaOp]:
)
for part, pq, unit in it.product(
range(self.input0.start_parts, self.input0.parts),
range(self.context.key_rns),
range(self.context.current_rns),
range(self.context.units),
)
)
# muli for krns
ls.extend(
pisa_op.Muli(
self.context.label,
self.output(part, pq, unit),
rns_poly(part, input_rns_index, unit),
r2(part, pq, unit),
pq,
)
for part, pq, unit in it.product(
range(self.input0.start_parts, self.input0.parts),
range(self.context.max_rns, self.context.key_rns),
range(self.context.units),
)
)

output_tmp = Polys.from_polys(self.output)
output_tmp.name += "_" + ascii_letters[input_rns_index]
ls.extend(NTT(self.context, output_tmp, self.output).to_pisa())
output_split = Polys.from_polys(self.output)
output_split.rns = self.context.current_rns
# ntt for 0-current_rns
ls.extend(NTT(self.context, output_tmp, output_split).to_pisa())

output_split = Polys.from_polys(self.output)
output_split.rns = self.context.key_rns
output_split.start_rns = self.context.max_rns
# ntt for krns
ls.extend(NTT(self.context, output_tmp, output_split).to_pisa())

return mixed_to_pisa_ops(
INTT(self.context, rns_poly, self.input0),
Expand Down
7 changes: 4 additions & 3 deletions kerngen/pisa_generators/mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,12 @@ class Mod(HighOp):
context: KernelContext
output: Polys
input0: Polys
var_suffix: str = MOD_QLAST # default to qlast, use mod_q otherwise
var_suffix: str = MOD_QLAST # default to qlast, use mod_p otherwise

def to_pisa(self) -> list[PIsaOp]:
"""Return the p-isa code to perform an mod switch down"""
# Immediates
last_q = self.input0.rns - 1
self.input0.start_rns = (self.context.key_rns - 1) - self.context.current_rns

it = Immediate(name="it" + self.var_suffix)
t = Immediate(name="t", rns=last_q)
Expand All @@ -48,7 +47,9 @@ def to_pisa(self) -> list[PIsaOp]:
)

# Drop down input rns
input_last_rns, input_remaining_rns = split_last_rns_polys(self.input0)
input_last_rns, input_remaining_rns = split_last_rns_polys(
self.input0, self.context.current_rns
)

# Temp.
temp_input_last_rns = duplicate_polys(input_last_rns, "y")
Expand Down
3 changes: 1 addition & 2 deletions kerngen/pisa_generators/relin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from high_parser.pisa_operations import PIsaOp, Comment
from high_parser import KernelContext, HighOp, KeyPolys, Polys
from .basic import Add, KeyMul, mixed_to_pisa_ops, extract_last_part_polys

from .mod import Mod
from .decomp import DigitDecompExtend

Expand All @@ -24,7 +25,6 @@ def to_pisa(self) -> list[PIsaOp]:
supports number of digits equal to the RNS size"""
self.output.parts = 2
self.input0.parts = 3

relin_key = KeyPolys(
"rlk", parts=2, rns=self.context.key_rns, digits=self.input0.rns
)
Expand All @@ -40,7 +40,6 @@ def to_pisa(self) -> list[PIsaOp]:

add_original = Polys.from_polys(mul_by_rlk_modded_down)
add_original.name = self.input0.name

return mixed_to_pisa_ops(
Comment("Start of relin kernel"),
Comment("Digit decomposition and extend base from Q to PQ"),
Expand Down
11 changes: 10 additions & 1 deletion kerngen/pisa_generators/rescale.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,30 @@
class Rescale(HighOp):
"""Class representing mod down operation"""

MOD_QLAST = "_mod_qLast"
context: KernelContext
output: Polys
input0: Polys
var_suffix: str = MOD_QLAST # default to qlast

def to_pisa(self) -> list[PIsaOp]:
"""Return the p-isa code to perform an mod switch down"""

# Immediates
last_q = self.input0.rns - 1
one, r2, iq = common_immediates(r2_rns=last_q, iq_rns=last_q)

one, r2, iq = common_immediates(
r2_rns=last_q, iq_rns=last_q, iq_suffix=self.var_suffix
)

q_last_half = Polys("qLastHalf", 1, self.input0.rns)
q_i_last_half = Polys("qiLastHalf", 1, rns=last_q)

# split input
input_last_rns, input_remaining_rns = split_last_rns_polys(self.input0)
input_last_rns, input_remaining_rns = split_last_rns_polys(
self.input0, self.context.current_rns
)

# Create temp vars for input_last/remaining
temp_input_last_rns = duplicate_polys(input_last_rns, "y")
Expand Down