@ -429,6 +429,12 @@ struct whisper_context {
int32_t exp_n_audio_ctx ; // 0 - use default
int32_t exp_n_audio_ctx ; // 0 - use default
} ;
} ;
template < typename T >
static void read_safe ( std : : ifstream & fin , T & dest )
{
fin . read ( ( char * ) & dest , sizeof ( T ) ) ;
}
// load the model from a ggml file
// load the model from a ggml file
//
//
// file format:
// file format:
@ -455,7 +461,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
// verify magic
// verify magic
{
{
uint32_t magic ;
uint32_t magic ;
fin. read ( ( char * ) & magic , sizeof ( magic ) ) ;
read_safe( fin , magic ) ;
if ( magic ! = 0x67676d6c ) {
if ( magic ! = 0x67676d6c ) {
fprintf ( stderr , " %s: invalid model file '%s' (bad magic) \n " , __func__ , fname . c_str ( ) ) ;
fprintf ( stderr , " %s: invalid model file '%s' (bad magic) \n " , __func__ , fname . c_str ( ) ) ;
return false ;
return false ;
@ -466,17 +472,17 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
{
{
auto & hparams = model . hparams ;
auto & hparams = model . hparams ;
fin. read ( ( char * ) & hparams . n_vocab , sizeof ( hparams . n_vocab ) ) ;
read_safe( fin , hparams . n_vocab ) ;
fin. read ( ( char * ) & hparams . n_audio_ctx , sizeof ( hparams . n_audio_ctx ) ) ;
read_safe( fin , hparams . n_audio_ctx ) ;
fin. read ( ( char * ) & hparams . n_audio_state , sizeof ( hparams . n_audio_state ) ) ;
read_safe( fin , hparams . n_audio_state ) ;
fin. read ( ( char * ) & hparams . n_audio_head , sizeof ( hparams . n_audio_head ) ) ;
read_safe( fin , hparams . n_audio_head ) ;
fin. read ( ( char * ) & hparams . n_audio_layer , sizeof ( hparams . n_audio_layer ) ) ;
read_safe( fin , hparams . n_audio_layer ) ;
fin. read ( ( char * ) & hparams . n_text_ctx , sizeof ( hparams . n_text_ctx ) ) ;
read_safe( fin , hparams . n_text_ctx ) ;
fin. read ( ( char * ) & hparams . n_text_state , sizeof ( hparams . n_text_state ) ) ;
read_safe( fin , hparams . n_text_state ) ;
fin. read ( ( char * ) & hparams . n_text_head , sizeof ( hparams . n_text_head ) ) ;
read_safe( fin , hparams . n_text_head ) ;
fin. read ( ( char * ) & hparams . n_text_layer , sizeof ( hparams . n_text_layer ) ) ;
read_safe( fin , hparams . n_text_layer ) ;
fin. read ( ( char * ) & hparams . n_mels , sizeof ( hparams . n_mels ) ) ;
read_safe( fin , hparams . n_mels ) ;
fin. read ( ( char * ) & hparams . f16 , sizeof ( hparams . f16 ) ) ;
read_safe( fin , hparams . f16 ) ;
assert ( hparams . n_text_state = = hparams . n_audio_state ) ;
assert ( hparams . n_text_state = = hparams . n_audio_state ) ;
@ -524,8 +530,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
{
{
auto & filters = wctx . model . filters ;
auto & filters = wctx . model . filters ;
fin. read ( ( char * ) & filters . n_mel , sizeof ( filters . n_mel ) ) ;
read_safe( fin , filters . n_mel ) ;
fin. read ( ( char * ) & filters . n_fft , sizeof ( filters . n_fft ) ) ;
read_safe( fin , filters . n_fft ) ;
filters . data . resize ( filters . n_mel * filters . n_fft ) ;
filters . data . resize ( filters . n_mel * filters . n_fft ) ;
fin . read ( ( char * ) filters . data . data ( ) , filters . data . size ( ) * sizeof ( float ) ) ;
fin . read ( ( char * ) filters . data . data ( ) , filters . data . size ( ) * sizeof ( float ) ) ;
@ -534,7 +540,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
// load vocab
// load vocab
{
{
int32_t n_vocab = 0 ;
int32_t n_vocab = 0 ;
fin. read ( ( char * ) & n_vocab , sizeof ( n_vocab ) ) ;
read_safe( fin , n_vocab ) ;
//if (n_vocab != model.hparams.n_vocab) {
//if (n_vocab != model.hparams.n_vocab) {
// fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
// fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
@ -545,10 +551,11 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
std : : string word ;
std : : string word ;
for ( int i = 0 ; i < n_vocab ; i + + ) {
for ( int i = 0 ; i < n_vocab ; i + + ) {
uint32_t len ;
uint32_t len ;
fin. read ( ( char * ) & len , sizeof ( len ) ) ;
read_safe( fin , len ) ;
word . resize ( len ) ;
std : : vector < char > tmp ( len ) ; // create a buffer
fin . read ( ( char * ) word . data ( ) , len ) ;
fin . read ( & tmp [ 0 ] , tmp . size ( ) ) ; // read to buffer
word . assign ( & tmp [ 0 ] , tmp . size ( ) ) ;
vocab . token_to_id [ word ] = i ;
vocab . token_to_id [ word ] = i ;
vocab . id_to_token [ i ] = word ;
vocab . id_to_token [ i ] = word ;
@ -998,9 +1005,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
int32_t length ;
int32_t length ;
int32_t ftype ;
int32_t ftype ;
fin. read ( reinterpret_cast < char * > ( & n_dims ) , sizeof ( n_dims ) ) ;
read_safe( fin , n_dims ) ;
fin. read ( reinterpret_cast < char * > ( & length ) , sizeof ( length ) ) ;
read_safe( fin , length ) ;
fin. read ( reinterpret_cast < char * > ( & ftype ) , sizeof ( ftype ) ) ;
read_safe( fin , ftype ) ;
if ( fin . eof ( ) ) {
if ( fin . eof ( ) ) {
break ;
break ;
@ -1009,12 +1016,14 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
int32_t nelements = 1 ;
int32_t nelements = 1 ;
int32_t ne [ 3 ] = { 1 , 1 , 1 } ;
int32_t ne [ 3 ] = { 1 , 1 , 1 } ;
for ( int i = 0 ; i < n_dims ; + + i ) {
for ( int i = 0 ; i < n_dims ; + + i ) {
fin. read ( reinterpret_cast < char * > ( & ne [ i ] ) , sizeof ( ne [ i ] ) ) ;
read_safe( fin , ne [ i ] ) ;
nelements * = ne [ i ] ;
nelements * = ne [ i ] ;
}
}
std : : string name ( length , 0 ) ;
std : : string name ;
fin . read ( & name [ 0 ] , length ) ;
std : : vector < char > tmp ( length ) ; // create a buffer
fin . read ( & tmp [ 0 ] , tmp . size ( ) ) ; // read to buffer
name . assign ( & tmp [ 0 ] , tmp . size ( ) ) ;
if ( model . tensors . find ( name . data ( ) ) = = model . tensors . end ( ) ) {
if ( model . tensors . find ( name . data ( ) ) = = model . tensors . end ( ) ) {
fprintf ( stderr , " %s: unknown tensor '%s' in model file \n " , __func__ , name . data ( ) ) ;
fprintf ( stderr , " %s: unknown tensor '%s' in model file \n " , __func__ , name . data ( ) ) ;