
/**********************************************************//**
 **
 ** @file util/bitstream.h
 **
 ** Copyright (C) 2008  Xpace, LLC.  All rights reserved
 **
 **************************************************************/


#ifndef XPACE_BITSTREAM_H
#define XPACE_BITSTREAM_H

#include <algorithm>
#include <vector>
#include <cassert>

#include "base/types.h"
#include "base/exception.h"
#include "util/bitops.h"

namespace Xpace
{
  // ================================== IN-MEMORY BIT STREAM ====

  class XPACE_EXPORT MemBitStream
  {
  public :
    // with a fixed buffer; if buf = 0, buffer is empty
    MemBitStream
      (byte *buf,
       size_t bufLen);   // max 2^29

    // with a dynamically-resizable buffer; if buf = 0, allocate a buffer 
    MemBitStream
      (std::vector<byte>* buf = 0);

    ~MemBitStream
      ();

    void setBuf
      (byte *buf,
       size_t bufLen);       // max 2^29

    byte* getBuf
      ();

    size_t getByteLen
      ()
    const;
    size_t getBitLen
      ()
    const;

    bool atEnd
      ()
    const;

    size_t getBitPos
      ()
      const
    { 
      return bit_pos; 
    };

    size_t getBytePos
      ()
    const
    { 
      return (bit_pos + 7) >> 3; 
    };
   
    // bit offsets
    // return false on error
    bool seek    
      (size_t off);

    void clear
      ();

    bool addBit
      (bool bit);

    /// add bits to stream
    /// @param numBits number of bits to add
    /// @param bits add low numBits from this 
    /// @return true if successful, false if bits can't be written (buffer or file full)
    bool addBits
      (size_t numBits,       
       uint bits);

    /// add bits to stream
    /// @param numBits number of bits to add
    /// @param bits add low numBits from this 
    /// @return true if successful, false if bits can't be written (buffer or file full)
    bool addBits
      (size_t numBits,       
       uint64 bits);

    /// add bits to stream
    /// @param numBits number of bits to add
    /// @param bits add sign bit plus low numBits - 1 from this 
    /// @return true if successful, false if bits can't be written (buffer or file full)
    bool addBits
      (size_t numBits,
       int bits)
    { 
      return addBit(bits < 0) && addBits(numBits - 1, uint(abs(bits)));
    }

    /// add bits to stream
    /// @param numBits number of bits to add
    /// @param bits add sign bit plus low numBits - 1 from this 
    /// @return true if successful, false if bits can't be written (buffer or file full)
    bool addBits
      (size_t numBits,       
       int64 bits)
    { 
      return (bits < 0) 
        ? (addBit(true) && addBits(numBits, uint64(-bits)))
        : (addBit(false) && addBits(numBits, uint64(bits)));
    }

    /// add bits to stream
    /// @param numBits number of bits to add
    /// @param bits add low numBits from this
    bool addBits
      (size_t numBits,       
       const byte *bits);   

    /// get a single bit from the stream
    /// @return the bit
    bool getBit
      ();

    /// get bits from stream
    /// @param numBits bits to get
    /// @return bits read
    template <typename T>
    T getBits
      (size_t numBits);

    /// for very small values only
    static uint lenTriInt
      (uint n);
    bool addTriInt
      (uint n);
    uint getTriInt    // return -1 at end of list    
      ();

    /// general variable ints

    /// thrown on get if stream is corrupt
    class Corrupt : public Exception
    {
    public:
      /// @param pos the position in the stream of the corruption
      Corrupt
        (uint64 pos) :
          Exception("BitStream corrupt at pos \"%1\".")
        {
          addParam(String().setNum(pos));
        }
    };

    /// how many bits to store a uint
    /// @param n the uint
    static size_t lenVar
      (uint n);

    /// how many bits to store an int
    /// @param n the int
    static size_t lenVar
      (int n);

    /// add a uint to the stream
    /// @param n the uint to add
    /// @return true if successful, false if bits can't be written (buffer or file full)
    /// N.B. buffer must be cleared
    bool addVar  
      (uint n);

    /// how many bits to store a uint64
    /// @param n the uint64
    static size_t lenVar
      (uint64 n);

    /// add a uint64 to the stream
    /// @param n uint64 to add
    /// @return true if successful, false if bits can't be written (buffer or file full)
    /// N.B. buffer must be cleared
    bool addVar  
      (uint64 n);

    /// add an int to the stream
    /// @param n the int to add
    /// @return true if successful, false if bits can't be written (buffer or file full)
    /// N.B. buffer must be cleared
    bool addVar
      (int n);

    /// add a uint64 to the stream
    /// @param n uint64 to add
    /// @return true if successful, false if bits can't be written (buffer or file full)
    /// N.B. buffer must be cleared
    bool addVar
      (int64 n);

    /// fill the remainder of current byte with 1's
    void fillByte
      ();

    /// get a number (T = int, uint, int64, uint64) from the stream
    /// @param ok fillin false iff internally inconsistent or end-of-file
    /// @return the number
    template <typename T>
    T getVar
      (bool* ok);

    /// get a number (T = int, uint, int64, uint64) from the stream
    /// @return the number
    /// @throw Corrupt if internally inconsistent or end-of-file
    template <typename T>
    T getVar
      ();
 
    #ifndef NDEBUG
    static
    void test
      ();
    #endif

  private :
    byte* buffer;
    size_t bit_pos;
    size_t bit_len;
    size_t byte_len;

    std::vector<byte>* vect;
    bool own_vect;
    void resize_vect
      ();

    enum buf_status
    {
      buf_ok,
      buf_by_byte,
      buf_end
    };
    template <typename T>
    buf_status check_buf
      (size_t byt,
       size_t numBits);

    template <typename T>
    void add_bits_by_byte
      (size_t byt,
       size_t shift,
       size_t numBits,
       T bits);

    uint get_ubits
      (size_t num_bits); 
    uint64 get_ubits64
      (size_t num_bits);
    int get_bits
      (size_t num_bits); 
    int64 get_bits64
      (size_t num_bits);

    uint get_bits_no_check
      (size_t numBits);

    // var int access - fillin ok = false on end-of-buffer
    uint get_var_uint          
      (bool* ok);
    uint64 get_var_uint64
      (bool* ok);
    int get_var_int          
      (bool* ok);
    int64 get_var_int64
      (bool* ok);
    // var int access -throw corrupt on end-of-buffer
    uint get_var_uint          
      ();
    uint64 get_var_uint64
      ();
    int get_var_int          
      ();
    int64 get_var_int64
      ();
  };

  // ================================== FILE BIT STREAM (READ) ==

  #if defined XPACE_FILE_H
  class XPACE_EXPORT FileBitStream
  {
  public :
    FileBitStream
      ();

    void finishByte
      ();

    virtual ~FileBitStream
      () 
    {};

    bool setFile
      (File *f,
       uint64 start = File::errorPosition,    // start position; defaults to current
       uint64 bitLen = 0);                    // defaults to file size * 8 - start  
      
    uint64 getByteLen
      ()
    const;
    uint64 getBitLen
      ()
      const;

    bool atEnd
      ()
      const;

    // relative to stream, not whole file
    uint64 getBitPos
      ()
      const;
    uint64 getBytePos
      ()
      const;

    const File* getFile
      ()
      const;

    // relative to file, not stream
    uint64 getFileBitPos
      ()
      const;
    uint64 getFileBytePos
      ()
      const;

   
    // bit offsets
    // return false on error
    bool seek   
      (uint off);
    bool seek
      (uint64 off);

    bool getBit
      ();

    template <typename T>
    T getBits
      (size_t num_bits);

    uint getTriInt
      ();

    // T = int, uint, int64, uint64
    // fillin false on end-of-file or at bitLen
    template <typename T>
    T getVar
      (bool* ok);
    // throw corrupt on end-of-file or at bitLen
    template <typename T>
    T getVar
      ();

  private :
    File* file;
    uint64 start;     
    uint64 bit_len;  
    uint64 end;
    uint64 bit_pos;   // relative to file 

    uint get_ubits
      (size_t num_bits); 
    uint64 get_ubits64
      (size_t num_bits);
    int get_bits
      (size_t num_bits); 
    int64 get_bits64
      (size_t num_bits);

    uint get_var_uint          
      (bool* ok);
    uint64 get_var_uint64
      (bool* ok);
    int get_var_int          
      (bool* ok);
    int64 get_var_int64
      (bool* ok);

    uint get_var_uint          
      ();
    uint64 get_var_uint64
      ();
    int get_var_int          
      ();
    int64 get_var_int64
      ();
  };
  #endif

  // ================================== ENCODING ============

  template <typename T> 
  class encoding
  {
  public :
    encoding
      (uint base);
    
    uint base
      ()
      const
    { 
      return b; 
    };

    T operator[]
      (uint i)
      const
    { 
      return n[i]; 
    };
        
  private :
    const uint b;
    std::vector<T> n;
  };


  template <typename T> 
  class decoding
  {
  public :
    decoding
      (uint base);
    
    uint base
      ()
      const
    { 
      return b;
    };

    T operator[]
      (uint i)
      const
    { 
      return n[i]; 
    };
        
  private :
    const uint b;
    std::vector<T> n;
  };

  namespace BSS
  {
    extern const uint32 bitMask32[];
    extern const uint64 bitMask64[];
    extern const uint bitsSet[];
    extern uint8 fillHigh[];
    extern uint size3[];
    extern uint encoding3[];
    extern encoding<uint> encoding4;
    extern encoding<uint64> encoding4_64; 
    extern decoding<uint> decoding4;
    extern decoding<uint64> decoding4_64; 
  }

  // ============================================================
  // ================================== INLINES =================
  // ============================================================

  template <typename T>
  encoding<T>::encoding
    (uint base) :
      b(base)
  {
    n.push_back(0);
    for (T add(4); add; add <<= b)
      n.push_back(n.back() + add);
    n.push_back(T(-1));
  };


  template <typename T>
  decoding<T>::decoding
    (uint base) :
      b(base)
  {
    n.push_back(0);
    for (T add(1); add; add <<= b)
      n.push_back(n.back() + add);
  };


  // ================================== MEMBITSTREAM ============

  inline
  bool MemBitStream::getBit
    ()
  {
    if (bit_pos >= bit_len)
      return false;

    bool ret(!!(buffer[bit_pos >> 3] & (uint8(1) << (bit_pos & 7))));
    ++bit_pos;
    return ret;
  }

  inline
  bool MemBitStream::addBit
    (bool bit)
  {
    if (bit_pos >= bit_len)
    {
      if (vect)
        resize_vect();
      else
        return false;
    }

    if (bit)
      buffer[bit_pos >> 3] |= uint8(1) << (bit_pos & 7);
    ++bit_pos;
    return true;
  }

  template <typename T>
  MemBitStream::buf_status MemBitStream::check_buf
    (size_t byt,
     size_t numBits)
  {
    if (byt + sizeof(T) <= byte_len)  
      return buf_ok;
  
    if (vect)
    {
      resize_vect();
      return buf_ok;
    }
      
    return (bit_pos + numBits > bit_len)
      ? buf_end : buf_by_byte;
  }
  
  template <typename T>
  inline
  void MemBitStream::add_bits_by_byte
    (size_t byt,
     size_t shift,
     size_t numBits,
     T bits)
  {
    const byte* b(reinterpret_cast<const byte*>(&bits));
    byte tmp[sizeof(T) + 1];
    tmp[sizeof(T)] = static_cast<byte>(b[sizeof(T) - 1] >> (8 - shift));
    bits <<= shift;
    memcpy(tmp, b, sizeof(T));

    byte* src(tmp);
    byte* dest(&buffer[byt]);
    ptrdiff_t n(numBits + shift);
    do
      *dest++ |= *src++;
    while ((n -= 8) > 0);
  }

  inline
  bool MemBitStream::addBits
    (size_t numBits,
     uint bits)
  {
    size_t byt(bit_pos >> 3);
    size_t shift(bit_pos & 7);

    if (numBits + shift > sizeof(uint) * 8)
    {
      switch (check_buf<uint64>(byt, numBits))
      {
        case buf_ok:
        {
          uint64 n((static_cast<uint64>(bits) & BSS::bitMask32[numBits]) << shift); 
          *reinterpret_cast<uint64*>(&buffer[byt]) |= n;
          break;
        }
        case buf_by_byte:
          add_bits_by_byte(byt, shift, numBits, bits);
          break;
        case buf_end:
          return false;
      }
    }
    else
    {
      switch (check_buf<uint>(byt, numBits))
      {
        case buf_ok:
        {
          uint n((bits & BSS::bitMask32[numBits]) << shift); 
          *reinterpret_cast<uint*>(&buffer[byt]) |= n;
          break;
        }
        case buf_by_byte:
          add_bits_by_byte(byt, shift, numBits, bits);
          break;
        case buf_end:
          return false;
      }
    }  

    bit_pos += numBits;
    return true;
  }

  inline
  bool MemBitStream::addBits
    (size_t numBits,
     uint64 bits)
  {
    size_t byt(bit_pos >> 3);
    size_t bit(bit_pos & 7);

    switch (check_buf<uint64>(byt, numBits))    
    {
      case buf_ok :
      {
        uint64 n = (bits & BSS::bitMask64[numBits]) << bit;

        // low bits
        *reinterpret_cast<uint64*>(&buffer[byt]) |= n;
        if (numBits + bit > (sizeof(uint64) << 3))
        {
          // add high bits
          uint8 n(static_cast<uint8>(bits >> (sizeof(uint64) * 8 - bit)));
          buffer[byt + sizeof(uint64) - 1] |= n;
        }
      }
      break;

      case buf_by_byte:
        add_bits_by_byte(byt, bit, numBits, bits);
        break;
      case buf_end:
        return false;
    }  
   
    bit_pos += numBits;
    return true;
  }

  inline
  bool MemBitStream::addBits
    (size_t numBits,
     const byte* bits)
  {
    size_t s(static_cast<uint>(sizeof(uint) << 3));
    while (1)
    {
      size_t n(std::min(numBits, s));
      if (!addBits(n, *(uint*)bits))
        return false;
      if ((numBits -= n) == 0)
        return true;
      bits += sizeof(uint);
    }
  }

  inline
  void MemBitStream::fillByte
    ()
  {
    size_t bit = bit_pos & 7;
    if (bit)
    {
      byte *dest = &buffer[bit_pos >> 3];
      *dest |= BSS::fillHigh[bit];
    }
  }

  inline
  uint MemBitStream::get_ubits
    (size_t numBits)
  {
    uint bits;
    size_t byt = bit_pos >> 3;
    size_t bit = bit_pos & 7;

    if (byt + sizeof(uint) > byte_len)
    {
      if ((bit_pos + numBits) > bit_len)
        return 0;
      bits = *reinterpret_cast<uint*>(&buffer[byte_len - sizeof(uint)]);
      bits >>= (sizeof(uint) << 3) + bit_pos - bit_len;
    }
   
    else
    {
      bits = *reinterpret_cast<uint*>(&buffer[byt]);
      bits >>= bit;
    }

    if (numBits + bit > 32)
      bits |= (buffer[byt + 4] & BSS::bitMask32[bit]) << (32 - bit);

    bits &= BSS::bitMask32[numBits];

    bit_pos += numBits;

    return bits;
  }

  inline
  int MemBitStream::get_bits
    (size_t numBits)
  {
    bool sign(getBit());
    int n(get_ubits(numBits - 1));
    return (sign) ? -n : n;
  }

  inline 
  uint64 MemBitStream::get_ubits64
    (size_t num_bits)
  {
    if (((bit_pos + num_bits + 7) >> 3) > byte_len)
      return 0;

    uint64 bits;
    size_t byte = bit_pos >> 3;
    size_t bit = bit_pos & 7;

    bits = *reinterpret_cast<uint64*>(&buffer[byte]);
    bits >>= bit;

    if (num_bits + bit > 64)
      bits |= (buffer[byte + 8] & BSS::bitMask64[bit]) << (64 - bit);

    bits &= BSS::bitMask64[num_bits];

    bit_pos += num_bits;

    return bits;
  }

  inline
  int64 MemBitStream::get_bits64
    (size_t numBits)
  {
    bool sign(getBit());
    int64 n(get_ubits64(numBits - 1));
    return (sign) ? -n : n;
  }

  template <>
  inline
  uint MemBitStream::getBits<uint>
    (size_t numBits)
  {
    return get_ubits(numBits);
  }

  template <>
  inline
  int MemBitStream::getBits<int>
    (size_t numBits)
  {
    return get_bits(numBits);
  }

  template <>
  inline
  uint64 MemBitStream::getBits<uint64>
    (size_t numBits)
  {
    return get_ubits64(numBits);
  }

  #if defined __LP64__
  template <>
  inline
  size_t MemBitStream::getBits<size_t>
    (size_t numBits)
  {
    return get_ubits64(numBits);
  }
  #endif

  template <>
  inline
  int64 MemBitStream::getBits<int64>
    (size_t numBits)
  {
    return get_bits64(numBits);
  }

  inline
  byte* MemBitStream::getBuf
    ()
  {
    return buffer;
  }

  inline
  size_t MemBitStream::getByteLen
    ()
  const
  {
    return byte_len;
  }

  inline
  size_t MemBitStream::getBitLen
      ()
    const
  {
    return bit_len;
  }

  inline
  bool MemBitStream::seek
    (size_t off)
  {
    if (off > bit_len)
      return false;
    bit_pos = off;
    return true;
  }

  inline
  bool MemBitStream::atEnd
    ()
    const
  {
    return (bit_pos >= bit_len);
  }

   
  FORCEINLINE
  uint MemBitStream::getTriInt
    ()
  {
    if (bit_pos > bit_len - 8)
      Trace::Msg("MemBitStream::getTriInt", "len %u  off %u", bit_len, bit_pos);

    size_t bit(bit_pos & 7);
    size_t m(1 << bit);
    uint d(0);
    
    if ((bit_pos >> 3) + sizeof(uint) < byte_len)
    {
      // don't worry about going off the end of the buffer
      // find zero bit
      uint* b(reinterpret_cast<uint*>(&buffer[bit_pos >> 3]));
      while (*b & m)
      {
        m <<= 1;
        ++d;
      }

      if (d & 1)
      {
        bit_pos += d + 1;
        return (d / 2) * 3 + 1;
      }

      bit_pos += d + 2;
      return (*b & m << 1)
        ? (d / 2) * 3 + 2
        : (d / 2) * 3;
    }

    // must be careful about end-of-buffer
    size_t stop(bit_len - bit_pos);
    byte* b(&buffer[bit_pos >> 3]);
    
    // find next zero bit
    while (1)
    {
      if (d == stop)
      {
        // no zero bit -> end of buffer
       bit_pos = bit_len;
       return uint(-1);
      }

      if (!(*b & m))
       break;

      if ((m <<= 1) == 0x100)
      {
        m = 1;
        ++b;
      }
      ++d;
    }
    
    if (d & 1)
    {
      bit_pos += d + 1;
      return (d / 2) * 3 + 1;
    }

    bit_pos += d + 2;

    if ((m <<= 1) == 0x100)
    {
      m = 1;
      ++b;
    }

    return (*b & m)
      ? (d / 2) * 3 + 2
      : (d / 2) * 3;
  }

  FORCEINLINE
  uint MemBitStream::get_bits_no_check
    (size_t numBits)
  {
    uint bits;
    size_t byt = bit_pos >> 3;
    size_t bit = bit_pos  & 7;

    const byte* buf(buffer);

    if (byt + sizeof(uint) > byte_len)
    {
      bits = *reinterpret_cast<const uint*>(&buf[byte_len - sizeof(uint)]);
      bits >>= (sizeof(uint) << 3) + bit_pos - bit_len;
    } 
    else
    {
      bits = *reinterpret_cast<const uint*>(&buf[byt]);
      bits >>= bit;
    }

    if (numBits + bit > 32)
      bits |= (buf[byt + 4] & BSS::bitMask32[bit]) << (32 - bit);

    bits &= BSS::bitMask32[numBits];
    bit_pos += numBits;

    return bits;
  }

  // ================================== MEMBITSTREAM::GETVAR ====

  FORCEINLINE
  uint MemBitStream::get_var_uint
    (bool* ok)
  { 
    int groups(getTriInt());
    if (groups < 1)
    {
      *ok = !atEnd();
      return 0;
    }

    uint num_bits(groups * BSS::decoding4.base());

    *ok = true;
    return get_bits_no_check(num_bits) + BSS::decoding4[groups];
  }

  FORCEINLINE
  int MemBitStream::get_var_int
    (bool *ok)
  { 
    int n(get_var_uint(ok));
    bool sign = n && getBit();
    return sign ? -n : n;
  }

  inline
  uint64 MemBitStream::get_var_uint64
    (bool *ok)
  {
    int groups(getTriInt());
    if (groups < 1)
    {
      *ok = !atEnd();
      return 0;
    }

    *ok = true;
    return get_ubits64(groups * BSS::decoding4_64.base()) +
           BSS::decoding4_64[groups];
  }

  FORCEINLINE
  int64 MemBitStream::get_var_int64
    (bool* ok)
  { 
    bool sign = getBit();
    int64 n(get_var_uint64(ok));
    return sign ? -n : n;
  }

  FORCEINLINE
  uint MemBitStream::get_var_uint
    ()
  { 
    int groups(getTriInt());
    if (groups < 1)
    {
      if (groups == 0)
        return 0;
      throw MemBitStream::Corrupt(bit_pos);
    }

    uint num_bits(groups * BSS::decoding4.base());
    return get_bits_no_check(num_bits) + BSS::decoding4[groups];
  }

  FORCEINLINE
  int MemBitStream::get_var_int
    ()
  { 
    int n(get_var_uint());
    return (n && getBit()) ? -n : n;
  }

  inline
  uint64 MemBitStream::get_var_uint64
    ()
  {
    int groups(getTriInt());
    if (groups < 1)
    {
      if (groups == 0)
        return 0;
      throw MemBitStream::Corrupt(bit_pos);
    }

    return get_ubits64(groups * BSS::decoding4_64.base()) +
           BSS::decoding4_64[groups];
  }

  FORCEINLINE
  int64 MemBitStream::get_var_int64
    ()
  { 
    int64 n(get_var_uint64());
    return (n && getBit()) ? -n : n;
  }

  // ================================== MEMBITSTREAN::LEN... ====

  // static
  inline
  uint MemBitStream::lenTriInt
    (uint n)
  {
    return (n < 48) ? BSS::size3[n] : 2 + (n / 3) * 2;
  }

  // static
  inline
  size_t MemBitStream::lenVar
    (uint n)
  {
    uint groups;
    for (groups = 0; n > BSS::encoding4[groups]; groups++)
      ;
    return lenTriInt(groups) + groups * BSS::encoding4.base();
  }

  // static
  inline
  size_t MemBitStream::lenVar
    (int _n)
  {
    uint n(abs(_n));
    uint groups;
    for (groups = 0; n > BSS::encoding4[groups]; groups++)
      ;
    return lenTriInt(groups) + groups * BSS::encoding4.base() + !!n;
  }

  //static
  inline
  size_t MemBitStream::lenVar
    (uint64 n)
  {
    uint groups;
    for (groups = 0; n > BSS::encoding4_64[groups]; groups++)
      ;
    return lenTriInt(groups) + groups * BSS::encoding4.base();
  }

  // ================================== MEMBITSTREAM::ADD... ====

  inline 
  bool MemBitStream::addTriInt
    (uint n)
  {
    if (n < 48)
      return addBits(BSS::size3[n], BSS::encoding3[n]);

    while (n > 2)
    {
      if (!addBits(2, (uint) 3))
        return false;
      n -= 3;
    }
    return addBits(2, n);
  }

  inline
  bool MemBitStream::addVar
    (uint64 n)
  {
    uint groups;
    for (groups = 0; n > BSS::encoding4_64[groups]; groups++)
      ;

    return (addTriInt(groups) &&
            addBits(groups * BSS::encoding4_64.base(), 
                    n - BSS::decoding4_64[groups]));
  }

  inline
  bool MemBitStream::addVar
    (uint n)
  {
    uint groups;
    for (groups = 0; n > BSS::encoding4[groups]; groups++)
      ;

    return (addTriInt(groups) &&
            addBits(groups * BSS::encoding4.base(), n - BSS::decoding4[groups]));
  }

  inline
  bool MemBitStream::addVar
    (int n)
  {
    return addVar(uint(abs(n))) && (!n || addBit(n < 0));
  }

  inline
  bool MemBitStream::addVar
    (int64 n)
  {
    if (n < 0) 
      return addVar(uint64(-n)) && addBit(true);

    return addVar(uint64(n)) && (!n || addBit(false));
  }

  template <>
  inline
  uint32 MemBitStream::getVar<uint>
    (bool* ok)
  {
    return get_var_uint(ok);
  }

  template <>
  inline
  uint64 MemBitStream::getVar<uint64>
    (bool* ok)
  {
    return get_var_uint64(ok);
  }

  template <>
  inline
  int32 MemBitStream::getVar<int>
    (bool *ok)
  {
    return get_var_int(ok);
  }

  template <>
  inline
  int64 MemBitStream::getVar<int64>
    (bool *ok)
  {
    return get_var_int64(ok);
  }

  template <>
  inline
  uint32 MemBitStream::getVar<uint>
    ()
  {
    return get_var_uint();
  }

  template <>
  inline
  uint64 MemBitStream::getVar<uint64>
    ()
  {
    return get_var_uint64();
  }

  template <>
  inline
  int32 MemBitStream::getVar<int>
    ()
  {
    return get_var_int();
  }

  template <>
  inline
  int64 MemBitStream::getVar<int64>
    ()
  {
    return get_var_int64();
  }

  #if defined __LP64__
  template <>
  inline
  size_t MemBitStream::getVar<size_t>
    ()
  {
    return size_t(get_var_uint64());
  }
  #endif

  #if defined XPACE_FILE_H || defined DOCUMENTATION
  // ================================== FILE BIT STREAM =========

  inline
  FileBitStream::FileBitStream
    ()
  {
  }

  inline
  void FileBitStream::finishByte
    ()
  {
    if (bit_pos & 7)
    {
      file->seekRel(1);
      bit_pos = (bit_pos + 8) & ~7;
    }
    assert(file->getPos() == bit_pos >> 3);
  }


  inline
  bool FileBitStream::setFile
    (File *f,
     uint64 s,      
     uint64 bitLen) 
  {
    file = f;
    bit_pos = start = (s == File::errorPosition) ? (f->getPos() << 3) : s;
    File::Position file_len(file->getLength());
    if (file_len == File::errorPosition)
      return false;
    bit_len = bitLen ? bitLen : (file_len << 3) - start;
    end = start + bit_len;

    return file->seek(bit_pos >> 3);
  }

  inline
  uint64 FileBitStream::getByteLen
    ()
  const
  {
    return bit_len >> 3;
  }

  inline
  uint64 FileBitStream::getBitLen
    ()
  const
  {
    return bit_len;
  }

  inline
  bool FileBitStream::atEnd
    ()
  const
  {
    return bit_pos >= end;
  }

  inline
  uint64 FileBitStream::getBitPos
    ()
  const
  { 
    return bit_pos - start; 
  }

  inline
  uint64 FileBitStream::getBytePos
    ()
  const
  { 
    return (getBitPos() + 7) >> 3; 
  }

  inline const File* FileBitStream::getFile
    ()
    const
  {
    return file;
  }

  inline 
  uint64 FileBitStream::getFileBitPos
    ()
  const
  {
    return bit_pos;
  }

  inline
  uint64 FileBitStream::getFileBytePos
    ()
  const
  { 
    return (getFileBitPos() + 7) >> 3; 
  }

  inline
  bool FileBitStream::seek
    (uint64 off)
  {
    uint64 targ(start + off);
    if ((targ > end) || !file->seek(targ >> 3))
      return false;
    bit_pos = targ;
    return true;
  }

  inline
  bool FileBitStream::getBit
    ()
  {
    uint ret(*(file->get<byte>()) & (1 << (bit_pos++ & 7)));
    if (bit_pos & 7)
      file->seekRel(-1);
    return !!ret;
  }

  inline
  uint FileBitStream::get_ubits
    (size_t numBits)
  {
    uint64 bit(bit_pos & 7);
    uint64 e(numBits + bit);
    uint64 bytes((e + 7) >> 3);

    uint bits;
    byte* b;
    if (file->get(&b, size_t(bytes)) != bytes)
      throw File_Cant_Read(*file);

    if (e & 7)
      file->seekRel(-1);

    bits = *reinterpret_cast<uint*>(b);
    bits >>= bit;

    if (numBits + bit > 32)
      bits |= (*(b + 4) & BSS::bitMask32[bit]) << (32 - bit);

    bits &= BSS::bitMask32[numBits];

    bit_pos += numBits;
    return bits;
  }

  inline
  int FileBitStream::get_bits
    (size_t numBits)
  {
    bool sign(getBit());
    int n(get_ubits(numBits - 1));
    return sign ? -n : n;
  }

  inline
  uint64 FileBitStream::get_ubits64
    (size_t numBits)
  {
    uint bit(static_cast<uint>(bit_pos & 7));
    size_t e(numBits + bit);
    size_t bytes((e + 7) >> 3);

    uint64 bits;
    byte* b;
    if (file->get(&b, bytes) != bytes)
      throw File_Cant_Read(*file);
    if (e & 7)
      file->seekRel(-1);

    bits = *reinterpret_cast<uint64*>(b);
    bits >>= bit;

    if (numBits + bit > 64)
      bits |= (*(b + 8) & BSS::bitMask64[bit]) << (64 - bit);

    bits &= BSS::bitMask64[numBits];

    bit_pos += numBits;
    return bits;
  }


  inline
  int64 FileBitStream::get_bits64
    (size_t numBits)
  {
    bool sign(getBit());
    int64 n(get_ubits64(numBits - 1));
    return sign ? -n : n;
  }

  template <>
  inline
  uint FileBitStream::getBits<uint>
    (size_t numBits)
  {
    return get_ubits(numBits);
  }

  template <>
  inline
  int FileBitStream::getBits<int>
    (size_t numBits)
  {
    return get_bits(numBits);
  }

  template <>
  inline
  uint64 FileBitStream::getBits<uint64>
    (size_t numBits)
  {
    return get_ubits64(numBits);
  }

  #if defined __LP64__
  template <>
  inline
  size_t FileBitStream::getBits<size_t>
    (size_t numBits)
  {
    return size_t(get_ubits64(numBits));
  }
  #endif

  template <>
  inline
  int64 FileBitStream::getBits<int64>
    (size_t numBits)
  {
    return get_bits64(numBits);
  }

  inline
  uint FileBitStream::getTriInt
    ()
  {
    uint64 bit(bit_pos & 7);
    uint b;
    size_t d;

    if ((end - bit_pos >= sizeof(uint) * 8) &&
      (file->bufRemains() > sizeof(uint)))        // @todo: combine tests
    {
      // the usual case
      b = *file->get<uint>();
      b >>= bit;
      if (~b == 0)
        throw File_Corrupt(*file);
      d = rightBit(~b);
      file->seekRel(((bit + ((d + 2) & ~1)) >> 3) - sizeof(uint));
    }
    else
    {
      d = 0;
      // near the end of a buffer
      b = uint(-1);
      size_t bytes(std::min(sizeof(uint), file->bufRemains()));
      if (file->read(&b, bytes) != bytes)
        throw File_Corrupt(*file);

      if ((b >> bit) == (uint(-1) >> bit))
      {
        if (!file->read(reinterpret_cast<byte*>(&b) + bytes, sizeof(uint) - bytes))
          throw File_Corrupt(*file);
        bytes = sizeof(uint);
      }

      if (bit + end - bit_pos < sizeof(uint) * 8)
        b |= ~BSS::bitMask32[bit + end - bit_pos];

      b >>= bit;
      if ((0 == ~b) || 
          ((d = rightBit(~b)), bit_pos + d >= end))
        // this is okay - just means end of stream
      {
        bit_pos = end;
        return uint(-1);
      }

      // if first nonzero bit is the last bit in the buffer and 
      // the first of the pair of bits representing the digit
      // we need to read the next buffer to see what the high bit is
      if ((bit + d) == ((bytes * 8) - 1) && !(d & 1))
      {
        b <<= bit;
        file->read((reinterpret_cast<byte*>(&b)) + bytes, 1);
        file->seekRel(-1);
        b >>= bit;
      }
      else
        file->seekRel(((bit + ((d + 2) & ~1)) >> 3) - bytes);
    }

    if (d & 1)
    {
      // we're on the second of the pair
      bit_pos += d + 1;
      return static_cast<uint>((d / 2) * 3 + 1);
    }
    else
    {
      // we're on the first of the pair
      bit_pos += d + 2;
      uint m(1 << d);
       if (b & m << 1)
        return static_cast<uint>((d / 2) * 3 + 2);
      else
        return static_cast<uint>((d / 2) * 3);
    }
  }

  // ================================== FILEBITSTREAM::GETVAR...

  inline
  uint FileBitStream::get_var_uint
    (bool* ok)
  { 
    if (!file->seek(bit_pos >> 3))
    {
      *ok = false;
      return 0;
    }

    int groups(getTriInt());
    if (groups < 1)
    {
      *ok = (groups == 0);
      return 0;
    }

    assert(file->getPos() == (bit_pos >> 3));
    if (groups > 16)
      throw MemBitStream::Corrupt(bit_pos);

    uint num_bits(groups * BSS::decoding4.base());

    *ok = true;
    return get_ubits(num_bits) + BSS::decoding4[groups];
  }

  inline
  int FileBitStream::get_var_int
    (bool* ok)
  {
    bool sign(getBit());
    int n(get_var_uint(ok));
    return sign ? -n : n;
  }

  inline
  uint64 FileBitStream::get_var_uint64
    (bool* ok)
  {
    if (!file->seek(bit_pos >> 3))
    {
      *ok = false;
      return 0;
    }

    int groups(getTriInt());
    if (groups < 1)
    {
      *ok = (groups == 0);
      return 0;
    }

    assert(file->getPos() == (bit_pos >> 3));
    if (groups > 32)
      throw MemBitStream::Corrupt(bit_pos);

    *ok = true;
    return get_ubits64(groups * BSS::decoding4_64.base()) +
           BSS::decoding4_64[groups];
  }

  inline
  int64 FileBitStream::get_var_int64
    (bool* ok)
  {
    bool sign(getBit());
    int64 n(get_var_uint(ok));
    return sign ? -n : n;
  }

  inline
  uint FileBitStream::get_var_uint
    ()
  { 
    if (!file->seek(bit_pos >> 3))
      throw(File_Cant_Read(*file));

    int groups(getTriInt());
    if (groups < 1)
    {
      if (groups == 0)
        return 0;
      throw MemBitStream::Corrupt(bit_pos);
    }

    assert(file->getPos() == (bit_pos >> 3));
    if (groups > 16)
      throw MemBitStream::Corrupt(bit_pos);

    uint num_bits(groups * BSS::decoding4.base());

    return get_ubits(num_bits) + BSS::decoding4[groups];
  }

  inline
  int FileBitStream::get_var_int
    ()
  {
    bool sign(getBit());
    int n(get_var_uint());
    return sign ? -n : n;
  }

  inline
  uint64 FileBitStream::get_var_uint64
    ()
  {
    if (!file->seek(bit_pos >> 3))
      throw File_Cant_Read(*file);

    int groups(getTriInt());
    if (groups < 1)
    {
      if (groups == 0)
        return 0;
      throw MemBitStream::Corrupt(bit_pos);
    }

    assert(file->getPos() == (bit_pos >> 3));
    if (groups > 32)
      throw MemBitStream::Corrupt(bit_pos);

    return get_ubits64(groups * BSS::decoding4_64.base()) +
           BSS::decoding4_64[groups];
  }

  inline
  int64 FileBitStream::get_var_int64
    ()
  {
    bool sign(getBit());
    int64 n(get_var_uint());
    return sign ? -n : n;
  }

  template <>
  inline
  uint FileBitStream::getVar<uint>
    (bool* ok)
  {
    return FileBitStream::get_var_uint(ok);
  }

  template <>
  inline
  uint64 FileBitStream::getVar<uint64>
    (bool* ok)
  {
    return FileBitStream::get_var_uint64(ok);
  }

  #if defined __LP64__
  template <>
  inline
  size_t FileBitStream::getVar<size_t>
    (bool* ok)
  {
    return size_t(FileBitStream::get_var_uint64(ok));
  }
  #endif

  template <>
  inline
  int FileBitStream::getVar<int>
    (bool* ok)
  {
    return FileBitStream::get_var_int(ok);
  }

  template <>
  inline
  int64 FileBitStream::getVar<int64>
    (bool* ok)
  {
    return FileBitStream::get_var_int64(ok);
  }

  template <>
  inline
  uint FileBitStream::getVar<uint>
    ()
  {
    return FileBitStream::get_var_uint();
  }

  #if defined __LP64__
  template <>
  inline
  size_t FileBitStream::getVar<size_t>
    ()
  {
    return size_t(FileBitStream::get_var_uint64());
  }
  #endif

  template <>
  inline
  uint64 FileBitStream::getVar<uint64>
    ()
  {
    return FileBitStream::get_var_uint64();
  }

  template <>
  inline
  int FileBitStream::getVar<int>
    ()
  {
    return FileBitStream::get_var_int();
  }

  template <>
  inline
  int64 FileBitStream::getVar<int64>
    ()
  {
    return FileBitStream::get_var_int64();
  }

  #endif // defined XPACE_FILE_H || defined DOCUMENTATION

}

#endif
