monotone

monotone Mtn Source Tree

Root/cryptopp/zinflate.cpp

1// zinflate.cpp - written and placed in the public domain by Wei Dai
2
3// This is a complete reimplementation of the DEFLATE decompression algorithm.
4// It should not be affected by any security vulnerabilities in the zlib
5// compression library. In particular it is not affected by the double free bug
6// (http://www.kb.cert.org/vuls/id/368819).
7
8#include "pch.h"
9#include "zinflate.h"
10
11NAMESPACE_BEGIN(CryptoPP)
12
13struct CodeLessThan
14{
15inline bool operator()(const CryptoPP::HuffmanDecoder::code_t lhs, const CryptoPP::HuffmanDecoder::CodeInfo &rhs)
16{return lhs < rhs.code;}
17};
18
19inline bool LowFirstBitReader::FillBuffer(unsigned int length)
20{
21while (m_bitsBuffered < length)
22{
23byte b;
24if (!m_store.Get(b))
25return false;
26m_buffer |= (unsigned long)b << m_bitsBuffered;
27m_bitsBuffered += 8;
28}
29assert(m_bitsBuffered <= sizeof(unsigned long)*8);
30return true;
31}
32
33inline unsigned long LowFirstBitReader::PeekBits(unsigned int length)
34{
35bool result = FillBuffer(length);
36assert(result);
37return m_buffer & (((unsigned long)1 << length) - 1);
38}
39
40inline void LowFirstBitReader::SkipBits(unsigned int length)
41{
42assert(m_bitsBuffered >= length);
43m_buffer >>= length;
44m_bitsBuffered -= length;
45}
46
47inline unsigned long LowFirstBitReader::GetBits(unsigned int length)
48{
49unsigned long result = PeekBits(length);
50SkipBits(length);
51return result;
52}
53
54inline HuffmanDecoder::code_t HuffmanDecoder::NormalizeCode(HuffmanDecoder::code_t code, unsigned int codeBits)
55{
56return code << (MAX_CODE_BITS - codeBits);
57}
58
59void HuffmanDecoder::Initialize(const unsigned int *codeBits, unsigned int nCodes)
60{
61// the Huffman codes are represented in 3 ways in this code:
62//
63// 1. most significant code bit (i.e. top of code tree) in the least significant bit position
64// 2. most significant code bit (i.e. top of code tree) in the most significant bit position
65// 3. most significant code bit (i.e. top of code tree) in n-th least significant bit position,
66// where n is the maximum code length for this code tree
67//
68// (1) is the way the codes come in from the deflate stream
69// (2) is used to sort codes so they can be binary searched
70// (3) is used in this function to compute codes from code lengths
71//
72// a code in representation (2) is called "normalized" here
73// The BitReverse() function is used to convert between (1) and (2)
74// The NormalizeCode() function is used to convert from (3) to (2)
75
76if (nCodes == 0)
77throw Err("null code");
78
79m_maxCodeBits = *std::max_element(codeBits, codeBits+nCodes);
80
81if (m_maxCodeBits > MAX_CODE_BITS)
82throw Err("code length exceeds maximum");
83
84if (m_maxCodeBits == 0)
85throw Err("null code");
86
87// count number of codes of each length
88SecBlockWithHint<unsigned int, 15+1> blCount(m_maxCodeBits+1);
89std::fill(blCount.begin(), blCount.end(), 0);
90unsigned int i;
91for (i=0; i<nCodes; i++)
92blCount[codeBits[i]]++;
93
94// compute the starting code of each length
95code_t code = 0;
96SecBlockWithHint<code_t, 15+1> nextCode(m_maxCodeBits+1);
97nextCode[1] = 0;
98for (i=2; i<=m_maxCodeBits; i++)
99{
100// compute this while checking for overflow: code = (code + blCount[i-1]) << 1
101if (code > code + blCount[i-1])
102throw Err("codes oversubscribed");
103code += blCount[i-1];
104if (code > (code << 1))
105throw Err("codes oversubscribed");
106code <<= 1;
107nextCode[i] = code;
108}
109
110if (code > (1 << m_maxCodeBits) - blCount[m_maxCodeBits])
111throw Err("codes oversubscribed");
112else if (m_maxCodeBits != 1 && code < (1 << m_maxCodeBits) - blCount[m_maxCodeBits])
113throw Err("codes incomplete");
114
115// compute a vector of <code, length, value> triples sorted by code
116m_codeToValue.resize(nCodes - blCount[0]);
117unsigned int j=0;
118for (i=0; i<nCodes; i++)
119{
120unsigned int len = codeBits[i];
121if (len != 0)
122{
123code = NormalizeCode(nextCode[len]++, len);
124m_codeToValue[j].code = code;
125m_codeToValue[j].len = len;
126m_codeToValue[j].value = i;
127j++;
128}
129}
130std::sort(m_codeToValue.begin(), m_codeToValue.end());
131
132// initialize the decoding cache
133m_cacheBits = STDMIN(9U, m_maxCodeBits);
134m_cacheMask = (1 << m_cacheBits) - 1;
135m_normalizedCacheMask = NormalizeCode(m_cacheMask, m_cacheBits);
136assert(m_normalizedCacheMask == BitReverse(m_cacheMask));
137
138if (m_cache.size() != 1 << m_cacheBits)
139m_cache.resize(1 << m_cacheBits);
140
141for (i=0; i<m_cache.size(); i++)
142m_cache[i].type = 0;
143}
144
145void HuffmanDecoder::FillCacheEntry(LookupEntry &entry, code_t normalizedCode) const
146{
147normalizedCode &= m_normalizedCacheMask;
148const CodeInfo &codeInfo = *(std::upper_bound(m_codeToValue.begin(), m_codeToValue.end(), normalizedCode, CodeLessThan())-1);
149if (codeInfo.len <= m_cacheBits)
150{
151entry.type = 1;
152entry.value = codeInfo.value;
153entry.len = codeInfo.len;
154}
155else
156{
157entry.begin = &codeInfo;
158const CodeInfo *last = & *(std::upper_bound(m_codeToValue.begin(), m_codeToValue.end(), normalizedCode + ~m_normalizedCacheMask, CodeLessThan())-1);
159if (codeInfo.len == last->len)
160{
161entry.type = 2;
162entry.len = codeInfo.len;
163}
164else
165{
166entry.type = 3;
167entry.end = last+1;
168}
169}
170}
171
172inline unsigned int HuffmanDecoder::Decode(code_t code, /* out */ value_t &value) const
173{
174assert(m_codeToValue.size() > 0);
175LookupEntry &entry = m_cache[code & m_cacheMask];
176
177code_t normalizedCode;
178if (entry.type != 1)
179normalizedCode = BitReverse(code);
180
181if (entry.type == 0)
182FillCacheEntry(entry, normalizedCode);
183
184if (entry.type == 1)
185{
186value = entry.value;
187return entry.len;
188}
189else
190{
191const CodeInfo &codeInfo = (entry.type == 2)
192? entry.begin[(normalizedCode << m_cacheBits) >> (MAX_CODE_BITS - (entry.len - m_cacheBits))]
193: *(std::upper_bound(entry.begin, entry.end, normalizedCode, CodeLessThan())-1);
194value = codeInfo.value;
195return codeInfo.len;
196}
197}
198
199bool HuffmanDecoder::Decode(LowFirstBitReader &reader, value_t &value) const
200{
201reader.FillBuffer(m_maxCodeBits);
202unsigned int codeBits = Decode(reader.PeekBuffer(), value);
203if (codeBits > reader.BitsBuffered())
204return false;
205reader.SkipBits(codeBits);
206return true;
207}
208
209// *************************************************************
210
211Inflator::Inflator(BufferedTransformation *attachment, bool repeat, int propagation)
212: AutoSignaling<Filter>(attachment, propagation)
213, m_state(PRE_STREAM), m_repeat(repeat)
214, m_decodersInitializedWithFixedCodes(false), m_reader(m_inQueue)
215{
216}
217
218void Inflator::IsolatedInitialize(const NameValuePairs &parameters)
219{
220m_state = PRE_STREAM;
221parameters.GetValue("Repeat", m_repeat);
222m_inQueue.Clear();
223m_reader.SkipBits(m_reader.BitsBuffered());
224}
225
226inline void Inflator::OutputByte(byte b)
227{
228m_window[m_current++] = b;
229if (m_current == m_window.size())
230{
231ProcessDecompressedData(m_window + m_lastFlush, m_window.size() - m_lastFlush);
232m_lastFlush = 0;
233m_current = 0;
234}
235if (m_maxDistance < m_window.size())
236m_maxDistance++;
237}
238
239void Inflator::OutputString(const byte *string, unsigned int length)
240{
241while (length--)
242OutputByte(*string++);
243}
244
245void Inflator::OutputPast(unsigned int length, unsigned int distance)
246{
247if (distance > m_maxDistance)
248throw BadBlockErr();
249unsigned int start;
250if (m_current > distance)
251start = m_current - distance;
252else
253start = m_current + m_window.size() - distance;
254
255if (start + length > m_window.size())
256{
257for (; start < m_window.size(); start++, length--)
258OutputByte(m_window[start]);
259start = 0;
260}
261
262if (start + length > m_current || m_current + length >= m_window.size())
263{
264while (length--)
265OutputByte(m_window[start++]);
266}
267else
268{
269memcpy(m_window + m_current, m_window + start, length);
270m_current += length;
271m_maxDistance = STDMIN((unsigned int)m_window.size(), m_maxDistance + length);
272}
273}
274
275unsigned int Inflator::Put2(const byte *inString, unsigned int length, int messageEnd, bool blocking)
276{
277if (!blocking)
278throw BlockingInputOnly("Inflator");
279
280LazyPutter lp(m_inQueue, inString, length);
281ProcessInput(messageEnd != 0);
282
283if (messageEnd)
284if (!(m_state == PRE_STREAM || m_state == AFTER_END))
285throw UnexpectedEndErr();
286
287Output(0, NULL, 0, messageEnd, blocking);
288return 0;
289}
290
291bool Inflator::IsolatedFlush(bool hardFlush, bool blocking)
292{
293if (!blocking)
294throw BlockingInputOnly("Inflator");
295
296if (hardFlush)
297ProcessInput(true);
298FlushOutput();
299
300return false;
301}
302
303void Inflator::ProcessInput(bool flush)
304{
305while (true)
306{
307if (m_inQueue.IsEmpty())
308return;
309
310switch (m_state)
311{
312case PRE_STREAM:
313if (!flush && m_inQueue.CurrentSize() < MaxPrestreamHeaderSize())
314return;
315ProcessPrestreamHeader();
316m_state = WAIT_HEADER;
317m_maxDistance = 0;
318m_current = 0;
319m_lastFlush = 0;
320m_window.New(1 << GetLog2WindowSize());
321break;
322case WAIT_HEADER:
323{
324// maximum number of bytes before actual compressed data starts
325const unsigned int MAX_HEADER_SIZE = BitsToBytes(3+5+5+4+19*7+286*15+19*15);
326if (m_inQueue.CurrentSize() < (flush ? 1 : MAX_HEADER_SIZE))
327return;
328DecodeHeader();
329break;
330}
331case DECODING_BODY:
332if (!DecodeBody())
333return;
334break;
335case POST_STREAM:
336if (!flush && m_inQueue.CurrentSize() < MaxPoststreamTailSize())
337return;
338ProcessPoststreamTail();
339m_state = m_repeat ? PRE_STREAM : AFTER_END;
340Output(0, NULL, 0, GetAutoSignalPropagation(), true);// TODO: non-blocking
341break;
342case AFTER_END:
343m_inQueue.TransferTo(*AttachedTransformation());
344return;
345}
346}
347}
348
349void Inflator::DecodeHeader()
350{
351if (!m_reader.FillBuffer(3))
352throw UnexpectedEndErr();
353m_eof = m_reader.GetBits(1) != 0;
354m_blockType = (byte)m_reader.GetBits(2);
355switch (m_blockType)
356{
357case 0:// stored
358{
359m_reader.SkipBits(m_reader.BitsBuffered() % 8);
360if (!m_reader.FillBuffer(32))
361throw UnexpectedEndErr();
362m_storedLen = (word16)m_reader.GetBits(16);
363word16 nlen = (word16)m_reader.GetBits(16);
364if (nlen != (word16)~m_storedLen)
365throw BadBlockErr();
366break;
367}
368case 1:// fixed codes
369if (!m_decodersInitializedWithFixedCodes)
370{
371unsigned int codeLengths[288];
372std::fill(codeLengths + 0, codeLengths + 144, 8);
373std::fill(codeLengths + 144, codeLengths + 256, 9);
374std::fill(codeLengths + 256, codeLengths + 280, 7);
375std::fill(codeLengths + 280, codeLengths + 288, 8);
376m_literalDecoder.Initialize(codeLengths, 288);
377std::fill(codeLengths + 0, codeLengths + 32, 5);
378m_distanceDecoder.Initialize(codeLengths, 32);
379m_decodersInitializedWithFixedCodes = true;
380}
381m_nextDecode = LITERAL;
382break;
383case 2:// dynamic codes
384{
385m_decodersInitializedWithFixedCodes = false;
386if (!m_reader.FillBuffer(5+5+4))
387throw UnexpectedEndErr();
388unsigned int hlit = m_reader.GetBits(5);
389unsigned int hdist = m_reader.GetBits(5);
390unsigned int hclen = m_reader.GetBits(4);
391
392FixedSizeSecBlock<unsigned int, 286+32> codeLengths;
393unsigned int i;
394static const unsigned int border[] = { // Order of the bit length code lengths
39516, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};
396std::fill(codeLengths.begin(), codeLengths+19, 0);
397for (i=0; i<hclen+4; i++)
398codeLengths[border[i]] = m_reader.GetBits(3);
399
400try
401{
402HuffmanDecoder codeLengthDecoder(codeLengths, 19);
403for (i = 0; i < hlit+257+hdist+1; )
404{
405unsigned int k, count, repeater;
406bool result = codeLengthDecoder.Decode(m_reader, k);
407if (!result)
408throw UnexpectedEndErr();
409if (k <= 15)
410{
411count = 1;
412repeater = k;
413}
414else switch (k)
415{
416case 16:
417if (!m_reader.FillBuffer(2))
418throw UnexpectedEndErr();
419count = 3 + m_reader.GetBits(2);
420if (i == 0)
421throw BadBlockErr();
422repeater = codeLengths[i-1];
423break;
424case 17:
425if (!m_reader.FillBuffer(3))
426throw UnexpectedEndErr();
427count = 3 + m_reader.GetBits(3);
428repeater = 0;
429break;
430case 18:
431if (!m_reader.FillBuffer(7))
432throw UnexpectedEndErr();
433count = 11 + m_reader.GetBits(7);
434repeater = 0;
435break;
436}
437if (i + count > hlit+257+hdist+1)
438throw BadBlockErr();
439std::fill(codeLengths + i, codeLengths + i + count, repeater);
440i += count;
441}
442m_literalDecoder.Initialize(codeLengths, hlit+257);
443if (hdist == 0 && codeLengths[hlit+257] == 0)
444{
445if (hlit != 0)// a single zero distance code length means all literals
446throw BadBlockErr();
447}
448else
449m_distanceDecoder.Initialize(codeLengths+hlit+257, hdist+1);
450m_nextDecode = LITERAL;
451}
452catch (HuffmanDecoder::Err &)
453{
454throw BadBlockErr();
455}
456break;
457}
458default:
459throw BadBlockErr();// reserved block type
460}
461m_state = DECODING_BODY;
462}
463
464bool Inflator::DecodeBody()
465{
466bool blockEnd = false;
467switch (m_blockType)
468{
469case 0:// stored
470assert(m_reader.BitsBuffered() == 0);
471while (!m_inQueue.IsEmpty() && !blockEnd)
472{
473unsigned int size;
474const byte *block = m_inQueue.Spy(size);
475size = STDMIN(size, (unsigned int)m_storedLen);
476OutputString(block, size);
477m_inQueue.Skip(size);
478m_storedLen -= size;
479if (m_storedLen == 0)
480blockEnd = true;
481}
482break;
483case 1:// fixed codes
484case 2:// dynamic codes
485static const unsigned int lengthStarts[] = {
4863, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31,
48735, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258};
488static const unsigned int lengthExtraBits[] = {
4890, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2,
4903, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0};
491static const unsigned int distanceStarts[] = {
4921, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193,
493257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145,
4948193, 12289, 16385, 24577};
495static const unsigned int distanceExtraBits[] = {
4960, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6,
4977, 7, 8, 8, 9, 9, 10, 10, 11, 11,
49812, 12, 13, 13};
499
500switch (m_nextDecode)
501{
502while (true)
503{
504case LITERAL:
505if (!m_literalDecoder.Decode(m_reader, m_literal))
506{
507m_nextDecode = LITERAL;
508break;
509}
510if (m_literal < 256)
511OutputByte((byte)m_literal);
512else if (m_literal == 256)// end of block
513{
514blockEnd = true;
515break;
516}
517else
518{
519if (m_literal > 285)
520throw BadBlockErr();
521unsigned int bits;
522case LENGTH_BITS:
523bits = lengthExtraBits[m_literal-257];
524if (!m_reader.FillBuffer(bits))
525{
526m_nextDecode = LENGTH_BITS;
527break;
528}
529m_literal = m_reader.GetBits(bits) + lengthStarts[m_literal-257];
530case DISTANCE:
531if (!m_distanceDecoder.Decode(m_reader, m_distance))
532{
533m_nextDecode = DISTANCE;
534break;
535}
536case DISTANCE_BITS:
537bits = distanceExtraBits[m_distance];
538if (!m_reader.FillBuffer(bits))
539{
540m_nextDecode = DISTANCE_BITS;
541break;
542}
543m_distance = m_reader.GetBits(bits) + distanceStarts[m_distance];
544OutputPast(m_literal, m_distance);
545}
546}
547}
548}
549if (blockEnd)
550{
551if (m_eof)
552{
553FlushOutput();
554m_reader.SkipBits(m_reader.BitsBuffered()%8);
555if (m_reader.BitsBuffered())
556{
557// undo too much lookahead
558SecBlockWithHint<byte, 4> buffer(m_reader.BitsBuffered() / 8);
559for (unsigned int i=0; i<buffer.size(); i++)
560buffer[i] = (byte)m_reader.GetBits(8);
561m_inQueue.Unget(buffer, buffer.size());
562}
563m_state = POST_STREAM;
564}
565else
566m_state = WAIT_HEADER;
567}
568return blockEnd;
569}
570
571void Inflator::FlushOutput()
572{
573if (m_state != PRE_STREAM)
574{
575assert(m_current >= m_lastFlush);
576ProcessDecompressedData(m_window + m_lastFlush, m_current - m_lastFlush);
577m_lastFlush = m_current;
578}
579}
580
581NAMESPACE_END

Archive Download this file

Branches

Tags

Quick Links:     www.monotone.ca    -     Downloads    -     Documentation    -     Wiki    -     Code Forge    -     Build Status