Crypto++
zinflate.cpp
1 // zinflate.cpp - written and placed in the public domain by Wei Dai
2 
3 // This is a complete reimplementation of the DEFLATE decompression algorithm.
4 // It should not be affected by any security vulnerabilities in the zlib
5 // compression library. In particular it is not affected by the double free bug
6 // (http://www.kb.cert.org/vuls/id/368819).
7 
8 #include "pch.h"
9 #include "zinflate.h"
10 
11 NAMESPACE_BEGIN(CryptoPP)
12 
14 {
15  inline bool operator()(CryptoPP::HuffmanDecoder::code_t lhs, const CryptoPP::HuffmanDecoder::CodeInfo &rhs)
16  {return lhs < rhs.code;}
17  // needed for MSVC .NET 2005
18  inline bool operator()(const CryptoPP::HuffmanDecoder::CodeInfo &lhs, const CryptoPP::HuffmanDecoder::CodeInfo &rhs)
19  {return lhs.code < rhs.code;}
20 };
21 
22 inline bool LowFirstBitReader::FillBuffer(unsigned int length)
23 {
24  while (m_bitsBuffered < length)
25  {
26  byte b;
27  if (!m_store.Get(b))
28  return false;
29  m_buffer |= (unsigned long)b << m_bitsBuffered;
30  m_bitsBuffered += 8;
31  }
32  assert(m_bitsBuffered <= sizeof(unsigned long)*8);
33  return true;
34 }
35 
36 inline unsigned long LowFirstBitReader::PeekBits(unsigned int length)
37 {
38  bool result = FillBuffer(length);
39  assert(result);
40  return m_buffer & (((unsigned long)1 << length) - 1);
41 }
42 
43 inline void LowFirstBitReader::SkipBits(unsigned int length)
44 {
45  assert(m_bitsBuffered >= length);
46  m_buffer >>= length;
47  m_bitsBuffered -= length;
48 }
49 
50 inline unsigned long LowFirstBitReader::GetBits(unsigned int length)
51 {
52  unsigned long result = PeekBits(length);
53  SkipBits(length);
54  return result;
55 }
56 
57 inline HuffmanDecoder::code_t HuffmanDecoder::NormalizeCode(HuffmanDecoder::code_t code, unsigned int codeBits)
58 {
59  return code << (MAX_CODE_BITS - codeBits);
60 }
61 
62 void HuffmanDecoder::Initialize(const unsigned int *codeBits, unsigned int nCodes)
63 {
64  // the Huffman codes are represented in 3 ways in this code:
65  //
66  // 1. most significant code bit (i.e. top of code tree) in the least significant bit position
67  // 2. most significant code bit (i.e. top of code tree) in the most significant bit position
68  // 3. most significant code bit (i.e. top of code tree) in n-th least significant bit position,
69  // where n is the maximum code length for this code tree
70  //
71  // (1) is the way the codes come in from the deflate stream
72  // (2) is used to sort codes so they can be binary searched
73  // (3) is used in this function to compute codes from code lengths
74  //
75  // a code in representation (2) is called "normalized" here
76  // The BitReverse() function is used to convert between (1) and (2)
77  // The NormalizeCode() function is used to convert from (3) to (2)
78 
79  if (nCodes == 0)
80  throw Err("null code");
81 
82  m_maxCodeBits = *std::max_element(codeBits, codeBits+nCodes);
83 
84  if (m_maxCodeBits > MAX_CODE_BITS)
85  throw Err("code length exceeds maximum");
86 
87  if (m_maxCodeBits == 0)
88  throw Err("null code");
89 
90  // count number of codes of each length
91  SecBlockWithHint<unsigned int, 15+1> blCount(m_maxCodeBits+1);
92  std::fill(blCount.begin(), blCount.end(), 0);
93  unsigned int i;
94  for (i=0; i<nCodes; i++)
95  blCount[codeBits[i]]++;
96 
97  // compute the starting code of each length
98  code_t code = 0;
99  SecBlockWithHint<code_t, 15+1> nextCode(m_maxCodeBits+1);
100  nextCode[1] = 0;
101  for (i=2; i<=m_maxCodeBits; i++)
102  {
103  // compute this while checking for overflow: code = (code + blCount[i-1]) << 1
104  if (code > code + blCount[i-1])
105  throw Err("codes oversubscribed");
106  code += blCount[i-1];
107  if (code > (code << 1))
108  throw Err("codes oversubscribed");
109  code <<= 1;
110  nextCode[i] = code;
111  }
112 
113  if (code > (1 << m_maxCodeBits) - blCount[m_maxCodeBits])
114  throw Err("codes oversubscribed");
115  else if (m_maxCodeBits != 1 && code < (1 << m_maxCodeBits) - blCount[m_maxCodeBits])
116  throw Err("codes incomplete");
117 
118  // compute a vector of <code, length, value> triples sorted by code
119  m_codeToValue.resize(nCodes - blCount[0]);
120  unsigned int j=0;
121  for (i=0; i<nCodes; i++)
122  {
123  unsigned int len = codeBits[i];
124  if (len != 0)
125  {
126  code = NormalizeCode(nextCode[len]++, len);
127  m_codeToValue[j].code = code;
128  m_codeToValue[j].len = len;
129  m_codeToValue[j].value = i;
130  j++;
131  }
132  }
133  std::sort(m_codeToValue.begin(), m_codeToValue.end());
134 
135  // initialize the decoding cache
136  m_cacheBits = STDMIN(9U, m_maxCodeBits);
137  m_cacheMask = (1 << m_cacheBits) - 1;
138  m_normalizedCacheMask = NormalizeCode(m_cacheMask, m_cacheBits);
139  assert(m_normalizedCacheMask == BitReverse(m_cacheMask));
140 
141  if (m_cache.size() != size_t(1) << m_cacheBits)
142  m_cache.resize(1 << m_cacheBits);
143 
144  for (i=0; i<m_cache.size(); i++)
145  m_cache[i].type = 0;
146 }
147 
148 void HuffmanDecoder::FillCacheEntry(LookupEntry &entry, code_t normalizedCode) const
149 {
150  normalizedCode &= m_normalizedCacheMask;
151  const CodeInfo &codeInfo = *(std::upper_bound(m_codeToValue.begin(), m_codeToValue.end(), normalizedCode, CodeLessThan())-1);
152  if (codeInfo.len <= m_cacheBits)
153  {
154  entry.type = 1;
155  entry.value = codeInfo.value;
156  entry.len = codeInfo.len;
157  }
158  else
159  {
160  entry.begin = &codeInfo;
161  const CodeInfo *last = & *(std::upper_bound(m_codeToValue.begin(), m_codeToValue.end(), normalizedCode + ~m_normalizedCacheMask, CodeLessThan())-1);
162  if (codeInfo.len == last->len)
163  {
164  entry.type = 2;
165  entry.len = codeInfo.len;
166  }
167  else
168  {
169  entry.type = 3;
170  entry.end = last+1;
171  }
172  }
173 }
174 
175 inline unsigned int HuffmanDecoder::Decode(code_t code, /* out */ value_t &value) const
176 {
177  assert(m_codeToValue.size() > 0);
178  LookupEntry &entry = m_cache[code & m_cacheMask];
179 
180  code_t normalizedCode;
181  if (entry.type != 1)
182  normalizedCode = BitReverse(code);
183 
184  if (entry.type == 0)
185  FillCacheEntry(entry, normalizedCode);
186 
187  if (entry.type == 1)
188  {
189  value = entry.value;
190  return entry.len;
191  }
192  else
193  {
194  const CodeInfo &codeInfo = (entry.type == 2)
195  ? entry.begin[(normalizedCode << m_cacheBits) >> (MAX_CODE_BITS - (entry.len - m_cacheBits))]
196  : *(std::upper_bound(entry.begin, entry.end, normalizedCode, CodeLessThan())-1);
197  value = codeInfo.value;
198  return codeInfo.len;
199  }
200 }
201 
202 bool HuffmanDecoder::Decode(LowFirstBitReader &reader, value_t &value) const
203 {
204  reader.FillBuffer(m_maxCodeBits);
205  unsigned int codeBits = Decode(reader.PeekBuffer(), value);
206  if (codeBits > reader.BitsBuffered())
207  return false;
208  reader.SkipBits(codeBits);
209  return true;
210 }
211 
212 // *************************************************************
213 
214 Inflator::Inflator(BufferedTransformation *attachment, bool repeat, int propagation)
215  : AutoSignaling<Filter>(propagation)
216  , m_state(PRE_STREAM), m_repeat(repeat), m_reader(m_inQueue)
217 {
218  Detach(attachment);
219 }
220 
221 void Inflator::IsolatedInitialize(const NameValuePairs &parameters)
222 {
223  m_state = PRE_STREAM;
224  parameters.GetValue("Repeat", m_repeat);
225  m_inQueue.Clear();
226  m_reader.SkipBits(m_reader.BitsBuffered());
227 }
228 
229 void Inflator::OutputByte(byte b)
230 {
231  m_window[m_current++] = b;
232  if (m_current == m_window.size())
233  {
234  ProcessDecompressedData(m_window + m_lastFlush, m_window.size() - m_lastFlush);
235  m_lastFlush = 0;
236  m_current = 0;
237  m_wrappedAround = true;
238  }
239 }
240 
241 void Inflator::OutputString(const byte *string, size_t length)
242 {
243  while (length)
244  {
245  size_t len = UnsignedMin(length, m_window.size() - m_current);
246  memcpy(m_window + m_current, string, len);
247  m_current += len;
248  if (m_current == m_window.size())
249  {
250  ProcessDecompressedData(m_window + m_lastFlush, m_window.size() - m_lastFlush);
251  m_lastFlush = 0;
252  m_current = 0;
253  m_wrappedAround = true;
254  }
255  string += len;
256  length -= len;
257  }
258 }
259 
260 void Inflator::OutputPast(unsigned int length, unsigned int distance)
261 {
262  size_t start;
263  if (distance <= m_current)
264  start = m_current - distance;
265  else if (m_wrappedAround && distance <= m_window.size())
266  start = m_current + m_window.size() - distance;
267  else
268  throw BadBlockErr();
269 
270  if (start + length > m_window.size())
271  {
272  for (; start < m_window.size(); start++, length--)
273  OutputByte(m_window[start]);
274  start = 0;
275  }
276 
277  if (start + length > m_current || m_current + length >= m_window.size())
278  {
279  while (length--)
280  OutputByte(m_window[start++]);
281  }
282  else
283  {
284  memcpy(m_window + m_current, m_window + start, length);
285  m_current += length;
286  }
287 }
288 
289 size_t Inflator::Put2(const byte *inString, size_t length, int messageEnd, bool blocking)
290 {
291  if (!blocking)
292  throw BlockingInputOnly("Inflator");
293 
294  LazyPutter lp(m_inQueue, inString, length);
295  ProcessInput(messageEnd != 0);
296 
297  if (messageEnd)
298  if (!(m_state == PRE_STREAM || m_state == AFTER_END))
299  throw UnexpectedEndErr();
300 
301  Output(0, NULL, 0, messageEnd, blocking);
302  return 0;
303 }
304 
305 bool Inflator::IsolatedFlush(bool hardFlush, bool blocking)
306 {
307  if (!blocking)
308  throw BlockingInputOnly("Inflator");
309 
310  if (hardFlush)
311  ProcessInput(true);
312  FlushOutput();
313 
314  return false;
315 }
316 
317 void Inflator::ProcessInput(bool flush)
318 {
319  while (true)
320  {
321  switch (m_state)
322  {
323  case PRE_STREAM:
324  if (!flush && m_inQueue.CurrentSize() < MaxPrestreamHeaderSize())
325  return;
326  ProcessPrestreamHeader();
327  m_state = WAIT_HEADER;
328  m_wrappedAround = false;
329  m_current = 0;
330  m_lastFlush = 0;
331  m_window.New(1 << GetLog2WindowSize());
332  break;
333  case WAIT_HEADER:
334  {
335  // maximum number of bytes before actual compressed data starts
336  const size_t MAX_HEADER_SIZE = BitsToBytes(3+5+5+4+19*7+286*15+19*15);
337  if (m_inQueue.CurrentSize() < (flush ? 1 : MAX_HEADER_SIZE))
338  return;
339  DecodeHeader();
340  break;
341  }
342  case DECODING_BODY:
343  if (!DecodeBody())
344  return;
345  break;
346  case POST_STREAM:
347  if (!flush && m_inQueue.CurrentSize() < MaxPoststreamTailSize())
348  return;
349  ProcessPoststreamTail();
350  m_state = m_repeat ? PRE_STREAM : AFTER_END;
351  Output(0, NULL, 0, GetAutoSignalPropagation(), true); // TODO: non-blocking
352  if (m_inQueue.IsEmpty())
353  return;
354  break;
355  case AFTER_END:
356  m_inQueue.TransferTo(*AttachedTransformation());
357  return;
358  }
359  }
360 }
361 
362 void Inflator::DecodeHeader()
363 {
364  if (!m_reader.FillBuffer(3))
365  throw UnexpectedEndErr();
366  m_eof = m_reader.GetBits(1) != 0;
367  m_blockType = (byte)m_reader.GetBits(2);
368  switch (m_blockType)
369  {
370  case 0: // stored
371  {
372  m_reader.SkipBits(m_reader.BitsBuffered() % 8);
373  if (!m_reader.FillBuffer(32))
374  throw UnexpectedEndErr();
375  m_storedLen = (word16)m_reader.GetBits(16);
376  word16 nlen = (word16)m_reader.GetBits(16);
377  if (nlen != (word16)~m_storedLen)
378  throw BadBlockErr();
379  break;
380  }
381  case 1: // fixed codes
382  m_nextDecode = LITERAL;
383  break;
384  case 2: // dynamic codes
385  {
386  if (!m_reader.FillBuffer(5+5+4))
387  throw UnexpectedEndErr();
388  unsigned int hlit = m_reader.GetBits(5);
389  unsigned int hdist = m_reader.GetBits(5);
390  unsigned int hclen = m_reader.GetBits(4);
391 
393  unsigned int i;
394  static const unsigned int border[] = { // Order of the bit length code lengths
395  16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15};
396  std::fill(codeLengths.begin(), codeLengths+19, 0);
397  for (i=0; i<hclen+4; i++)
398  codeLengths[border[i]] = m_reader.GetBits(3);
399 
400  try
401  {
402  HuffmanDecoder codeLengthDecoder(codeLengths, 19);
403  for (i = 0; i < hlit+257+hdist+1; )
404  {
405  unsigned int k, count, repeater;
406  bool result = codeLengthDecoder.Decode(m_reader, k);
407  if (!result)
408  throw UnexpectedEndErr();
409  if (k <= 15)
410  {
411  count = 1;
412  repeater = k;
413  }
414  else switch (k)
415  {
416  case 16:
417  if (!m_reader.FillBuffer(2))
418  throw UnexpectedEndErr();
419  count = 3 + m_reader.GetBits(2);
420  if (i == 0)
421  throw BadBlockErr();
422  repeater = codeLengths[i-1];
423  break;
424  case 17:
425  if (!m_reader.FillBuffer(3))
426  throw UnexpectedEndErr();
427  count = 3 + m_reader.GetBits(3);
428  repeater = 0;
429  break;
430  case 18:
431  if (!m_reader.FillBuffer(7))
432  throw UnexpectedEndErr();
433  count = 11 + m_reader.GetBits(7);
434  repeater = 0;
435  break;
436  }
437  if (i + count > hlit+257+hdist+1)
438  throw BadBlockErr();
439  std::fill(codeLengths + i, codeLengths + i + count, repeater);
440  i += count;
441  }
442  m_dynamicLiteralDecoder.Initialize(codeLengths, hlit+257);
443  if (hdist == 0 && codeLengths[hlit+257] == 0)
444  {
445  if (hlit != 0) // a single zero distance code length means all literals
446  throw BadBlockErr();
447  }
448  else
449  m_dynamicDistanceDecoder.Initialize(codeLengths+hlit+257, hdist+1);
450  m_nextDecode = LITERAL;
451  }
452  catch (HuffmanDecoder::Err &)
453  {
454  throw BadBlockErr();
455  }
456  break;
457  }
458  default:
459  throw BadBlockErr(); // reserved block type
460  }
461  m_state = DECODING_BODY;
462 }
463 
464 bool Inflator::DecodeBody()
465 {
466  bool blockEnd = false;
467  switch (m_blockType)
468  {
469  case 0: // stored
470  assert(m_reader.BitsBuffered() == 0);
471  while (!m_inQueue.IsEmpty() && !blockEnd)
472  {
473  size_t size;
474  const byte *block = m_inQueue.Spy(size);
475  size = UnsignedMin(m_storedLen, size);
476  OutputString(block, size);
477  m_inQueue.Skip(size);
478  m_storedLen -= (word16)size;
479  if (m_storedLen == 0)
480  blockEnd = true;
481  }
482  break;
483  case 1: // fixed codes
484  case 2: // dynamic codes
485  static const unsigned int lengthStarts[] = {
486  3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31,
487  35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258};
488  static const unsigned int lengthExtraBits[] = {
489  0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2,
490  3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0};
491  static const unsigned int distanceStarts[] = {
492  1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193,
493  257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145,
494  8193, 12289, 16385, 24577};
495  static const unsigned int distanceExtraBits[] = {
496  0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6,
497  7, 7, 8, 8, 9, 9, 10, 10, 11, 11,
498  12, 12, 13, 13};
499 
500  const HuffmanDecoder& literalDecoder = GetLiteralDecoder();
501  const HuffmanDecoder& distanceDecoder = GetDistanceDecoder();
502 
503  switch (m_nextDecode)
504  {
505  case LITERAL:
506  while (true)
507  {
508  if (!literalDecoder.Decode(m_reader, m_literal))
509  {
510  m_nextDecode = LITERAL;
511  break;
512  }
513  if (m_literal < 256)
514  OutputByte((byte)m_literal);
515  else if (m_literal == 256) // end of block
516  {
517  blockEnd = true;
518  break;
519  }
520  else
521  {
522  if (m_literal > 285)
523  throw BadBlockErr();
524  unsigned int bits;
525  case LENGTH_BITS:
526  bits = lengthExtraBits[m_literal-257];
527  if (!m_reader.FillBuffer(bits))
528  {
529  m_nextDecode = LENGTH_BITS;
530  break;
531  }
532  m_literal = m_reader.GetBits(bits) + lengthStarts[m_literal-257];
533  case DISTANCE:
534  if (!distanceDecoder.Decode(m_reader, m_distance))
535  {
536  m_nextDecode = DISTANCE;
537  break;
538  }
539  case DISTANCE_BITS:
540  bits = distanceExtraBits[m_distance];
541  if (!m_reader.FillBuffer(bits))
542  {
543  m_nextDecode = DISTANCE_BITS;
544  break;
545  }
546  m_distance = m_reader.GetBits(bits) + distanceStarts[m_distance];
547  OutputPast(m_literal, m_distance);
548  }
549  }
550  }
551  }
552  if (blockEnd)
553  {
554  if (m_eof)
555  {
556  FlushOutput();
557  m_reader.SkipBits(m_reader.BitsBuffered()%8);
558  if (m_reader.BitsBuffered())
559  {
560  // undo too much lookahead
561  SecBlockWithHint<byte, 4> buffer(m_reader.BitsBuffered() / 8);
562  for (unsigned int i=0; i<buffer.size(); i++)
563  buffer[i] = (byte)m_reader.GetBits(8);
564  m_inQueue.Unget(buffer, buffer.size());
565  }
566  m_state = POST_STREAM;
567  }
568  else
569  m_state = WAIT_HEADER;
570  }
571  return blockEnd;
572 }
573 
574 void Inflator::FlushOutput()
575 {
576  if (m_state != PRE_STREAM)
577  {
578  assert(m_current >= m_lastFlush);
579  ProcessDecompressedData(m_window + m_lastFlush, m_current - m_lastFlush);
580  m_lastFlush = m_current;
581  }
582 }
583 
585 {
586  HuffmanDecoder * operator()() const
587  {
588  unsigned int codeLengths[288];
589  std::fill(codeLengths + 0, codeLengths + 144, 8);
590  std::fill(codeLengths + 144, codeLengths + 256, 9);
591  std::fill(codeLengths + 256, codeLengths + 280, 7);
592  std::fill(codeLengths + 280, codeLengths + 288, 8);
593  std::auto_ptr<HuffmanDecoder> pDecoder(new HuffmanDecoder);
594  pDecoder->Initialize(codeLengths, 288);
595  return pDecoder.release();
596  }
597 };
598 
600 {
601  HuffmanDecoder * operator()() const
602  {
603  unsigned int codeLengths[32];
604  std::fill(codeLengths + 0, codeLengths + 32, 5);
605  std::auto_ptr<HuffmanDecoder> pDecoder(new HuffmanDecoder);
606  pDecoder->Initialize(codeLengths, 32);
607  return pDecoder.release();
608  }
609 };
610 
611 const HuffmanDecoder& Inflator::GetLiteralDecoder() const
612 {
613  return m_blockType == 1 ? Singleton<HuffmanDecoder, NewFixedLiteralDecoder>().Ref() : m_dynamicLiteralDecoder;
614 }
615 
616 const HuffmanDecoder& Inflator::GetDistanceDecoder() const
617 {
618  return m_blockType == 1 ? Singleton<HuffmanDecoder, NewFixedDistanceDecoder>().Ref() : m_dynamicDistanceDecoder;
619 }
620 
621 NAMESPACE_END
Inflator(BufferedTransformation *attachment=NULL, bool repeat=false, int autoSignalPropagation=-1)
Definition: zinflate.cpp:214
use this to make sure LazyPut is finalized in event of exception
Definition: queue.h:112
a SecBlock that preallocates size S statically, and uses the heap when this size is exceeded ...
Definition: secblock.h:435
void New(size_type newSize)
change size, without preserving contents
Definition: secblock.h:361
interface for buffered transformations
Definition: cryptlib.h:771
bool GetValue(const char *name, T &value) const
get a named value, returns true if the name exists
Definition: cryptlib.h:262
lword TransferTo(BufferedTransformation &target, lword transferMax=LWORD_MAX, const std::string &channel=DEFAULT_CHANNEL)
move transferMax bytes of the buffered output to target as input
Definition: cryptlib.h:900
void Detach(BufferedTransformation *newAttachment=NULL)
delete the current attachment chain and replace it with newAttachment
Definition: filters.cpp:40
BufferedTransformation * AttachedTransformation()
returns the object immediately attached to this object or NULL for no attachment
Definition: filters.cpp:26
virtual lword Skip(lword skipMax=LWORD_MAX)
discard skipMax bytes from the output buffer
Definition: cryptlib.cpp:443
Huffman Decoder.
Definition: zinflate.h:32
a SecBlock with fixed size, allocated statically
Definition: secblock.h:422
thrown by objects that have not implemented nonblocking input processing
Definition: cryptlib.h:821
provides an implementation of BufferedTransformation&#39;s attachment interface
Definition: filters.h:17
size_t Put2(const byte *inString, size_t length, int messageEnd, bool blocking)
input multiple bytes for blocking or non-blocking processing
Definition: zinflate.cpp:289
BufferedTransformation & Ref()
return a reference to this object, useful for passing a temporary object to a function that takes a n...
Definition: cryptlib.h:780
virtual size_t Get(byte &outByte)
try to retrieve a single byte
Definition: cryptlib.cpp:405
interface for retrieving values given their names
Definition: cryptlib.h:225