Training
Framework
For the models training we used the following frameworks:
where SpanMarker is a framework for training span-based models. It is built on top of theTransformers
and uses the Trainer
API from the Transformers
library internally and implements the idea of
Packed Levitated Marker for Entity and Relation Extraction, D. Ye et al.,2022.
So, generally speaking, it's AutoModelForTokenClassification
from the Transformers
library, but with additional technique.
Training script
The training script is simple:
import argparse
import json
import os
from typing import Optional
import torch
import gc
from span_marker import SpanMarkerModel
from datasets import load_dataset, DatasetDict
from transformers import TrainingArguments, AutoConfig
from span_marker import Trainer
def load_ner_dataset(dataset_name: str, dataset_config_name: Optional[str] = None) -> DatasetDict:
if dataset_config_name is None:
dataset = load_dataset(dataset_name)
else:
dataset = load_dataset(dataset_name, dataset_config_name)
def check_if_empty(l):
"""
Check if list is empty or contains only empty strings
:param l:
:return:
"""
if isinstance(l, list):
if len(l) == 0:
return True
elif len(l) == 1 and l[0] == "":
return True
if all([e == "" or e == "\u200b" or e == '\u200b\u200b' for e in l]):
return True
if any([e is None for e in l]):
return True
return False
def remove_empty(record):
"""
Remove empty tokens and ner tags if any
"""
tokens = record["tokens"]
ner_tags = record["ner_tags"]
new_tokens = []
new_ner_tags = []
for token, ner_tag in zip(tokens, ner_tags):
# TODO: improve normalization for Arabic
token = token.replace("\u200b", "").replace("\ufeff", "").strip()
if token.strip():
new_tokens.append(token)
new_ner_tags.append(ner_tag)
record["tokens"] = new_tokens
record["ner_tags"] = new_ner_tags
return record
dataset["train"] = dataset["train"].filter(lambda x: not check_if_empty(x["tokens"]))
dataset["validation"] = dataset["validation"].filter(lambda x: not check_if_empty(x["tokens"]))
dataset["test"] = dataset["test"].filter(lambda x: not check_if_empty(x["tokens"]))
dataset["train"] = dataset["train"].map(remove_empty)
dataset["validation"] = dataset["validation"].map(remove_empty)
dataset["test"] = dataset["test"].map(remove_empty)
return dataset
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="xlm-roberta-base")
parser.add_argument("--dataset_name", type=str)
parser.add_argument("--dataset_config_name", type=str, default=None)
parser.add_argument("--entity_max_length", type=str, default=150)
parser.add_argument("--lr", type=float, default=1e-5)
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--warmup_ratio", type=float, default=0.1)
parser.add_argument("--gradient_accumulation_steps", type=int, default=2)
parser.add_argument("--per_device_train_batch_size", type=int, default=8)
parser.add_argument("--per_device_eval_batch_size", type=int, default=8)
parser.add_argument("--evaluation_strategy", type=str, default="steps")
parser.add_argument("--save_strategy", type=str, default="steps")
parser.add_argument("--save_steps", type=int, default=1000)
parser.add_argument("--eval_steps", type=int, default=1000)
parser.add_argument("--save_total_limit", type=int, default=3)
args = parser.parse_args()
torch.cuda.empty_cache()
gc.collect()
prefix = args.dataset_config_name if args.dataset_config_name else args.dataset_name
dataset = load_ner_dataset(args.dataset_name, args.dataset_config_name)
labels = dataset["train"].features["ner_tags"].feature.names
encoder_id = args.model_name
config = AutoConfig.from_pretrained(encoder_id)
model = SpanMarkerModel.from_pretrained(encoder_id,
labels=labels,
model_max_length=config.max_position_embeddings,
entity_max_length=args.entity_max_length)
os.makedirs(f"models/{prefix}", exist_ok=True)
args = TrainingArguments(
output_dir=f"models/{prefix}",
learning_rate=args.lr,
gradient_accumulation_steps=args.gradient_accumulation_steps,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
num_train_epochs=args.epochs,
evaluation_strategy=args.evaluation_strategy,
save_strategy=args.save_strategy,
save_steps=args.save_steps,
eval_steps=args.eval_steps,
push_to_hub=False,
logging_steps=50,
fp16=True,
warmup_ratio=args.warmup_ratio,
load_best_model_at_end=True,
save_total_limit=args.save_total_limit,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
)
trainer.train()
trainer.save_model(f"models/{prefix}")
metrics = trainer.evaluate()
with open(f"models/{prefix}/metrics.json", "w") as f:
f.write(json.dumps(metrics, indent=2))
SpanMarker details
Explanation
Consider the following example input sentence: "Tom is my name."
, which tokenizes to this using the standard RoBERTa tokenizer:
In addition to these tokens, we also have position IDs, which tells the RoBERTa encoder where in the text each of these tokens exist. In the above example, the position IDs are:
$$
\begin{matrix}
[2 & 3 & 4 & 5 & 6 & 7 & 8]
\end{matrix}
$$
For this example, we consider a maximum token length of 16 (note: this is unreasonably low, 256 or 512 would be more sensible in real scenarios). The SpanMarker codebase pads using 0's, so then the padded tokens (input_ids
) are:
$$
\begin{matrix}
[0 & 1560 & 16 & 127 & 766 & 4 & 2 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0]
\end{matrix}
$$
And the position IDs now become:
$$
\begin{matrix}
[2 & 3 & 4 & 5 & 6 & 7 & 8 & 9 & 10 & 11 & 12 & 13 & 14 & 15 & 16 & 17]
\end{matrix}
$$
Markers
Crucially, the model takes advantage of "markers". These are special tokens that mark the start and the end of a span within the input sentence. Every single legal span within the input sentence corresponds with exactly one set of two markers.
In this example, we recognize that the text consists of 5 words, including the dot. If we assume that an entity span is somewhere between 1 and 8 words, then this example has these 15 legal spans, indexed from 0 onwards: $$ \begin{matrix} (0, 0) & (0, 1) & (0, 2) & (0, 3) & (0, 4) & (1, 1) & (1, 2) & (1, 3)\ (1, 4) & (2, 2) & (2, 3) & (2, 4) & (3, 3) & (3, 4) & (4, 4) \end{matrix} $$ If one of the entities that we are interested in is "Person", then the "(0, 0)" span has label "Person", while all other spans have label "NIL" (or "O" for "outside"). In total, we have 15 spans each consisting of a start and an end index.
SpanMarker uses the special tokens "\<start>" with ID 50261 and "\<end>" with ID 50262 to represent the starts and ends, respectively. These marker tokens are then appended to the padded text tokens (input_ids
), tripling its size. Then, the position IDs are also updated to virtually position these markers between the texts.
Note that the appended position IDs correspond exactly with the 15 legal spans. For example, the (1, 2) span is now represented using the tokens 50261 and 50262 with position IDs 3 and 4, as shown in bold in the following vectors.
This results in the following input ID vector: $$ \begin{matrix} [0 & 1560 & 16 & 127 & 766 & 4 & 2 & 0 \ 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \ 50261 & 50261 & 50261 & 50261 & 50261 & 50261 & \textbf{50261} & 50261\ 50261 & 50261 & 50261 & 50261 & 50261 & 50261 & 50261 & 0 \ 50262 & 50262 & 50262 & 50262 & 50262 & 50262 & \textbf{50262} & 50262\ 50262 & 50262 & 50262 & 50262 & 50262 & 50262 & 50262 & 0] \end{matrix} $$ And this position ID vector: $$ \begin{matrix} [2 & 3 & 4 & 5 & 6 & 7 & 8 & 9 \ 10 & 11 & 12 & 13 & 14 & 15 & 16 & 17 \ 2 & 2 & 2 & 2 & 2 & 3 & \textbf{3} & 3 \ 3 & 4 & 4 & 4 & 5 & 5 & 6 & 0 \ 2 & 3 & 4 & 5 & 6 & 3 & \textbf{4} & 5 \ 6 & 4 & 5 & 6 & 5 & 6 & 6 & 0] \end{matrix} $$ Whenever the number of legal spans is less than the maximum token length, like in our example, then the position IDs are padded with 0 such that the input ID and position ID vectors are still equal in length.
Furthermore, if the number of legal spans is larger than the maximum token length, then we create multiple vectors, each using a section of all of the spans. This is used both in training and in inference.
Attention
The final piece of this complex puzzle is the attention mask matrix. This matrix gives us control over which tokens can attend to which other tokens. Generally, an attention mask vector is used, but using a matrix, we can specify a one-directional attention. This proves very useful for our situation.
See Figure 10 for the attention matrix used for our example. The block of black on the top left allows for self-attention of the text tokens, which is important for the RoBERTa encoder to correctly train the text. Note the dissymmetry: the sections of black on the left side allow for the start and end markers to attend to the text tokens, while the opposite is not possible. Furthermore, the four diagonal sections allow for the markers to attend to themselves and to the corresponding complementary marker.
Encoder
With the input IDs, position IDs and attention mask matrix prepared, these can be fed through the pretrained RoBERTa encoder. This encoder returns an embedding for each input ID, resulting in an output shape of [3 * max_tokens, embedding_size]
. The SpanMarker then computes a large embedding by concatenating the following vectors for each of the spans in the sample:
1. The embedding of the start marker.
2. The embedding of the end marker.
This results in a feature vector of the shape [max_tokens, 2 * embedding_size]
. This feature vector is fed through a linear layer to map it down from 2 * embedding_size
to num_labels
. The resulting [max_tokens, num_labels]
matrix is then used in a cross-entropy loss comparing it against labels for each of the spans in the sample.
In our example, the feature vector for the (0, 0) span would then be compared against the gold label "Person", while all other feature vectors are compared against "NIL" or "O", or rather the integer representation of those labels.