001/*
002 *  Licensed under the Apache License, Version 2.0 (the "License");
003 *  you may not use this file except in compliance with the License.
004 *  You may obtain a copy of the License at
005 *
006 *       http://www.apache.org/licenses/LICENSE-2.0
007 *
008 *  Unless required by applicable law or agreed to in writing, software
009 *  distributed under the License is distributed on an "AS IS" BASIS,
010 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
011 *  See the License for the specific language governing permissions and
012 *  limitations under the License.
013 *  under the License.
014 */
015
016package org.apache.commons.imaging.formats.jpeg.decoder;
017
018import java.awt.image.BufferedImage;
019import java.awt.image.ColorModel;
020import java.awt.image.DataBuffer;
021import java.awt.image.DirectColorModel;
022import java.awt.image.Raster;
023import java.awt.image.WritableRaster;
024import java.io.ByteArrayInputStream;
025import java.io.IOException;
026import java.util.ArrayList;
027import java.util.Arrays;
028import java.util.List;
029import java.util.Properties;
030
031import org.apache.commons.imaging.ImageReadException;
032import org.apache.commons.imaging.color.ColorConversions;
033import org.apache.commons.imaging.common.BinaryFileParser;
034import org.apache.commons.imaging.common.bytesource.ByteSource;
035import org.apache.commons.imaging.formats.jpeg.JpegConstants;
036import org.apache.commons.imaging.formats.jpeg.JpegUtils;
037import org.apache.commons.imaging.formats.jpeg.segments.DhtSegment;
038import org.apache.commons.imaging.formats.jpeg.segments.DqtSegment;
039import org.apache.commons.imaging.formats.jpeg.segments.SofnSegment;
040import org.apache.commons.imaging.formats.jpeg.segments.SosSegment;
041
042import static org.apache.commons.imaging.common.BinaryFunctions.read2Bytes;
043import static org.apache.commons.imaging.common.BinaryFunctions.readBytes;
044
045public class JpegDecoder extends BinaryFileParser implements JpegUtils.Visitor {
046    /*
047     * JPEG is an advanced image format that takes significant computation to
048     * decode. Keep decoding fast: - Don't allocate memory inside loops,
049     * allocate it once and reuse. - Minimize calculations per pixel and per
050     * block (using lookup tables for YCbCr->RGB conversion doubled
051     * performance). - Math.round() is slow, use (int)(x+0.5f) instead for
052     * positive numbers.
053     */
054
055    private final DqtSegment.QuantizationTable[] quantizationTables = new DqtSegment.QuantizationTable[4];
056    private final DhtSegment.HuffmanTable[] huffmanDCTables = new DhtSegment.HuffmanTable[4];
057    private final DhtSegment.HuffmanTable[] huffmanACTables = new DhtSegment.HuffmanTable[4];
058    private SofnSegment sofnSegment;
059    private SosSegment sosSegment;
060    private final float[][] scaledQuantizationTables = new float[4][];
061    private BufferedImage image;
062    private ImageReadException imageReadException;
063    private IOException ioException;
064    private final int[] zz = new int[64];
065    private final int[] blockInt = new int[64];
066    private final float[] block = new float[64];
067
068    @Override
069    public boolean beginSOS() {
070        return true;
071    }
072
073    @Override
074    public void visitSOS(final int marker, final byte[] markerBytes, final byte[] imageData) {
075        final ByteArrayInputStream is = new ByteArrayInputStream(imageData);
076        try {
077            // read the scan header
078            final int segmentLength = read2Bytes("segmentLength", is,"Not a Valid JPEG File", getByteOrder());
079            final byte[] sosSegmentBytes = readBytes("SosSegment", is, segmentLength - 2, "Not a Valid JPEG File");
080            sosSegment = new SosSegment(marker, sosSegmentBytes);
081            // read the payload of the scan, this is the remainder of image data after the header
082            // the payload contains the entropy-encoded segments (or ECS) divided by RST markers
083            // or only one ECS if the entropy-encoded data is not divided by RST markers
084            // length of payload = length of image data - length of data already read
085            final int[] scanPayload = new int[imageData.length - segmentLength];
086            int payloadReadCount = 0;
087            while (payloadReadCount < scanPayload.length) {
088                scanPayload[payloadReadCount] = is.read();
089                payloadReadCount++;
090            }
091
092            int hMax = 0;
093            int vMax = 0;
094            for (int i = 0; i < sofnSegment.numberOfComponents; i++) {
095                hMax = Math.max(hMax,
096                        sofnSegment.getComponents(i).horizontalSamplingFactor);
097                vMax = Math.max(vMax,
098                        sofnSegment.getComponents(i).verticalSamplingFactor);
099            }
100            final int hSize = 8 * hMax;
101            final int vSize = 8 * vMax;
102
103            final int xMCUs = (sofnSegment.width + hSize - 1) / hSize;
104            final int yMCUs = (sofnSegment.height + vSize - 1) / vSize;
105            final Block[] mcu = allocateMCUMemory();
106            final Block[] scaledMCU = new Block[mcu.length];
107            for (int i = 0; i < scaledMCU.length; i++) {
108                scaledMCU[i] = new Block(hSize, vSize);
109            }
110            final int[] preds = new int[sofnSegment.numberOfComponents];
111            ColorModel colorModel;
112            WritableRaster raster;
113            if (sofnSegment.numberOfComponents == 4) {
114                colorModel = new DirectColorModel(24, 0x00ff0000, 0x0000ff00, 0x000000ff);
115                int bandMasks[] = new int[] { 0x00ff0000, 0x0000ff00, 0x000000ff };
116                raster = Raster.createPackedRaster(DataBuffer.TYPE_INT, sofnSegment.width, sofnSegment.height, bandMasks, null);
117            } else if (sofnSegment.numberOfComponents == 3) {
118                colorModel = new DirectColorModel(24, 0x00ff0000, 0x0000ff00,
119                        0x000000ff);
120                raster = Raster.createPackedRaster(DataBuffer.TYPE_INT,
121                        sofnSegment.width, sofnSegment.height, new int[] {
122                                0x00ff0000, 0x0000ff00, 0x000000ff }, null);
123            } else if (sofnSegment.numberOfComponents == 1) {
124                colorModel = new DirectColorModel(24, 0x00ff0000, 0x0000ff00,
125                        0x000000ff);
126                raster = Raster.createPackedRaster(DataBuffer.TYPE_INT,
127                        sofnSegment.width, sofnSegment.height, new int[] {
128                                0x00ff0000, 0x0000ff00, 0x000000ff }, null);
129                // FIXME: why do images come out too bright with CS_GRAY?
130                // colorModel = new ComponentColorModel(
131                // ColorSpace.getInstance(ColorSpace.CS_GRAY), false, true,
132                // Transparency.OPAQUE, DataBuffer.TYPE_BYTE);
133                // raster = colorModel.createCompatibleWritableRaster(
134                // sofnSegment.width, sofnSegment.height);
135            } else {
136                throw new ImageReadException(sofnSegment.numberOfComponents
137                        + " components are invalid or unsupported");
138            }
139            final DataBuffer dataBuffer = raster.getDataBuffer();
140
141            final JpegInputStream[] bitInputStreams = splitByRstMarkers(scanPayload);
142            int bitInputStreamCount = 0;
143            JpegInputStream bitInputStream = bitInputStreams[0];
144
145            for (int y1 = 0; y1 < vSize * yMCUs; y1 += vSize) {
146                for (int x1 = 0; x1 < hSize * xMCUs; x1 += hSize) {
147                    // Provide the next interval if an interval is read until it's end
148                    // as long there are unread intervals available
149                    if (!bitInputStream.hasNext()) {
150                        bitInputStreamCount++;
151                        if (bitInputStreamCount < bitInputStreams.length) {
152                            bitInputStream = bitInputStreams[bitInputStreamCount];
153                        }
154                    }
155
156                    readMCU(bitInputStream, preds, mcu);
157                    rescaleMCU(mcu, hSize, vSize, scaledMCU);
158                    int srcRowOffset = 0;
159                    int dstRowOffset = y1 * sofnSegment.width + x1;
160                    for (int y2 = 0; y2 < vSize && y1 + y2 < sofnSegment.height; y2++) {
161                        for (int x2 = 0; x2 < hSize
162                                && x1 + x2 < sofnSegment.width; x2++) {
163                            if (scaledMCU.length == 4) {
164                                final int C = scaledMCU[0].samples[srcRowOffset + x2];
165                                final int M = scaledMCU[1].samples[srcRowOffset + x2];
166                                final int Y = scaledMCU[2].samples[srcRowOffset + x2];
167                                final int K = scaledMCU[3].samples[srcRowOffset + x2];
168                                final int rgb = ColorConversions.convertCMYKtoRGB(C, M, Y, K);
169                                dataBuffer.setElem(dstRowOffset + x2, rgb);
170                            } else if (scaledMCU.length == 3) {
171                                final int Y = scaledMCU[0].samples[srcRowOffset + x2];
172                                final int Cb = scaledMCU[1].samples[srcRowOffset + x2];
173                                final int Cr = scaledMCU[2].samples[srcRowOffset + x2];
174                                final int rgb = YCbCrConverter.convertYCbCrToRGB(Y,
175                                        Cb, Cr);
176                                dataBuffer.setElem(dstRowOffset + x2, rgb);
177                            } else if (mcu.length == 1) {
178                                final int Y = scaledMCU[0].samples[srcRowOffset + x2];
179                                dataBuffer.setElem(dstRowOffset + x2, (Y << 16)
180                                        | (Y << 8) | Y);
181                            } else {
182                                throw new ImageReadException(
183                                        "Unsupported JPEG with " + mcu.length
184                                                + " components");
185                            }
186                        }
187                        srcRowOffset += hSize;
188                        dstRowOffset += sofnSegment.width;
189                    }
190                }
191            }
192            image = new BufferedImage(colorModel, raster,
193                    colorModel.isAlphaPremultiplied(), new Properties());
194            // byte[] remainder = super.getStreamBytes(is);
195            // for (int i = 0; i < remainder.length; i++)
196            // {
197            // System.out.println("" + i + " = " +
198            // Integer.toHexString(remainder[i]));
199            // }
200        } catch (final ImageReadException imageReadEx) {
201            imageReadException = imageReadEx;
202        } catch (final IOException ioEx) {
203            ioException = ioEx;
204        } catch (final RuntimeException ex) {
205            // Corrupt images can throw NPE and IOOBE
206            imageReadException = new ImageReadException("Error parsing JPEG",ex);
207        }
208    }
209
210    @Override
211    public boolean visitSegment(final int marker, final byte[] markerBytes,
212            final int segmentLength, final byte[] segmentLengthBytes, final byte[] segmentData)
213            throws ImageReadException, IOException {
214        final int[] sofnSegments = {
215                JpegConstants.SOF0_MARKER,
216                JpegConstants.SOF1_MARKER,
217                JpegConstants.SOF2_MARKER,
218                JpegConstants.SOF3_MARKER,
219                JpegConstants.SOF5_MARKER,
220                JpegConstants.SOF6_MARKER,
221                JpegConstants.SOF7_MARKER,
222                JpegConstants.SOF9_MARKER,
223                JpegConstants.SOF10_MARKER,
224                JpegConstants.SOF11_MARKER,
225                JpegConstants.SOF13_MARKER,
226                JpegConstants.SOF14_MARKER,
227                JpegConstants.SOF15_MARKER,
228        };
229
230        if (Arrays.binarySearch(sofnSegments, marker) >= 0) {
231            if (marker != JpegConstants.SOF0_MARKER) {
232                throw new ImageReadException("Only sequential, baseline JPEGs "
233                        + "are supported at the moment");
234            }
235            sofnSegment = new SofnSegment(marker, segmentData);
236        } else if (marker == JpegConstants.DQT_MARKER) {
237            final DqtSegment dqtSegment = new DqtSegment(marker, segmentData);
238            for (int i = 0; i < dqtSegment.quantizationTables.size(); i++) {
239                final DqtSegment.QuantizationTable table = dqtSegment.quantizationTables.get(i);
240                if (0 > table.destinationIdentifier
241                        || table.destinationIdentifier >= quantizationTables.length) {
242                    throw new ImageReadException(
243                            "Invalid quantization table identifier "
244                                    + table.destinationIdentifier);
245                }
246                quantizationTables[table.destinationIdentifier] = table;
247                final int[] quantizationMatrixInt = new int[64];
248                ZigZag.zigZagToBlock(table.getElements(), quantizationMatrixInt);
249                final float[] quantizationMatrixFloat = new float[64];
250                for (int j = 0; j < 64; j++) {
251                    quantizationMatrixFloat[j] = quantizationMatrixInt[j];
252                }
253                Dct.scaleDequantizationMatrix(quantizationMatrixFloat);
254                scaledQuantizationTables[table.destinationIdentifier] = quantizationMatrixFloat;
255            }
256        } else if (marker == JpegConstants.DHT_MARKER) {
257            final DhtSegment dhtSegment = new DhtSegment(marker, segmentData);
258            for (int i = 0; i < dhtSegment.huffmanTables.size(); i++) {
259                final DhtSegment.HuffmanTable table = dhtSegment.huffmanTables.get(i);
260                DhtSegment.HuffmanTable[] tables;
261                if (table.tableClass == 0) {
262                    tables = huffmanDCTables;
263                } else if (table.tableClass == 1) {
264                    tables = huffmanACTables;
265                } else {
266                    throw new ImageReadException("Invalid huffman table class "
267                            + table.tableClass);
268                }
269                if (0 > table.destinationIdentifier
270                        || table.destinationIdentifier >= tables.length) {
271                    throw new ImageReadException(
272                            "Invalid huffman table identifier "
273                                    + table.destinationIdentifier);
274                }
275                tables[table.destinationIdentifier] = table;
276            }
277        }
278        return true;
279    }
280
281    private void rescaleMCU(final Block[] dataUnits, final int hSize, final int vSize, final Block[] ret) {
282        for (int i = 0; i < dataUnits.length; i++) {
283            final Block dataUnit = dataUnits[i];
284            if (dataUnit.width == hSize && dataUnit.height == vSize) {
285                System.arraycopy(dataUnit.samples, 0, ret[i].samples, 0, hSize
286                        * vSize);
287            } else {
288                final int hScale = hSize / dataUnit.width;
289                final int vScale = vSize / dataUnit.height;
290                if (hScale == 2 && vScale == 2) {
291                    int srcRowOffset = 0;
292                    int dstRowOffset = 0;
293                    for (int y = 0; y < dataUnit.height; y++) {
294                        for (int x = 0; x < hSize; x++) {
295                            final int sample = dataUnit.samples[srcRowOffset + (x >> 1)];
296                            ret[i].samples[dstRowOffset + x] = sample;
297                            ret[i].samples[dstRowOffset + hSize + x] = sample;
298                        }
299                        srcRowOffset += dataUnit.width;
300                        dstRowOffset += 2 * hSize;
301                    }
302                } else {
303                    // FIXME: optimize
304                    int dstRowOffset = 0;
305                    for (int y = 0; y < vSize; y++) {
306                        for (int x = 0; x < hSize; x++) {
307                            ret[i].samples[dstRowOffset + x] = dataUnit.samples[(y / vScale)
308                                    * dataUnit.width + (x / hScale)];
309                        }
310                        dstRowOffset += hSize;
311                    }
312                }
313            }
314        }
315    }
316
317    private Block[] allocateMCUMemory() throws ImageReadException {
318        final Block[] mcu = new Block[sosSegment.numberOfComponents];
319        for (int i = 0; i < sosSegment.numberOfComponents; i++) {
320            final SosSegment.Component scanComponent = sosSegment.getComponents(i);
321            SofnSegment.Component frameComponent = null;
322            for (int j = 0; j < sofnSegment.numberOfComponents; j++) {
323                if (sofnSegment.getComponents(j).componentIdentifier == scanComponent.scanComponentSelector) {
324                    frameComponent = sofnSegment.getComponents(j);
325                    break;
326                }
327            }
328            if (frameComponent == null) {
329                throw new ImageReadException("Invalid component");
330            }
331            final Block fullBlock = new Block(
332                    8 * frameComponent.horizontalSamplingFactor,
333                    8 * frameComponent.verticalSamplingFactor);
334            mcu[i] = fullBlock;
335        }
336        return mcu;
337    }
338
339    private void readMCU(final JpegInputStream is, final int[] preds, final Block[] mcu)
340            throws IOException, ImageReadException {
341        for (int i = 0; i < sosSegment.numberOfComponents; i++) {
342            final SosSegment.Component scanComponent = sosSegment.getComponents(i);
343            SofnSegment.Component frameComponent = null;
344            for (int j = 0; j < sofnSegment.numberOfComponents; j++) {
345                if (sofnSegment.getComponents(j).componentIdentifier == scanComponent.scanComponentSelector) {
346                    frameComponent = sofnSegment.getComponents(j);
347                    break;
348                }
349            }
350            if (frameComponent == null) {
351                throw new ImageReadException("Invalid component");
352            }
353            final Block fullBlock = mcu[i];
354            for (int y = 0; y < frameComponent.verticalSamplingFactor; y++) {
355                for (int x = 0; x < frameComponent.horizontalSamplingFactor; x++) {
356                    Arrays.fill(zz, 0);
357                    // page 104 of T.81
358                    final int t = decode(
359                            is,
360                            huffmanDCTables[scanComponent.dcCodingTableSelector]);
361                    int diff = receive(t, is);
362                    diff = extend(diff, t);
363                    zz[0] = preds[i] + diff;
364                    preds[i] = zz[0];
365
366                    // "Decode_AC_coefficients", figure F.13, page 106 of T.81
367                    int k = 1;
368                    while (true) {
369                        final int rs = decode(
370                                is,
371                                huffmanACTables[scanComponent.acCodingTableSelector]);
372                        final int ssss = rs & 0xf;
373                        final int rrrr = rs >> 4;
374                        final int r = rrrr;
375
376                        if (ssss == 0) {
377                            if (r == 15) {
378                                k += 16;
379                            } else {
380                                break;
381                            }
382                        } else {
383                            k += r;
384
385                            // "Decode_ZZ(k)", figure F.14, page 107 of T.81
386                            zz[k] = receive(ssss, is);
387                            zz[k] = extend(zz[k], ssss);
388
389                            if (k == 63) {
390                                break;
391                            } else {
392                                k++;
393                            }
394                        }
395                    }
396
397                    final int shift = (1 << (sofnSegment.precision - 1));
398                    final int max = (1 << sofnSegment.precision) - 1;
399
400                    final float[] scaledQuantizationTable = scaledQuantizationTables[frameComponent.quantTabDestSelector];
401                    ZigZag.zigZagToBlock(zz, blockInt);
402                    for (int j = 0; j < 64; j++) {
403                        block[j] = blockInt[j] * scaledQuantizationTable[j];
404                    }
405                    Dct.inverseDCT8x8(block);
406
407                    int dstRowOffset = 8 * y * 8
408                            * frameComponent.horizontalSamplingFactor + 8 * x;
409                    int srcNext = 0;
410                    for (int yy = 0; yy < 8; yy++) {
411                        for (int xx = 0; xx < 8; xx++) {
412                            float sample = block[srcNext++];
413                            sample += shift;
414                            int result;
415                            if (sample < 0) {
416                                result = 0;
417                            } else if (sample > max) {
418                                result = max;
419                            } else {
420                                result = fastRound(sample);
421                            }
422                            fullBlock.samples[dstRowOffset + xx] = result;
423                        }
424                        dstRowOffset += 8 * frameComponent.horizontalSamplingFactor;
425                    }
426                }
427            }
428        }
429    }
430
431    /**
432     * Returns an array of JpegInputStream where each field contains the JpegInputStream
433     * for one interval.
434     * @param scanPayload array to read intervals from
435     * @return JpegInputStreams for all intervals, at least one stream is always provided
436     */
437    static JpegInputStream[] splitByRstMarkers(final int[] scanPayload) {
438        final List<Integer> intervalStarts = getIntervalStartPositions(scanPayload);
439        // get number of intervals in payload to init an array of appropriate length
440        final int intervalCount = intervalStarts.size();
441        JpegInputStream[] streams = new JpegInputStream[intervalCount];
442        for (int i = 0; i < intervalCount; i++) {
443            int from = intervalStarts.get(i);
444            int to;
445            if (i < intervalCount - 1) {
446                // because each restart marker needs two bytes the end of
447                // this interval is two bytes before the next interval starts
448                to = intervalStarts.get(i + 1) - 2;
449            } else { // the last interval ends with the array
450                to = scanPayload.length;
451            }
452            int[] interval = Arrays.copyOfRange(scanPayload, from, to);
453            streams[i] = new JpegInputStream(interval);
454        }
455        return streams;
456    }
457
458    /**
459     * Returns the positions of where each interval in the provided array starts. The number
460     * of start positions is also the count of intervals while the number of restart markers
461     * found is equal to the number of start positions minus one (because restart markers
462     * are between intervals).
463     *
464     * @param scanPayload array to examine
465     * @return the start positions
466     */
467    static List<Integer> getIntervalStartPositions(final int[] scanPayload) {
468        final List<Integer> intervalStarts = new ArrayList<Integer>();
469        intervalStarts.add(0);
470        boolean foundFF = false;
471        boolean foundD0toD7 = false;
472        int pos = 0;
473        while (pos < scanPayload.length) {
474            if (foundFF) {
475                // found 0xFF D0 .. 0xFF D7 => RST marker
476                if (scanPayload[pos] >= (0xff & JpegConstants.RST0_MARKER) &&
477                    scanPayload[pos] <= (0xff & JpegConstants.RST7_MARKER)) {
478                    foundD0toD7 = true;
479                } else { // found 0xFF followed by something else => no RST marker
480                    foundFF = false;
481                }
482            }
483
484            if (scanPayload[pos] == 0xFF) {
485                foundFF = true;
486            }
487
488            // true if one of the RST markers was found
489            if (foundFF && foundD0toD7) {
490                // we need to add the position after the current position because
491                // we had already read 0xFF and are now at 0xDn
492                intervalStarts.add(pos + 1);
493                foundFF = foundD0toD7 = false;
494            }
495            pos++;
496        }
497        return intervalStarts;
498    }
499
500    private static int fastRound(final float x) {
501        return (int) (x + 0.5f);
502    }
503
504    private int extend(int v, final int t) {
505        // "EXTEND", section F.2.2.1, figure F.12, page 105 of T.81
506        int vt = (1 << (t - 1));
507        if (v < vt) {
508            vt = (-1 << t) + 1;
509            v += vt;
510        }
511        return v;
512    }
513
514    private int receive(final int ssss, final JpegInputStream is) throws IOException,
515            ImageReadException {
516        // "RECEIVE", section F.2.2.4, figure F.17, page 110 of T.81
517        int i = 0;
518        int v = 0;
519        while (i != ssss) {
520            i++;
521            v = (v << 1) + is.nextBit();
522        }
523        return v;
524    }
525
526    private int decode(final JpegInputStream is, final DhtSegment.HuffmanTable huffmanTable)
527            throws IOException, ImageReadException {
528        // "DECODE", section F.2.2.3, figure F.16, page 109 of T.81
529        int i = 1;
530        int code = is.nextBit();
531        while (code > huffmanTable.getMaxCode(i)) {
532            i++;
533            code = (code << 1) | is.nextBit();
534        }
535        int j = huffmanTable.getValPtr(i);
536        j += code - huffmanTable.getMinCode(i);
537        return huffmanTable.getHuffVal(j);
538    }
539
540    public BufferedImage decode(final ByteSource byteSource) throws IOException,
541            ImageReadException {
542        final JpegUtils jpegUtils = new JpegUtils();
543        jpegUtils.traverseJFIF(byteSource, this);
544        if (imageReadException != null) {
545            throw imageReadException;
546        }
547        if (ioException != null) {
548            throw ioException;
549        }
550        return image;
551    }
552}