Introduction to CKKS

cryptography
Author

Trevor Tomlin

Published

October 25, 2023

An introduction to the HEAAN / CKKS homomorphic encryption scheme

The HEAAN / CKKS scheme introduced in the paper, Homomorphic Encryption for Arithmetic of Approximate Numbers or named after the authors, Jung Hee Cheon, Andrey Kim, Miran Kim, and Yongsoo Song is a leveled homomorphic scheme based upon the Ring Learning with Errors (RLWE) problem that allows encypted computation on real real numbers. This scheme differs from BGV/BFV schemes because it works on real numbers rather than integers.

Notation

Understanding the following notations will be helpful for reading this blog post:

  • \(\mathcal{R} = \mathbb{Z}[X]/(X^N +1)\)
  • \(\mathbb{Z}_q\) is in the range \((−q/2, q/2]\)
  • \(\lfloor x \rceil\) represents the rounding of the real number x to the closest integer

Parameters

  • n is the dimension and is typically a power of 2, such as 2^11
  • \(\Delta\) is the scaling factor
  • Q and P are two large numbers that both mush be somewhere in the range of 2^20 to 2^200
  • L represents the max number of multiplications allowed
  • \(\sigma\) represents the standard deviation is typically set to 3.2
  • H is the hamming weight and is usually set to 64

There are multiple lattice hardness estimators such as Security Estimates for Lattice Problems that allow you to input parameters and see how many bits of security the parameters provide. Ideally the parameters provide at least 128 bits of security.

Homomorphic Encryption Standard provides standard parameters for FHE schemes and can be used as a reference.

Encoding

This scheme operates allows us to encode a vector, \(\mathbb{C}^{N/2}\), into the ciphertext space, \(\mathcal{R}_q = (\mathbb{Z}/q\mathbb{Z})[X]/(X^N +1)\) using a ring isomorphism and cannoncial embedding map.

A cyclotomic polynomial is a polynomial with integer coefficients whose roots are all primitive. An example would be the 4th cyclotomic polynomial or \(\phi_4(x) = x^2 + 1\). One important property of cyclotomic polynomials is that if N is a power of two, then the polynomial is equal to \(x^{N/2} + 1\).

A Toolkit for Ring-LWE Cryptography

More information about these polynomials can be found here on Wikipedia

For any cyclotomic polynomial, \(a\), we can say that \(a(\xi^j) = a(\overline{\xi^{-j}})\)

Code
import numpy as np
import math

def primitive_nth_roots_of_unity(n: int) -> np.ndarray:
    roots = []
    
    for i in range(1, n+1):
        if math.gcd(i, n) == 1:
            root = np.e ** (2 * np.pi * 1j * i / n)
            roots.append(root)

    return np.array(roots)

eigth_prim_roots = primitive_nth_roots_of_unity(8)

for i, r in enumerate(eigth_prim_roots):
    print(f"Root #{i+1}: {r}\nConjugate: {np.conjugate(r)}\n")
Root #1: (0.7071067811865476+0.7071067811865475j)
Conjugate: (0.7071067811865476-0.7071067811865475j)

Root #2: (-0.7071067811865475+0.7071067811865476j)
Conjugate: (-0.7071067811865475-0.7071067811865476j)

Root #3: (-0.7071067811865477-0.7071067811865475j)
Conjugate: (-0.7071067811865477+0.7071067811865475j)

Root #4: (0.7071067811865474-0.7071067811865477j)
Conjugate: (0.7071067811865474+0.7071067811865477j)

From this week can see that \(\xi^1 = \overline{\xi^7}\) and \(\xi^3 = \overline{\xi^5}\).

Code
import numpy as np
import matplotlib.pyplot as plt

n = 8

roots = [np.exp(2j * np.pi * k / n) for k in range(n) if np.gcd(k, n) == 1]

real_parts = [root.real for root in roots]
imaginary_parts = [root.imag for root in roots]

labels = ['ξ^1', 'ξ^3', 'ξ^5', 'ξ^7']

for i, (x, y) in enumerate(zip(real_parts, imaginary_parts)):
    plt.scatter(x, y, color='red', label=labels[i])
    plt.text(x + 0.1, y, labels[i], fontsize=12)

for root in roots:
    plt.plot([0, root.real], [0, root.imag], linestyle='--', color='blue')

plt.xlabel('Real Part')
plt.ylabel('Imaginary Part')
plt.axhline(0, color='black', linewidth=0.5)
plt.axvline(0, color='black', linewidth=0.5)
plt.grid(color='gray', linestyle='--', linewidth=0.5)

plt.axis('equal')

plt.xlim(-1.5, 1.5)
plt.ylim(-1.5, 1.5)

plt.title(f'Primitive 8th Roots of Unity')

plt.show()

This makes intuitive sense if we look at a graph of the roots and see that primitive roots of unity are always symmetric over the x axis.

\(\mathbb{H}\) of is a subring of \(\mathbb{C}^N\) that contains elements of \(\mathbb{C}^N\) where the conjugate is also in the subring.

We define an operation \(\pi^{-1}\) which expands a vector \(\mathbb{C}^{N/2}\) into \(\mathbb{H}\)

def pi_inverse(z: np.ndarray) -> np.ndarray:
    return np.concatenate([z, [np.conjugate(x) for x in z[::-1]]])

z = np.array([3+4j,2-1j])

for i in range(len(z)):
    print(f"i: {i}, {z[i]} = i: {(-i-1) % 4}, {np.conjugate(z[-i])}")

print()
print(list(pi_inverse(z)))
i: 0, (3+4j) = i: 3, (3-4j)
i: 1, (2-1j) = i: 2, (2+1j)

[(3+4j), (2-1j), (2+1j), (3-4j)]

\(\sigma(x)\) is equivalent to the evaluation of a polynomial on the roots of unity for the given cyclotomic polynomial which will produce an element in \(\mathbb{C}^N\).

Imagine \(\Delta\) is 10,000. We can say that \(\frac{(\Delta m * \Delta n)}{\Delta^2} = m * n\). For example let \(m = 3.14\) and \(n = 2.72\). We expect the result to equal 8.5408.

  1. Multiply \(m\) by \(\Delta\), \(\Delta * m = 31,400\)
  2. Multiply \(n\) by \(\Delta\), \(\Delta * n = 27,200\)
  3. Add \(\Delta m * \Delta n = 854,080,000\)
  4. Divide result by \(\Delta^2\), \(854,080,000 / \Delta^2 = 8.5408\)

The canonical embedding \(\sigma\) is defined as \(\sigma : \mathbb{R}[X]/(X^N +1) \rightarrow \mathbb{C}^N\) which is equaluating the polynomial on the primitive complex roots of unity to get a vector of points Its inverse would be \(\sigma^{-1} : \mathbb{C}^N \rightarrow \mathbb{R}[X]/(X^N +1)\) which takes a list of points and produces a polynomial.

An element in \(\mathbb{H}\) is not necessarily an element in \(\sigma(\mathcal{R})\) so we have to project each element into \(\sigma(\mathcal{R})\).

We want to project the elements of \(\mathbb{H}\) onto the bases of \(\sigma(\mathcal{R})\) which are \(\beta = (\sigma(1),\sigma(X),\dots,\sigma(X^{N−1}))\) and to have integer coefficients rather than complex ones.

The first step is to project the elements of our vector \(z\) onto the bases of \(\sigma(\mathcal{R})\).

This can be done by using the vector projection equation first learning in a college algebra course. It is defined as \(proj_\beta z = \frac{z * \beta}{\left\| \beta \right\|} * \beta\).

After that we will have a polynomial that has the same bases of \(\sigma(\mathcal{R})\), but we still need to round it so that it has integer coefficients.

A technique for this discretization is called “coordinate-wise randomized rounding” and is explained in section 2.4.2 of A Toolkit for Ring-LWE Cryptography.

Let \(\Lambda = \mathcal{L}(B) \subset \mathbb{R}\) represented by the basis \(B\) Given a point \(x \in \mathbb{R}\) and a coset \(c \in \mathbb{R}\), the goal is to find an element, \(y\) on the lattice \(\Lambda\) that minimizes the length of \(y-x\).

We define a new coset, \(c' = \Sigma^{n}_i=1=a_ib_i \mod \lambda\) for the coefficients a which are the decimal parts of our x point between 0 and 1. We choose \(f_i\) randomly from \(\{a-1, a\}\). We then multiply by the basis to complete the prijection.

In other words, we first compute \(randround(\frac{\langle z_i b_i \rangle}{\langle b_i b_i\rangle})\) for every basis vector and then we multiply it by \(b\) like the previous projection formula told us.

def rand_round(x: float) -> int:
    decimal = x - np.floor(x)
    f = np.random.choice([decimal-1, decimal])
    return int(x + f)


def cwrr(coords: np.ndarray) -> np.ndarray:
    return np.array([rand_round(x) for x in coords])


def discretization(z: np.ndarray) -> np.ndarray:
    bases = np.vander(primitive_nth_roots_of_unity(2*N), N, increasing=True).T
    proj = np.array([np.real(np.vdot(z, b) / np.vdot(b,b)) for b in bases])

    return cwrr(proj) @ bases

After that we have have to compute \(\sigma^{-1}(\lfloor\Delta \pi^{-1}(m)\rceil_{\sigma(\mathcal{R})})\) to get our final transformation into the ciphertext space \(\mathbb{Z}_q[X]/(X^N +1)\)

We have a list of points so we can use the vandermonde matrix to solve for the coefficients by solving \(Va = y\) where V is the vandermonde matrix, a are the polynomial coefficients, and y are the points that we have.

from numpy.polynomial import Polynomial


def sigma_inverse(b: np.ndarray) -> Polynomial:
    A = np.vander(primitive_nth_roots_of_unity(2*N), N, increasing=True)
    coeffs = np.linalg.solve(A, b)
    return Polynomial(coeffs)

Once we put it together we have our complete encoding functionality.

DELTA = 64
N = 4


def pi_inverse(z: np.ndarray) -> np.ndarray:
    return np.concatenate([z, [np.conjugate(x) for x in z[::-1]]])


def rand_round(x: float) -> int:
    decimal = x - np.floor(x)
    f = np.random.choice([decimal-1, decimal])
    return int(x + f)


def cwrr(coords: np.ndarray) -> np.ndarray:
    return np.array([rand_round(x) for x in coords])


def discretization(z: np.ndarray) -> np.ndarray:
    bases = np.vander(primitive_nth_roots_of_unity(2*N), N, increasing=True).T
    proj = np.array([np.real(np.vdot(z, b) / np.vdot(b,b)) for b in bases])

    return cwrr(proj) @ bases


def sigma_inverse(b: np.ndarray) -> Polynomial:
    A = np.vander(primitive_nth_roots_of_unity(2*N), N, increasing=True)
    coeffs = np.linalg.solve(A, b)
    return Polynomial(coeffs)


def encode(z: np.ndarray) -> Polynomial:
    out = pi_inverse(z)
    out *= DELTA
    out = discretization(out)
    out = sigma_inverse(out)
    
    out = np.round(np.real(out.coef)).astype(int)
    return Polynomial(out)


z = np.array([3+4j, 2-1j])
encoded_z = encode(z)
print(encoded_z)
159.0 + 90.0·x + 160.0·x² + 44.0·x³

If compare with the paper’s example we can see that we got the same polynomial that they did and our scheme worked.

To recap how our encoding scheme encodes \(m \in \mathbb{C}^{N/2} \rightarrow \mathbb{Z}_q[X]/(X^N +1)\):

  1. \(\pi^{-1}(m)\) applies \(\mathbb{C}^{N/2} \rightarrow \mathbb{H}\)
  2. Scale our element: \(\Delta \pi^{-1}(m)\)
  3. Project and randomly round our element into the space of \(\sigma(\mathcal{R})\)
  4. Apply \(\sigma^{-1}(\lfloor\Delta \pi^{-1}(m)\rceil_{\sigma(\mathcal{R})}) \in \mathbb{Z}_q[X]/(X^N +1)\)

Decoding

Decoding is very easy to understand as it is pretty much the opposite of encoding.

Our goal is to convert the polynomial in the ring \(\mathbb{Z}_q[X]/(X^N +1)\) to our plaintext in the space \(\mathbb{C}^{N/2}\).

First we have to multiply our polynomial by \(\Delta^{-1}\).

Then we use the canonical embedding map \(\sigma: \mathbb{R}[X]/(X^N +1) \rightarrow \mathbb{C}^{N/2}\) which is equivalent to evaluating the polynomial, p, at all of our primitive roots of unity and can be calculated with the vandermonde matrix.

def sigma(p: Polynomial) -> np.ndarray:
    V = np.vander(primitive_nth_roots_of_unity(2*N), N, increasing=True)
    a = p.coef
    y = V @ a
    return y

Next we have to perform \(\pi\) to take an element \(\in \mathbb{H}\) and transform it into and element \(\in \mathbb{C}^{N/2}\)

def pi(z: np.ndarray) -> np.ndarray:
    return z[:N//2]

Now we put it all together to get our decoding function.

def decode(p: Polynomial) -> np.ndarray:
    p *= 1/DELTA
    y = sigma(p)
    return pi(y)

print(decode(encoded_z))
[2.992608+3.98050482j 1.976142-1.01949518j]

We can see that the decoded value is very close to our orignal value.

To summarize the decoding, we follow these steps:

  1. Multiply our polynomial by \(\Delta^{-1}\)
  2. Use the canoncial embedding map on our polynomial: \(\sigma: \mathbb{R}[X]/(X^N +1) \rightarrow \mathbb{C}^{N/2}\)
  3. Transform our element from \(\mathbb{H} \rightarrow \mathbb{C}^{N/2}\) using \(\pi\)

RLWE

The ring learning with errors (RLWE) problem is closely related to the learning with errors (LWE) problem and has be proven to have the same security. RLWE works with polynomials in a finite quotient ring often denoted ${q}[ x] / (x)$ where \(\Phi(x)\) is an irreducible polynomial. There are three polynomials that are used in the scheme. There is a random, but known polynomial sampled from the ring called \(a(x)\), a random unknown error polynomial, \(e(x)\), and an unknown secret polynomial, \(s(x)\). $b(x) is formed from the equation \(b(x) = a(x) * s(x) + e(x)\) and is publicly known. The RWLE problem has two different sub-problems. The first problem is called the search version and the goal is to determine the secret polynomial, \(s(x)\) from the public pair (a(x), b(x)). The second problem is called the decision problem and the goal is to determine whether a polynomial \(b(x)\) is from the equation \(b(x) = a(x) * s(x) + e(x)\) or if it is randomly sampled from the quotient ring, ${q}[ x] / (x)$. These problems have been proven to be incredibly hard to solve even for quantum computers which makes them an excellent choice for cryptography.

Key Generation

The key generation phase creates a private key and uses that to create a public key and an evaluation key.

We sample \(s\) from a function that returns a polynomial with coefficients that are 0 or 1 and have a specified hamming weight denoted \(s \leftarrow HWT(h)\) .

The secret key, \(sk\), is defined as (1, s).

We uniformly sample coefficients for \(R_{qL}\) in the range \(\lbrack qL/2, qL/2-1\}\).

The code for this function is shown below:

def gen_uniform_poly(n: int, modulus: int) -> Polynomial:
    coeffs = [random.randrange(-modulus//2, modulus//2 + 1) for _ in range(n)]
    return Polynomial(coeffs)

Then we sample a error polynomial from a discrete gaussian function denoted \(e \leftarrow DG(\sigma^2)\).

An implementation of the discrete gaussian follows:

def gen_gaussian_poly(n: int, modulus: int, sigma: float = 3.2) -> Polynomial:
    return Polynomial(np.rint(np.random.normal(0, sigma**2, size=n)))

For the final part of our public key, we set \(b \leftarrow -as + e \space (mod \space qL)\).

Our public key, \(pk\), is set to (b, a).

This is enough to do simple additions, but for a technique we will go into depth more later we need an evaluation key.

Like before we sample an \(a'\) and \(e'\) polynomial with the same techniques.

This time, however, we set \(b' \leftarrow -a's + e' + Ps^2 \space (mod \space P qL)\).

And our evaluation key, \(evk\), is (b’, a’).

An example function in Python might look like:

N = 4
P = 2**15
Q = 2**15
L = 3
h = 64
poly_mod = Polynomial([1, 0, 0, 0, 1])

def keygen():
    Ql = P ** L * Q

    s = gen_hwt_poly(N, h)
    a = gen_uniform_poly(N, Ql)
    e = gen_gaussian_poly(N, Ql)

    sk = (1, s)

    b = poly_coef_mod((-a * s + e) % poly_mod, Ql)
    pk = (b, a)
                      
    # evk
    a1 = gen_uniform_poly(N, P * Ql)
    e1 = gen_gaussian_poly(N, P * Ql)

    b1 = poly_coef_mod((-a1 * s + e1 + P * (s ** 2)) % poly_mod, P * Ql)
    evk = (b1, a1)

    return pk, sk, evk

Encryption

The encryption step takes an encoded message in the form of a polynomial and hides it using the public key.

Encryption uses another sampling method called \(\mathcal{Z}O(\rho)\) that outputs either -1, 0, or 1. The probability of outputting -1 or 1 is rho/2 and the probability of outputting 0 is \(1 - \rho\).

Here is an implementation we can use:

def sample_triangle(n: int) -> Polynomial:
    coeffs = [0] * n

    for i in range(n):
        r = random.randrange(0, 4)
        if r == 0: coeffs[i] = -1
        elif r == 1: coeffs[i] = 1
        else: coeffs[i] = 0
    return Polynomial(coeffs)

The steps for encryption are:

  1. Sample \(v \leftarrow \mathcal{Z}O(0.5)\)
  2. Sample \(e1, e2 \leftarrow DG(\sigma^2)\)
  3. Output \(v · pk + (m + e_0, e_1) \space (mod \space qL)\)
def encrypt(pk: tuple, m: Polynomial) -> tuple:
    Ql = P ** L * Q

    v = sample_triangle(N)
    e1 = gen_gaussian_poly(N, Ql)
    e2 = gen_gaussian_poly(N, Ql)

    pk1, pk2 = pk

    c1 = poly_coef_mod((v * pk1 + m + e1) % poly_mod, Ql)
    c2 = poly_coef_mod((v * pk2 + e2) % poly_mod, Ql)

    return ((c1, c2), L)

Decryption

def decrypt(sk: tuple, ctx: tuple):
    b, a = ctx[0]
    l = ctx[1]
    Ql = (P ** l) * Q
    return poly_coef_mod((b + a * sk[1]) % poly_mod, Ql)

Addition of Constants

Addition of constants in this scheme is very simple. We just add the encoded (but not encrypted) constant to the first element of the ciphertext. The following code demonstrates this addition.

def add_const(c: tuple, const: Polynomial) -> tuple:
    Ql = P ** c[1] * Q
    c1, c2 = c[0]
    cadd1 = poly_coef_mod((c1 + const) % poly_mod, Ql)
    return ((cadd1, c2), c.l)

Addition of Ciphertexts

Addition of ciphertexts in CKKS is also very simple. We just output the component-wise sum of the ciphertexts. Addition in most homomorphic encryption schemes does not add much noise and we usually do not have to do any noise reduction. However, if a very large number of additions are performed, a noise reduction technique will have to be employed.

def add(c1: tuple, c2: tuple) -> tuple:
    Ql = P ** c1[1] * Q
    c11, c12 = c1[0]
    c21, c22 = c2[0]
    cadd1 = poly_coef_mod((c11 + c21) % poly_mod, Ql)
    cadd2 = poly_coef_mod((c12 + c22) % poly_mod, Ql)
    return ((cadd1, cadd2), c1.l)

Multiplication of Constants

Multiplying a ciphertext by a plaintext is another simple operation as we only have to multiply each component of the ciphertext by the encoded constant.

def mult_const(c: tuple, const: Polynomial) -> tuple:
    Ql = P ** c.l * Q
    c1, c2 = c.c
    cmult1 = poly_coef_mod((c1 * const) % poly_mod, Ql)
    cmult2 = poly_coef_mod((c2 * const) % poly_mod, Ql)
    return ((cmult1, cmult2), c.l)

Multiplication of Ciphertexts

The most complex operation is multiplying two ciphertexts together. The noise of each ciphertext greatly increases the product’s noise and we have to have methods to deal with it so that it can be correctly decrypted.

Multiplication is mathematically defined as Given \(c_1 = (b_1, a_1), c_2 = (b_2, a_2)\) Multiply \((b_1b_2, a_1b_2 + a_2b_1, a_1a_2) (mod \space q_l) = (d_1, d_2, d_3)\)

We notice that this increases our ciphertext space from 2 dimensions to 3 dimensions. This is a similar affect of multiplying two degree 1 polynomials and getting a degree polynomial eg. \((x+1)(x+3) = (x^2 + 4x + 3)\)

To decrypt we need to have a cipher text of two dimensions. To convert the new ciphertext into 2 dimensions we introduce an operation called Relinearization.

If you remember from the key generation algorithm we created an evaluation key that involved masking the squared secret key.

Because of this evaluation key we can define our relinearization algorithm as: \((d_0, d_1) = \lfloor P^{-1} * d_2 * evk \rceil (mod \space q_l)\)

There is still one more step needed to implement the multiplication algorithm.

The Rescaling procedure is similar to the Modulus Switching technique used in other algorithms like BGV and has the purpose of eliminating some of the LSBs. It also keeps the scaling of the ciphertext consistent because when two ciphertexts with a scale factor of \(\Delta\) multipy each other we get a scale factor of \(\Delta^2\).

Rescaling is defined as \(\lfloor \frac{q_{l'}}{q_l}c \rceil\)

So for multiplication we:

  1. Multiply two ciphertexts to get a degree 3 ciphertext
  2. Relinearize the result back to 2 dimensions
  3. Rescaled the ciphertext so that the scale factor is correct

An implementation is ommited because it multiplication in this scheme is unlikely to work unless you are a using a language that supports arbitrarily sized numbers.

There is a paper that will be described below that shows how we can use this scheme with 64 bit numbers.

Next Steps

If you are implementing the scheme from the paper, you will likely run into issues. There are many optimizations that one can do to improve it. One easy one is to replace parts of the encoding with the Fast Fourier Transform (FFT) instead of the Vandermonde matrix to improve the time complexity. The more important optimization is related to the modulus P*Q. You need a very large amount of bits to represent these large numbers. A fix for this is to use the Chinese Remaineder Theorem (CRT) to represent the large numbers as smaller primes. The paper that implements this is called A Full RNS Variant of Approximate Homomorphic Encryption and is recommended if you plan on implementing this scheme.