Datasets:
repository_name
string
| func_path_in_repository
string
| func_name
string
| whole_func_string
string
| language
string
| func_code_string
string
| func_code_tokens
sequence
| func_documentation_string
string
| func_documentation_tokens
sequence
| split_name
string
| func_code_url
string
|
---|---|---|---|---|---|---|---|---|---|---|
"ageitgey/face_recognition" | "examples/face_recognition_knn.py" | "train" | "def train(train_dir, model_save_path=None, n_neighbors=None, knn_algo='ball_tree', verbose=False):
"""
Trains a k-nearest neighbors classifier for face recognition.
:param train_dir: directory that contains a sub-directory for each known person, with its name.
(View in source code to see train_dir example tree structure)
Structure:
<train_dir>/
βββ <person1>/
β βββ <somename1>.jpeg
β βββ <somename2>.jpeg
β βββ ...
βββ <person2>/
β βββ <somename1>.jpeg
β βββ <somename2>.jpeg
βββ ...
:param model_save_path: (optional) path to save model on disk
:param n_neighbors: (optional) number of neighbors to weigh in classification. Chosen automatically if not specified
:param knn_algo: (optional) underlying data structure to support knn.default is ball_tree
:param verbose: verbosity of training
:return: returns knn classifier that was trained on the given data.
"""
X = []
y = []
# Loop through each person in the training set
for class_dir in os.listdir(train_dir):
if not os.path.isdir(os.path.join(train_dir, class_dir)):
continue
# Loop through each training image for the current person
for img_path in image_files_in_folder(os.path.join(train_dir, class_dir)):
image = face_recognition.load_image_file(img_path)
face_bounding_boxes = face_recognition.face_locations(image)
if len(face_bounding_boxes) != 1:
# If there are no people (or too many people) in a training image, skip the image.
if verbose:
print("Image {} not suitable for training: {}".format(img_path, "Didn't find a face" if len(face_bounding_boxes) < 1 else "Found more than one face"))
else:
# Add face encoding for current image to the training set
X.append(face_recognition.face_encodings(image, known_face_locations=face_bounding_boxes)[0])
y.append(class_dir)
# Determine how many neighbors to use for weighting in the KNN classifier
if n_neighbors is None:
n_neighbors = int(round(math.sqrt(len(X))))
if verbose:
print("Chose n_neighbors automatically:", n_neighbors)
# Create and train the KNN classifier
knn_clf = neighbors.KNeighborsClassifier(n_neighbors=n_neighbors, algorithm=knn_algo, weights='distance')
knn_clf.fit(X, y)
# Save the trained KNN classifier
if model_save_path is not None:
with open(model_save_path, 'wb') as f:
pickle.dump(knn_clf, f)
return knn_clf" | "python" | "def train(train_dir, model_save_path=None, n_neighbors=None, knn_algo='ball_tree', verbose=False):
"""
Trains a k-nearest neighbors classifier for face recognition.
:param train_dir: directory that contains a sub-directory for each known person, with its name.
(View in source code to see train_dir example tree structure)
Structure:
<train_dir>/
βββ <person1>/
β βββ <somename1>.jpeg
β βββ <somename2>.jpeg
β βββ ...
βββ <person2>/
β βββ <somename1>.jpeg
β βββ <somename2>.jpeg
βββ ...
:param model_save_path: (optional) path to save model on disk
:param n_neighbors: (optional) number of neighbors to weigh in classification. Chosen automatically if not specified
:param knn_algo: (optional) underlying data structure to support knn.default is ball_tree
:param verbose: verbosity of training
:return: returns knn classifier that was trained on the given data.
"""
X = []
y = []
# Loop through each person in the training set
for class_dir in os.listdir(train_dir):
if not os.path.isdir(os.path.join(train_dir, class_dir)):
continue
# Loop through each training image for the current person
for img_path in image_files_in_folder(os.path.join(train_dir, class_dir)):
image = face_recognition.load_image_file(img_path)
face_bounding_boxes = face_recognition.face_locations(image)
if len(face_bounding_boxes) != 1:
# If there are no people (or too many people) in a training image, skip the image.
if verbose:
print("Image {} not suitable for training: {}".format(img_path, "Didn't find a face" if len(face_bounding_boxes) < 1 else "Found more than one face"))
else:
# Add face encoding for current image to the training set
X.append(face_recognition.face_encodings(image, known_face_locations=face_bounding_boxes)[0])
y.append(class_dir)
# Determine how many neighbors to use for weighting in the KNN classifier
if n_neighbors is None:
n_neighbors = int(round(math.sqrt(len(X))))
if verbose:
print("Chose n_neighbors automatically:", n_neighbors)
# Create and train the KNN classifier
knn_clf = neighbors.KNeighborsClassifier(n_neighbors=n_neighbors, algorithm=knn_algo, weights='distance')
knn_clf.fit(X, y)
# Save the trained KNN classifier
if model_save_path is not None:
with open(model_save_path, 'wb') as f:
pickle.dump(knn_clf, f)
return knn_clf" | [
"def",
"train",
"(",
"train_dir",
",",
"model_save_path",
"=",
"None",
",",
"n_neighbors",
"=",
"None",
",",
"knn_algo",
"=",
"'ball_tree'",
",",
"verbose",
"=",
"False",
")",
":",
"X",
"=",
"[",
"]",
"y",
"=",
"[",
"]",
"# Loop through each person in the training set",
"for",
"class_dir",
"in",
"os",
".",
"listdir",
"(",
"train_dir",
")",
":",
"if",
"not",
"os",
".",
"path",
".",
"isdir",
"(",
"os",
".",
"path",
".",
"join",
"(",
"train_dir",
",",
"class_dir",
")",
")",
":",
"continue",
"# Loop through each training image for the current person",
"for",
"img_path",
"in",
"image_files_in_folder",
"(",
"os",
".",
"path",
".",
"join",
"(",
"train_dir",
",",
"class_dir",
")",
")",
":",
"image",
"=",
"face_recognition",
".",
"load_image_file",
"(",
"img_path",
")",
"face_bounding_boxes",
"=",
"face_recognition",
".",
"face_locations",
"(",
"image",
")",
"if",
"len",
"(",
"face_bounding_boxes",
")",
"!=",
"1",
":",
"# If there are no people (or too many people) in a training image, skip the image.",
"if",
"verbose",
":",
"print",
"(",
"\"Image {} not suitable for training: {}\"",
".",
"format",
"(",
"img_path",
",",
"\"Didn't find a face\"",
"if",
"len",
"(",
"face_bounding_boxes",
")",
"<",
"1",
"else",
"\"Found more than one face\"",
")",
")",
"else",
":",
"# Add face encoding for current image to the training set",
"X",
".",
"append",
"(",
"face_recognition",
".",
"face_encodings",
"(",
"image",
",",
"known_face_locations",
"=",
"face_bounding_boxes",
")",
"[",
"0",
"]",
")",
"y",
".",
"append",
"(",
"class_dir",
")",
"# Determine how many neighbors to use for weighting in the KNN classifier",
"if",
"n_neighbors",
"is",
"None",
":",
"n_neighbors",
"=",
"int",
"(",
"round",
"(",
"math",
".",
"sqrt",
"(",
"len",
"(",
"X",
")",
")",
")",
")",
"if",
"verbose",
":",
"print",
"(",
"\"Chose n_neighbors automatically:\"",
",",
"n_neighbors",
")",
"# Create and train the KNN classifier",
"knn_clf",
"=",
"neighbors",
".",
"KNeighborsClassifier",
"(",
"n_neighbors",
"=",
"n_neighbors",
",",
"algorithm",
"=",
"knn_algo",
",",
"weights",
"=",
"'distance'",
")",
"knn_clf",
".",
"fit",
"(",
"X",
",",
"y",
")",
"# Save the trained KNN classifier",
"if",
"model_save_path",
"is",
"not",
"None",
":",
"with",
"open",
"(",
"model_save_path",
",",
"'wb'",
")",
"as",
"f",
":",
"pickle",
".",
"dump",
"(",
"knn_clf",
",",
"f",
")",
"return",
"knn_clf"
] | "Trains a k-nearest neighbors classifier for face recognition.
:param train_dir: directory that contains a sub-directory for each known person, with its name.
(View in source code to see train_dir example tree structure)
Structure:
<train_dir>/
βββ <person1>/
β βββ <somename1>.jpeg
β βββ <somename2>.jpeg
β βββ ...
βββ <person2>/
β βββ <somename1>.jpeg
β βββ <somename2>.jpeg
βββ ...
:param model_save_path: (optional) path to save model on disk
:param n_neighbors: (optional) number of neighbors to weigh in classification. Chosen automatically if not specified
:param knn_algo: (optional) underlying data structure to support knn.default is ball_tree
:param verbose: verbosity of training
:return: returns knn classifier that was trained on the given data." | [
"Trains",
"a",
"k",
"-",
"nearest",
"neighbors",
"classifier",
"for",
"face",
"recognition",
"."
] | "train" | "https://github.com/ageitgey/face_recognition/blob/c96b010c02f15e8eeb0f71308c641179ac1f19bb/examples/face_recognition_knn.py#L46-L108" |
"ageitgey/face_recognition" | "examples/face_recognition_knn.py" | "predict" | "def predict(X_img_path, knn_clf=None, model_path=None, distance_threshold=0.6):
"""
Recognizes faces in given image using a trained KNN classifier
:param X_img_path: path to image to be recognized
:param knn_clf: (optional) a knn classifier object. if not specified, model_save_path must be specified.
:param model_path: (optional) path to a pickled knn classifier. if not specified, model_save_path must be knn_clf.
:param distance_threshold: (optional) distance threshold for face classification. the larger it is, the more chance
of mis-classifying an unknown person as a known one.
:return: a list of names and face locations for the recognized faces in the image: [(name, bounding box), ...].
For faces of unrecognized persons, the name 'unknown' will be returned.
"""
if not os.path.isfile(X_img_path) or os.path.splitext(X_img_path)[1][1:] not in ALLOWED_EXTENSIONS:
raise Exception("Invalid image path: {}".format(X_img_path))
if knn_clf is None and model_path is None:
raise Exception("Must supply knn classifier either thourgh knn_clf or model_path")
# Load a trained KNN model (if one was passed in)
if knn_clf is None:
with open(model_path, 'rb') as f:
knn_clf = pickle.load(f)
# Load image file and find face locations
X_img = face_recognition.load_image_file(X_img_path)
X_face_locations = face_recognition.face_locations(X_img)
# If no faces are found in the image, return an empty result.
if len(X_face_locations) == 0:
return []
# Find encodings for faces in the test iamge
faces_encodings = face_recognition.face_encodings(X_img, known_face_locations=X_face_locations)
# Use the KNN model to find the best matches for the test face
closest_distances = knn_clf.kneighbors(faces_encodings, n_neighbors=1)
are_matches = [closest_distances[0][i][0] <= distance_threshold for i in range(len(X_face_locations))]
# Predict classes and remove classifications that aren't within the threshold
return [(pred, loc) if rec else ("unknown", loc) for pred, loc, rec in zip(knn_clf.predict(faces_encodings), X_face_locations, are_matches)]" | "python" | "def predict(X_img_path, knn_clf=None, model_path=None, distance_threshold=0.6):
"""
Recognizes faces in given image using a trained KNN classifier
:param X_img_path: path to image to be recognized
:param knn_clf: (optional) a knn classifier object. if not specified, model_save_path must be specified.
:param model_path: (optional) path to a pickled knn classifier. if not specified, model_save_path must be knn_clf.
:param distance_threshold: (optional) distance threshold for face classification. the larger it is, the more chance
of mis-classifying an unknown person as a known one.
:return: a list of names and face locations for the recognized faces in the image: [(name, bounding box), ...].
For faces of unrecognized persons, the name 'unknown' will be returned.
"""
if not os.path.isfile(X_img_path) or os.path.splitext(X_img_path)[1][1:] not in ALLOWED_EXTENSIONS:
raise Exception("Invalid image path: {}".format(X_img_path))
if knn_clf is None and model_path is None:
raise Exception("Must supply knn classifier either thourgh knn_clf or model_path")
# Load a trained KNN model (if one was passed in)
if knn_clf is None:
with open(model_path, 'rb') as f:
knn_clf = pickle.load(f)
# Load image file and find face locations
X_img = face_recognition.load_image_file(X_img_path)
X_face_locations = face_recognition.face_locations(X_img)
# If no faces are found in the image, return an empty result.
if len(X_face_locations) == 0:
return []
# Find encodings for faces in the test iamge
faces_encodings = face_recognition.face_encodings(X_img, known_face_locations=X_face_locations)
# Use the KNN model to find the best matches for the test face
closest_distances = knn_clf.kneighbors(faces_encodings, n_neighbors=1)
are_matches = [closest_distances[0][i][0] <= distance_threshold for i in range(len(X_face_locations))]
# Predict classes and remove classifications that aren't within the threshold
return [(pred, loc) if rec else ("unknown", loc) for pred, loc, rec in zip(knn_clf.predict(faces_encodings), X_face_locations, are_matches)]" | [
"def",
"predict",
"(",
"X_img_path",
",",
"knn_clf",
"=",
"None",
",",
"model_path",
"=",
"None",
",",
"distance_threshold",
"=",
"0.6",
")",
":",
"if",
"not",
"os",
".",
"path",
".",
"isfile",
"(",
"X_img_path",
")",
"or",
"os",
".",
"path",
".",
"splitext",
"(",
"X_img_path",
")",
"[",
"1",
"]",
"[",
"1",
":",
"]",
"not",
"in",
"ALLOWED_EXTENSIONS",
":",
"raise",
"Exception",
"(",
"\"Invalid image path: {}\"",
".",
"format",
"(",
"X_img_path",
")",
")",
"if",
"knn_clf",
"is",
"None",
"and",
"model_path",
"is",
"None",
":",
"raise",
"Exception",
"(",
"\"Must supply knn classifier either thourgh knn_clf or model_path\"",
")",
"# Load a trained KNN model (if one was passed in)",
"if",
"knn_clf",
"is",
"None",
":",
"with",
"open",
"(",
"model_path",
",",
"'rb'",
")",
"as",
"f",
":",
"knn_clf",
"=",
"pickle",
".",
"load",
"(",
"f",
")",
"# Load image file and find face locations",
"X_img",
"=",
"face_recognition",
".",
"load_image_file",
"(",
"X_img_path",
")",
"X_face_locations",
"=",
"face_recognition",
".",
"face_locations",
"(",
"X_img",
")",
"# If no faces are found in the image, return an empty result.",
"if",
"len",
"(",
"X_face_locations",
")",
"==",
"0",
":",
"return",
"[",
"]",
"# Find encodings for faces in the test iamge",
"faces_encodings",
"=",
"face_recognition",
".",
"face_encodings",
"(",
"X_img",
",",
"known_face_locations",
"=",
"X_face_locations",
")",
"# Use the KNN model to find the best matches for the test face",
"closest_distances",
"=",
"knn_clf",
".",
"kneighbors",
"(",
"faces_encodings",
",",
"n_neighbors",
"=",
"1",
")",
"are_matches",
"=",
"[",
"closest_distances",
"[",
"0",
"]",
"[",
"i",
"]",
"[",
"0",
"]",
"<=",
"distance_threshold",
"for",
"i",
"in",
"range",
"(",
"len",
"(",
"X_face_locations",
")",
")",
"]",
"# Predict classes and remove classifications that aren't within the threshold",
"return",
"[",
"(",
"pred",
",",
"loc",
")",
"if",
"rec",
"else",
"(",
"\"unknown\"",
",",
"loc",
")",
"for",
"pred",
",",
"loc",
",",
"rec",
"in",
"zip",
"(",
"knn_clf",
".",
"predict",
"(",
"faces_encodings",
")",
",",
"X_face_locations",
",",
"are_matches",
")",
"]"
] | "Recognizes faces in given image using a trained KNN classifier
:param X_img_path: path to image to be recognized
:param knn_clf: (optional) a knn classifier object. if not specified, model_save_path must be specified.
:param model_path: (optional) path to a pickled knn classifier. if not specified, model_save_path must be knn_clf.
:param distance_threshold: (optional) distance threshold for face classification. the larger it is, the more chance
of mis-classifying an unknown person as a known one.
:return: a list of names and face locations for the recognized faces in the image: [(name, bounding box), ...].
For faces of unrecognized persons, the name 'unknown' will be returned." | [
"Recognizes",
"faces",
"in",
"given",
"image",
"using",
"a",
"trained",
"KNN",
"classifier"
] | "train" | "https://github.com/ageitgey/face_recognition/blob/c96b010c02f15e8eeb0f71308c641179ac1f19bb/examples/face_recognition_knn.py#L111-L150" |
"ageitgey/face_recognition" | "examples/face_recognition_knn.py" | "show_prediction_labels_on_image" | "def show_prediction_labels_on_image(img_path, predictions):
"""
Shows the face recognition results visually.
:param img_path: path to image to be recognized
:param predictions: results of the predict function
:return:
"""
pil_image = Image.open(img_path).convert("RGB")
draw = ImageDraw.Draw(pil_image)
for name, (top, right, bottom, left) in predictions:
# Draw a box around the face using the Pillow module
draw.rectangle(((left, top), (right, bottom)), outline=(0, 0, 255))
# There's a bug in Pillow where it blows up with non-UTF-8 text
# when using the default bitmap font
name = name.encode("UTF-8")
# Draw a label with a name below the face
text_width, text_height = draw.textsize(name)
draw.rectangle(((left, bottom - text_height - 10), (right, bottom)), fill=(0, 0, 255), outline=(0, 0, 255))
draw.text((left + 6, bottom - text_height - 5), name, fill=(255, 255, 255, 255))
# Remove the drawing library from memory as per the Pillow docs
del draw
# Display the resulting image
pil_image.show()" | "python" | "def show_prediction_labels_on_image(img_path, predictions):
"""
Shows the face recognition results visually.
:param img_path: path to image to be recognized
:param predictions: results of the predict function
:return:
"""
pil_image = Image.open(img_path).convert("RGB")
draw = ImageDraw.Draw(pil_image)
for name, (top, right, bottom, left) in predictions:
# Draw a box around the face using the Pillow module
draw.rectangle(((left, top), (right, bottom)), outline=(0, 0, 255))
# There's a bug in Pillow where it blows up with non-UTF-8 text
# when using the default bitmap font
name = name.encode("UTF-8")
# Draw a label with a name below the face
text_width, text_height = draw.textsize(name)
draw.rectangle(((left, bottom - text_height - 10), (right, bottom)), fill=(0, 0, 255), outline=(0, 0, 255))
draw.text((left + 6, bottom - text_height - 5), name, fill=(255, 255, 255, 255))
# Remove the drawing library from memory as per the Pillow docs
del draw
# Display the resulting image
pil_image.show()" | [
"def",
"show_prediction_labels_on_image",
"(",
"img_path",
",",
"predictions",
")",
":",
"pil_image",
"=",
"Image",
".",
"open",
"(",
"img_path",
")",
".",
"convert",
"(",
"\"RGB\"",
")",
"draw",
"=",
"ImageDraw",
".",
"Draw",
"(",
"pil_image",
")",
"for",
"name",
",",
"(",
"top",
",",
"right",
",",
"bottom",
",",
"left",
")",
"in",
"predictions",
":",
"# Draw a box around the face using the Pillow module",
"draw",
".",
"rectangle",
"(",
"(",
"(",
"left",
",",
"top",
")",
",",
"(",
"right",
",",
"bottom",
")",
")",
",",
"outline",
"=",
"(",
"0",
",",
"0",
",",
"255",
")",
")",
"# There's a bug in Pillow where it blows up with non-UTF-8 text",
"# when using the default bitmap font",
"name",
"=",
"name",
".",
"encode",
"(",
"\"UTF-8\"",
")",
"# Draw a label with a name below the face",
"text_width",
",",
"text_height",
"=",
"draw",
".",
"textsize",
"(",
"name",
")",
"draw",
".",
"rectangle",
"(",
"(",
"(",
"left",
",",
"bottom",
"-",
"text_height",
"-",
"10",
")",
",",
"(",
"right",
",",
"bottom",
")",
")",
",",
"fill",
"=",
"(",
"0",
",",
"0",
",",
"255",
")",
",",
"outline",
"=",
"(",
"0",
",",
"0",
",",
"255",
")",
")",
"draw",
".",
"text",
"(",
"(",
"left",
"+",
"6",
",",
"bottom",
"-",
"text_height",
"-",
"5",
")",
",",
"name",
",",
"fill",
"=",
"(",
"255",
",",
"255",
",",
"255",
",",
"255",
")",
")",
"# Remove the drawing library from memory as per the Pillow docs",
"del",
"draw",
"# Display the resulting image",
"pil_image",
".",
"show",
"(",
")"
] | "Shows the face recognition results visually.
:param img_path: path to image to be recognized
:param predictions: results of the predict function
:return:" | [
"Shows",
"the",
"face",
"recognition",
"results",
"visually",
"."
] | "train" | "https://github.com/ageitgey/face_recognition/blob/c96b010c02f15e8eeb0f71308c641179ac1f19bb/examples/face_recognition_knn.py#L153-L181" |
"ageitgey/face_recognition" | "face_recognition/api.py" | "_rect_to_css" | "def _rect_to_css(rect):
"""
Convert a dlib 'rect' object to a plain tuple in (top, right, bottom, left) order
:param rect: a dlib 'rect' object
:return: a plain tuple representation of the rect in (top, right, bottom, left) order
"""
return rect.top(), rect.right(), rect.bottom(), rect.left()" | "python" | "def _rect_to_css(rect):
"""
Convert a dlib 'rect' object to a plain tuple in (top, right, bottom, left) order
:param rect: a dlib 'rect' object
:return: a plain tuple representation of the rect in (top, right, bottom, left) order
"""
return rect.top(), rect.right(), rect.bottom(), rect.left()" | [
"def",
"_rect_to_css",
"(",
"rect",
")",
":",
"return",
"rect",
".",
"top",
"(",
")",
",",
"rect",
".",
"right",
"(",
")",
",",
"rect",
".",
"bottom",
"(",
")",
",",
"rect",
".",
"left",
"(",
")"
] | "Convert a dlib 'rect' object to a plain tuple in (top, right, bottom, left) order
:param rect: a dlib 'rect' object
:return: a plain tuple representation of the rect in (top, right, bottom, left) order" | [
"Convert",
"a",
"dlib",
"rect",
"object",
"to",
"a",
"plain",
"tuple",
"in",
"(",
"top",
"right",
"bottom",
"left",
")",
"order"
] | "train" | "https://github.com/ageitgey/face_recognition/blob/c96b010c02f15e8eeb0f71308c641179ac1f19bb/face_recognition/api.py#L32-L39" |
"ageitgey/face_recognition" | "face_recognition/api.py" | "_trim_css_to_bounds" | "def _trim_css_to_bounds(css, image_shape):
"""
Make sure a tuple in (top, right, bottom, left) order is within the bounds of the image.
:param css: plain tuple representation of the rect in (top, right, bottom, left) order
:param image_shape: numpy shape of the image array
:return: a trimmed plain tuple representation of the rect in (top, right, bottom, left) order
"""
return max(css[0], 0), min(css[1], image_shape[1]), min(css[2], image_shape[0]), max(css[3], 0)" | "python" | "def _trim_css_to_bounds(css, image_shape):
"""
Make sure a tuple in (top, right, bottom, left) order is within the bounds of the image.
:param css: plain tuple representation of the rect in (top, right, bottom, left) order
:param image_shape: numpy shape of the image array
:return: a trimmed plain tuple representation of the rect in (top, right, bottom, left) order
"""
return max(css[0], 0), min(css[1], image_shape[1]), min(css[2], image_shape[0]), max(css[3], 0)" | [
"def",
"_trim_css_to_bounds",
"(",
"css",
",",
"image_shape",
")",
":",
"return",
"max",
"(",
"css",
"[",
"0",
"]",
",",
"0",
")",
",",
"min",
"(",
"css",
"[",
"1",
"]",
",",
"image_shape",
"[",
"1",
"]",
")",
",",
"min",
"(",
"css",
"[",
"2",
"]",
",",
"image_shape",
"[",
"0",
"]",
")",
",",
"max",
"(",
"css",
"[",
"3",
"]",
",",
"0",
")"
] | "Make sure a tuple in (top, right, bottom, left) order is within the bounds of the image.
:param css: plain tuple representation of the rect in (top, right, bottom, left) order
:param image_shape: numpy shape of the image array
:return: a trimmed plain tuple representation of the rect in (top, right, bottom, left) order" | [
"Make",
"sure",
"a",
"tuple",
"in",
"(",
"top",
"right",
"bottom",
"left",
")",
"order",
"is",
"within",
"the",
"bounds",
"of",
"the",
"image",
"."
] | "train" | "https://github.com/ageitgey/face_recognition/blob/c96b010c02f15e8eeb0f71308c641179ac1f19bb/face_recognition/api.py#L52-L60" |
"ageitgey/face_recognition" | "face_recognition/api.py" | "face_distance" | "def face_distance(face_encodings, face_to_compare):
"""
Given a list of face encodings, compare them to a known face encoding and get a euclidean distance
for each comparison face. The distance tells you how similar the faces are.
:param faces: List of face encodings to compare
:param face_to_compare: A face encoding to compare against
:return: A numpy ndarray with the distance for each face in the same order as the 'faces' array
"""
if len(face_encodings) == 0:
return np.empty((0))
return np.linalg.norm(face_encodings - face_to_compare, axis=1)" | "python" | "def face_distance(face_encodings, face_to_compare):
"""
Given a list of face encodings, compare them to a known face encoding and get a euclidean distance
for each comparison face. The distance tells you how similar the faces are.
:param faces: List of face encodings to compare
:param face_to_compare: A face encoding to compare against
:return: A numpy ndarray with the distance for each face in the same order as the 'faces' array
"""
if len(face_encodings) == 0:
return np.empty((0))
return np.linalg.norm(face_encodings - face_to_compare, axis=1)" | [
"def",
"face_distance",
"(",
"face_encodings",
",",
"face_to_compare",
")",
":",
"if",
"len",
"(",
"face_encodings",
")",
"==",
"0",
":",
"return",
"np",
".",
"empty",
"(",
"(",
"0",
")",
")",
"return",
"np",
".",
"linalg",
".",
"norm",
"(",
"face_encodings",
"-",
"face_to_compare",
",",
"axis",
"=",
"1",
")"
] | "Given a list of face encodings, compare them to a known face encoding and get a euclidean distance
for each comparison face. The distance tells you how similar the faces are.
:param faces: List of face encodings to compare
:param face_to_compare: A face encoding to compare against
:return: A numpy ndarray with the distance for each face in the same order as the 'faces' array" | [
"Given",
"a",
"list",
"of",
"face",
"encodings",
"compare",
"them",
"to",
"a",
"known",
"face",
"encoding",
"and",
"get",
"a",
"euclidean",
"distance",
"for",
"each",
"comparison",
"face",
".",
"The",
"distance",
"tells",
"you",
"how",
"similar",
"the",
"faces",
"are",
"."
] | "train" | "https://github.com/ageitgey/face_recognition/blob/c96b010c02f15e8eeb0f71308c641179ac1f19bb/face_recognition/api.py#L63-L75" |
"ageitgey/face_recognition" | "face_recognition/api.py" | "load_image_file" | "def load_image_file(file, mode='RGB'):
"""
Loads an image file (.jpg, .png, etc) into a numpy array
:param file: image file name or file object to load
:param mode: format to convert the image to. Only 'RGB' (8-bit RGB, 3 channels) and 'L' (black and white) are supported.
:return: image contents as numpy array
"""
im = PIL.Image.open(file)
if mode:
im = im.convert(mode)
return np.array(im)" | "python" | "def load_image_file(file, mode='RGB'):
"""
Loads an image file (.jpg, .png, etc) into a numpy array
:param file: image file name or file object to load
:param mode: format to convert the image to. Only 'RGB' (8-bit RGB, 3 channels) and 'L' (black and white) are supported.
:return: image contents as numpy array
"""
im = PIL.Image.open(file)
if mode:
im = im.convert(mode)
return np.array(im)" | [
"def",
"load_image_file",
"(",
"file",
",",
"mode",
"=",
"'RGB'",
")",
":",
"im",
"=",
"PIL",
".",
"Image",
".",
"open",
"(",
"file",
")",
"if",
"mode",
":",
"im",
"=",
"im",
".",
"convert",
"(",
"mode",
")",
"return",
"np",
".",
"array",
"(",
"im",
")"
] | "Loads an image file (.jpg, .png, etc) into a numpy array
:param file: image file name or file object to load
:param mode: format to convert the image to. Only 'RGB' (8-bit RGB, 3 channels) and 'L' (black and white) are supported.
:return: image contents as numpy array" | [
"Loads",
"an",
"image",
"file",
"(",
".",
"jpg",
".",
"png",
"etc",
")",
"into",
"a",
"numpy",
"array"
] | "train" | "https://github.com/ageitgey/face_recognition/blob/c96b010c02f15e8eeb0f71308c641179ac1f19bb/face_recognition/api.py#L78-L89" |
"ageitgey/face_recognition" | "face_recognition/api.py" | "_raw_face_locations" | "def _raw_face_locations(img, number_of_times_to_upsample=1, model="hog"):
"""
Returns an array of bounding boxes of human faces in a image
:param img: An image (as a numpy array)
:param number_of_times_to_upsample: How many times to upsample the image looking for faces. Higher numbers find smaller faces.
:param model: Which face detection model to use. "hog" is less accurate but faster on CPUs. "cnn" is a more accurate
deep-learning model which is GPU/CUDA accelerated (if available). The default is "hog".
:return: A list of dlib 'rect' objects of found face locations
"""
if model == "cnn":
return cnn_face_detector(img, number_of_times_to_upsample)
else:
return face_detector(img, number_of_times_to_upsample)" | "python" | "def _raw_face_locations(img, number_of_times_to_upsample=1, model="hog"):
"""
Returns an array of bounding boxes of human faces in a image
:param img: An image (as a numpy array)
:param number_of_times_to_upsample: How many times to upsample the image looking for faces. Higher numbers find smaller faces.
:param model: Which face detection model to use. "hog" is less accurate but faster on CPUs. "cnn" is a more accurate
deep-learning model which is GPU/CUDA accelerated (if available). The default is "hog".
:return: A list of dlib 'rect' objects of found face locations
"""
if model == "cnn":
return cnn_face_detector(img, number_of_times_to_upsample)
else:
return face_detector(img, number_of_times_to_upsample)" | [
"def",
"_raw_face_locations",
"(",
"img",
",",
"number_of_times_to_upsample",
"=",
"1",
",",
"model",
"=",
"\"hog\"",
")",
":",
"if",
"model",
"==",
"\"cnn\"",
":",
"return",
"cnn_face_detector",
"(",
"img",
",",
"number_of_times_to_upsample",
")",
"else",
":",
"return",
"face_detector",
"(",
"img",
",",
"number_of_times_to_upsample",
")"
] | "Returns an array of bounding boxes of human faces in a image
:param img: An image (as a numpy array)
:param number_of_times_to_upsample: How many times to upsample the image looking for faces. Higher numbers find smaller faces.
:param model: Which face detection model to use. "hog" is less accurate but faster on CPUs. "cnn" is a more accurate
deep-learning model which is GPU/CUDA accelerated (if available). The default is "hog".
:return: A list of dlib 'rect' objects of found face locations" | [
"Returns",
"an",
"array",
"of",
"bounding",
"boxes",
"of",
"human",
"faces",
"in",
"a",
"image"
] | "train" | "https://github.com/ageitgey/face_recognition/blob/c96b010c02f15e8eeb0f71308c641179ac1f19bb/face_recognition/api.py#L92-L105" |
"ageitgey/face_recognition" | "face_recognition/api.py" | "face_locations" | "def face_locations(img, number_of_times_to_upsample=1, model="hog"):
"""
Returns an array of bounding boxes of human faces in a image
:param img: An image (as a numpy array)
:param number_of_times_to_upsample: How many times to upsample the image looking for faces. Higher numbers find smaller faces.
:param model: Which face detection model to use. "hog" is less accurate but faster on CPUs. "cnn" is a more accurate
deep-learning model which is GPU/CUDA accelerated (if available). The default is "hog".
:return: A list of tuples of found face locations in css (top, right, bottom, left) order
"""
if model == "cnn":
return [_trim_css_to_bounds(_rect_to_css(face.rect), img.shape) for face in _raw_face_locations(img, number_of_times_to_upsample, "cnn")]
else:
return [_trim_css_to_bounds(_rect_to_css(face), img.shape) for face in _raw_face_locations(img, number_of_times_to_upsample, model)]" | "python" | "def face_locations(img, number_of_times_to_upsample=1, model="hog"):
"""
Returns an array of bounding boxes of human faces in a image
:param img: An image (as a numpy array)
:param number_of_times_to_upsample: How many times to upsample the image looking for faces. Higher numbers find smaller faces.
:param model: Which face detection model to use. "hog" is less accurate but faster on CPUs. "cnn" is a more accurate
deep-learning model which is GPU/CUDA accelerated (if available). The default is "hog".
:return: A list of tuples of found face locations in css (top, right, bottom, left) order
"""
if model == "cnn":
return [_trim_css_to_bounds(_rect_to_css(face.rect), img.shape) for face in _raw_face_locations(img, number_of_times_to_upsample, "cnn")]
else:
return [_trim_css_to_bounds(_rect_to_css(face), img.shape) for face in _raw_face_locations(img, number_of_times_to_upsample, model)]" | [
"def",
"face_locations",
"(",
"img",
",",
"number_of_times_to_upsample",
"=",
"1",
",",
"model",
"=",
"\"hog\"",
")",
":",
"if",
"model",
"==",
"\"cnn\"",
":",
"return",
"[",
"_trim_css_to_bounds",
"(",
"_rect_to_css",
"(",
"face",
".",
"rect",
")",
",",
"img",
".",
"shape",
")",
"for",
"face",
"in",
"_raw_face_locations",
"(",
"img",
",",
"number_of_times_to_upsample",
",",
"\"cnn\"",
")",
"]",
"else",
":",
"return",
"[",
"_trim_css_to_bounds",
"(",
"_rect_to_css",
"(",
"face",
")",
",",
"img",
".",
"shape",
")",
"for",
"face",
"in",
"_raw_face_locations",
"(",
"img",
",",
"number_of_times_to_upsample",
",",
"model",
")",
"]"
] | "Returns an array of bounding boxes of human faces in a image
:param img: An image (as a numpy array)
:param number_of_times_to_upsample: How many times to upsample the image looking for faces. Higher numbers find smaller faces.
:param model: Which face detection model to use. "hog" is less accurate but faster on CPUs. "cnn" is a more accurate
deep-learning model which is GPU/CUDA accelerated (if available). The default is "hog".
:return: A list of tuples of found face locations in css (top, right, bottom, left) order" | [
"Returns",
"an",
"array",
"of",
"bounding",
"boxes",
"of",
"human",
"faces",
"in",
"a",
"image"
] | "train" | "https://github.com/ageitgey/face_recognition/blob/c96b010c02f15e8eeb0f71308c641179ac1f19bb/face_recognition/api.py#L108-L121" |
"ageitgey/face_recognition" | "face_recognition/api.py" | "batch_face_locations" | "def batch_face_locations(images, number_of_times_to_upsample=1, batch_size=128):
"""
Returns an 2d array of bounding boxes of human faces in a image using the cnn face detector
If you are using a GPU, this can give you much faster results since the GPU
can process batches of images at once. If you aren't using a GPU, you don't need this function.
:param img: A list of images (each as a numpy array)
:param number_of_times_to_upsample: How many times to upsample the image looking for faces. Higher numbers find smaller faces.
:param batch_size: How many images to include in each GPU processing batch.
:return: A list of tuples of found face locations in css (top, right, bottom, left) order
"""
def convert_cnn_detections_to_css(detections):
return [_trim_css_to_bounds(_rect_to_css(face.rect), images[0].shape) for face in detections]
raw_detections_batched = _raw_face_locations_batched(images, number_of_times_to_upsample, batch_size)
return list(map(convert_cnn_detections_to_css, raw_detections_batched))" | "python" | "def batch_face_locations(images, number_of_times_to_upsample=1, batch_size=128):
"""
Returns an 2d array of bounding boxes of human faces in a image using the cnn face detector
If you are using a GPU, this can give you much faster results since the GPU
can process batches of images at once. If you aren't using a GPU, you don't need this function.
:param img: A list of images (each as a numpy array)
:param number_of_times_to_upsample: How many times to upsample the image looking for faces. Higher numbers find smaller faces.
:param batch_size: How many images to include in each GPU processing batch.
:return: A list of tuples of found face locations in css (top, right, bottom, left) order
"""
def convert_cnn_detections_to_css(detections):
return [_trim_css_to_bounds(_rect_to_css(face.rect), images[0].shape) for face in detections]
raw_detections_batched = _raw_face_locations_batched(images, number_of_times_to_upsample, batch_size)
return list(map(convert_cnn_detections_to_css, raw_detections_batched))" | [
"def",
"batch_face_locations",
"(",
"images",
",",
"number_of_times_to_upsample",
"=",
"1",
",",
"batch_size",
"=",
"128",
")",
":",
"def",
"convert_cnn_detections_to_css",
"(",
"detections",
")",
":",
"return",
"[",
"_trim_css_to_bounds",
"(",
"_rect_to_css",
"(",
"face",
".",
"rect",
")",
",",
"images",
"[",
"0",
"]",
".",
"shape",
")",
"for",
"face",
"in",
"detections",
"]",
"raw_detections_batched",
"=",
"_raw_face_locations_batched",
"(",
"images",
",",
"number_of_times_to_upsample",
",",
"batch_size",
")",
"return",
"list",
"(",
"map",
"(",
"convert_cnn_detections_to_css",
",",
"raw_detections_batched",
")",
")"
] | "Returns an 2d array of bounding boxes of human faces in a image using the cnn face detector
If you are using a GPU, this can give you much faster results since the GPU
can process batches of images at once. If you aren't using a GPU, you don't need this function.
:param img: A list of images (each as a numpy array)
:param number_of_times_to_upsample: How many times to upsample the image looking for faces. Higher numbers find smaller faces.
:param batch_size: How many images to include in each GPU processing batch.
:return: A list of tuples of found face locations in css (top, right, bottom, left) order" | [
"Returns",
"an",
"2d",
"array",
"of",
"bounding",
"boxes",
"of",
"human",
"faces",
"in",
"a",
"image",
"using",
"the",
"cnn",
"face",
"detector",
"If",
"you",
"are",
"using",
"a",
"GPU",
"this",
"can",
"give",
"you",
"much",
"faster",
"results",
"since",
"the",
"GPU",
"can",
"process",
"batches",
"of",
"images",
"at",
"once",
".",
"If",
"you",
"aren",
"t",
"using",
"a",
"GPU",
"you",
"don",
"t",
"need",
"this",
"function",
"."
] | "train" | "https://github.com/ageitgey/face_recognition/blob/c96b010c02f15e8eeb0f71308c641179ac1f19bb/face_recognition/api.py#L135-L151" |
"ageitgey/face_recognition" | "face_recognition/api.py" | "face_landmarks" | "def face_landmarks(face_image, face_locations=None, model="large"):
"""
Given an image, returns a dict of face feature locations (eyes, nose, etc) for each face in the image
:param face_image: image to search
:param face_locations: Optionally provide a list of face locations to check.
:param model: Optional - which model to use. "large" (default) or "small" which only returns 5 points but is faster.
:return: A list of dicts of face feature locations (eyes, nose, etc)
"""
landmarks = _raw_face_landmarks(face_image, face_locations, model)
landmarks_as_tuples = [[(p.x, p.y) for p in landmark.parts()] for landmark in landmarks]
# For a definition of each point index, see https://cdn-images-1.medium.com/max/1600/1*AbEg31EgkbXSQehuNJBlWg.png
if model == 'large':
return [{
"chin": points[0:17],
"left_eyebrow": points[17:22],
"right_eyebrow": points[22:27],
"nose_bridge": points[27:31],
"nose_tip": points[31:36],
"left_eye": points[36:42],
"right_eye": points[42:48],
"top_lip": points[48:55] + [points[64]] + [points[63]] + [points[62]] + [points[61]] + [points[60]],
"bottom_lip": points[54:60] + [points[48]] + [points[60]] + [points[67]] + [points[66]] + [points[65]] + [points[64]]
} for points in landmarks_as_tuples]
elif model == 'small':
return [{
"nose_tip": [points[4]],
"left_eye": points[2:4],
"right_eye": points[0:2],
} for points in landmarks_as_tuples]
else:
raise ValueError("Invalid landmarks model type. Supported models are ['small', 'large'].")" | "python" | "def face_landmarks(face_image, face_locations=None, model="large"):
"""
Given an image, returns a dict of face feature locations (eyes, nose, etc) for each face in the image
:param face_image: image to search
:param face_locations: Optionally provide a list of face locations to check.
:param model: Optional - which model to use. "large" (default) or "small" which only returns 5 points but is faster.
:return: A list of dicts of face feature locations (eyes, nose, etc)
"""
landmarks = _raw_face_landmarks(face_image, face_locations, model)
landmarks_as_tuples = [[(p.x, p.y) for p in landmark.parts()] for landmark in landmarks]
# For a definition of each point index, see https://cdn-images-1.medium.com/max/1600/1*AbEg31EgkbXSQehuNJBlWg.png
if model == 'large':
return [{
"chin": points[0:17],
"left_eyebrow": points[17:22],
"right_eyebrow": points[22:27],
"nose_bridge": points[27:31],
"nose_tip": points[31:36],
"left_eye": points[36:42],
"right_eye": points[42:48],
"top_lip": points[48:55] + [points[64]] + [points[63]] + [points[62]] + [points[61]] + [points[60]],
"bottom_lip": points[54:60] + [points[48]] + [points[60]] + [points[67]] + [points[66]] + [points[65]] + [points[64]]
} for points in landmarks_as_tuples]
elif model == 'small':
return [{
"nose_tip": [points[4]],
"left_eye": points[2:4],
"right_eye": points[0:2],
} for points in landmarks_as_tuples]
else:
raise ValueError("Invalid landmarks model type. Supported models are ['small', 'large'].")" | [
"def",
"face_landmarks",
"(",
"face_image",
",",
"face_locations",
"=",
"None",
",",
"model",
"=",
"\"large\"",
")",
":",
"landmarks",
"=",
"_raw_face_landmarks",
"(",
"face_image",
",",
"face_locations",
",",
"model",
")",
"landmarks_as_tuples",
"=",
"[",
"[",
"(",
"p",
".",
"x",
",",
"p",
".",
"y",
")",
"for",
"p",
"in",
"landmark",
".",
"parts",
"(",
")",
"]",
"for",
"landmark",
"in",
"landmarks",
"]",
"# For a definition of each point index, see https://cdn-images-1.medium.com/max/1600/1*AbEg31EgkbXSQehuNJBlWg.png",
"if",
"model",
"==",
"'large'",
":",
"return",
"[",
"{",
"\"chin\"",
":",
"points",
"[",
"0",
":",
"17",
"]",
",",
"\"left_eyebrow\"",
":",
"points",
"[",
"17",
":",
"22",
"]",
",",
"\"right_eyebrow\"",
":",
"points",
"[",
"22",
":",
"27",
"]",
",",
"\"nose_bridge\"",
":",
"points",
"[",
"27",
":",
"31",
"]",
",",
"\"nose_tip\"",
":",
"points",
"[",
"31",
":",
"36",
"]",
",",
"\"left_eye\"",
":",
"points",
"[",
"36",
":",
"42",
"]",
",",
"\"right_eye\"",
":",
"points",
"[",
"42",
":",
"48",
"]",
",",
"\"top_lip\"",
":",
"points",
"[",
"48",
":",
"55",
"]",
"+",
"[",
"points",
"[",
"64",
"]",
"]",
"+",
"[",
"points",
"[",
"63",
"]",
"]",
"+",
"[",
"points",
"[",
"62",
"]",
"]",
"+",
"[",
"points",
"[",
"61",
"]",
"]",
"+",
"[",
"points",
"[",
"60",
"]",
"]",
",",
"\"bottom_lip\"",
":",
"points",
"[",
"54",
":",
"60",
"]",
"+",
"[",
"points",
"[",
"48",
"]",
"]",
"+",
"[",
"points",
"[",
"60",
"]",
"]",
"+",
"[",
"points",
"[",
"67",
"]",
"]",
"+",
"[",
"points",
"[",
"66",
"]",
"]",
"+",
"[",
"points",
"[",
"65",
"]",
"]",
"+",
"[",
"points",
"[",
"64",
"]",
"]",
"}",
"for",
"points",
"in",
"landmarks_as_tuples",
"]",
"elif",
"model",
"==",
"'small'",
":",
"return",
"[",
"{",
"\"nose_tip\"",
":",
"[",
"points",
"[",
"4",
"]",
"]",
",",
"\"left_eye\"",
":",
"points",
"[",
"2",
":",
"4",
"]",
",",
"\"right_eye\"",
":",
"points",
"[",
"0",
":",
"2",
"]",
",",
"}",
"for",
"points",
"in",
"landmarks_as_tuples",
"]",
"else",
":",
"raise",
"ValueError",
"(",
"\"Invalid landmarks model type. Supported models are ['small', 'large'].\"",
")"
] | "Given an image, returns a dict of face feature locations (eyes, nose, etc) for each face in the image
:param face_image: image to search
:param face_locations: Optionally provide a list of face locations to check.
:param model: Optional - which model to use. "large" (default) or "small" which only returns 5 points but is faster.
:return: A list of dicts of face feature locations (eyes, nose, etc)" | [
"Given",
"an",
"image",
"returns",
"a",
"dict",
"of",
"face",
"feature",
"locations",
"(",
"eyes",
"nose",
"etc",
")",
"for",
"each",
"face",
"in",
"the",
"image"
] | "train" | "https://github.com/ageitgey/face_recognition/blob/c96b010c02f15e8eeb0f71308c641179ac1f19bb/face_recognition/api.py#L168-L200" |
"ageitgey/face_recognition" | "face_recognition/api.py" | "face_encodings" | "def face_encodings(face_image, known_face_locations=None, num_jitters=1):
"""
Given an image, return the 128-dimension face encoding for each face in the image.
:param face_image: The image that contains one or more faces
:param known_face_locations: Optional - the bounding boxes of each face if you already know them.
:param num_jitters: How many times to re-sample the face when calculating encoding. Higher is more accurate, but slower (i.e. 100 is 100x slower)
:return: A list of 128-dimensional face encodings (one for each face in the image)
"""
raw_landmarks = _raw_face_landmarks(face_image, known_face_locations, model="small")
return [np.array(face_encoder.compute_face_descriptor(face_image, raw_landmark_set, num_jitters)) for raw_landmark_set in raw_landmarks]" | "python" | "def face_encodings(face_image, known_face_locations=None, num_jitters=1):
"""
Given an image, return the 128-dimension face encoding for each face in the image.
:param face_image: The image that contains one or more faces
:param known_face_locations: Optional - the bounding boxes of each face if you already know them.
:param num_jitters: How many times to re-sample the face when calculating encoding. Higher is more accurate, but slower (i.e. 100 is 100x slower)
:return: A list of 128-dimensional face encodings (one for each face in the image)
"""
raw_landmarks = _raw_face_landmarks(face_image, known_face_locations, model="small")
return [np.array(face_encoder.compute_face_descriptor(face_image, raw_landmark_set, num_jitters)) for raw_landmark_set in raw_landmarks]" | [
"def",
"face_encodings",
"(",
"face_image",
",",
"known_face_locations",
"=",
"None",
",",
"num_jitters",
"=",
"1",
")",
":",
"raw_landmarks",
"=",
"_raw_face_landmarks",
"(",
"face_image",
",",
"known_face_locations",
",",
"model",
"=",
"\"small\"",
")",
"return",
"[",
"np",
".",
"array",
"(",
"face_encoder",
".",
"compute_face_descriptor",
"(",
"face_image",
",",
"raw_landmark_set",
",",
"num_jitters",
")",
")",
"for",
"raw_landmark_set",
"in",
"raw_landmarks",
"]"
] | "Given an image, return the 128-dimension face encoding for each face in the image.
:param face_image: The image that contains one or more faces
:param known_face_locations: Optional - the bounding boxes of each face if you already know them.
:param num_jitters: How many times to re-sample the face when calculating encoding. Higher is more accurate, but slower (i.e. 100 is 100x slower)
:return: A list of 128-dimensional face encodings (one for each face in the image)" | [
"Given",
"an",
"image",
"return",
"the",
"128",
"-",
"dimension",
"face",
"encoding",
"for",
"each",
"face",
"in",
"the",
"image",
"."
] | "train" | "https://github.com/ageitgey/face_recognition/blob/c96b010c02f15e8eeb0f71308c641179ac1f19bb/face_recognition/api.py#L203-L213" |
"apache/spark" | "python/pyspark/sql/types.py" | "_parse_datatype_string" | "def _parse_datatype_string(s):
"""
Parses the given data type string to a :class:`DataType`. The data type string format equals
to :class:`DataType.simpleString`, except that top level struct type can omit
the ``struct<>`` and atomic types use ``typeName()`` as their format, e.g. use ``byte`` instead
of ``tinyint`` for :class:`ByteType`. We can also use ``int`` as a short name
for :class:`IntegerType`. Since Spark 2.3, this also supports a schema in a DDL-formatted
string and case-insensitive strings.
>>> _parse_datatype_string("int ")
IntegerType
>>> _parse_datatype_string("INT ")
IntegerType
>>> _parse_datatype_string("a: byte, b: decimal( 16 , 8 ) ")
StructType(List(StructField(a,ByteType,true),StructField(b,DecimalType(16,8),true)))
>>> _parse_datatype_string("a DOUBLE, b STRING")
StructType(List(StructField(a,DoubleType,true),StructField(b,StringType,true)))
>>> _parse_datatype_string("a: array< short>")
StructType(List(StructField(a,ArrayType(ShortType,true),true)))
>>> _parse_datatype_string(" map<string , string > ")
MapType(StringType,StringType,true)
>>> # Error cases
>>> _parse_datatype_string("blabla") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ParseException:...
>>> _parse_datatype_string("a: int,") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ParseException:...
>>> _parse_datatype_string("array<int") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ParseException:...
>>> _parse_datatype_string("map<int, boolean>>") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ParseException:...
"""
sc = SparkContext._active_spark_context
def from_ddl_schema(type_str):
return _parse_datatype_json_string(
sc._jvm.org.apache.spark.sql.types.StructType.fromDDL(type_str).json())
def from_ddl_datatype(type_str):
return _parse_datatype_json_string(
sc._jvm.org.apache.spark.sql.api.python.PythonSQLUtils.parseDataType(type_str).json())
try:
# DDL format, "fieldname datatype, fieldname datatype".
return from_ddl_schema(s)
except Exception as e:
try:
# For backwards compatibility, "integer", "struct<fieldname: datatype>" and etc.
return from_ddl_datatype(s)
except:
try:
# For backwards compatibility, "fieldname: datatype, fieldname: datatype" case.
return from_ddl_datatype("struct<%s>" % s.strip())
except:
raise e" | "python" | "def _parse_datatype_string(s):
"""
Parses the given data type string to a :class:`DataType`. The data type string format equals
to :class:`DataType.simpleString`, except that top level struct type can omit
the ``struct<>`` and atomic types use ``typeName()`` as their format, e.g. use ``byte`` instead
of ``tinyint`` for :class:`ByteType`. We can also use ``int`` as a short name
for :class:`IntegerType`. Since Spark 2.3, this also supports a schema in a DDL-formatted
string and case-insensitive strings.
>>> _parse_datatype_string("int ")
IntegerType
>>> _parse_datatype_string("INT ")
IntegerType
>>> _parse_datatype_string("a: byte, b: decimal( 16 , 8 ) ")
StructType(List(StructField(a,ByteType,true),StructField(b,DecimalType(16,8),true)))
>>> _parse_datatype_string("a DOUBLE, b STRING")
StructType(List(StructField(a,DoubleType,true),StructField(b,StringType,true)))
>>> _parse_datatype_string("a: array< short>")
StructType(List(StructField(a,ArrayType(ShortType,true),true)))
>>> _parse_datatype_string(" map<string , string > ")
MapType(StringType,StringType,true)
>>> # Error cases
>>> _parse_datatype_string("blabla") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ParseException:...
>>> _parse_datatype_string("a: int,") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ParseException:...
>>> _parse_datatype_string("array<int") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ParseException:...
>>> _parse_datatype_string("map<int, boolean>>") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ParseException:...
"""
sc = SparkContext._active_spark_context
def from_ddl_schema(type_str):
return _parse_datatype_json_string(
sc._jvm.org.apache.spark.sql.types.StructType.fromDDL(type_str).json())
def from_ddl_datatype(type_str):
return _parse_datatype_json_string(
sc._jvm.org.apache.spark.sql.api.python.PythonSQLUtils.parseDataType(type_str).json())
try:
# DDL format, "fieldname datatype, fieldname datatype".
return from_ddl_schema(s)
except Exception as e:
try:
# For backwards compatibility, "integer", "struct<fieldname: datatype>" and etc.
return from_ddl_datatype(s)
except:
try:
# For backwards compatibility, "fieldname: datatype, fieldname: datatype" case.
return from_ddl_datatype("struct<%s>" % s.strip())
except:
raise e" | [
"def",
"_parse_datatype_string",
"(",
"s",
")",
":",
"sc",
"=",
"SparkContext",
".",
"_active_spark_context",
"def",
"from_ddl_schema",
"(",
"type_str",
")",
":",
"return",
"_parse_datatype_json_string",
"(",
"sc",
".",
"_jvm",
".",
"org",
".",
"apache",
".",
"spark",
".",
"sql",
".",
"types",
".",
"StructType",
".",
"fromDDL",
"(",
"type_str",
")",
".",
"json",
"(",
")",
")",
"def",
"from_ddl_datatype",
"(",
"type_str",
")",
":",
"return",
"_parse_datatype_json_string",
"(",
"sc",
".",
"_jvm",
".",
"org",
".",
"apache",
".",
"spark",
".",
"sql",
".",
"api",
".",
"python",
".",
"PythonSQLUtils",
".",
"parseDataType",
"(",
"type_str",
")",
".",
"json",
"(",
")",
")",
"try",
":",
"# DDL format, \"fieldname datatype, fieldname datatype\".",
"return",
"from_ddl_schema",
"(",
"s",
")",
"except",
"Exception",
"as",
"e",
":",
"try",
":",
"# For backwards compatibility, \"integer\", \"struct<fieldname: datatype>\" and etc.",
"return",
"from_ddl_datatype",
"(",
"s",
")",
"except",
":",
"try",
":",
"# For backwards compatibility, \"fieldname: datatype, fieldname: datatype\" case.",
"return",
"from_ddl_datatype",
"(",
"\"struct<%s>\"",
"%",
"s",
".",
"strip",
"(",
")",
")",
"except",
":",
"raise",
"e"
] | "Parses the given data type string to a :class:`DataType`. The data type string format equals
to :class:`DataType.simpleString`, except that top level struct type can omit
the ``struct<>`` and atomic types use ``typeName()`` as their format, e.g. use ``byte`` instead
of ``tinyint`` for :class:`ByteType`. We can also use ``int`` as a short name
for :class:`IntegerType`. Since Spark 2.3, this also supports a schema in a DDL-formatted
string and case-insensitive strings.
>>> _parse_datatype_string("int ")
IntegerType
>>> _parse_datatype_string("INT ")
IntegerType
>>> _parse_datatype_string("a: byte, b: decimal( 16 , 8 ) ")
StructType(List(StructField(a,ByteType,true),StructField(b,DecimalType(16,8),true)))
>>> _parse_datatype_string("a DOUBLE, b STRING")
StructType(List(StructField(a,DoubleType,true),StructField(b,StringType,true)))
>>> _parse_datatype_string("a: array< short>")
StructType(List(StructField(a,ArrayType(ShortType,true),true)))
>>> _parse_datatype_string(" map<string , string > ")
MapType(StringType,StringType,true)
>>> # Error cases
>>> _parse_datatype_string("blabla") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ParseException:...
>>> _parse_datatype_string("a: int,") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ParseException:...
>>> _parse_datatype_string("array<int") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ParseException:...
>>> _parse_datatype_string("map<int, boolean>>") # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ParseException:..." | [
"Parses",
"the",
"given",
"data",
"type",
"string",
"to",
"a",
":",
"class",
":",
"DataType",
".",
"The",
"data",
"type",
"string",
"format",
"equals",
"to",
":",
"class",
":",
"DataType",
".",
"simpleString",
"except",
"that",
"top",
"level",
"struct",
"type",
"can",
"omit",
"the",
"struct<",
">",
"and",
"atomic",
"types",
"use",
"typeName",
"()",
"as",
"their",
"format",
"e",
".",
"g",
".",
"use",
"byte",
"instead",
"of",
"tinyint",
"for",
":",
"class",
":",
"ByteType",
".",
"We",
"can",
"also",
"use",
"int",
"as",
"a",
"short",
"name",
"for",
":",
"class",
":",
"IntegerType",
".",
"Since",
"Spark",
"2",
".",
"3",
"this",
"also",
"supports",
"a",
"schema",
"in",
"a",
"DDL",
"-",
"formatted",
"string",
"and",
"case",
"-",
"insensitive",
"strings",
"."
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/sql/types.py#L758-L820" |
"apache/spark" | "python/pyspark/sql/types.py" | "_int_size_to_type" | "def _int_size_to_type(size):
"""
Return the Catalyst datatype from the size of integers.
"""
if size <= 8:
return ByteType
if size <= 16:
return ShortType
if size <= 32:
return IntegerType
if size <= 64:
return LongType" | "python" | "def _int_size_to_type(size):
"""
Return the Catalyst datatype from the size of integers.
"""
if size <= 8:
return ByteType
if size <= 16:
return ShortType
if size <= 32:
return IntegerType
if size <= 64:
return LongType" | [
"def",
"_int_size_to_type",
"(",
"size",
")",
":",
"if",
"size",
"<=",
"8",
":",
"return",
"ByteType",
"if",
"size",
"<=",
"16",
":",
"return",
"ShortType",
"if",
"size",
"<=",
"32",
":",
"return",
"IntegerType",
"if",
"size",
"<=",
"64",
":",
"return",
"LongType"
] | "Return the Catalyst datatype from the size of integers." | [
"Return",
"the",
"Catalyst",
"datatype",
"from",
"the",
"size",
"of",
"integers",
"."
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/sql/types.py#L944-L955" |
"apache/spark" | "python/pyspark/sql/types.py" | "_infer_type" | "def _infer_type(obj):
"""Infer the DataType from obj
"""
if obj is None:
return NullType()
if hasattr(obj, '__UDT__'):
return obj.__UDT__
dataType = _type_mappings.get(type(obj))
if dataType is DecimalType:
# the precision and scale of `obj` may be different from row to row.
return DecimalType(38, 18)
elif dataType is not None:
return dataType()
if isinstance(obj, dict):
for key, value in obj.items():
if key is not None and value is not None:
return MapType(_infer_type(key), _infer_type(value), True)
return MapType(NullType(), NullType(), True)
elif isinstance(obj, list):
for v in obj:
if v is not None:
return ArrayType(_infer_type(obj[0]), True)
return ArrayType(NullType(), True)
elif isinstance(obj, array):
if obj.typecode in _array_type_mappings:
return ArrayType(_array_type_mappings[obj.typecode](), False)
else:
raise TypeError("not supported type: array(%s)" % obj.typecode)
else:
try:
return _infer_schema(obj)
except TypeError:
raise TypeError("not supported type: %s" % type(obj))" | "python" | "def _infer_type(obj):
"""Infer the DataType from obj
"""
if obj is None:
return NullType()
if hasattr(obj, '__UDT__'):
return obj.__UDT__
dataType = _type_mappings.get(type(obj))
if dataType is DecimalType:
# the precision and scale of `obj` may be different from row to row.
return DecimalType(38, 18)
elif dataType is not None:
return dataType()
if isinstance(obj, dict):
for key, value in obj.items():
if key is not None and value is not None:
return MapType(_infer_type(key), _infer_type(value), True)
return MapType(NullType(), NullType(), True)
elif isinstance(obj, list):
for v in obj:
if v is not None:
return ArrayType(_infer_type(obj[0]), True)
return ArrayType(NullType(), True)
elif isinstance(obj, array):
if obj.typecode in _array_type_mappings:
return ArrayType(_array_type_mappings[obj.typecode](), False)
else:
raise TypeError("not supported type: array(%s)" % obj.typecode)
else:
try:
return _infer_schema(obj)
except TypeError:
raise TypeError("not supported type: %s" % type(obj))" | [
"def",
"_infer_type",
"(",
"obj",
")",
":",
"if",
"obj",
"is",
"None",
":",
"return",
"NullType",
"(",
")",
"if",
"hasattr",
"(",
"obj",
",",
"'__UDT__'",
")",
":",
"return",
"obj",
".",
"__UDT__",
"dataType",
"=",
"_type_mappings",
".",
"get",
"(",
"type",
"(",
"obj",
")",
")",
"if",
"dataType",
"is",
"DecimalType",
":",
"# the precision and scale of `obj` may be different from row to row.",
"return",
"DecimalType",
"(",
"38",
",",
"18",
")",
"elif",
"dataType",
"is",
"not",
"None",
":",
"return",
"dataType",
"(",
")",
"if",
"isinstance",
"(",
"obj",
",",
"dict",
")",
":",
"for",
"key",
",",
"value",
"in",
"obj",
".",
"items",
"(",
")",
":",
"if",
"key",
"is",
"not",
"None",
"and",
"value",
"is",
"not",
"None",
":",
"return",
"MapType",
"(",
"_infer_type",
"(",
"key",
")",
",",
"_infer_type",
"(",
"value",
")",
",",
"True",
")",
"return",
"MapType",
"(",
"NullType",
"(",
")",
",",
"NullType",
"(",
")",
",",
"True",
")",
"elif",
"isinstance",
"(",
"obj",
",",
"list",
")",
":",
"for",
"v",
"in",
"obj",
":",
"if",
"v",
"is",
"not",
"None",
":",
"return",
"ArrayType",
"(",
"_infer_type",
"(",
"obj",
"[",
"0",
"]",
")",
",",
"True",
")",
"return",
"ArrayType",
"(",
"NullType",
"(",
")",
",",
"True",
")",
"elif",
"isinstance",
"(",
"obj",
",",
"array",
")",
":",
"if",
"obj",
".",
"typecode",
"in",
"_array_type_mappings",
":",
"return",
"ArrayType",
"(",
"_array_type_mappings",
"[",
"obj",
".",
"typecode",
"]",
"(",
")",
",",
"False",
")",
"else",
":",
"raise",
"TypeError",
"(",
"\"not supported type: array(%s)\"",
"%",
"obj",
".",
"typecode",
")",
"else",
":",
"try",
":",
"return",
"_infer_schema",
"(",
"obj",
")",
"except",
"TypeError",
":",
"raise",
"TypeError",
"(",
"\"not supported type: %s\"",
"%",
"type",
"(",
"obj",
")",
")"
] | "Infer the DataType from obj" | [
"Infer",
"the",
"DataType",
"from",
"obj"
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/sql/types.py#L1003-L1038" |
"apache/spark" | "python/pyspark/sql/types.py" | "_infer_schema" | "def _infer_schema(row, names=None):
"""Infer the schema from dict/namedtuple/object"""
if isinstance(row, dict):
items = sorted(row.items())
elif isinstance(row, (tuple, list)):
if hasattr(row, "__fields__"): # Row
items = zip(row.__fields__, tuple(row))
elif hasattr(row, "_fields"): # namedtuple
items = zip(row._fields, tuple(row))
else:
if names is None:
names = ['_%d' % i for i in range(1, len(row) + 1)]
elif len(names) < len(row):
names.extend('_%d' % i for i in range(len(names) + 1, len(row) + 1))
items = zip(names, row)
elif hasattr(row, "__dict__"): # object
items = sorted(row.__dict__.items())
else:
raise TypeError("Can not infer schema for type: %s" % type(row))
fields = [StructField(k, _infer_type(v), True) for k, v in items]
return StructType(fields)" | "python" | "def _infer_schema(row, names=None):
"""Infer the schema from dict/namedtuple/object"""
if isinstance(row, dict):
items = sorted(row.items())
elif isinstance(row, (tuple, list)):
if hasattr(row, "__fields__"): # Row
items = zip(row.__fields__, tuple(row))
elif hasattr(row, "_fields"): # namedtuple
items = zip(row._fields, tuple(row))
else:
if names is None:
names = ['_%d' % i for i in range(1, len(row) + 1)]
elif len(names) < len(row):
names.extend('_%d' % i for i in range(len(names) + 1, len(row) + 1))
items = zip(names, row)
elif hasattr(row, "__dict__"): # object
items = sorted(row.__dict__.items())
else:
raise TypeError("Can not infer schema for type: %s" % type(row))
fields = [StructField(k, _infer_type(v), True) for k, v in items]
return StructType(fields)" | [
"def",
"_infer_schema",
"(",
"row",
",",
"names",
"=",
"None",
")",
":",
"if",
"isinstance",
"(",
"row",
",",
"dict",
")",
":",
"items",
"=",
"sorted",
"(",
"row",
".",
"items",
"(",
")",
")",
"elif",
"isinstance",
"(",
"row",
",",
"(",
"tuple",
",",
"list",
")",
")",
":",
"if",
"hasattr",
"(",
"row",
",",
"\"__fields__\"",
")",
":",
"# Row",
"items",
"=",
"zip",
"(",
"row",
".",
"__fields__",
",",
"tuple",
"(",
"row",
")",
")",
"elif",
"hasattr",
"(",
"row",
",",
"\"_fields\"",
")",
":",
"# namedtuple",
"items",
"=",
"zip",
"(",
"row",
".",
"_fields",
",",
"tuple",
"(",
"row",
")",
")",
"else",
":",
"if",
"names",
"is",
"None",
":",
"names",
"=",
"[",
"'_%d'",
"%",
"i",
"for",
"i",
"in",
"range",
"(",
"1",
",",
"len",
"(",
"row",
")",
"+",
"1",
")",
"]",
"elif",
"len",
"(",
"names",
")",
"<",
"len",
"(",
"row",
")",
":",
"names",
".",
"extend",
"(",
"'_%d'",
"%",
"i",
"for",
"i",
"in",
"range",
"(",
"len",
"(",
"names",
")",
"+",
"1",
",",
"len",
"(",
"row",
")",
"+",
"1",
")",
")",
"items",
"=",
"zip",
"(",
"names",
",",
"row",
")",
"elif",
"hasattr",
"(",
"row",
",",
"\"__dict__\"",
")",
":",
"# object",
"items",
"=",
"sorted",
"(",
"row",
".",
"__dict__",
".",
"items",
"(",
")",
")",
"else",
":",
"raise",
"TypeError",
"(",
"\"Can not infer schema for type: %s\"",
"%",
"type",
"(",
"row",
")",
")",
"fields",
"=",
"[",
"StructField",
"(",
"k",
",",
"_infer_type",
"(",
"v",
")",
",",
"True",
")",
"for",
"k",
",",
"v",
"in",
"items",
"]",
"return",
"StructType",
"(",
"fields",
")"
] | "Infer the schema from dict/namedtuple/object" | [
"Infer",
"the",
"schema",
"from",
"dict",
"/",
"namedtuple",
"/",
"object"
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/sql/types.py#L1041-L1065" |
"apache/spark" | "python/pyspark/sql/types.py" | "_has_nulltype" | "def _has_nulltype(dt):
""" Return whether there is NullType in `dt` or not """
if isinstance(dt, StructType):
return any(_has_nulltype(f.dataType) for f in dt.fields)
elif isinstance(dt, ArrayType):
return _has_nulltype((dt.elementType))
elif isinstance(dt, MapType):
return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType)
else:
return isinstance(dt, NullType)" | "python" | "def _has_nulltype(dt):
""" Return whether there is NullType in `dt` or not """
if isinstance(dt, StructType):
return any(_has_nulltype(f.dataType) for f in dt.fields)
elif isinstance(dt, ArrayType):
return _has_nulltype((dt.elementType))
elif isinstance(dt, MapType):
return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType)
else:
return isinstance(dt, NullType)" | [
"def",
"_has_nulltype",
"(",
"dt",
")",
":",
"if",
"isinstance",
"(",
"dt",
",",
"StructType",
")",
":",
"return",
"any",
"(",
"_has_nulltype",
"(",
"f",
".",
"dataType",
")",
"for",
"f",
"in",
"dt",
".",
"fields",
")",
"elif",
"isinstance",
"(",
"dt",
",",
"ArrayType",
")",
":",
"return",
"_has_nulltype",
"(",
"(",
"dt",
".",
"elementType",
")",
")",
"elif",
"isinstance",
"(",
"dt",
",",
"MapType",
")",
":",
"return",
"_has_nulltype",
"(",
"dt",
".",
"keyType",
")",
"or",
"_has_nulltype",
"(",
"dt",
".",
"valueType",
")",
"else",
":",
"return",
"isinstance",
"(",
"dt",
",",
"NullType",
")"
] | "Return whether there is NullType in `dt` or not" | [
"Return",
"whether",
"there",
"is",
"NullType",
"in",
"dt",
"or",
"not"
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/sql/types.py#L1068-L1077" |
"apache/spark" | "python/pyspark/sql/types.py" | "_create_converter" | "def _create_converter(dataType):
"""Create a converter to drop the names of fields in obj """
if not _need_converter(dataType):
return lambda x: x
if isinstance(dataType, ArrayType):
conv = _create_converter(dataType.elementType)
return lambda row: [conv(v) for v in row]
elif isinstance(dataType, MapType):
kconv = _create_converter(dataType.keyType)
vconv = _create_converter(dataType.valueType)
return lambda row: dict((kconv(k), vconv(v)) for k, v in row.items())
elif isinstance(dataType, NullType):
return lambda x: None
elif not isinstance(dataType, StructType):
return lambda x: x
# dataType must be StructType
names = [f.name for f in dataType.fields]
converters = [_create_converter(f.dataType) for f in dataType.fields]
convert_fields = any(_need_converter(f.dataType) for f in dataType.fields)
def convert_struct(obj):
if obj is None:
return
if isinstance(obj, (tuple, list)):
if convert_fields:
return tuple(conv(v) for v, conv in zip(obj, converters))
else:
return tuple(obj)
if isinstance(obj, dict):
d = obj
elif hasattr(obj, "__dict__"): # object
d = obj.__dict__
else:
raise TypeError("Unexpected obj type: %s" % type(obj))
if convert_fields:
return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
else:
return tuple([d.get(name) for name in names])
return convert_struct" | "python" | "def _create_converter(dataType):
"""Create a converter to drop the names of fields in obj """
if not _need_converter(dataType):
return lambda x: x
if isinstance(dataType, ArrayType):
conv = _create_converter(dataType.elementType)
return lambda row: [conv(v) for v in row]
elif isinstance(dataType, MapType):
kconv = _create_converter(dataType.keyType)
vconv = _create_converter(dataType.valueType)
return lambda row: dict((kconv(k), vconv(v)) for k, v in row.items())
elif isinstance(dataType, NullType):
return lambda x: None
elif not isinstance(dataType, StructType):
return lambda x: x
# dataType must be StructType
names = [f.name for f in dataType.fields]
converters = [_create_converter(f.dataType) for f in dataType.fields]
convert_fields = any(_need_converter(f.dataType) for f in dataType.fields)
def convert_struct(obj):
if obj is None:
return
if isinstance(obj, (tuple, list)):
if convert_fields:
return tuple(conv(v) for v, conv in zip(obj, converters))
else:
return tuple(obj)
if isinstance(obj, dict):
d = obj
elif hasattr(obj, "__dict__"): # object
d = obj.__dict__
else:
raise TypeError("Unexpected obj type: %s" % type(obj))
if convert_fields:
return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
else:
return tuple([d.get(name) for name in names])
return convert_struct" | [
"def",
"_create_converter",
"(",
"dataType",
")",
":",
"if",
"not",
"_need_converter",
"(",
"dataType",
")",
":",
"return",
"lambda",
"x",
":",
"x",
"if",
"isinstance",
"(",
"dataType",
",",
"ArrayType",
")",
":",
"conv",
"=",
"_create_converter",
"(",
"dataType",
".",
"elementType",
")",
"return",
"lambda",
"row",
":",
"[",
"conv",
"(",
"v",
")",
"for",
"v",
"in",
"row",
"]",
"elif",
"isinstance",
"(",
"dataType",
",",
"MapType",
")",
":",
"kconv",
"=",
"_create_converter",
"(",
"dataType",
".",
"keyType",
")",
"vconv",
"=",
"_create_converter",
"(",
"dataType",
".",
"valueType",
")",
"return",
"lambda",
"row",
":",
"dict",
"(",
"(",
"kconv",
"(",
"k",
")",
",",
"vconv",
"(",
"v",
")",
")",
"for",
"k",
",",
"v",
"in",
"row",
".",
"items",
"(",
")",
")",
"elif",
"isinstance",
"(",
"dataType",
",",
"NullType",
")",
":",
"return",
"lambda",
"x",
":",
"None",
"elif",
"not",
"isinstance",
"(",
"dataType",
",",
"StructType",
")",
":",
"return",
"lambda",
"x",
":",
"x",
"# dataType must be StructType",
"names",
"=",
"[",
"f",
".",
"name",
"for",
"f",
"in",
"dataType",
".",
"fields",
"]",
"converters",
"=",
"[",
"_create_converter",
"(",
"f",
".",
"dataType",
")",
"for",
"f",
"in",
"dataType",
".",
"fields",
"]",
"convert_fields",
"=",
"any",
"(",
"_need_converter",
"(",
"f",
".",
"dataType",
")",
"for",
"f",
"in",
"dataType",
".",
"fields",
")",
"def",
"convert_struct",
"(",
"obj",
")",
":",
"if",
"obj",
"is",
"None",
":",
"return",
"if",
"isinstance",
"(",
"obj",
",",
"(",
"tuple",
",",
"list",
")",
")",
":",
"if",
"convert_fields",
":",
"return",
"tuple",
"(",
"conv",
"(",
"v",
")",
"for",
"v",
",",
"conv",
"in",
"zip",
"(",
"obj",
",",
"converters",
")",
")",
"else",
":",
"return",
"tuple",
"(",
"obj",
")",
"if",
"isinstance",
"(",
"obj",
",",
"dict",
")",
":",
"d",
"=",
"obj",
"elif",
"hasattr",
"(",
"obj",
",",
"\"__dict__\"",
")",
":",
"# object",
"d",
"=",
"obj",
".",
"__dict__",
"else",
":",
"raise",
"TypeError",
"(",
"\"Unexpected obj type: %s\"",
"%",
"type",
"(",
"obj",
")",
")",
"if",
"convert_fields",
":",
"return",
"tuple",
"(",
"[",
"conv",
"(",
"d",
".",
"get",
"(",
"name",
")",
")",
"for",
"name",
",",
"conv",
"in",
"zip",
"(",
"names",
",",
"converters",
")",
"]",
")",
"else",
":",
"return",
"tuple",
"(",
"[",
"d",
".",
"get",
"(",
"name",
")",
"for",
"name",
"in",
"names",
"]",
")",
"return",
"convert_struct"
] | "Create a converter to drop the names of fields in obj" | [
"Create",
"a",
"converter",
"to",
"drop",
"the",
"names",
"of",
"fields",
"in",
"obj"
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/sql/types.py#L1133-L1180" |
"apache/spark" | "python/pyspark/sql/types.py" | "_make_type_verifier" | "def _make_type_verifier(dataType, nullable=True, name=None):
"""
Make a verifier that checks the type of obj against dataType and raises a TypeError if they do
not match.
This verifier also checks the value of obj against datatype and raises a ValueError if it's not
within the allowed range, e.g. using 128 as ByteType will overflow. Note that, Python float is
not checked, so it will become infinity when cast to Java float if it overflows.
>>> _make_type_verifier(StructType([]))(None)
>>> _make_type_verifier(StringType())("")
>>> _make_type_verifier(LongType())(0)
>>> _make_type_verifier(ArrayType(ShortType()))(list(range(3)))
>>> _make_type_verifier(ArrayType(StringType()))(set()) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
TypeError:...
>>> _make_type_verifier(MapType(StringType(), IntegerType()))({})
>>> _make_type_verifier(StructType([]))(())
>>> _make_type_verifier(StructType([]))([])
>>> _make_type_verifier(StructType([]))([1]) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
>>> # Check if numeric values are within the allowed range.
>>> _make_type_verifier(ByteType())(12)
>>> _make_type_verifier(ByteType())(1234) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
>>> _make_type_verifier(ByteType(), False)(None) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
>>> _make_type_verifier(
... ArrayType(ShortType(), False))([1, None]) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
>>> _make_type_verifier(MapType(StringType(), IntegerType()))({None: 1})
Traceback (most recent call last):
...
ValueError:...
>>> schema = StructType().add("a", IntegerType()).add("b", StringType(), False)
>>> _make_type_verifier(schema)((1, None)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
"""
if name is None:
new_msg = lambda msg: msg
new_name = lambda n: "field %s" % n
else:
new_msg = lambda msg: "%s: %s" % (name, msg)
new_name = lambda n: "field %s in %s" % (n, name)
def verify_nullability(obj):
if obj is None:
if nullable:
return True
else:
raise ValueError(new_msg("This field is not nullable, but got None"))
else:
return False
_type = type(dataType)
def assert_acceptable_types(obj):
assert _type in _acceptable_types, \
new_msg("unknown datatype: %s for object %r" % (dataType, obj))
def verify_acceptable_types(obj):
# subclass of them can not be fromInternal in JVM
if type(obj) not in _acceptable_types[_type]:
raise TypeError(new_msg("%s can not accept object %r in type %s"
% (dataType, obj, type(obj))))
if isinstance(dataType, StringType):
# StringType can work with any types
verify_value = lambda _: _
elif isinstance(dataType, UserDefinedType):
verifier = _make_type_verifier(dataType.sqlType(), name=name)
def verify_udf(obj):
if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType):
raise ValueError(new_msg("%r is not an instance of type %r" % (obj, dataType)))
verifier(dataType.toInternal(obj))
verify_value = verify_udf
elif isinstance(dataType, ByteType):
def verify_byte(obj):
assert_acceptable_types(obj)
verify_acceptable_types(obj)
if obj < -128 or obj > 127:
raise ValueError(new_msg("object of ByteType out of range, got: %s" % obj))
verify_value = verify_byte
elif isinstance(dataType, ShortType):
def verify_short(obj):
assert_acceptable_types(obj)
verify_acceptable_types(obj)
if obj < -32768 or obj > 32767:
raise ValueError(new_msg("object of ShortType out of range, got: %s" % obj))
verify_value = verify_short
elif isinstance(dataType, IntegerType):
def verify_integer(obj):
assert_acceptable_types(obj)
verify_acceptable_types(obj)
if obj < -2147483648 or obj > 2147483647:
raise ValueError(
new_msg("object of IntegerType out of range, got: %s" % obj))
verify_value = verify_integer
elif isinstance(dataType, ArrayType):
element_verifier = _make_type_verifier(
dataType.elementType, dataType.containsNull, name="element in array %s" % name)
def verify_array(obj):
assert_acceptable_types(obj)
verify_acceptable_types(obj)
for i in obj:
element_verifier(i)
verify_value = verify_array
elif isinstance(dataType, MapType):
key_verifier = _make_type_verifier(dataType.keyType, False, name="key of map %s" % name)
value_verifier = _make_type_verifier(
dataType.valueType, dataType.valueContainsNull, name="value of map %s" % name)
def verify_map(obj):
assert_acceptable_types(obj)
verify_acceptable_types(obj)
for k, v in obj.items():
key_verifier(k)
value_verifier(v)
verify_value = verify_map
elif isinstance(dataType, StructType):
verifiers = []
for f in dataType.fields:
verifier = _make_type_verifier(f.dataType, f.nullable, name=new_name(f.name))
verifiers.append((f.name, verifier))
def verify_struct(obj):
assert_acceptable_types(obj)
if isinstance(obj, dict):
for f, verifier in verifiers:
verifier(obj.get(f))
elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False):
# the order in obj could be different than dataType.fields
for f, verifier in verifiers:
verifier(obj[f])
elif isinstance(obj, (tuple, list)):
if len(obj) != len(verifiers):
raise ValueError(
new_msg("Length of object (%d) does not match with "
"length of fields (%d)" % (len(obj), len(verifiers))))
for v, (_, verifier) in zip(obj, verifiers):
verifier(v)
elif hasattr(obj, "__dict__"):
d = obj.__dict__
for f, verifier in verifiers:
verifier(d.get(f))
else:
raise TypeError(new_msg("StructType can not accept object %r in type %s"
% (obj, type(obj))))
verify_value = verify_struct
else:
def verify_default(obj):
assert_acceptable_types(obj)
verify_acceptable_types(obj)
verify_value = verify_default
def verify(obj):
if not verify_nullability(obj):
verify_value(obj)
return verify" | "python" | "def _make_type_verifier(dataType, nullable=True, name=None):
"""
Make a verifier that checks the type of obj against dataType and raises a TypeError if they do
not match.
This verifier also checks the value of obj against datatype and raises a ValueError if it's not
within the allowed range, e.g. using 128 as ByteType will overflow. Note that, Python float is
not checked, so it will become infinity when cast to Java float if it overflows.
>>> _make_type_verifier(StructType([]))(None)
>>> _make_type_verifier(StringType())("")
>>> _make_type_verifier(LongType())(0)
>>> _make_type_verifier(ArrayType(ShortType()))(list(range(3)))
>>> _make_type_verifier(ArrayType(StringType()))(set()) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
TypeError:...
>>> _make_type_verifier(MapType(StringType(), IntegerType()))({})
>>> _make_type_verifier(StructType([]))(())
>>> _make_type_verifier(StructType([]))([])
>>> _make_type_verifier(StructType([]))([1]) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
>>> # Check if numeric values are within the allowed range.
>>> _make_type_verifier(ByteType())(12)
>>> _make_type_verifier(ByteType())(1234) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
>>> _make_type_verifier(ByteType(), False)(None) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
>>> _make_type_verifier(
... ArrayType(ShortType(), False))([1, None]) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
>>> _make_type_verifier(MapType(StringType(), IntegerType()))({None: 1})
Traceback (most recent call last):
...
ValueError:...
>>> schema = StructType().add("a", IntegerType()).add("b", StringType(), False)
>>> _make_type_verifier(schema)((1, None)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
"""
if name is None:
new_msg = lambda msg: msg
new_name = lambda n: "field %s" % n
else:
new_msg = lambda msg: "%s: %s" % (name, msg)
new_name = lambda n: "field %s in %s" % (n, name)
def verify_nullability(obj):
if obj is None:
if nullable:
return True
else:
raise ValueError(new_msg("This field is not nullable, but got None"))
else:
return False
_type = type(dataType)
def assert_acceptable_types(obj):
assert _type in _acceptable_types, \
new_msg("unknown datatype: %s for object %r" % (dataType, obj))
def verify_acceptable_types(obj):
# subclass of them can not be fromInternal in JVM
if type(obj) not in _acceptable_types[_type]:
raise TypeError(new_msg("%s can not accept object %r in type %s"
% (dataType, obj, type(obj))))
if isinstance(dataType, StringType):
# StringType can work with any types
verify_value = lambda _: _
elif isinstance(dataType, UserDefinedType):
verifier = _make_type_verifier(dataType.sqlType(), name=name)
def verify_udf(obj):
if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType):
raise ValueError(new_msg("%r is not an instance of type %r" % (obj, dataType)))
verifier(dataType.toInternal(obj))
verify_value = verify_udf
elif isinstance(dataType, ByteType):
def verify_byte(obj):
assert_acceptable_types(obj)
verify_acceptable_types(obj)
if obj < -128 or obj > 127:
raise ValueError(new_msg("object of ByteType out of range, got: %s" % obj))
verify_value = verify_byte
elif isinstance(dataType, ShortType):
def verify_short(obj):
assert_acceptable_types(obj)
verify_acceptable_types(obj)
if obj < -32768 or obj > 32767:
raise ValueError(new_msg("object of ShortType out of range, got: %s" % obj))
verify_value = verify_short
elif isinstance(dataType, IntegerType):
def verify_integer(obj):
assert_acceptable_types(obj)
verify_acceptable_types(obj)
if obj < -2147483648 or obj > 2147483647:
raise ValueError(
new_msg("object of IntegerType out of range, got: %s" % obj))
verify_value = verify_integer
elif isinstance(dataType, ArrayType):
element_verifier = _make_type_verifier(
dataType.elementType, dataType.containsNull, name="element in array %s" % name)
def verify_array(obj):
assert_acceptable_types(obj)
verify_acceptable_types(obj)
for i in obj:
element_verifier(i)
verify_value = verify_array
elif isinstance(dataType, MapType):
key_verifier = _make_type_verifier(dataType.keyType, False, name="key of map %s" % name)
value_verifier = _make_type_verifier(
dataType.valueType, dataType.valueContainsNull, name="value of map %s" % name)
def verify_map(obj):
assert_acceptable_types(obj)
verify_acceptable_types(obj)
for k, v in obj.items():
key_verifier(k)
value_verifier(v)
verify_value = verify_map
elif isinstance(dataType, StructType):
verifiers = []
for f in dataType.fields:
verifier = _make_type_verifier(f.dataType, f.nullable, name=new_name(f.name))
verifiers.append((f.name, verifier))
def verify_struct(obj):
assert_acceptable_types(obj)
if isinstance(obj, dict):
for f, verifier in verifiers:
verifier(obj.get(f))
elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False):
# the order in obj could be different than dataType.fields
for f, verifier in verifiers:
verifier(obj[f])
elif isinstance(obj, (tuple, list)):
if len(obj) != len(verifiers):
raise ValueError(
new_msg("Length of object (%d) does not match with "
"length of fields (%d)" % (len(obj), len(verifiers))))
for v, (_, verifier) in zip(obj, verifiers):
verifier(v)
elif hasattr(obj, "__dict__"):
d = obj.__dict__
for f, verifier in verifiers:
verifier(d.get(f))
else:
raise TypeError(new_msg("StructType can not accept object %r in type %s"
% (obj, type(obj))))
verify_value = verify_struct
else:
def verify_default(obj):
assert_acceptable_types(obj)
verify_acceptable_types(obj)
verify_value = verify_default
def verify(obj):
if not verify_nullability(obj):
verify_value(obj)
return verify" | [
"def",
"_make_type_verifier",
"(",
"dataType",
",",
"nullable",
"=",
"True",
",",
"name",
"=",
"None",
")",
":",
"if",
"name",
"is",
"None",
":",
"new_msg",
"=",
"lambda",
"msg",
":",
"msg",
"new_name",
"=",
"lambda",
"n",
":",
"\"field %s\"",
"%",
"n",
"else",
":",
"new_msg",
"=",
"lambda",
"msg",
":",
"\"%s: %s\"",
"%",
"(",
"name",
",",
"msg",
")",
"new_name",
"=",
"lambda",
"n",
":",
"\"field %s in %s\"",
"%",
"(",
"n",
",",
"name",
")",
"def",
"verify_nullability",
"(",
"obj",
")",
":",
"if",
"obj",
"is",
"None",
":",
"if",
"nullable",
":",
"return",
"True",
"else",
":",
"raise",
"ValueError",
"(",
"new_msg",
"(",
"\"This field is not nullable, but got None\"",
")",
")",
"else",
":",
"return",
"False",
"_type",
"=",
"type",
"(",
"dataType",
")",
"def",
"assert_acceptable_types",
"(",
"obj",
")",
":",
"assert",
"_type",
"in",
"_acceptable_types",
",",
"new_msg",
"(",
"\"unknown datatype: %s for object %r\"",
"%",
"(",
"dataType",
",",
"obj",
")",
")",
"def",
"verify_acceptable_types",
"(",
"obj",
")",
":",
"# subclass of them can not be fromInternal in JVM",
"if",
"type",
"(",
"obj",
")",
"not",
"in",
"_acceptable_types",
"[",
"_type",
"]",
":",
"raise",
"TypeError",
"(",
"new_msg",
"(",
"\"%s can not accept object %r in type %s\"",
"%",
"(",
"dataType",
",",
"obj",
",",
"type",
"(",
"obj",
")",
")",
")",
")",
"if",
"isinstance",
"(",
"dataType",
",",
"StringType",
")",
":",
"# StringType can work with any types",
"verify_value",
"=",
"lambda",
"_",
":",
"_",
"elif",
"isinstance",
"(",
"dataType",
",",
"UserDefinedType",
")",
":",
"verifier",
"=",
"_make_type_verifier",
"(",
"dataType",
".",
"sqlType",
"(",
")",
",",
"name",
"=",
"name",
")",
"def",
"verify_udf",
"(",
"obj",
")",
":",
"if",
"not",
"(",
"hasattr",
"(",
"obj",
",",
"'__UDT__'",
")",
"and",
"obj",
".",
"__UDT__",
"==",
"dataType",
")",
":",
"raise",
"ValueError",
"(",
"new_msg",
"(",
"\"%r is not an instance of type %r\"",
"%",
"(",
"obj",
",",
"dataType",
")",
")",
")",
"verifier",
"(",
"dataType",
".",
"toInternal",
"(",
"obj",
")",
")",
"verify_value",
"=",
"verify_udf",
"elif",
"isinstance",
"(",
"dataType",
",",
"ByteType",
")",
":",
"def",
"verify_byte",
"(",
"obj",
")",
":",
"assert_acceptable_types",
"(",
"obj",
")",
"verify_acceptable_types",
"(",
"obj",
")",
"if",
"obj",
"<",
"-",
"128",
"or",
"obj",
">",
"127",
":",
"raise",
"ValueError",
"(",
"new_msg",
"(",
"\"object of ByteType out of range, got: %s\"",
"%",
"obj",
")",
")",
"verify_value",
"=",
"verify_byte",
"elif",
"isinstance",
"(",
"dataType",
",",
"ShortType",
")",
":",
"def",
"verify_short",
"(",
"obj",
")",
":",
"assert_acceptable_types",
"(",
"obj",
")",
"verify_acceptable_types",
"(",
"obj",
")",
"if",
"obj",
"<",
"-",
"32768",
"or",
"obj",
">",
"32767",
":",
"raise",
"ValueError",
"(",
"new_msg",
"(",
"\"object of ShortType out of range, got: %s\"",
"%",
"obj",
")",
")",
"verify_value",
"=",
"verify_short",
"elif",
"isinstance",
"(",
"dataType",
",",
"IntegerType",
")",
":",
"def",
"verify_integer",
"(",
"obj",
")",
":",
"assert_acceptable_types",
"(",
"obj",
")",
"verify_acceptable_types",
"(",
"obj",
")",
"if",
"obj",
"<",
"-",
"2147483648",
"or",
"obj",
">",
"2147483647",
":",
"raise",
"ValueError",
"(",
"new_msg",
"(",
"\"object of IntegerType out of range, got: %s\"",
"%",
"obj",
")",
")",
"verify_value",
"=",
"verify_integer",
"elif",
"isinstance",
"(",
"dataType",
",",
"ArrayType",
")",
":",
"element_verifier",
"=",
"_make_type_verifier",
"(",
"dataType",
".",
"elementType",
",",
"dataType",
".",
"containsNull",
",",
"name",
"=",
"\"element in array %s\"",
"%",
"name",
")",
"def",
"verify_array",
"(",
"obj",
")",
":",
"assert_acceptable_types",
"(",
"obj",
")",
"verify_acceptable_types",
"(",
"obj",
")",
"for",
"i",
"in",
"obj",
":",
"element_verifier",
"(",
"i",
")",
"verify_value",
"=",
"verify_array",
"elif",
"isinstance",
"(",
"dataType",
",",
"MapType",
")",
":",
"key_verifier",
"=",
"_make_type_verifier",
"(",
"dataType",
".",
"keyType",
",",
"False",
",",
"name",
"=",
"\"key of map %s\"",
"%",
"name",
")",
"value_verifier",
"=",
"_make_type_verifier",
"(",
"dataType",
".",
"valueType",
",",
"dataType",
".",
"valueContainsNull",
",",
"name",
"=",
"\"value of map %s\"",
"%",
"name",
")",
"def",
"verify_map",
"(",
"obj",
")",
":",
"assert_acceptable_types",
"(",
"obj",
")",
"verify_acceptable_types",
"(",
"obj",
")",
"for",
"k",
",",
"v",
"in",
"obj",
".",
"items",
"(",
")",
":",
"key_verifier",
"(",
"k",
")",
"value_verifier",
"(",
"v",
")",
"verify_value",
"=",
"verify_map",
"elif",
"isinstance",
"(",
"dataType",
",",
"StructType",
")",
":",
"verifiers",
"=",
"[",
"]",
"for",
"f",
"in",
"dataType",
".",
"fields",
":",
"verifier",
"=",
"_make_type_verifier",
"(",
"f",
".",
"dataType",
",",
"f",
".",
"nullable",
",",
"name",
"=",
"new_name",
"(",
"f",
".",
"name",
")",
")",
"verifiers",
".",
"append",
"(",
"(",
"f",
".",
"name",
",",
"verifier",
")",
")",
"def",
"verify_struct",
"(",
"obj",
")",
":",
"assert_acceptable_types",
"(",
"obj",
")",
"if",
"isinstance",
"(",
"obj",
",",
"dict",
")",
":",
"for",
"f",
",",
"verifier",
"in",
"verifiers",
":",
"verifier",
"(",
"obj",
".",
"get",
"(",
"f",
")",
")",
"elif",
"isinstance",
"(",
"obj",
",",
"Row",
")",
"and",
"getattr",
"(",
"obj",
",",
"\"__from_dict__\"",
",",
"False",
")",
":",
"# the order in obj could be different than dataType.fields",
"for",
"f",
",",
"verifier",
"in",
"verifiers",
":",
"verifier",
"(",
"obj",
"[",
"f",
"]",
")",
"elif",
"isinstance",
"(",
"obj",
",",
"(",
"tuple",
",",
"list",
")",
")",
":",
"if",
"len",
"(",
"obj",
")",
"!=",
"len",
"(",
"verifiers",
")",
":",
"raise",
"ValueError",
"(",
"new_msg",
"(",
"\"Length of object (%d) does not match with \"",
"\"length of fields (%d)\"",
"%",
"(",
"len",
"(",
"obj",
")",
",",
"len",
"(",
"verifiers",
")",
")",
")",
")",
"for",
"v",
",",
"(",
"_",
",",
"verifier",
")",
"in",
"zip",
"(",
"obj",
",",
"verifiers",
")",
":",
"verifier",
"(",
"v",
")",
"elif",
"hasattr",
"(",
"obj",
",",
"\"__dict__\"",
")",
":",
"d",
"=",
"obj",
".",
"__dict__",
"for",
"f",
",",
"verifier",
"in",
"verifiers",
":",
"verifier",
"(",
"d",
".",
"get",
"(",
"f",
")",
")",
"else",
":",
"raise",
"TypeError",
"(",
"new_msg",
"(",
"\"StructType can not accept object %r in type %s\"",
"%",
"(",
"obj",
",",
"type",
"(",
"obj",
")",
")",
")",
")",
"verify_value",
"=",
"verify_struct",
"else",
":",
"def",
"verify_default",
"(",
"obj",
")",
":",
"assert_acceptable_types",
"(",
"obj",
")",
"verify_acceptable_types",
"(",
"obj",
")",
"verify_value",
"=",
"verify_default",
"def",
"verify",
"(",
"obj",
")",
":",
"if",
"not",
"verify_nullability",
"(",
"obj",
")",
":",
"verify_value",
"(",
"obj",
")",
"return",
"verify"
] | "Make a verifier that checks the type of obj against dataType and raises a TypeError if they do
not match.
This verifier also checks the value of obj against datatype and raises a ValueError if it's not
within the allowed range, e.g. using 128 as ByteType will overflow. Note that, Python float is
not checked, so it will become infinity when cast to Java float if it overflows.
>>> _make_type_verifier(StructType([]))(None)
>>> _make_type_verifier(StringType())("")
>>> _make_type_verifier(LongType())(0)
>>> _make_type_verifier(ArrayType(ShortType()))(list(range(3)))
>>> _make_type_verifier(ArrayType(StringType()))(set()) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
TypeError:...
>>> _make_type_verifier(MapType(StringType(), IntegerType()))({})
>>> _make_type_verifier(StructType([]))(())
>>> _make_type_verifier(StructType([]))([])
>>> _make_type_verifier(StructType([]))([1]) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
>>> # Check if numeric values are within the allowed range.
>>> _make_type_verifier(ByteType())(12)
>>> _make_type_verifier(ByteType())(1234) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
>>> _make_type_verifier(ByteType(), False)(None) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
>>> _make_type_verifier(
... ArrayType(ShortType(), False))([1, None]) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
>>> _make_type_verifier(MapType(StringType(), IntegerType()))({None: 1})
Traceback (most recent call last):
...
ValueError:...
>>> schema = StructType().add("a", IntegerType()).add("b", StringType(), False)
>>> _make_type_verifier(schema)((1, None)) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:..." | [
"Make",
"a",
"verifier",
"that",
"checks",
"the",
"type",
"of",
"obj",
"against",
"dataType",
"and",
"raises",
"a",
"TypeError",
"if",
"they",
"do",
"not",
"match",
"."
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/sql/types.py#L1202-L1391" |
"apache/spark" | "python/pyspark/sql/types.py" | "to_arrow_type" | "def to_arrow_type(dt):
""" Convert Spark data type to pyarrow type
"""
import pyarrow as pa
if type(dt) == BooleanType:
arrow_type = pa.bool_()
elif type(dt) == ByteType:
arrow_type = pa.int8()
elif type(dt) == ShortType:
arrow_type = pa.int16()
elif type(dt) == IntegerType:
arrow_type = pa.int32()
elif type(dt) == LongType:
arrow_type = pa.int64()
elif type(dt) == FloatType:
arrow_type = pa.float32()
elif type(dt) == DoubleType:
arrow_type = pa.float64()
elif type(dt) == DecimalType:
arrow_type = pa.decimal128(dt.precision, dt.scale)
elif type(dt) == StringType:
arrow_type = pa.string()
elif type(dt) == BinaryType:
arrow_type = pa.binary()
elif type(dt) == DateType:
arrow_type = pa.date32()
elif type(dt) == TimestampType:
# Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read
arrow_type = pa.timestamp('us', tz='UTC')
elif type(dt) == ArrayType:
if type(dt.elementType) in [StructType, TimestampType]:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
arrow_type = pa.list_(to_arrow_type(dt.elementType))
elif type(dt) == StructType:
if any(type(field.dataType) == StructType for field in dt):
raise TypeError("Nested StructType not supported in conversion to Arrow")
fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
for field in dt]
arrow_type = pa.struct(fields)
else:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
return arrow_type" | "python" | "def to_arrow_type(dt):
""" Convert Spark data type to pyarrow type
"""
import pyarrow as pa
if type(dt) == BooleanType:
arrow_type = pa.bool_()
elif type(dt) == ByteType:
arrow_type = pa.int8()
elif type(dt) == ShortType:
arrow_type = pa.int16()
elif type(dt) == IntegerType:
arrow_type = pa.int32()
elif type(dt) == LongType:
arrow_type = pa.int64()
elif type(dt) == FloatType:
arrow_type = pa.float32()
elif type(dt) == DoubleType:
arrow_type = pa.float64()
elif type(dt) == DecimalType:
arrow_type = pa.decimal128(dt.precision, dt.scale)
elif type(dt) == StringType:
arrow_type = pa.string()
elif type(dt) == BinaryType:
arrow_type = pa.binary()
elif type(dt) == DateType:
arrow_type = pa.date32()
elif type(dt) == TimestampType:
# Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read
arrow_type = pa.timestamp('us', tz='UTC')
elif type(dt) == ArrayType:
if type(dt.elementType) in [StructType, TimestampType]:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
arrow_type = pa.list_(to_arrow_type(dt.elementType))
elif type(dt) == StructType:
if any(type(field.dataType) == StructType for field in dt):
raise TypeError("Nested StructType not supported in conversion to Arrow")
fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
for field in dt]
arrow_type = pa.struct(fields)
else:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
return arrow_type" | [
"def",
"to_arrow_type",
"(",
"dt",
")",
":",
"import",
"pyarrow",
"as",
"pa",
"if",
"type",
"(",
"dt",
")",
"==",
"BooleanType",
":",
"arrow_type",
"=",
"pa",
".",
"bool_",
"(",
")",
"elif",
"type",
"(",
"dt",
")",
"==",
"ByteType",
":",
"arrow_type",
"=",
"pa",
".",
"int8",
"(",
")",
"elif",
"type",
"(",
"dt",
")",
"==",
"ShortType",
":",
"arrow_type",
"=",
"pa",
".",
"int16",
"(",
")",
"elif",
"type",
"(",
"dt",
")",
"==",
"IntegerType",
":",
"arrow_type",
"=",
"pa",
".",
"int32",
"(",
")",
"elif",
"type",
"(",
"dt",
")",
"==",
"LongType",
":",
"arrow_type",
"=",
"pa",
".",
"int64",
"(",
")",
"elif",
"type",
"(",
"dt",
")",
"==",
"FloatType",
":",
"arrow_type",
"=",
"pa",
".",
"float32",
"(",
")",
"elif",
"type",
"(",
"dt",
")",
"==",
"DoubleType",
":",
"arrow_type",
"=",
"pa",
".",
"float64",
"(",
")",
"elif",
"type",
"(",
"dt",
")",
"==",
"DecimalType",
":",
"arrow_type",
"=",
"pa",
".",
"decimal128",
"(",
"dt",
".",
"precision",
",",
"dt",
".",
"scale",
")",
"elif",
"type",
"(",
"dt",
")",
"==",
"StringType",
":",
"arrow_type",
"=",
"pa",
".",
"string",
"(",
")",
"elif",
"type",
"(",
"dt",
")",
"==",
"BinaryType",
":",
"arrow_type",
"=",
"pa",
".",
"binary",
"(",
")",
"elif",
"type",
"(",
"dt",
")",
"==",
"DateType",
":",
"arrow_type",
"=",
"pa",
".",
"date32",
"(",
")",
"elif",
"type",
"(",
"dt",
")",
"==",
"TimestampType",
":",
"# Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read",
"arrow_type",
"=",
"pa",
".",
"timestamp",
"(",
"'us'",
",",
"tz",
"=",
"'UTC'",
")",
"elif",
"type",
"(",
"dt",
")",
"==",
"ArrayType",
":",
"if",
"type",
"(",
"dt",
".",
"elementType",
")",
"in",
"[",
"StructType",
",",
"TimestampType",
"]",
":",
"raise",
"TypeError",
"(",
"\"Unsupported type in conversion to Arrow: \"",
"+",
"str",
"(",
"dt",
")",
")",
"arrow_type",
"=",
"pa",
".",
"list_",
"(",
"to_arrow_type",
"(",
"dt",
".",
"elementType",
")",
")",
"elif",
"type",
"(",
"dt",
")",
"==",
"StructType",
":",
"if",
"any",
"(",
"type",
"(",
"field",
".",
"dataType",
")",
"==",
"StructType",
"for",
"field",
"in",
"dt",
")",
":",
"raise",
"TypeError",
"(",
"\"Nested StructType not supported in conversion to Arrow\"",
")",
"fields",
"=",
"[",
"pa",
".",
"field",
"(",
"field",
".",
"name",
",",
"to_arrow_type",
"(",
"field",
".",
"dataType",
")",
",",
"nullable",
"=",
"field",
".",
"nullable",
")",
"for",
"field",
"in",
"dt",
"]",
"arrow_type",
"=",
"pa",
".",
"struct",
"(",
"fields",
")",
"else",
":",
"raise",
"TypeError",
"(",
"\"Unsupported type in conversion to Arrow: \"",
"+",
"str",
"(",
"dt",
")",
")",
"return",
"arrow_type"
] | "Convert Spark data type to pyarrow type" | [
"Convert",
"Spark",
"data",
"type",
"to",
"pyarrow",
"type"
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/sql/types.py#L1581-L1622" |
"apache/spark" | "python/pyspark/sql/types.py" | "to_arrow_schema" | "def to_arrow_schema(schema):
""" Convert a schema from Spark to Arrow
"""
import pyarrow as pa
fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
for field in schema]
return pa.schema(fields)" | "python" | "def to_arrow_schema(schema):
""" Convert a schema from Spark to Arrow
"""
import pyarrow as pa
fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
for field in schema]
return pa.schema(fields)" | [
"def",
"to_arrow_schema",
"(",
"schema",
")",
":",
"import",
"pyarrow",
"as",
"pa",
"fields",
"=",
"[",
"pa",
".",
"field",
"(",
"field",
".",
"name",
",",
"to_arrow_type",
"(",
"field",
".",
"dataType",
")",
",",
"nullable",
"=",
"field",
".",
"nullable",
")",
"for",
"field",
"in",
"schema",
"]",
"return",
"pa",
".",
"schema",
"(",
"fields",
")"
] | "Convert a schema from Spark to Arrow" | [
"Convert",
"a",
"schema",
"from",
"Spark",
"to",
"Arrow"
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/sql/types.py#L1625-L1631" |
"apache/spark" | "python/pyspark/sql/types.py" | "from_arrow_type" | "def from_arrow_type(at):
""" Convert pyarrow type to Spark data type.
"""
import pyarrow.types as types
if types.is_boolean(at):
spark_type = BooleanType()
elif types.is_int8(at):
spark_type = ByteType()
elif types.is_int16(at):
spark_type = ShortType()
elif types.is_int32(at):
spark_type = IntegerType()
elif types.is_int64(at):
spark_type = LongType()
elif types.is_float32(at):
spark_type = FloatType()
elif types.is_float64(at):
spark_type = DoubleType()
elif types.is_decimal(at):
spark_type = DecimalType(precision=at.precision, scale=at.scale)
elif types.is_string(at):
spark_type = StringType()
elif types.is_binary(at):
spark_type = BinaryType()
elif types.is_date32(at):
spark_type = DateType()
elif types.is_timestamp(at):
spark_type = TimestampType()
elif types.is_list(at):
if types.is_timestamp(at.value_type):
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
spark_type = ArrayType(from_arrow_type(at.value_type))
elif types.is_struct(at):
if any(types.is_struct(field.type) for field in at):
raise TypeError("Nested StructType not supported in conversion from Arrow: " + str(at))
return StructType(
[StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
for field in at])
else:
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
return spark_type" | "python" | "def from_arrow_type(at):
""" Convert pyarrow type to Spark data type.
"""
import pyarrow.types as types
if types.is_boolean(at):
spark_type = BooleanType()
elif types.is_int8(at):
spark_type = ByteType()
elif types.is_int16(at):
spark_type = ShortType()
elif types.is_int32(at):
spark_type = IntegerType()
elif types.is_int64(at):
spark_type = LongType()
elif types.is_float32(at):
spark_type = FloatType()
elif types.is_float64(at):
spark_type = DoubleType()
elif types.is_decimal(at):
spark_type = DecimalType(precision=at.precision, scale=at.scale)
elif types.is_string(at):
spark_type = StringType()
elif types.is_binary(at):
spark_type = BinaryType()
elif types.is_date32(at):
spark_type = DateType()
elif types.is_timestamp(at):
spark_type = TimestampType()
elif types.is_list(at):
if types.is_timestamp(at.value_type):
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
spark_type = ArrayType(from_arrow_type(at.value_type))
elif types.is_struct(at):
if any(types.is_struct(field.type) for field in at):
raise TypeError("Nested StructType not supported in conversion from Arrow: " + str(at))
return StructType(
[StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
for field in at])
else:
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
return spark_type" | [
"def",
"from_arrow_type",
"(",
"at",
")",
":",
"import",
"pyarrow",
".",
"types",
"as",
"types",
"if",
"types",
".",
"is_boolean",
"(",
"at",
")",
":",
"spark_type",
"=",
"BooleanType",
"(",
")",
"elif",
"types",
".",
"is_int8",
"(",
"at",
")",
":",
"spark_type",
"=",
"ByteType",
"(",
")",
"elif",
"types",
".",
"is_int16",
"(",
"at",
")",
":",
"spark_type",
"=",
"ShortType",
"(",
")",
"elif",
"types",
".",
"is_int32",
"(",
"at",
")",
":",
"spark_type",
"=",
"IntegerType",
"(",
")",
"elif",
"types",
".",
"is_int64",
"(",
"at",
")",
":",
"spark_type",
"=",
"LongType",
"(",
")",
"elif",
"types",
".",
"is_float32",
"(",
"at",
")",
":",
"spark_type",
"=",
"FloatType",
"(",
")",
"elif",
"types",
".",
"is_float64",
"(",
"at",
")",
":",
"spark_type",
"=",
"DoubleType",
"(",
")",
"elif",
"types",
".",
"is_decimal",
"(",
"at",
")",
":",
"spark_type",
"=",
"DecimalType",
"(",
"precision",
"=",
"at",
".",
"precision",
",",
"scale",
"=",
"at",
".",
"scale",
")",
"elif",
"types",
".",
"is_string",
"(",
"at",
")",
":",
"spark_type",
"=",
"StringType",
"(",
")",
"elif",
"types",
".",
"is_binary",
"(",
"at",
")",
":",
"spark_type",
"=",
"BinaryType",
"(",
")",
"elif",
"types",
".",
"is_date32",
"(",
"at",
")",
":",
"spark_type",
"=",
"DateType",
"(",
")",
"elif",
"types",
".",
"is_timestamp",
"(",
"at",
")",
":",
"spark_type",
"=",
"TimestampType",
"(",
")",
"elif",
"types",
".",
"is_list",
"(",
"at",
")",
":",
"if",
"types",
".",
"is_timestamp",
"(",
"at",
".",
"value_type",
")",
":",
"raise",
"TypeError",
"(",
"\"Unsupported type in conversion from Arrow: \"",
"+",
"str",
"(",
"at",
")",
")",
"spark_type",
"=",
"ArrayType",
"(",
"from_arrow_type",
"(",
"at",
".",
"value_type",
")",
")",
"elif",
"types",
".",
"is_struct",
"(",
"at",
")",
":",
"if",
"any",
"(",
"types",
".",
"is_struct",
"(",
"field",
".",
"type",
")",
"for",
"field",
"in",
"at",
")",
":",
"raise",
"TypeError",
"(",
"\"Nested StructType not supported in conversion from Arrow: \"",
"+",
"str",
"(",
"at",
")",
")",
"return",
"StructType",
"(",
"[",
"StructField",
"(",
"field",
".",
"name",
",",
"from_arrow_type",
"(",
"field",
".",
"type",
")",
",",
"nullable",
"=",
"field",
".",
"nullable",
")",
"for",
"field",
"in",
"at",
"]",
")",
"else",
":",
"raise",
"TypeError",
"(",
"\"Unsupported type in conversion from Arrow: \"",
"+",
"str",
"(",
"at",
")",
")",
"return",
"spark_type"
] | "Convert pyarrow type to Spark data type." | [
"Convert",
"pyarrow",
"type",
"to",
"Spark",
"data",
"type",
"."
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/sql/types.py#L1634-L1674" |
"apache/spark" | "python/pyspark/sql/types.py" | "from_arrow_schema" | "def from_arrow_schema(arrow_schema):
""" Convert schema from Arrow to Spark.
"""
return StructType(
[StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
for field in arrow_schema])" | "python" | "def from_arrow_schema(arrow_schema):
""" Convert schema from Arrow to Spark.
"""
return StructType(
[StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
for field in arrow_schema])" | [
"def",
"from_arrow_schema",
"(",
"arrow_schema",
")",
":",
"return",
"StructType",
"(",
"[",
"StructField",
"(",
"field",
".",
"name",
",",
"from_arrow_type",
"(",
"field",
".",
"type",
")",
",",
"nullable",
"=",
"field",
".",
"nullable",
")",
"for",
"field",
"in",
"arrow_schema",
"]",
")"
] | "Convert schema from Arrow to Spark." | [
"Convert",
"schema",
"from",
"Arrow",
"to",
"Spark",
"."
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/sql/types.py#L1677-L1682" |
"apache/spark" | "python/pyspark/sql/types.py" | "_check_series_localize_timestamps" | "def _check_series_localize_timestamps(s, timezone):
"""
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone.
If the input series is not a timestamp series, then the same series is returned. If the input
series is a timestamp series, then a converted series is returned.
:param s: pandas.Series
:param timezone: the timezone to convert. if None then use local timezone
:return pandas.Series that have been converted to tz-naive
"""
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()
from pandas.api.types import is_datetime64tz_dtype
tz = timezone or _get_local_timezone()
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64tz_dtype(s.dtype):
return s.dt.tz_convert(tz).dt.tz_localize(None)
else:
return s" | "python" | "def _check_series_localize_timestamps(s, timezone):
"""
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone.
If the input series is not a timestamp series, then the same series is returned. If the input
series is a timestamp series, then a converted series is returned.
:param s: pandas.Series
:param timezone: the timezone to convert. if None then use local timezone
:return pandas.Series that have been converted to tz-naive
"""
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()
from pandas.api.types import is_datetime64tz_dtype
tz = timezone or _get_local_timezone()
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64tz_dtype(s.dtype):
return s.dt.tz_convert(tz).dt.tz_localize(None)
else:
return s" | [
"def",
"_check_series_localize_timestamps",
"(",
"s",
",",
"timezone",
")",
":",
"from",
"pyspark",
".",
"sql",
".",
"utils",
"import",
"require_minimum_pandas_version",
"require_minimum_pandas_version",
"(",
")",
"from",
"pandas",
".",
"api",
".",
"types",
"import",
"is_datetime64tz_dtype",
"tz",
"=",
"timezone",
"or",
"_get_local_timezone",
"(",
")",
"# TODO: handle nested timestamps, such as ArrayType(TimestampType())?",
"if",
"is_datetime64tz_dtype",
"(",
"s",
".",
"dtype",
")",
":",
"return",
"s",
".",
"dt",
".",
"tz_convert",
"(",
"tz",
")",
".",
"dt",
".",
"tz_localize",
"(",
"None",
")",
"else",
":",
"return",
"s"
] | "Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone.
If the input series is not a timestamp series, then the same series is returned. If the input
series is a timestamp series, then a converted series is returned.
:param s: pandas.Series
:param timezone: the timezone to convert. if None then use local timezone
:return pandas.Series that have been converted to tz-naive" | [
"Convert",
"timezone",
"aware",
"timestamps",
"to",
"timezone",
"-",
"naive",
"in",
"the",
"specified",
"timezone",
"or",
"local",
"timezone",
"."
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/sql/types.py#L1700-L1720" |
"apache/spark" | "python/pyspark/sql/types.py" | "_check_dataframe_localize_timestamps" | "def _check_dataframe_localize_timestamps(pdf, timezone):
"""
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone
:param pdf: pandas.DataFrame
:param timezone: the timezone to convert. if None then use local timezone
:return pandas.DataFrame where any timezone aware columns have been converted to tz-naive
"""
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()
for column, series in pdf.iteritems():
pdf[column] = _check_series_localize_timestamps(series, timezone)
return pdf" | "python" | "def _check_dataframe_localize_timestamps(pdf, timezone):
"""
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone
:param pdf: pandas.DataFrame
:param timezone: the timezone to convert. if None then use local timezone
:return pandas.DataFrame where any timezone aware columns have been converted to tz-naive
"""
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()
for column, series in pdf.iteritems():
pdf[column] = _check_series_localize_timestamps(series, timezone)
return pdf" | [
"def",
"_check_dataframe_localize_timestamps",
"(",
"pdf",
",",
"timezone",
")",
":",
"from",
"pyspark",
".",
"sql",
".",
"utils",
"import",
"require_minimum_pandas_version",
"require_minimum_pandas_version",
"(",
")",
"for",
"column",
",",
"series",
"in",
"pdf",
".",
"iteritems",
"(",
")",
":",
"pdf",
"[",
"column",
"]",
"=",
"_check_series_localize_timestamps",
"(",
"series",
",",
"timezone",
")",
"return",
"pdf"
] | "Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone
:param pdf: pandas.DataFrame
:param timezone: the timezone to convert. if None then use local timezone
:return pandas.DataFrame where any timezone aware columns have been converted to tz-naive" | [
"Convert",
"timezone",
"aware",
"timestamps",
"to",
"timezone",
"-",
"naive",
"in",
"the",
"specified",
"timezone",
"or",
"local",
"timezone"
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/sql/types.py#L1723-L1736" |
"apache/spark" | "python/pyspark/sql/types.py" | "_check_series_convert_timestamps_internal" | "def _check_series_convert_timestamps_internal(s, timezone):
"""
Convert a tz-naive timestamp in the specified timezone or local timezone to UTC normalized for
Spark internal storage
:param s: a pandas.Series
:param timezone: the timezone to convert. if None then use local timezone
:return pandas.Series where if it is a timestamp, has been UTC normalized without a time zone
"""
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64_dtype(s.dtype):
# When tz_localize a tz-naive timestamp, the result is ambiguous if the tz-naive
# timestamp is during the hour when the clock is adjusted backward during due to
# daylight saving time (dst).
# E.g., for America/New_York, the clock is adjusted backward on 2015-11-01 2:00 to
# 2015-11-01 1:00 from dst-time to standard time, and therefore, when tz_localize
# a tz-naive timestamp 2015-11-01 1:30 with America/New_York timezone, it can be either
# dst time (2015-01-01 1:30-0400) or standard time (2015-11-01 1:30-0500).
#
# Here we explicit choose to use standard time. This matches the default behavior of
# pytz.
#
# Here are some code to help understand this behavior:
# >>> import datetime
# >>> import pandas as pd
# >>> import pytz
# >>>
# >>> t = datetime.datetime(2015, 11, 1, 1, 30)
# >>> ts = pd.Series([t])
# >>> tz = pytz.timezone('America/New_York')
# >>>
# >>> ts.dt.tz_localize(tz, ambiguous=True)
# 0 2015-11-01 01:30:00-04:00
# dtype: datetime64[ns, America/New_York]
# >>>
# >>> ts.dt.tz_localize(tz, ambiguous=False)
# 0 2015-11-01 01:30:00-05:00
# dtype: datetime64[ns, America/New_York]
# >>>
# >>> str(tz.localize(t))
# '2015-11-01 01:30:00-05:00'
tz = timezone or _get_local_timezone()
return s.dt.tz_localize(tz, ambiguous=False).dt.tz_convert('UTC')
elif is_datetime64tz_dtype(s.dtype):
return s.dt.tz_convert('UTC')
else:
return s" | "python" | "def _check_series_convert_timestamps_internal(s, timezone):
"""
Convert a tz-naive timestamp in the specified timezone or local timezone to UTC normalized for
Spark internal storage
:param s: a pandas.Series
:param timezone: the timezone to convert. if None then use local timezone
:return pandas.Series where if it is a timestamp, has been UTC normalized without a time zone
"""
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64_dtype(s.dtype):
# When tz_localize a tz-naive timestamp, the result is ambiguous if the tz-naive
# timestamp is during the hour when the clock is adjusted backward during due to
# daylight saving time (dst).
# E.g., for America/New_York, the clock is adjusted backward on 2015-11-01 2:00 to
# 2015-11-01 1:00 from dst-time to standard time, and therefore, when tz_localize
# a tz-naive timestamp 2015-11-01 1:30 with America/New_York timezone, it can be either
# dst time (2015-01-01 1:30-0400) or standard time (2015-11-01 1:30-0500).
#
# Here we explicit choose to use standard time. This matches the default behavior of
# pytz.
#
# Here are some code to help understand this behavior:
# >>> import datetime
# >>> import pandas as pd
# >>> import pytz
# >>>
# >>> t = datetime.datetime(2015, 11, 1, 1, 30)
# >>> ts = pd.Series([t])
# >>> tz = pytz.timezone('America/New_York')
# >>>
# >>> ts.dt.tz_localize(tz, ambiguous=True)
# 0 2015-11-01 01:30:00-04:00
# dtype: datetime64[ns, America/New_York]
# >>>
# >>> ts.dt.tz_localize(tz, ambiguous=False)
# 0 2015-11-01 01:30:00-05:00
# dtype: datetime64[ns, America/New_York]
# >>>
# >>> str(tz.localize(t))
# '2015-11-01 01:30:00-05:00'
tz = timezone or _get_local_timezone()
return s.dt.tz_localize(tz, ambiguous=False).dt.tz_convert('UTC')
elif is_datetime64tz_dtype(s.dtype):
return s.dt.tz_convert('UTC')
else:
return s" | [
"def",
"_check_series_convert_timestamps_internal",
"(",
"s",
",",
"timezone",
")",
":",
"from",
"pyspark",
".",
"sql",
".",
"utils",
"import",
"require_minimum_pandas_version",
"require_minimum_pandas_version",
"(",
")",
"from",
"pandas",
".",
"api",
".",
"types",
"import",
"is_datetime64_dtype",
",",
"is_datetime64tz_dtype",
"# TODO: handle nested timestamps, such as ArrayType(TimestampType())?",
"if",
"is_datetime64_dtype",
"(",
"s",
".",
"dtype",
")",
":",
"# When tz_localize a tz-naive timestamp, the result is ambiguous if the tz-naive",
"# timestamp is during the hour when the clock is adjusted backward during due to",
"# daylight saving time (dst).",
"# E.g., for America/New_York, the clock is adjusted backward on 2015-11-01 2:00 to",
"# 2015-11-01 1:00 from dst-time to standard time, and therefore, when tz_localize",
"# a tz-naive timestamp 2015-11-01 1:30 with America/New_York timezone, it can be either",
"# dst time (2015-01-01 1:30-0400) or standard time (2015-11-01 1:30-0500).",
"#",
"# Here we explicit choose to use standard time. This matches the default behavior of",
"# pytz.",
"#",
"# Here are some code to help understand this behavior:",
"# >>> import datetime",
"# >>> import pandas as pd",
"# >>> import pytz",
"# >>>",
"# >>> t = datetime.datetime(2015, 11, 1, 1, 30)",
"# >>> ts = pd.Series([t])",
"# >>> tz = pytz.timezone('America/New_York')",
"# >>>",
"# >>> ts.dt.tz_localize(tz, ambiguous=True)",
"# 0 2015-11-01 01:30:00-04:00",
"# dtype: datetime64[ns, America/New_York]",
"# >>>",
"# >>> ts.dt.tz_localize(tz, ambiguous=False)",
"# 0 2015-11-01 01:30:00-05:00",
"# dtype: datetime64[ns, America/New_York]",
"# >>>",
"# >>> str(tz.localize(t))",
"# '2015-11-01 01:30:00-05:00'",
"tz",
"=",
"timezone",
"or",
"_get_local_timezone",
"(",
")",
"return",
"s",
".",
"dt",
".",
"tz_localize",
"(",
"tz",
",",
"ambiguous",
"=",
"False",
")",
".",
"dt",
".",
"tz_convert",
"(",
"'UTC'",
")",
"elif",
"is_datetime64tz_dtype",
"(",
"s",
".",
"dtype",
")",
":",
"return",
"s",
".",
"dt",
".",
"tz_convert",
"(",
"'UTC'",
")",
"else",
":",
"return",
"s"
] | "Convert a tz-naive timestamp in the specified timezone or local timezone to UTC normalized for
Spark internal storage
:param s: a pandas.Series
:param timezone: the timezone to convert. if None then use local timezone
:return pandas.Series where if it is a timestamp, has been UTC normalized without a time zone" | [
"Convert",
"a",
"tz",
"-",
"naive",
"timestamp",
"in",
"the",
"specified",
"timezone",
"or",
"local",
"timezone",
"to",
"UTC",
"normalized",
"for",
"Spark",
"internal",
"storage"
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/sql/types.py#L1739-L1789" |
"apache/spark" | "python/pyspark/sql/types.py" | "_check_series_convert_timestamps_localize" | "def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone):
"""
Convert timestamp to timezone-naive in the specified timezone or local timezone
:param s: a pandas.Series
:param from_timezone: the timezone to convert from. if None then use local timezone
:param to_timezone: the timezone to convert to. if None then use local timezone
:return pandas.Series where if it is a timestamp, has been converted to tz-naive
"""
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()
import pandas as pd
from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype
from_tz = from_timezone or _get_local_timezone()
to_tz = to_timezone or _get_local_timezone()
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64tz_dtype(s.dtype):
return s.dt.tz_convert(to_tz).dt.tz_localize(None)
elif is_datetime64_dtype(s.dtype) and from_tz != to_tz:
# `s.dt.tz_localize('tzlocal()')` doesn't work properly when including NaT.
return s.apply(
lambda ts: ts.tz_localize(from_tz, ambiguous=False).tz_convert(to_tz).tz_localize(None)
if ts is not pd.NaT else pd.NaT)
else:
return s" | "python" | "def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone):
"""
Convert timestamp to timezone-naive in the specified timezone or local timezone
:param s: a pandas.Series
:param from_timezone: the timezone to convert from. if None then use local timezone
:param to_timezone: the timezone to convert to. if None then use local timezone
:return pandas.Series where if it is a timestamp, has been converted to tz-naive
"""
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()
import pandas as pd
from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype
from_tz = from_timezone or _get_local_timezone()
to_tz = to_timezone or _get_local_timezone()
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if is_datetime64tz_dtype(s.dtype):
return s.dt.tz_convert(to_tz).dt.tz_localize(None)
elif is_datetime64_dtype(s.dtype) and from_tz != to_tz:
# `s.dt.tz_localize('tzlocal()')` doesn't work properly when including NaT.
return s.apply(
lambda ts: ts.tz_localize(from_tz, ambiguous=False).tz_convert(to_tz).tz_localize(None)
if ts is not pd.NaT else pd.NaT)
else:
return s" | [
"def",
"_check_series_convert_timestamps_localize",
"(",
"s",
",",
"from_timezone",
",",
"to_timezone",
")",
":",
"from",
"pyspark",
".",
"sql",
".",
"utils",
"import",
"require_minimum_pandas_version",
"require_minimum_pandas_version",
"(",
")",
"import",
"pandas",
"as",
"pd",
"from",
"pandas",
".",
"api",
".",
"types",
"import",
"is_datetime64tz_dtype",
",",
"is_datetime64_dtype",
"from_tz",
"=",
"from_timezone",
"or",
"_get_local_timezone",
"(",
")",
"to_tz",
"=",
"to_timezone",
"or",
"_get_local_timezone",
"(",
")",
"# TODO: handle nested timestamps, such as ArrayType(TimestampType())?",
"if",
"is_datetime64tz_dtype",
"(",
"s",
".",
"dtype",
")",
":",
"return",
"s",
".",
"dt",
".",
"tz_convert",
"(",
"to_tz",
")",
".",
"dt",
".",
"tz_localize",
"(",
"None",
")",
"elif",
"is_datetime64_dtype",
"(",
"s",
".",
"dtype",
")",
"and",
"from_tz",
"!=",
"to_tz",
":",
"# `s.dt.tz_localize('tzlocal()')` doesn't work properly when including NaT.",
"return",
"s",
".",
"apply",
"(",
"lambda",
"ts",
":",
"ts",
".",
"tz_localize",
"(",
"from_tz",
",",
"ambiguous",
"=",
"False",
")",
".",
"tz_convert",
"(",
"to_tz",
")",
".",
"tz_localize",
"(",
"None",
")",
"if",
"ts",
"is",
"not",
"pd",
".",
"NaT",
"else",
"pd",
".",
"NaT",
")",
"else",
":",
"return",
"s"
] | "Convert timestamp to timezone-naive in the specified timezone or local timezone
:param s: a pandas.Series
:param from_timezone: the timezone to convert from. if None then use local timezone
:param to_timezone: the timezone to convert to. if None then use local timezone
:return pandas.Series where if it is a timestamp, has been converted to tz-naive" | [
"Convert",
"timestamp",
"to",
"timezone",
"-",
"naive",
"in",
"the",
"specified",
"timezone",
"or",
"local",
"timezone"
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/sql/types.py#L1792-L1817" |
"apache/spark" | "python/pyspark/sql/types.py" | "StructType.add" | "def add(self, field, data_type=None, nullable=True, metadata=None):
"""
Construct a StructType by adding new elements to it to define the schema. The method accepts
either:
a) A single parameter which is a StructField object.
b) Between 2 and 4 parameters as (name, data_type, nullable (optional),
metadata(optional). The data_type parameter may be either a String or a
DataType object.
>>> struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
>>> struct2 = StructType([StructField("f1", StringType(), True), \\
... StructField("f2", StringType(), True, None)])
>>> struct1 == struct2
True
>>> struct1 = StructType().add(StructField("f1", StringType(), True))
>>> struct2 = StructType([StructField("f1", StringType(), True)])
>>> struct1 == struct2
True
>>> struct1 = StructType().add("f1", "string", True)
>>> struct2 = StructType([StructField("f1", StringType(), True)])
>>> struct1 == struct2
True
:param field: Either the name of the field or a StructField object
:param data_type: If present, the DataType of the StructField to create
:param nullable: Whether the field to add should be nullable (default True)
:param metadata: Any additional metadata (default None)
:return: a new updated StructType
"""
if isinstance(field, StructField):
self.fields.append(field)
self.names.append(field.name)
else:
if isinstance(field, str) and data_type is None:
raise ValueError("Must specify DataType if passing name of struct_field to create.")
if isinstance(data_type, str):
data_type_f = _parse_datatype_json_value(data_type)
else:
data_type_f = data_type
self.fields.append(StructField(field, data_type_f, nullable, metadata))
self.names.append(field)
# Precalculated list of fields that need conversion with fromInternal/toInternal functions
self._needConversion = [f.needConversion() for f in self]
self._needSerializeAnyField = any(self._needConversion)
return self" | "python" | "def add(self, field, data_type=None, nullable=True, metadata=None):
"""
Construct a StructType by adding new elements to it to define the schema. The method accepts
either:
a) A single parameter which is a StructField object.
b) Between 2 and 4 parameters as (name, data_type, nullable (optional),
metadata(optional). The data_type parameter may be either a String or a
DataType object.
>>> struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
>>> struct2 = StructType([StructField("f1", StringType(), True), \\
... StructField("f2", StringType(), True, None)])
>>> struct1 == struct2
True
>>> struct1 = StructType().add(StructField("f1", StringType(), True))
>>> struct2 = StructType([StructField("f1", StringType(), True)])
>>> struct1 == struct2
True
>>> struct1 = StructType().add("f1", "string", True)
>>> struct2 = StructType([StructField("f1", StringType(), True)])
>>> struct1 == struct2
True
:param field: Either the name of the field or a StructField object
:param data_type: If present, the DataType of the StructField to create
:param nullable: Whether the field to add should be nullable (default True)
:param metadata: Any additional metadata (default None)
:return: a new updated StructType
"""
if isinstance(field, StructField):
self.fields.append(field)
self.names.append(field.name)
else:
if isinstance(field, str) and data_type is None:
raise ValueError("Must specify DataType if passing name of struct_field to create.")
if isinstance(data_type, str):
data_type_f = _parse_datatype_json_value(data_type)
else:
data_type_f = data_type
self.fields.append(StructField(field, data_type_f, nullable, metadata))
self.names.append(field)
# Precalculated list of fields that need conversion with fromInternal/toInternal functions
self._needConversion = [f.needConversion() for f in self]
self._needSerializeAnyField = any(self._needConversion)
return self" | [
"def",
"add",
"(",
"self",
",",
"field",
",",
"data_type",
"=",
"None",
",",
"nullable",
"=",
"True",
",",
"metadata",
"=",
"None",
")",
":",
"if",
"isinstance",
"(",
"field",
",",
"StructField",
")",
":",
"self",
".",
"fields",
".",
"append",
"(",
"field",
")",
"self",
".",
"names",
".",
"append",
"(",
"field",
".",
"name",
")",
"else",
":",
"if",
"isinstance",
"(",
"field",
",",
"str",
")",
"and",
"data_type",
"is",
"None",
":",
"raise",
"ValueError",
"(",
"\"Must specify DataType if passing name of struct_field to create.\"",
")",
"if",
"isinstance",
"(",
"data_type",
",",
"str",
")",
":",
"data_type_f",
"=",
"_parse_datatype_json_value",
"(",
"data_type",
")",
"else",
":",
"data_type_f",
"=",
"data_type",
"self",
".",
"fields",
".",
"append",
"(",
"StructField",
"(",
"field",
",",
"data_type_f",
",",
"nullable",
",",
"metadata",
")",
")",
"self",
".",
"names",
".",
"append",
"(",
"field",
")",
"# Precalculated list of fields that need conversion with fromInternal/toInternal functions",
"self",
".",
"_needConversion",
"=",
"[",
"f",
".",
"needConversion",
"(",
")",
"for",
"f",
"in",
"self",
"]",
"self",
".",
"_needSerializeAnyField",
"=",
"any",
"(",
"self",
".",
"_needConversion",
")",
"return",
"self"
] | "Construct a StructType by adding new elements to it to define the schema. The method accepts
either:
a) A single parameter which is a StructField object.
b) Between 2 and 4 parameters as (name, data_type, nullable (optional),
metadata(optional). The data_type parameter may be either a String or a
DataType object.
>>> struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
>>> struct2 = StructType([StructField("f1", StringType(), True), \\
... StructField("f2", StringType(), True, None)])
>>> struct1 == struct2
True
>>> struct1 = StructType().add(StructField("f1", StringType(), True))
>>> struct2 = StructType([StructField("f1", StringType(), True)])
>>> struct1 == struct2
True
>>> struct1 = StructType().add("f1", "string", True)
>>> struct2 = StructType([StructField("f1", StringType(), True)])
>>> struct1 == struct2
True
:param field: Either the name of the field or a StructField object
:param data_type: If present, the DataType of the StructField to create
:param nullable: Whether the field to add should be nullable (default True)
:param metadata: Any additional metadata (default None)
:return: a new updated StructType" | [
"Construct",
"a",
"StructType",
"by",
"adding",
"new",
"elements",
"to",
"it",
"to",
"define",
"the",
"schema",
".",
"The",
"method",
"accepts",
"either",
":"
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/sql/types.py#L491-L537" |
"apache/spark" | "python/pyspark/sql/types.py" | "UserDefinedType._cachedSqlType" | "def _cachedSqlType(cls):
"""
Cache the sqlType() into class, because it's heavy used in `toInternal`.
"""
if not hasattr(cls, "_cached_sql_type"):
cls._cached_sql_type = cls.sqlType()
return cls._cached_sql_type" | "python" | "def _cachedSqlType(cls):
"""
Cache the sqlType() into class, because it's heavy used in `toInternal`.
"""
if not hasattr(cls, "_cached_sql_type"):
cls._cached_sql_type = cls.sqlType()
return cls._cached_sql_type" | [
"def",
"_cachedSqlType",
"(",
"cls",
")",
":",
"if",
"not",
"hasattr",
"(",
"cls",
",",
"\"_cached_sql_type\"",
")",
":",
"cls",
".",
"_cached_sql_type",
"=",
"cls",
".",
"sqlType",
"(",
")",
"return",
"cls",
".",
"_cached_sql_type"
] | "Cache the sqlType() into class, because it's heavy used in `toInternal`." | [
"Cache",
"the",
"sqlType",
"()",
"into",
"class",
"because",
"it",
"s",
"heavy",
"used",
"in",
"toInternal",
"."
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/sql/types.py#L675-L681" |
"apache/spark" | "python/pyspark/sql/types.py" | "Row.asDict" | "def asDict(self, recursive=False):
"""
Return as an dict
:param recursive: turns the nested Row as dict (default: False).
>>> Row(name="Alice", age=11).asDict() == {'name': 'Alice', 'age': 11}
True
>>> row = Row(key=1, value=Row(name='a', age=2))
>>> row.asDict() == {'key': 1, 'value': Row(age=2, name='a')}
True
>>> row.asDict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}}
True
"""
if not hasattr(self, "__fields__"):
raise TypeError("Cannot convert a Row class into dict")
if recursive:
def conv(obj):
if isinstance(obj, Row):
return obj.asDict(True)
elif isinstance(obj, list):
return [conv(o) for o in obj]
elif isinstance(obj, dict):
return dict((k, conv(v)) for k, v in obj.items())
else:
return obj
return dict(zip(self.__fields__, (conv(o) for o in self)))
else:
return dict(zip(self.__fields__, self))" | "python" | "def asDict(self, recursive=False):
"""
Return as an dict
:param recursive: turns the nested Row as dict (default: False).
>>> Row(name="Alice", age=11).asDict() == {'name': 'Alice', 'age': 11}
True
>>> row = Row(key=1, value=Row(name='a', age=2))
>>> row.asDict() == {'key': 1, 'value': Row(age=2, name='a')}
True
>>> row.asDict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}}
True
"""
if not hasattr(self, "__fields__"):
raise TypeError("Cannot convert a Row class into dict")
if recursive:
def conv(obj):
if isinstance(obj, Row):
return obj.asDict(True)
elif isinstance(obj, list):
return [conv(o) for o in obj]
elif isinstance(obj, dict):
return dict((k, conv(v)) for k, v in obj.items())
else:
return obj
return dict(zip(self.__fields__, (conv(o) for o in self)))
else:
return dict(zip(self.__fields__, self))" | [
"def",
"asDict",
"(",
"self",
",",
"recursive",
"=",
"False",
")",
":",
"if",
"not",
"hasattr",
"(",
"self",
",",
"\"__fields__\"",
")",
":",
"raise",
"TypeError",
"(",
"\"Cannot convert a Row class into dict\"",
")",
"if",
"recursive",
":",
"def",
"conv",
"(",
"obj",
")",
":",
"if",
"isinstance",
"(",
"obj",
",",
"Row",
")",
":",
"return",
"obj",
".",
"asDict",
"(",
"True",
")",
"elif",
"isinstance",
"(",
"obj",
",",
"list",
")",
":",
"return",
"[",
"conv",
"(",
"o",
")",
"for",
"o",
"in",
"obj",
"]",
"elif",
"isinstance",
"(",
"obj",
",",
"dict",
")",
":",
"return",
"dict",
"(",
"(",
"k",
",",
"conv",
"(",
"v",
")",
")",
"for",
"k",
",",
"v",
"in",
"obj",
".",
"items",
"(",
")",
")",
"else",
":",
"return",
"obj",
"return",
"dict",
"(",
"zip",
"(",
"self",
".",
"__fields__",
",",
"(",
"conv",
"(",
"o",
")",
"for",
"o",
"in",
"self",
")",
")",
")",
"else",
":",
"return",
"dict",
"(",
"zip",
"(",
"self",
".",
"__fields__",
",",
"self",
")",
")"
] | "Return as an dict
:param recursive: turns the nested Row as dict (default: False).
>>> Row(name="Alice", age=11).asDict() == {'name': 'Alice', 'age': 11}
True
>>> row = Row(key=1, value=Row(name='a', age=2))
>>> row.asDict() == {'key': 1, 'value': Row(age=2, name='a')}
True
>>> row.asDict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}}
True" | [
"Return",
"as",
"an",
"dict"
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/sql/types.py#L1463-L1492" |
"apache/spark" | "python/pyspark/ml/regression.py" | "LinearRegressionModel.summary" | "def summary(self):
"""
Gets summary (e.g. residuals, mse, r-squared ) of model on
training set. An exception is thrown if
`trainingSummary is None`.
"""
if self.hasSummary:
return LinearRegressionTrainingSummary(super(LinearRegressionModel, self).summary)
else:
raise RuntimeError("No training summary available for this %s" %
self.__class__.__name__)" | "python" | "def summary(self):
"""
Gets summary (e.g. residuals, mse, r-squared ) of model on
training set. An exception is thrown if
`trainingSummary is None`.
"""
if self.hasSummary:
return LinearRegressionTrainingSummary(super(LinearRegressionModel, self).summary)
else:
raise RuntimeError("No training summary available for this %s" %
self.__class__.__name__)" | [
"def",
"summary",
"(",
"self",
")",
":",
"if",
"self",
".",
"hasSummary",
":",
"return",
"LinearRegressionTrainingSummary",
"(",
"super",
"(",
"LinearRegressionModel",
",",
"self",
")",
".",
"summary",
")",
"else",
":",
"raise",
"RuntimeError",
"(",
"\"No training summary available for this %s\"",
"%",
"self",
".",
"__class__",
".",
"__name__",
")"
] | "Gets summary (e.g. residuals, mse, r-squared ) of model on
training set. An exception is thrown if
`trainingSummary is None`." | [
"Gets",
"summary",
"(",
"e",
".",
"g",
".",
"residuals",
"mse",
"r",
"-",
"squared",
")",
"of",
"model",
"on",
"training",
"set",
".",
"An",
"exception",
"is",
"thrown",
"if",
"trainingSummary",
"is",
"None",
"."
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/ml/regression.py#L198-L208" |
"apache/spark" | "python/pyspark/ml/regression.py" | "LinearRegressionModel.evaluate" | "def evaluate(self, dataset):
"""
Evaluates the model on a test dataset.
:param dataset:
Test dataset to evaluate model on, where dataset is an
instance of :py:class:`pyspark.sql.DataFrame`
"""
if not isinstance(dataset, DataFrame):
raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
java_lr_summary = self._call_java("evaluate", dataset)
return LinearRegressionSummary(java_lr_summary)" | "python" | "def evaluate(self, dataset):
"""
Evaluates the model on a test dataset.
:param dataset:
Test dataset to evaluate model on, where dataset is an
instance of :py:class:`pyspark.sql.DataFrame`
"""
if not isinstance(dataset, DataFrame):
raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
java_lr_summary = self._call_java("evaluate", dataset)
return LinearRegressionSummary(java_lr_summary)" | [
"def",
"evaluate",
"(",
"self",
",",
"dataset",
")",
":",
"if",
"not",
"isinstance",
"(",
"dataset",
",",
"DataFrame",
")",
":",
"raise",
"ValueError",
"(",
"\"dataset must be a DataFrame but got %s.\"",
"%",
"type",
"(",
"dataset",
")",
")",
"java_lr_summary",
"=",
"self",
".",
"_call_java",
"(",
"\"evaluate\"",
",",
"dataset",
")",
"return",
"LinearRegressionSummary",
"(",
"java_lr_summary",
")"
] | "Evaluates the model on a test dataset.
:param dataset:
Test dataset to evaluate model on, where dataset is an
instance of :py:class:`pyspark.sql.DataFrame`" | [
"Evaluates",
"the",
"model",
"on",
"a",
"test",
"dataset",
"."
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/ml/regression.py#L211-L222" |
"apache/spark" | "python/pyspark/ml/regression.py" | "GeneralizedLinearRegressionModel.summary" | "def summary(self):
"""
Gets summary (e.g. residuals, deviance, pValues) of model on
training set. An exception is thrown if
`trainingSummary is None`.
"""
if self.hasSummary:
return GeneralizedLinearRegressionTrainingSummary(
super(GeneralizedLinearRegressionModel, self).summary)
else:
raise RuntimeError("No training summary available for this %s" %
self.__class__.__name__)" | "python" | "def summary(self):
"""
Gets summary (e.g. residuals, deviance, pValues) of model on
training set. An exception is thrown if
`trainingSummary is None`.
"""
if self.hasSummary:
return GeneralizedLinearRegressionTrainingSummary(
super(GeneralizedLinearRegressionModel, self).summary)
else:
raise RuntimeError("No training summary available for this %s" %
self.__class__.__name__)" | [
"def",
"summary",
"(",
"self",
")",
":",
"if",
"self",
".",
"hasSummary",
":",
"return",
"GeneralizedLinearRegressionTrainingSummary",
"(",
"super",
"(",
"GeneralizedLinearRegressionModel",
",",
"self",
")",
".",
"summary",
")",
"else",
":",
"raise",
"RuntimeError",
"(",
"\"No training summary available for this %s\"",
"%",
"self",
".",
"__class__",
".",
"__name__",
")"
] | "Gets summary (e.g. residuals, deviance, pValues) of model on
training set. An exception is thrown if
`trainingSummary is None`." | [
"Gets",
"summary",
"(",
"e",
".",
"g",
".",
"residuals",
"deviance",
"pValues",
")",
"of",
"model",
"on",
"training",
"set",
".",
"An",
"exception",
"is",
"thrown",
"if",
"trainingSummary",
"is",
"None",
"."
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/ml/regression.py#L1679-L1690" |
"apache/spark" | "python/pyspark/ml/regression.py" | "GeneralizedLinearRegressionModel.evaluate" | "def evaluate(self, dataset):
"""
Evaluates the model on a test dataset.
:param dataset:
Test dataset to evaluate model on, where dataset is an
instance of :py:class:`pyspark.sql.DataFrame`
"""
if not isinstance(dataset, DataFrame):
raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
java_glr_summary = self._call_java("evaluate", dataset)
return GeneralizedLinearRegressionSummary(java_glr_summary)" | "python" | "def evaluate(self, dataset):
"""
Evaluates the model on a test dataset.
:param dataset:
Test dataset to evaluate model on, where dataset is an
instance of :py:class:`pyspark.sql.DataFrame`
"""
if not isinstance(dataset, DataFrame):
raise ValueError("dataset must be a DataFrame but got %s." % type(dataset))
java_glr_summary = self._call_java("evaluate", dataset)
return GeneralizedLinearRegressionSummary(java_glr_summary)" | [
"def",
"evaluate",
"(",
"self",
",",
"dataset",
")",
":",
"if",
"not",
"isinstance",
"(",
"dataset",
",",
"DataFrame",
")",
":",
"raise",
"ValueError",
"(",
"\"dataset must be a DataFrame but got %s.\"",
"%",
"type",
"(",
"dataset",
")",
")",
"java_glr_summary",
"=",
"self",
".",
"_call_java",
"(",
"\"evaluate\"",
",",
"dataset",
")",
"return",
"GeneralizedLinearRegressionSummary",
"(",
"java_glr_summary",
")"
] | "Evaluates the model on a test dataset.
:param dataset:
Test dataset to evaluate model on, where dataset is an
instance of :py:class:`pyspark.sql.DataFrame`" | [
"Evaluates",
"the",
"model",
"on",
"a",
"test",
"dataset",
"."
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/ml/regression.py#L1693-L1704" |
"apache/spark" | "python/pyspark/shuffle.py" | "_get_local_dirs" | "def _get_local_dirs(sub):
""" Get all the directories """
path = os.environ.get("SPARK_LOCAL_DIRS", "/tmp")
dirs = path.split(",")
if len(dirs) > 1:
# different order in different processes and instances
rnd = random.Random(os.getpid() + id(dirs))
random.shuffle(dirs, rnd.random)
return [os.path.join(d, "python", str(os.getpid()), sub) for d in dirs]" | "python" | "def _get_local_dirs(sub):
""" Get all the directories """
path = os.environ.get("SPARK_LOCAL_DIRS", "/tmp")
dirs = path.split(",")
if len(dirs) > 1:
# different order in different processes and instances
rnd = random.Random(os.getpid() + id(dirs))
random.shuffle(dirs, rnd.random)
return [os.path.join(d, "python", str(os.getpid()), sub) for d in dirs]" | [
"def",
"_get_local_dirs",
"(",
"sub",
")",
":",
"path",
"=",
"os",
".",
"environ",
".",
"get",
"(",
"\"SPARK_LOCAL_DIRS\"",
",",
"\"/tmp\"",
")",
"dirs",
"=",
"path",
".",
"split",
"(",
"\",\"",
")",
"if",
"len",
"(",
"dirs",
")",
">",
"1",
":",
"# different order in different processes and instances",
"rnd",
"=",
"random",
".",
"Random",
"(",
"os",
".",
"getpid",
"(",
")",
"+",
"id",
"(",
"dirs",
")",
")",
"random",
".",
"shuffle",
"(",
"dirs",
",",
"rnd",
".",
"random",
")",
"return",
"[",
"os",
".",
"path",
".",
"join",
"(",
"d",
",",
"\"python\"",
",",
"str",
"(",
"os",
".",
"getpid",
"(",
")",
")",
",",
"sub",
")",
"for",
"d",
"in",
"dirs",
"]"
] | "Get all the directories" | [
"Get",
"all",
"the",
"directories"
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/shuffle.py#L71-L79" |
"apache/spark" | "python/pyspark/shuffle.py" | "ExternalMerger._get_spill_dir" | "def _get_spill_dir(self, n):
""" Choose one directory for spill by number n """
return os.path.join(self.localdirs[n % len(self.localdirs)], str(n))" | "python" | "def _get_spill_dir(self, n):
""" Choose one directory for spill by number n """
return os.path.join(self.localdirs[n % len(self.localdirs)], str(n))" | [
"def",
"_get_spill_dir",
"(",
"self",
",",
"n",
")",
":",
"return",
"os",
".",
"path",
".",
"join",
"(",
"self",
".",
"localdirs",
"[",
"n",
"%",
"len",
"(",
"self",
".",
"localdirs",
")",
"]",
",",
"str",
"(",
"n",
")",
")"
] | "Choose one directory for spill by number n" | [
"Choose",
"one",
"directory",
"for",
"spill",
"by",
"number",
"n"
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/shuffle.py#L219-L221" |
"apache/spark" | "python/pyspark/shuffle.py" | "ExternalMerger.mergeValues" | "def mergeValues(self, iterator):
""" Combine the items by creator and combiner """
# speedup attribute lookup
creator, comb = self.agg.createCombiner, self.agg.mergeValue
c, data, pdata, hfun, batch = 0, self.data, self.pdata, self._partition, self.batch
limit = self.memory_limit
for k, v in iterator:
d = pdata[hfun(k)] if pdata else data
d[k] = comb(d[k], v) if k in d else creator(v)
c += 1
if c >= batch:
if get_used_memory() >= limit:
self._spill()
limit = self._next_limit()
batch /= 2
c = 0
else:
batch *= 1.5
if get_used_memory() >= limit:
self._spill()" | "python" | "def mergeValues(self, iterator):
""" Combine the items by creator and combiner """
# speedup attribute lookup
creator, comb = self.agg.createCombiner, self.agg.mergeValue
c, data, pdata, hfun, batch = 0, self.data, self.pdata, self._partition, self.batch
limit = self.memory_limit
for k, v in iterator:
d = pdata[hfun(k)] if pdata else data
d[k] = comb(d[k], v) if k in d else creator(v)
c += 1
if c >= batch:
if get_used_memory() >= limit:
self._spill()
limit = self._next_limit()
batch /= 2
c = 0
else:
batch *= 1.5
if get_used_memory() >= limit:
self._spill()" | [
"def",
"mergeValues",
"(",
"self",
",",
"iterator",
")",
":",
"# speedup attribute lookup",
"creator",
",",
"comb",
"=",
"self",
".",
"agg",
".",
"createCombiner",
",",
"self",
".",
"agg",
".",
"mergeValue",
"c",
",",
"data",
",",
"pdata",
",",
"hfun",
",",
"batch",
"=",
"0",
",",
"self",
".",
"data",
",",
"self",
".",
"pdata",
",",
"self",
".",
"_partition",
",",
"self",
".",
"batch",
"limit",
"=",
"self",
".",
"memory_limit",
"for",
"k",
",",
"v",
"in",
"iterator",
":",
"d",
"=",
"pdata",
"[",
"hfun",
"(",
"k",
")",
"]",
"if",
"pdata",
"else",
"data",
"d",
"[",
"k",
"]",
"=",
"comb",
"(",
"d",
"[",
"k",
"]",
",",
"v",
")",
"if",
"k",
"in",
"d",
"else",
"creator",
"(",
"v",
")",
"c",
"+=",
"1",
"if",
"c",
">=",
"batch",
":",
"if",
"get_used_memory",
"(",
")",
">=",
"limit",
":",
"self",
".",
"_spill",
"(",
")",
"limit",
"=",
"self",
".",
"_next_limit",
"(",
")",
"batch",
"/=",
"2",
"c",
"=",
"0",
"else",
":",
"batch",
"*=",
"1.5",
"if",
"get_used_memory",
"(",
")",
">=",
"limit",
":",
"self",
".",
"_spill",
"(",
")"
] | "Combine the items by creator and combiner" | [
"Combine",
"the",
"items",
"by",
"creator",
"and",
"combiner"
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/shuffle.py#L231-L253" |
"apache/spark" | "python/pyspark/shuffle.py" | "ExternalMerger.mergeCombiners" | "def mergeCombiners(self, iterator, limit=None):
""" Merge (K,V) pair by mergeCombiner """
if limit is None:
limit = self.memory_limit
# speedup attribute lookup
comb, hfun, objsize = self.agg.mergeCombiners, self._partition, self._object_size
c, data, pdata, batch = 0, self.data, self.pdata, self.batch
for k, v in iterator:
d = pdata[hfun(k)] if pdata else data
d[k] = comb(d[k], v) if k in d else v
if not limit:
continue
c += objsize(v)
if c > batch:
if get_used_memory() > limit:
self._spill()
limit = self._next_limit()
batch /= 2
c = 0
else:
batch *= 1.5
if limit and get_used_memory() >= limit:
self._spill()" | "python" | "def mergeCombiners(self, iterator, limit=None):
""" Merge (K,V) pair by mergeCombiner """
if limit is None:
limit = self.memory_limit
# speedup attribute lookup
comb, hfun, objsize = self.agg.mergeCombiners, self._partition, self._object_size
c, data, pdata, batch = 0, self.data, self.pdata, self.batch
for k, v in iterator:
d = pdata[hfun(k)] if pdata else data
d[k] = comb(d[k], v) if k in d else v
if not limit:
continue
c += objsize(v)
if c > batch:
if get_used_memory() > limit:
self._spill()
limit = self._next_limit()
batch /= 2
c = 0
else:
batch *= 1.5
if limit and get_used_memory() >= limit:
self._spill()" | [
"def",
"mergeCombiners",
"(",
"self",
",",
"iterator",
",",
"limit",
"=",
"None",
")",
":",
"if",
"limit",
"is",
"None",
":",
"limit",
"=",
"self",
".",
"memory_limit",
"# speedup attribute lookup",
"comb",
",",
"hfun",
",",
"objsize",
"=",
"self",
".",
"agg",
".",
"mergeCombiners",
",",
"self",
".",
"_partition",
",",
"self",
".",
"_object_size",
"c",
",",
"data",
",",
"pdata",
",",
"batch",
"=",
"0",
",",
"self",
".",
"data",
",",
"self",
".",
"pdata",
",",
"self",
".",
"batch",
"for",
"k",
",",
"v",
"in",
"iterator",
":",
"d",
"=",
"pdata",
"[",
"hfun",
"(",
"k",
")",
"]",
"if",
"pdata",
"else",
"data",
"d",
"[",
"k",
"]",
"=",
"comb",
"(",
"d",
"[",
"k",
"]",
",",
"v",
")",
"if",
"k",
"in",
"d",
"else",
"v",
"if",
"not",
"limit",
":",
"continue",
"c",
"+=",
"objsize",
"(",
"v",
")",
"if",
"c",
">",
"batch",
":",
"if",
"get_used_memory",
"(",
")",
">",
"limit",
":",
"self",
".",
"_spill",
"(",
")",
"limit",
"=",
"self",
".",
"_next_limit",
"(",
")",
"batch",
"/=",
"2",
"c",
"=",
"0",
"else",
":",
"batch",
"*=",
"1.5",
"if",
"limit",
"and",
"get_used_memory",
"(",
")",
">=",
"limit",
":",
"self",
".",
"_spill",
"(",
")"
] | "Merge (K,V) pair by mergeCombiner" | [
"Merge",
"(",
"K",
"V",
")",
"pair",
"by",
"mergeCombiner"
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/shuffle.py#L265-L289" |
"apache/spark" | "python/pyspark/shuffle.py" | "ExternalMerger._spill" | "def _spill(self):
"""
dump already partitioned data into disks.
It will dump the data in batch for better performance.
"""
global MemoryBytesSpilled, DiskBytesSpilled
path = self._get_spill_dir(self.spills)
if not os.path.exists(path):
os.makedirs(path)
used_memory = get_used_memory()
if not self.pdata:
# The data has not been partitioned, it will iterator the
# dataset once, write them into different files, has no
# additional memory. It only called when the memory goes
# above limit at the first time.
# open all the files for writing
streams = [open(os.path.join(path, str(i)), 'wb')
for i in range(self.partitions)]
for k, v in self.data.items():
h = self._partition(k)
# put one item in batch, make it compatible with load_stream
# it will increase the memory if dump them in batch
self.serializer.dump_stream([(k, v)], streams[h])
for s in streams:
DiskBytesSpilled += s.tell()
s.close()
self.data.clear()
self.pdata.extend([{} for i in range(self.partitions)])
else:
for i in range(self.partitions):
p = os.path.join(path, str(i))
with open(p, "wb") as f:
# dump items in batch
self.serializer.dump_stream(iter(self.pdata[i].items()), f)
self.pdata[i].clear()
DiskBytesSpilled += os.path.getsize(p)
self.spills += 1
gc.collect() # release the memory as much as possible
MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20" | "python" | "def _spill(self):
"""
dump already partitioned data into disks.
It will dump the data in batch for better performance.
"""
global MemoryBytesSpilled, DiskBytesSpilled
path = self._get_spill_dir(self.spills)
if not os.path.exists(path):
os.makedirs(path)
used_memory = get_used_memory()
if not self.pdata:
# The data has not been partitioned, it will iterator the
# dataset once, write them into different files, has no
# additional memory. It only called when the memory goes
# above limit at the first time.
# open all the files for writing
streams = [open(os.path.join(path, str(i)), 'wb')
for i in range(self.partitions)]
for k, v in self.data.items():
h = self._partition(k)
# put one item in batch, make it compatible with load_stream
# it will increase the memory if dump them in batch
self.serializer.dump_stream([(k, v)], streams[h])
for s in streams:
DiskBytesSpilled += s.tell()
s.close()
self.data.clear()
self.pdata.extend([{} for i in range(self.partitions)])
else:
for i in range(self.partitions):
p = os.path.join(path, str(i))
with open(p, "wb") as f:
# dump items in batch
self.serializer.dump_stream(iter(self.pdata[i].items()), f)
self.pdata[i].clear()
DiskBytesSpilled += os.path.getsize(p)
self.spills += 1
gc.collect() # release the memory as much as possible
MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20" | [
"def",
"_spill",
"(",
"self",
")",
":",
"global",
"MemoryBytesSpilled",
",",
"DiskBytesSpilled",
"path",
"=",
"self",
".",
"_get_spill_dir",
"(",
"self",
".",
"spills",
")",
"if",
"not",
"os",
".",
"path",
".",
"exists",
"(",
"path",
")",
":",
"os",
".",
"makedirs",
"(",
"path",
")",
"used_memory",
"=",
"get_used_memory",
"(",
")",
"if",
"not",
"self",
".",
"pdata",
":",
"# The data has not been partitioned, it will iterator the",
"# dataset once, write them into different files, has no",
"# additional memory. It only called when the memory goes",
"# above limit at the first time.",
"# open all the files for writing",
"streams",
"=",
"[",
"open",
"(",
"os",
".",
"path",
".",
"join",
"(",
"path",
",",
"str",
"(",
"i",
")",
")",
",",
"'wb'",
")",
"for",
"i",
"in",
"range",
"(",
"self",
".",
"partitions",
")",
"]",
"for",
"k",
",",
"v",
"in",
"self",
".",
"data",
".",
"items",
"(",
")",
":",
"h",
"=",
"self",
".",
"_partition",
"(",
"k",
")",
"# put one item in batch, make it compatible with load_stream",
"# it will increase the memory if dump them in batch",
"self",
".",
"serializer",
".",
"dump_stream",
"(",
"[",
"(",
"k",
",",
"v",
")",
"]",
",",
"streams",
"[",
"h",
"]",
")",
"for",
"s",
"in",
"streams",
":",
"DiskBytesSpilled",
"+=",
"s",
".",
"tell",
"(",
")",
"s",
".",
"close",
"(",
")",
"self",
".",
"data",
".",
"clear",
"(",
")",
"self",
".",
"pdata",
".",
"extend",
"(",
"[",
"{",
"}",
"for",
"i",
"in",
"range",
"(",
"self",
".",
"partitions",
")",
"]",
")",
"else",
":",
"for",
"i",
"in",
"range",
"(",
"self",
".",
"partitions",
")",
":",
"p",
"=",
"os",
".",
"path",
".",
"join",
"(",
"path",
",",
"str",
"(",
"i",
")",
")",
"with",
"open",
"(",
"p",
",",
"\"wb\"",
")",
"as",
"f",
":",
"# dump items in batch",
"self",
".",
"serializer",
".",
"dump_stream",
"(",
"iter",
"(",
"self",
".",
"pdata",
"[",
"i",
"]",
".",
"items",
"(",
")",
")",
",",
"f",
")",
"self",
".",
"pdata",
"[",
"i",
"]",
".",
"clear",
"(",
")",
"DiskBytesSpilled",
"+=",
"os",
".",
"path",
".",
"getsize",
"(",
"p",
")",
"self",
".",
"spills",
"+=",
"1",
"gc",
".",
"collect",
"(",
")",
"# release the memory as much as possible",
"MemoryBytesSpilled",
"+=",
"max",
"(",
"used_memory",
"-",
"get_used_memory",
"(",
")",
",",
"0",
")",
"<<",
"20"
] | "dump already partitioned data into disks.
It will dump the data in batch for better performance." | [
"dump",
"already",
"partitioned",
"data",
"into",
"disks",
"."
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/shuffle.py#L291-L337" |
"apache/spark" | "python/pyspark/shuffle.py" | "ExternalMerger.items" | "def items(self):
""" Return all merged items as iterator """
if not self.pdata and not self.spills:
return iter(self.data.items())
return self._external_items()" | "python" | "def items(self):
""" Return all merged items as iterator """
if not self.pdata and not self.spills:
return iter(self.data.items())
return self._external_items()" | [
"def",
"items",
"(",
"self",
")",
":",
"if",
"not",
"self",
".",
"pdata",
"and",
"not",
"self",
".",
"spills",
":",
"return",
"iter",
"(",
"self",
".",
"data",
".",
"items",
"(",
")",
")",
"return",
"self",
".",
"_external_items",
"(",
")"
] | "Return all merged items as iterator" | [
"Return",
"all",
"merged",
"items",
"as",
"iterator"
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/shuffle.py#L339-L343" |
"apache/spark" | "python/pyspark/shuffle.py" | "ExternalMerger._external_items" | "def _external_items(self):
""" Return all partitioned items as iterator """
assert not self.data
if any(self.pdata):
self._spill()
# disable partitioning and spilling when merge combiners from disk
self.pdata = []
try:
for i in range(self.partitions):
for v in self._merged_items(i):
yield v
self.data.clear()
# remove the merged partition
for j in range(self.spills):
path = self._get_spill_dir(j)
os.remove(os.path.join(path, str(i)))
finally:
self._cleanup()" | "python" | "def _external_items(self):
""" Return all partitioned items as iterator """
assert not self.data
if any(self.pdata):
self._spill()
# disable partitioning and spilling when merge combiners from disk
self.pdata = []
try:
for i in range(self.partitions):
for v in self._merged_items(i):
yield v
self.data.clear()
# remove the merged partition
for j in range(self.spills):
path = self._get_spill_dir(j)
os.remove(os.path.join(path, str(i)))
finally:
self._cleanup()" | [
"def",
"_external_items",
"(",
"self",
")",
":",
"assert",
"not",
"self",
".",
"data",
"if",
"any",
"(",
"self",
".",
"pdata",
")",
":",
"self",
".",
"_spill",
"(",
")",
"# disable partitioning and spilling when merge combiners from disk",
"self",
".",
"pdata",
"=",
"[",
"]",
"try",
":",
"for",
"i",
"in",
"range",
"(",
"self",
".",
"partitions",
")",
":",
"for",
"v",
"in",
"self",
".",
"_merged_items",
"(",
"i",
")",
":",
"yield",
"v",
"self",
".",
"data",
".",
"clear",
"(",
")",
"# remove the merged partition",
"for",
"j",
"in",
"range",
"(",
"self",
".",
"spills",
")",
":",
"path",
"=",
"self",
".",
"_get_spill_dir",
"(",
"j",
")",
"os",
".",
"remove",
"(",
"os",
".",
"path",
".",
"join",
"(",
"path",
",",
"str",
"(",
"i",
")",
")",
")",
"finally",
":",
"self",
".",
"_cleanup",
"(",
")"
] | "Return all partitioned items as iterator" | [
"Return",
"all",
"partitioned",
"items",
"as",
"iterator"
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/shuffle.py#L345-L364" |
"apache/spark" | "python/pyspark/shuffle.py" | "ExternalMerger._recursive_merged_items" | "def _recursive_merged_items(self, index):
"""
merge the partitioned items and return the as iterator
If one partition can not be fit in memory, then them will be
partitioned and merged recursively.
"""
subdirs = [os.path.join(d, "parts", str(index)) for d in self.localdirs]
m = ExternalMerger(self.agg, self.memory_limit, self.serializer, subdirs,
self.scale * self.partitions, self.partitions, self.batch)
m.pdata = [{} for _ in range(self.partitions)]
limit = self._next_limit()
for j in range(self.spills):
path = self._get_spill_dir(j)
p = os.path.join(path, str(index))
with open(p, 'rb') as f:
m.mergeCombiners(self.serializer.load_stream(f), 0)
if get_used_memory() > limit:
m._spill()
limit = self._next_limit()
return m._external_items()" | "python" | "def _recursive_merged_items(self, index):
"""
merge the partitioned items and return the as iterator
If one partition can not be fit in memory, then them will be
partitioned and merged recursively.
"""
subdirs = [os.path.join(d, "parts", str(index)) for d in self.localdirs]
m = ExternalMerger(self.agg, self.memory_limit, self.serializer, subdirs,
self.scale * self.partitions, self.partitions, self.batch)
m.pdata = [{} for _ in range(self.partitions)]
limit = self._next_limit()
for j in range(self.spills):
path = self._get_spill_dir(j)
p = os.path.join(path, str(index))
with open(p, 'rb') as f:
m.mergeCombiners(self.serializer.load_stream(f), 0)
if get_used_memory() > limit:
m._spill()
limit = self._next_limit()
return m._external_items()" | [
"def",
"_recursive_merged_items",
"(",
"self",
",",
"index",
")",
":",
"subdirs",
"=",
"[",
"os",
".",
"path",
".",
"join",
"(",
"d",
",",
"\"parts\"",
",",
"str",
"(",
"index",
")",
")",
"for",
"d",
"in",
"self",
".",
"localdirs",
"]",
"m",
"=",
"ExternalMerger",
"(",
"self",
".",
"agg",
",",
"self",
".",
"memory_limit",
",",
"self",
".",
"serializer",
",",
"subdirs",
",",
"self",
".",
"scale",
"*",
"self",
".",
"partitions",
",",
"self",
".",
"partitions",
",",
"self",
".",
"batch",
")",
"m",
".",
"pdata",
"=",
"[",
"{",
"}",
"for",
"_",
"in",
"range",
"(",
"self",
".",
"partitions",
")",
"]",
"limit",
"=",
"self",
".",
"_next_limit",
"(",
")",
"for",
"j",
"in",
"range",
"(",
"self",
".",
"spills",
")",
":",
"path",
"=",
"self",
".",
"_get_spill_dir",
"(",
"j",
")",
"p",
"=",
"os",
".",
"path",
".",
"join",
"(",
"path",
",",
"str",
"(",
"index",
")",
")",
"with",
"open",
"(",
"p",
",",
"'rb'",
")",
"as",
"f",
":",
"m",
".",
"mergeCombiners",
"(",
"self",
".",
"serializer",
".",
"load_stream",
"(",
"f",
")",
",",
"0",
")",
"if",
"get_used_memory",
"(",
")",
">",
"limit",
":",
"m",
".",
"_spill",
"(",
")",
"limit",
"=",
"self",
".",
"_next_limit",
"(",
")",
"return",
"m",
".",
"_external_items",
"(",
")"
] | "merge the partitioned items and return the as iterator
If one partition can not be fit in memory, then them will be
partitioned and merged recursively." | [
"merge",
"the",
"partitioned",
"items",
"and",
"return",
"the",
"as",
"iterator"
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/shuffle.py#L386-L409" |
"apache/spark" | "python/pyspark/shuffle.py" | "ExternalSorter._get_path" | "def _get_path(self, n):
""" Choose one directory for spill by number n """
d = self.local_dirs[n % len(self.local_dirs)]
if not os.path.exists(d):
os.makedirs(d)
return os.path.join(d, str(n))" | "python" | "def _get_path(self, n):
""" Choose one directory for spill by number n """
d = self.local_dirs[n % len(self.local_dirs)]
if not os.path.exists(d):
os.makedirs(d)
return os.path.join(d, str(n))" | [
"def",
"_get_path",
"(",
"self",
",",
"n",
")",
":",
"d",
"=",
"self",
".",
"local_dirs",
"[",
"n",
"%",
"len",
"(",
"self",
".",
"local_dirs",
")",
"]",
"if",
"not",
"os",
".",
"path",
".",
"exists",
"(",
"d",
")",
":",
"os",
".",
"makedirs",
"(",
"d",
")",
"return",
"os",
".",
"path",
".",
"join",
"(",
"d",
",",
"str",
"(",
"n",
")",
")"
] | "Choose one directory for spill by number n" | [
"Choose",
"one",
"directory",
"for",
"spill",
"by",
"number",
"n"
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/shuffle.py#L440-L445" |
"apache/spark" | "python/pyspark/shuffle.py" | "ExternalSorter.sorted" | "def sorted(self, iterator, key=None, reverse=False):
"""
Sort the elements in iterator, do external sort when the memory
goes above the limit.
"""
global MemoryBytesSpilled, DiskBytesSpilled
batch, limit = 100, self._next_limit()
chunks, current_chunk = [], []
iterator = iter(iterator)
while True:
# pick elements in batch
chunk = list(itertools.islice(iterator, batch))
current_chunk.extend(chunk)
if len(chunk) < batch:
break
used_memory = get_used_memory()
if used_memory > limit:
# sort them inplace will save memory
current_chunk.sort(key=key, reverse=reverse)
path = self._get_path(len(chunks))
with open(path, 'wb') as f:
self.serializer.dump_stream(current_chunk, f)
def load(f):
for v in self.serializer.load_stream(f):
yield v
# close the file explicit once we consume all the items
# to avoid ResourceWarning in Python3
f.close()
chunks.append(load(open(path, 'rb')))
current_chunk = []
MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20
DiskBytesSpilled += os.path.getsize(path)
os.unlink(path) # data will be deleted after close
elif not chunks:
batch = min(int(batch * 1.5), 10000)
current_chunk.sort(key=key, reverse=reverse)
if not chunks:
return current_chunk
if current_chunk:
chunks.append(iter(current_chunk))
return heapq.merge(chunks, key=key, reverse=reverse)" | "python" | "def sorted(self, iterator, key=None, reverse=False):
"""
Sort the elements in iterator, do external sort when the memory
goes above the limit.
"""
global MemoryBytesSpilled, DiskBytesSpilled
batch, limit = 100, self._next_limit()
chunks, current_chunk = [], []
iterator = iter(iterator)
while True:
# pick elements in batch
chunk = list(itertools.islice(iterator, batch))
current_chunk.extend(chunk)
if len(chunk) < batch:
break
used_memory = get_used_memory()
if used_memory > limit:
# sort them inplace will save memory
current_chunk.sort(key=key, reverse=reverse)
path = self._get_path(len(chunks))
with open(path, 'wb') as f:
self.serializer.dump_stream(current_chunk, f)
def load(f):
for v in self.serializer.load_stream(f):
yield v
# close the file explicit once we consume all the items
# to avoid ResourceWarning in Python3
f.close()
chunks.append(load(open(path, 'rb')))
current_chunk = []
MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20
DiskBytesSpilled += os.path.getsize(path)
os.unlink(path) # data will be deleted after close
elif not chunks:
batch = min(int(batch * 1.5), 10000)
current_chunk.sort(key=key, reverse=reverse)
if not chunks:
return current_chunk
if current_chunk:
chunks.append(iter(current_chunk))
return heapq.merge(chunks, key=key, reverse=reverse)" | [
"def",
"sorted",
"(",
"self",
",",
"iterator",
",",
"key",
"=",
"None",
",",
"reverse",
"=",
"False",
")",
":",
"global",
"MemoryBytesSpilled",
",",
"DiskBytesSpilled",
"batch",
",",
"limit",
"=",
"100",
",",
"self",
".",
"_next_limit",
"(",
")",
"chunks",
",",
"current_chunk",
"=",
"[",
"]",
",",
"[",
"]",
"iterator",
"=",
"iter",
"(",
"iterator",
")",
"while",
"True",
":",
"# pick elements in batch",
"chunk",
"=",
"list",
"(",
"itertools",
".",
"islice",
"(",
"iterator",
",",
"batch",
")",
")",
"current_chunk",
".",
"extend",
"(",
"chunk",
")",
"if",
"len",
"(",
"chunk",
")",
"<",
"batch",
":",
"break",
"used_memory",
"=",
"get_used_memory",
"(",
")",
"if",
"used_memory",
">",
"limit",
":",
"# sort them inplace will save memory",
"current_chunk",
".",
"sort",
"(",
"key",
"=",
"key",
",",
"reverse",
"=",
"reverse",
")",
"path",
"=",
"self",
".",
"_get_path",
"(",
"len",
"(",
"chunks",
")",
")",
"with",
"open",
"(",
"path",
",",
"'wb'",
")",
"as",
"f",
":",
"self",
".",
"serializer",
".",
"dump_stream",
"(",
"current_chunk",
",",
"f",
")",
"def",
"load",
"(",
"f",
")",
":",
"for",
"v",
"in",
"self",
".",
"serializer",
".",
"load_stream",
"(",
"f",
")",
":",
"yield",
"v",
"# close the file explicit once we consume all the items",
"# to avoid ResourceWarning in Python3",
"f",
".",
"close",
"(",
")",
"chunks",
".",
"append",
"(",
"load",
"(",
"open",
"(",
"path",
",",
"'rb'",
")",
")",
")",
"current_chunk",
"=",
"[",
"]",
"MemoryBytesSpilled",
"+=",
"max",
"(",
"used_memory",
"-",
"get_used_memory",
"(",
")",
",",
"0",
")",
"<<",
"20",
"DiskBytesSpilled",
"+=",
"os",
".",
"path",
".",
"getsize",
"(",
"path",
")",
"os",
".",
"unlink",
"(",
"path",
")",
"# data will be deleted after close",
"elif",
"not",
"chunks",
":",
"batch",
"=",
"min",
"(",
"int",
"(",
"batch",
"*",
"1.5",
")",
",",
"10000",
")",
"current_chunk",
".",
"sort",
"(",
"key",
"=",
"key",
",",
"reverse",
"=",
"reverse",
")",
"if",
"not",
"chunks",
":",
"return",
"current_chunk",
"if",
"current_chunk",
":",
"chunks",
".",
"append",
"(",
"iter",
"(",
"current_chunk",
")",
")",
"return",
"heapq",
".",
"merge",
"(",
"chunks",
",",
"key",
"=",
"key",
",",
"reverse",
"=",
"reverse",
")"
] | "Sort the elements in iterator, do external sort when the memory
goes above the limit." | [
"Sort",
"the",
"elements",
"in",
"iterator",
"do",
"external",
"sort",
"when",
"the",
"memory",
"goes",
"above",
"the",
"limit",
"."
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/shuffle.py#L455-L501" |
"apache/spark" | "python/pyspark/shuffle.py" | "ExternalList._spill" | "def _spill(self):
""" dump the values into disk """
global MemoryBytesSpilled, DiskBytesSpilled
if self._file is None:
self._open_file()
used_memory = get_used_memory()
pos = self._file.tell()
self._ser.dump_stream(self.values, self._file)
self.values = []
gc.collect()
DiskBytesSpilled += self._file.tell() - pos
MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20" | "python" | "def _spill(self):
""" dump the values into disk """
global MemoryBytesSpilled, DiskBytesSpilled
if self._file is None:
self._open_file()
used_memory = get_used_memory()
pos = self._file.tell()
self._ser.dump_stream(self.values, self._file)
self.values = []
gc.collect()
DiskBytesSpilled += self._file.tell() - pos
MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20" | [
"def",
"_spill",
"(",
"self",
")",
":",
"global",
"MemoryBytesSpilled",
",",
"DiskBytesSpilled",
"if",
"self",
".",
"_file",
"is",
"None",
":",
"self",
".",
"_open_file",
"(",
")",
"used_memory",
"=",
"get_used_memory",
"(",
")",
"pos",
"=",
"self",
".",
"_file",
".",
"tell",
"(",
")",
"self",
".",
"_ser",
".",
"dump_stream",
"(",
"self",
".",
"values",
",",
"self",
".",
"_file",
")",
"self",
".",
"values",
"=",
"[",
"]",
"gc",
".",
"collect",
"(",
")",
"DiskBytesSpilled",
"+=",
"self",
".",
"_file",
".",
"tell",
"(",
")",
"-",
"pos",
"MemoryBytesSpilled",
"+=",
"max",
"(",
"used_memory",
"-",
"get_used_memory",
"(",
")",
",",
"0",
")",
"<<",
"20"
] | "dump the values into disk" | [
"dump",
"the",
"values",
"into",
"disk"
] | "train" | "https://github.com/apache/spark/blob/618d6bff71073c8c93501ab7392c3cc579730f0b/python/pyspark/shuffle.py#L590-L602" |
Dataset Card for CodeSearchNet corpus
Dataset Summary
CodeSearchNet corpus is a dataset of 2 milllion (comment, code) pairs from opensource libraries hosted on GitHub. It contains code and documentation for several programming languages.
CodeSearchNet corpus was gathered to support the CodeSearchNet challenge, to explore the problem of code retrieval using natural language.
Supported Tasks and Leaderboards
language-modeling
: The dataset can be used to train a model for modelling programming languages, which consists in building language models for programming languages.
Languages
- Go programming language
- Java programming language
- Javascript programming language
- PHP programming language
- Python programming language
- Ruby programming language
Dataset Structure
Data Instances
A data point consists of a function code along with its documentation. Each data point also contains meta data on the function, such as the repository it was extracted from.
{
'id': '0',
'repository_name': 'organisation/repository',
'func_path_in_repository': 'src/path/to/file.py',
'func_name': 'func',
'whole_func_string': 'def func(args):\n"""Docstring"""\n [...]',
'language': 'python',
'func_code_string': '[...]',
'func_code_tokens': ['def', 'func', '(', 'args', ')', ...],
'func_documentation_string': 'Docstring',
'func_documentation_string_tokens': ['Docstring'],
'split_name': 'train',
'func_code_url': 'https://github.com/<org>/<repo>/blob/<hash>/src/path/to/file.py#L111-L150'
}
Data Fields
id
: Arbitrary numberrepository_name
: name of the GitHub repositoryfunc_path_in_repository
: tl;dr: path to the file which holds the function in the repositoryfunc_name
: name of the function in the filewhole_func_string
: Code + documentation of the functionlanguage
: Programming language in whoch the function is writtenfunc_code_string
: Function codefunc_code_tokens
: Tokens yielded by Treesitterfunc_documentation_string
: Function documentationfunc_documentation_string_tokens
: Tokens yielded by Treesittersplit_name
: Name of the split to which the example belongs (one of train, test or valid)func_code_url
: URL to the function code on Github
Data Splits
Three splits are available:
- train
- test
- valid
Dataset Creation
Curation Rationale
[More Information Needed]
Source Data
Initial Data Collection and Normalization
All information can be retrieved in the original technical review
Corpus collection:
Corpus has been collected from publicly available open-source non-fork GitHub repositories, using libraries.io to identify all projects which are used by at least one other project, and sort them by βpopularityβ as indicated by the number of stars and forks.
Then, any projects that do not have a license or whose license does not explicitly permit the re-distribution of parts of the project were removed. Treesitter - GitHub's universal parser - has been used to then tokenize all Go, Java, JavaScript, Python, PHP and Ruby functions (or methods) using and, where available, their respective documentation text using a heuristic regular expression.
Corpus filtering:
Functions without documentation are removed from the corpus. This yields a set of pairs ($c_i$, $d_i$) where ci is some function documented by di. Pairs ($c_i$, $d_i$) are passed through the folllowing preprocessing tasks:
- Documentation $d_i$ is truncated to the first full paragraph to remove in-depth discussion of function arguments and return values
- Pairs in which $d_i$ is shorter than three tokens are removed
- Functions $c_i$ whose implementation is shorter than three lines are removed
- Functions whose name contains the substring βtestβ are removed
- Constructors and standard extenion methods (eg
__str__
in Python ortoString
in Java) are removed - Duplicates and near duplicates functions are removed, in order to keep only one version of the function
Who are the source language producers?
OpenSource contributors produced the code and documentations.
The dataset was gatherered and preprocessed automatically.
Annotations
Annotation process
[More Information Needed]
Who are the annotators?
[More Information Needed]
Personal and Sensitive Information
[More Information Needed]
Considerations for Using the Data
Social Impact of Dataset
[More Information Needed]
Discussion of Biases
[More Information Needed]
Other Known Limitations
[More Information Needed]
Additional Information
Dataset Curators
[More Information Needed]
Licensing Information
Each example in the dataset has is extracted from a GitHub repository, and each repository has its own license. Example-wise license information is not (yet) included in this dataset: you will need to find out yourself which license the code is using.
Citation Information
@article{husain2019codesearchnet, title={{CodeSearchNet} challenge: Evaluating the state of semantic code search}, author={Husain, Hamel and Wu, Ho-Hsiang and Gazit, Tiferet and Allamanis, Miltiadis and Brockschmidt, Marc}, journal={arXiv preprint arXiv:1909.09436}, year={2019} }
Contributions
Thanks to @SBrandeis for adding this dataset.
- Downloads last month
- 49,987