1美元训练BERT,教你如何薅谷歌TPU羊毛|附Colab代码

1美元训练BERT,教你如何薅谷歌TPU羊毛|附Colab代码
晓查 发自 凹非寺
量子位 出品 | 公众号 QbitAI





BERT是谷歌去年推出的NLP模型,一经推出就在各项测试中碾压竞争对手,而且BERT是开源的。只可惜训练BERT的价格实在太高,让人望而却步。

之前需要用64个TPU训练4天才能完成,后来谷歌用并行计算优化了到只需一个多小时,但是需要的TPU数量陡增,达到了惊人的1024个。

那么总共要多少钱呢?谷歌云TPU的使用价格是每个每小时6.5美元,训练完成训练完整个模型需要近4万美元,简直就是天价。

现在,有个羊毛告诉你,在培养基上有人找到了薅谷歌羊毛的办法,只需1美元就能训练BERT,模型还能留存在你的谷歌云盘中,留作以后使用。

准备工作

为了薅谷歌的羊毛,您需要一个Google云存储(Google Cloud Storage)空间。按照Google云TPU快速入门指南,创建Google云平台(Google Cloud Platform)帐户和Google云存储账户。新的谷歌云平台用户可获得300美元的免费赠送金额。





在TPUv2上预训练BERT-Base模型大约需要54小时.Google Colab并非设计用于执行长时间运行的作业,它会每8小时左右中断一次训练过程。对于不间断的训练,请考虑使用付费的不间断使用TPUv2的方法

也就是说,使用Colab TPU,你可以在以1美元的价格在谷云盘上存储模型和数据,以几乎可忽略成本从头开始预训练BERT模型。

以下是整个过程的代码下面的代码,可以在Colab Jupyter环境中运行。

设置训练环境

首先,安装训练模型所需的包.Jupyter允许使用直接从笔记本执行的bash命令 '!':

!pip install sentencepiece!git clone https://github.com/google-research/bert

导入包并在谷歌云中授权:

import osimport sysimport jsonimport nltkimport randomimport loggingimport tensorflow as tfimport sentencepiece as spmfrom glob import globfrom google.colab import auth, drivefrom tensorflow.keras.utils import Progbarsys.path.append("bert")from bert import modeling, optimization, tokenizationfrom bert.run_pretraining import input_fn_builder, model_fn_builderauth.authenticate_user()# configure logginglog = logging.getLogger('tensorflow')log.setLevel(logging.INFO)# create formatter and add it to the handlersformatter = logging.Formatter('%(asctime)s : %(message)s')sh = logging.StreamHandler()sh.setLevel(logging.INFO)sh.setFormatter(formatter)log.handlers = [sh]if 'COLAB_TPU_ADDR' in os.environ: log.info("Using TPU runtime") USE_TPU = True TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR'] with tf.Session(TPU_ADDRESS) as session: log.info('TPU address is ' + TPU_ADDRESS) # Upload credentials to TPU. with open('/content/adc.json', 'r') as f: auth_info = json.load(f) tf.contrib.cloud.configure_gcs(session, credentials=auth_info)else: log.warning('Not connected to TPU runtime') USE_TPU = False

下载原始文本数据

接下来从网络上获取文本数据语料库。在本次实验中,我们使用OpenSubtitles数据集,该数据集包括65种语言

与更常用的文本数据集(如维基百科)不同,它不需要任何复杂的预处理,提供预格式化,一行一个句子。

AVAILABLE = {'af','ar','bg','bn','br','bs','ca','cs', 'da','de','el','en','eo','es','et','eu', 'fa','fi','fr','gl','he','hi','hr','hu', 'hy','id','is','it','ja','ka','kk','ko', 'lt','lv','mk','ml','ms','nl','no','pl', 'pt','pt_br','ro','ru','si','sk','sl','sq', 'sr','sv','ta','te','th','tl','tr','uk', 'ur','vi','ze_en','ze_zh','zh','zh_cn', 'zh_en','zh_tw','zh_zh'}LANG_CODE = "en" #@param {type:"string"}assert LANG_CODE in AVAILABLE, "Invalid language code selected"!wget http://opus.nlpl.eu/download.php?f=OpenSubtitles/v2016/mono/OpenSubtitles.raw.'$LANG_CODE'.gz -O dataset.txt.gz!gzip -d dataset.txt.gz!tail dataset.txt

你可以通过设置代码随意选择你需要的语言。出于演示目的,代码只默认使用整个语料库的一小部分。在实际训练模型时,请务必取消选中DEMO_MODE复选框,使用大100倍的数据集。

当然,100M数据足以训练出相当不错的BERT基础模型。

DEMO_MODE = True #@param {type:"boolean"}if DEMO_MODE: CORPUS_SIZE = 1000000else: CORPUS_SIZE = 100000000 #@param {type: "integer"}!(head -n $CORPUS_SIZE dataset.txt) > subdataset.txt!mv subdataset.txt dataset.txt

预处理文本数据

我们下载的原始文本数据包含标点符号,大写字母和非UTF符号,我们将在继续下一步之前将其删除。在推理期间,我们将对新数据应用相同的过程。

如果你需要不同的预处理方式(例如在推理期间预期会出现大写字母或标点符号),请修改以下代码以满足你的需求。

regex_tokenizer = nltk.RegexpTokenizer("\w+")def normalize_text(text): # lowercase text text = str(text).lower() # remove non-UTF text = text.encode("utf-8", "ignore").decode() # remove punktuation symbols text = " ".join(regex_tokenizer.tokenize(text)) return textdef count_lines(filename): count = 0 with open(filename) as fi: for line in fi: count += 1 return count

现在让我们预处理整个数据集:

RAW_DATA_FPATH = "dataset.txt" #@param {type: "string"}PRC_DATA_FPATH = "proc_dataset.txt" #@param {type: "string"}# apply normalization to the dataset# this will take a minute or twototal_lines = count_lines(RAW_DATA_FPATH)bar = Progbar(total_lines)with open(RAW_DATA_FPATH,encoding="utf-8") as fi: with open(PRC_DATA_FPATH, "w",encoding="utf-8") as fo: for l in fi: fo.write(normalize_text(l)+"\n") bar.add(1)

构建词汇表

下一步,我们将训练模型学习一个新的词汇表,用于表示我们的数据集。

BERT文件使用WordPiece分词器,在开源中不可用。我们将在单字模式下使用SentencePiece分词器。虽然它与BERT不直接兼容,但是通过一个小的处理方法,可以使它工作。

SentencePiece需要相当多的运行内存,因此在Colab中的运行完整数据集会导致内核崩溃。

为避免这种情况,我们将随机对数据集的一小部分进行子采样,构建词汇表。另一个选择是使用更大内存的机器来执行此步骤

此外,SentencePiece默认情况下将BOS和EOS控制符号添加到词汇表中。我们通过将其索引设置为-1来禁用它们。

VOC_SIZE的典型值介于32000和128000之间。如果想要更新词汇表,并在预训练阶段结束后对模型进行微调,我们会保留NUM_PLACEHOLDERS个令牌。

MODEL_PREFIX = "tokenizer" #@param {type: "string"}VOC_SIZE = 32000 #@param {type:"integer"}SUBSAMPLE_SIZE = 12800000 #@param {type:"integer"}NUM_PLACEHOLDERS = 256 #@param {type:"integer"}SPM_COMMAND = ('--input={} --model_prefix={} ' '--vocab_size={} --input_sentence_size={} ' '--shuffle_input_sentence=true '  '--bos_id=-1 --eos_id=-1').format( PRC_DATA_FPATH, MODEL_PREFIX,  VOC_SIZE - NUM_PLACEHOLDERS, SUBSAMPLE_SIZE)spm.SentencePieceTrainer.Train(SPM_COMMAND)

现在,让我们看看如何让SentencePiece在BERT模型上工作。

下面是使用来自官方的预训练英语BERT基础模型的WordPiece词汇表标记的语句。

>>> wordpiece.tokenize("Colorless geothermal substations are generating furiously")['color', '##less', 'geo', '##thermal', 'sub', '##station', '##s', 'are', 'generating', 'furiously']

WordPiece标记器在“##”的单词中间预置了出现的子字。在单词开头出现的子词不变。如果子词出现在单词的开头和中间,则两个版本(带和不带” ##')都会添加到词汇表中。

SentencePiece创建了两个文件:tokenizer.model和tokenizer.vocab让我们来看看它学到的词汇:

def read_sentencepiece_vocab(filepath): voc = [] with open(filepath, encoding='utf-8') as fi: for line in fi: voc.append(line.split("\t")[0]) # skip the first <unk> token voc = voc[1:] return vocsnt_vocab = read_sentencepiece_vocab("{}.vocab".format(MODEL_PREFIX))print("Learnt vocab size: {}".format(len(snt_vocab)))print("Sample tokens: {}".format(random.sample(snt_vocab, 10)))

运行结果:

Learnt vocab size: 31743 Sample tokens: ['▁cafe', '▁slippery', 'xious', '▁resonate', '▁terrier', '▁feat', '▁frequencies', 'ainty', '▁punning', 'modern']

SentencePiece与WordPiece的运行结果完全相反从文档中可以看出:SentencePiece首先使用元符号“_”将空格转义为空格,如下所示:


Hello_World。

然后文本被分段为小块:


[Hello] [_Wor] [ld] [.]

在空格之后出现的子词(也是大多数词开头的子词)前面加上“_”,而其他子词不变。这排除了仅出现在句子开头而不是其他地方的子词。然而,这些案件应该非常罕见。

因此,为了获得类似于WordPiece的词汇表,我们需要执行一个简单的转换,从包含它的标记中删除“_”,并将“##”添加到不包含它的标记中。

我们还添加了一些BERT架构所需的特殊控制符号。按照惯例,我们把它们放在词汇的开头。

另外,我们在词汇表中添加了一些占位符标记。

如果你希望使用新的用于特定任务的令牌来更新预先训练的模型,那么这些方法是很有用的。

在这种情况下,占位符标记被替换为新的令牌,重新生成预训练数据,并且对新数据进行微调。

def parse_sentencepiece_token(token): if token.startswith("▁"): return token[1:] else: return "##" + tokenbert_vocab = list(map(parse_sentencepiece_token, snt_vocab))ctrl_symbols = ["[PAD]","[UNK]","[CLS]","[SEP]","[MASK]"]bert_vocab = ctrl_symbols + bert_vocabbert_vocab += ["[UNUSED_{}]".format(i) for i in range(VOC_SIZE - len(bert_vocab))]print(len(bert_vocab))

最后,我们将获得的词汇表写入文件。

VOC_FNAME = "vocab.txt" #@param {type:"string"}with open(VOC_FNAME, "w") as fo: for token in bert_vocab: fo.write(token+"\n")

现在,让我们看看新词汇在实践中是如何运作的:

>>> testcase = "Colorless geothermal substations are generating furiously">>> bert_tokenizer = tokenization.FullTokenizer(VOC_FNAME)>>> bert_tokenizer.tokenize(testcase)['color',  '##less',  'geo',  '##ther',  '##mal',  'sub',  '##station',  '##s',  'are',  'generat',  '##ing',  'furious',  '##ly']

创建分片预训练数据(生成预训练数据)

通过手头的词汇表,我们可以为BERT模型生成预训练数据。

由于我们的数据集可能非常大,我们将其拆分为碎片:

mkdir ./shardssplit -a 4 -l 256000 -d $PRC_DATA_FPATH ./shards/shard_

现在,对于每个部分,我们需要从BERT仓库调用create_pretraining_data.py脚本,需要使用xargs的命令。

在开始生成之前,我们需要设置一些参数传递给脚本。你可以从自述文件中找到有关它们含义的更多信息。

MAX_SEQ_LENGTH = 128 #@param {type:"integer"}MASKED_LM_PROB = 0.15 #@paramMAX_PREDICTIONS = 20 #@param {type:"integer"}DO_LOWER_CASE = True #@param {type:"boolean"}PRETRAINING_DIR = "pretraining_data" #@param {type:"string"}# controls how many parallel processes xargs can createPROCESSES = 2 #@param {type:"integer"}

运行此操作可能需要相当长的时间,具体取决于数据集的大小。

XARGS_CMD = ("ls ./shards/ | " "xargs -n 1 -P {} -I{} " "python3 bert/create_pretraining_data.py " "--input_file=./shards/{} " "--output_file={}/{}.tfrecord " "--vocab_file={} " "--do_lower_case={} " "--max_predictions_per_seq={} " "--max_seq_length={} " "--masked_lm_prob={} " "--random_seed=34 " "--dupe_factor=5")XARGS_CMD = XARGS_CMD.format(PROCESSES, '{}', '{}', PRETRAINING_DIR, '{}',  VOC_FNAME, DO_LOWER_CASE,  MAX_PREDICTIONS, MAX_SEQ_LENGTH, MASKED_LM_PROB)tf.gfile.MkDir(PRETRAINING_DIR)!$XARGS_CMD

为数据和模型设置GCS存储,将数据和模型存储到云端

为了保留来之不易的训练模型,我们会将其保留在谷歌云存储中。

在谷歌云存储中创建两个目录,一个用于数据,一个用于模型。在模型目录中,我们将放置模型词汇表和配置文件。

在继续操作之前,请配置BUCKET_NAME变量,否则将无法训练模型。

BUCKET_NAME = "bert_resourses" #@param {type:"string"}MODEL_DIR = "bert_model" #@param {type:"string"}tf.gfile.MkDir(MODEL_DIR)if not BUCKET_NAME: log.warning("WARNING: BUCKET_NAME is not set. " "You will not be able to train the model.")

下面是BERT基的超参数配置示例:

# use this for BERT-basebert_base_config = { "attention_probs_dropout_prob": 0.1,  "directionality": "bidi",  "hidden_act": "gelu",  "hidden_dropout_prob": 0.1,  "hidden_size": 768,  "initializer_range": 0.02,  "intermediate_size": 3072,  "max_position_embeddings": 512,  "num_attention_heads": 12,  "num_hidden_layers": 12,  "pooler_fc_size": 768,  "pooler_num_attention_heads": 12,  "pooler_num_fc_layers": 3,  "pooler_size_per_head": 128,  "pooler_type": "first_token_transform",  "type_vocab_size": 2,  "vocab_size": VOC_SIZE}with open("{}/bert_config.json".format(MODEL_DIR), "w") as fo: json.dump(bert_base_config, fo, indent=2)with open("{}/{}".format(MODEL_DIR, VOC_FNAME), "w") as fo: for token in bert_vocab: fo.write(token+"\n")

现在,我们已准备好将模型和数据存储到谷歌云当中:

if BUCKET_NAME: !gsutil -m cp -r $MODEL_DIR $PRETRAINING_DIR gs://$BUCKET_NAME

在云TPU上训练模型

注意,之前步骤中的某些参数在此处不用改变。请确保在整个实验中设置的参数完全相同。

BUCKET_NAME = "bert_resourses" #@param {type:"string"}MODEL_DIR = "bert_model" #@param {type:"string"}PRETRAINING_DIR = "pretraining_data" #@param {type:"string"}VOC_FNAME = "vocab.txt" #@param {type:"string"}# Input data pipeline configTRAIN_BATCH_SIZE = 128 #@param {type:"integer"}MAX_PREDICTIONS = 20 #@param {type:"integer"}MAX_SEQ_LENGTH = 128 #@param {type:"integer"}MASKED_LM_PROB = 0.15 #@param# Training procedure configEVAL_BATCH_SIZE = 64LEARNING_RATE = 2e-5TRAIN_STEPS = 1000000 #@param {type:"integer"}SAVE_CHECKPOINTS_STEPS = 2500 #@param {type:"integer"}NUM_TPU_CORES = 8if BUCKET_NAME: BUCKET_PATH = "gs://{}".format(BUCKET_NAME)else: BUCKET_PATH = "."BERT_GCS_DIR = "{}/{}".format(BUCKET_PATH, MODEL_DIR)DATA_GCS_DIR = "{}/{}".format(BUCKET_PATH, PRETRAINING_DIR)VOCAB_FILE = os.path.join(BERT_GCS_DIR, VOC_FNAME)CONFIG_FILE = os.path.join(BERT_GCS_DIR, "bert_config.json")INIT_CHECKPOINT = tf.train.latest_checkpoint(BERT_GCS_DIR)bert_config = modeling.BertConfig.from_json_file(CONFIG_FILE)input_files = tf.gfile.Glob(os.path.join(DATA_GCS_DIR,'*tfrecord'))log.info("Using checkpoint: {}".format(INIT_CHECKPOINT))log.info("Using {} data shards".format(len(input_files)))

准备训练运行配置,建立评估器和输入函数启动BERT!

model_fn = model_fn_builder( bert_config=bert_config, init_checkpoint=INIT_CHECKPOINT, learning_rate=LEARNING_RATE, num_train_steps=TRAIN_STEPS, num_warmup_steps=10, use_tpu=USE_TPU, use_one_hot_embeddings=True)tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)run_config = tf.contrib.tpu.RunConfig( cluster=tpu_cluster_resolver, model_dir=BERT_GCS_DIR, save_checkpoints_steps=SAVE_CHECKPOINTS_STEPS, tpu_config=tf.contrib.tpu.TPUConfig( iterations_per_loop=SAVE_CHECKPOINTS_STEPS, num_shards=NUM_TPU_CORES, per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))estimator = tf.contrib.tpu.TPUEstimator( use_tpu=USE_TPU, model_fn=model_fn, config=run_config, train_batch_size=TRAIN_BATCH_SIZE, eval_batch_size=EVAL_BATCH_SIZE)train_input_fn = input_fn_builder( input_files=input_files, max_seq_length=MAX_SEQ_LENGTH, max_predictions_per_seq=MAX_PREDICTIONS, is_training=True)

执行!

estimator.train(input_fn=train_input_fn, max_steps=TRAIN_STEPS)

最后,使用默认参数训练模型需要100万步,约54小时的运行时间。如果内核由于某种原因重新启动,可以从断点处继续训练。

以上就是是在云TPU上从头开始预训练BERT的指南。

下一步

好的,我们已经训练好了模型,接下来可以做什么?

如图1所示,使用预训练的模型作为通用的自然语言理解模块;

2,针对某些特定的分类任务微调模型;

3,使用BERT作为构建块,去创建另一个深度学习模型。

传送门

原文地址

https ://http://towardsdatascience.com/pre-training-bert-from-scratch-with-cloud-tpu-6e2f71028379

Colab代码:

https ://http://colab.research.google.com/drive/1nVn6AFpQSzXBt8_ywfx6XR8ZfQXlKGAz

—完—
量子位 · QbitAI
վ'ᴗ' ի 追踪AI技术和产品新动态
戳右上角「+关注」获取最新资讯↗↗
如果喜欢,请分享or点赞吧~比心❤

免责声明:本网信息来自于互联网,目的在于传递更多信息,并不代表本网赞同其观点。其原创性以及文中陈述文字和内容未经本站证实,对本文以及其中全部或者部分内容、文字的真实性、完整性、及时性本站不作任何保证或承诺,并请自行核实相关内容。本站不承担此类作品侵权行为的直接责任及连带责任。如若本网有任何内容侵犯您的权益,请及时联系我们,本站将会在24小时内处理完毕。
相关文章
返回顶部