rsa_pss.h

#define hash_size SHASIZE/8
#define rsa_size RSAKEYSIZE/8
#define db_size (RSAKEYSIZE - SHASIZE - 8)/8
#define pad2_size (db_size - SHASIZE/8)

rsa_pss.h ファイルに頻繁に繰り返され、可読性が落ちるので別に define をしておいた。 hash_size, rsa_sizeはbyte単位で変換をしたものであり、db_sizeはDBの大きさをbyte単位で示し、pad2_sizeもDBのpaddingの大きさ(padding_2)をbyte単位で示した。

rsassa_pss_sign

int rsassa_pss_sign(const void *m, size_t mLen, const void *d, const void *n, void *s)
{
    unsigned char m_hash[hash_size];
    unsigned char salt[hash_size];
    unsigned char m_prime[8 + (hash_size) *2];
    unsigned char h[hash_size];
    unsigned char mgf_h[db_size];
    unsigned char db[db_size];
    unsigned char masked_db[db_size];
    unsigned char em[rsa_size];

    uint64_t is_long = 1;
    uint8_t bc = 0xbc;

    // error check
    is_long = (is_long << 61) - 1;
    if (mLen > is_long && SHASIZE <= 256)
        return EM_MSG_TOO_LONG;
    if (RSAKEYSIZE < SHASIZE * 2 + 16)
        return EM_HASH_TOO_LONG;

    // 1. M을 Hash
    (*sha)(m, mLen, m_hash);

    // 2. M'을 계산한다.

    // 8 byte의 padding1
    memset(m_prime, 0x00, 8);
    // salt
    *salt = arc4random_uniform(SHASIZE);
    
    memcpy(m_prime + 8, m_hash, hash_size);
    memcpy(m_prime + 8 + hash_size, salt, hash_size);

    // 3. M'을 Hash한 H를 구한다.
    (*sha)(m_prime, 8 + hash_size * 2, h);
    
    // 4. DB를 계산한다.

    // padding2를 만든다.
    memset(db, 0x00, pad2_size);
    db[pad2_size - 1] = 0x01;

    memcpy(db + pad2_size, salt, hash_size);

    // 5. H를 Mask Generation Function을 통과시킨다.
    mgf(h, hash_size, mgf_h, db_size);

    // 6. DB와 MGF를 통과한 H의 XOR 값을 계산한 maskedDB를 구한다.
    for (int i = 0; i < db_size; i++)
        masked_db[i] = db[i] ^ mgf_h[i];
    
    // 7. EM을 계산한다.
    memcpy(em, masked_db, db_size);
    memcpy(em + db_size, h, hash_size);
    memcpy(em + rsa_size - 1, &bc, 1);

		// MSB를 0으로 바꾼다.
    if(em[0] >> 7)
        em[0] &= 0x7f;

    // error check
    if (rsa_cipher(em, d, n))
        return EM_MSG_OUT_OF_RANGE;

    // 최종적으로 s값을 구한다.
    memcpy(s, em, rsa_size);
    return 0;
}

Message Mに署名をする。

最初にerrorcheckを実施するが、 SHA 入力最大値が$2^{64}$-1 bits であるが、64 乗を mLen との比較のため byte 単位に切り替えなければならないので、8 を割ると$2^{61}$-1 と考えられる。 そのようにSHA ハッシュ関数の長さを超える際にEM_MSG_TOO_LONG で処理した。 その次にEMにはハッシュ出力2個と2byte(16bits)が入らなければならないので(SHA(M')、salt、0xbc、0x01) EM_HASH_TOO_LONG を処理した。

errorcheckが終わったら、MをハッシュしてM'を計算する。 M'の計算は8 bytes であるpadding_1 とM をハッシュした値、そしてsalt(乱数)が入る。 salt乱数の生成は、以前のプロジェクトで使用したarc4random_uniformを使用してSHASIZE範囲内で乱数を生成するようにした。 その後、memcpyを通じてM'を最終的に求める。

M'の計算が終われば、再びハッシュをもう一度したHを求める。

次の過程でDBを計算するが、DBには最後が01のpadding_2とsalt値が入る。 memcpyを通じてDB値を求め、最終的にEMを求めるための過程が必要である。

まず、M'をハッシュしたHに対するMGF値を求め、その値とDBの値をXORしてmaskedDB値を求める。

最後に、makedDB、H、bc(0xbc)が入ったEM値を導出する。 EM値の最初のビットを0にするために0111111と&演算をする。

そのようにEM 値まで求めるようになればプライベート鍵で暗号化し、異常がなければ最終的にs 値で保存する。

rsassa_pss_verify

int rsassa_pss_verify(const void *m, size_t mLen, const void *e, const void *n, const void *s)
{
    unsigned char m_hash[hash_size];
    unsigned char salt[hash_size];
    unsigned char m_prime[8 + (hash_size) *2];
    unsigned char h_prime[hash_size];
    unsigned char h[hash_size];
    unsigned char mgf_h[db_size];
    unsigned char db[db_size];
    unsigned char masked_db[db_size];
    unsigned char em[rsa_size];

    uint64_t is_long = 1;
		uint8_t bc = 0xbc;

    memcpy(em, s, rsa_size);

    // error check
		is_long = (is_long << 61) - 1;
    if (mLen > is_long)
        return EM_MSG_TOO_LONG;
    if (RSAKEYSIZE < SHASIZE * 2 + 16)
        return EM_HASH_TOO_LONG;
    if (rsa_cipher(em, e, n))
        return EM_MSG_OUT_OF_RANGE;
		if (em[0] >> 7 & 1)
        return EM_INVALID_INIT;
    if (em[RSAKEYSIZE/8 - 1] ^ bc)
        return EM_INVALID_LAST;

    // 1. M을 Hash 한다.
    (*sha)(m, mLen, m_hash);

    // 2. EM을 분리하여 계산한다.
    // em에서 masked_db를 추출
    memcpy(masked_db, em, db_size);
    // em에서 H를 추출
    memcpy(h, em + db_size, hash_size);

    // 3. DB를 위한 연산을 수행
    mgf(h, hash_size, mgf_h, db_size);

    for(int i = 0; i < db_size; i++)
        db[i] = masked_db[i] ^ mgf_h[i];
    
    // 4. DB에서 salt값 추출
    memcpy(salt, db + pad2_size, hash_size);

    // error check
    for (int i = 1; i < pad2_size - 1; i++) {
        if (db[i] & 1)
            return EM_INVALID_PD2;
    }
    if (db[pad2_size - 1] != 0x01)
        return EM_INVALID_PD2;

    // 5. mHash와 salt로 M'을 계산
    memset(m_prime, 0x00, 8);
    memcpy(m_prime + 8, m_hash, hash_size);
    memcpy(m_prime + 8 + hash_size, salt, hash_size);

    // 6. M'을 Hash하여 H'을 구함
    (*sha)(m_prime, 8 + hash_size * 2, h_prime);

    // error check
    if (memcmp(h, h_prime, hash_size) != 0)
        return EM_HASH_MISMATCH;

    return 0;
}

まず、sの値をEMにコピーする。

そして、以前の署名と同様にmLen に対して検査を行い、EM_MSG_TOO_LONG を処理し、EM_HASH_TOO_LONG も処理する。 ここで重要なポイントは、0x01 も必須で入らなければならないので、2*SHASIZE + 16 となるが、その理由は、上記の署名過程でMSB 部分を 0 に変える過程があるが、salt 値に損傷を与えることがあるため、それを防ぐために必須に入らなければならない。

その後、eとnを復号するが、もし1がrsa_cipher関数で1がreturnになったら、EM_MSG_OUT_OF_RANGEにエラーメッセージを送る。