Skip to content

Commit

Permalink
Update to enable parallel training via SuperScaler (#186)
Browse files Browse the repository at this point in the history
* Update to enable parallel training via SuperScaler

* hearders update

* hearders update

* clean

* format fix

* remove dxcompute

* format fix

* NNFusion fixed by Ze Gao, mnist example added

* nit

* nit fix: python -> python3 for SuperScaler

* fix rocm tests
  • Loading branch information
lynex authored Dec 25, 2020
1 parent 392811e commit 7c96540
Show file tree
Hide file tree
Showing 47 changed files with 1,641 additions and 72 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ endif()
set(GLOBAL_INCLUDE_PATH
${CMAKE_CURRENT_SOURCE_DIR}/src
${CMAKE_CURRENT_SOURCE_DIR}/thirdparty
# for proto headers
${PROJECT_BINARY_DIR}/src
)

include_directories(
Expand Down
2 changes: 1 addition & 1 deletion maint/script/apply_code_style.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ pushd "${PWD}/../../" > /dev/null
| xargs "${CLANG_FORMAT}" -i -style=file 2>&1 \
| grep -v "Is a directory"
echo "Done."
popd > /dev/null
popd > /dev/null
2 changes: 1 addition & 1 deletion src/nnfusion/common/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.


add_subdirectory(serialize)
set(SRC
languageunit.cpp
type_info.cpp
Expand Down
19 changes: 19 additions & 0 deletions src/nnfusion/common/serialize/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

FILE(TO_NATIVE_PATH ${CMAKE_CURRENT_SOURCE_DIR} PROTOSRC_PATH)
FOREACH(item pbtypes attr_value tensor_shape node_def graph_def)
EXECUTE_PROCESS(COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} --proto_path=${PROTOSRC_PATH} --cpp_out=${CMAKE_CURRENT_BINARY_DIR} ${item}.proto)
FILE(TO_NATIVE_PATH ${item}.pb.h proto_header)
FILE(TO_NATIVE_PATH ${item}.pb.cc proto_source)
list(APPEND SRC ${proto_header} ${proto_source})
ENDFOREACH(item)

include_directories(${CMAKE_CURRENT_BINARY_DIR})
add_library(nnfusion_serialize STATIC ${SRC})
target_include_directories(nnfusion_serialize SYSTEM PUBLIC
${GLOBAL_INCLUDE_PATH}
)

target_compile_options(nnfusion_serialize PRIVATE "-fPIC")
target_link_libraries(nnfusion_serialize ${Protobuf_LIBRARIES})
58 changes: 58 additions & 0 deletions src/nnfusion/common/serialize/attr_value.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
syntax = "proto3";

package nnfusion.serialize;

// import "tensor.proto";
import "tensor_shape.proto";
import "pbtypes.proto";

option cc_enable_arenas = true;

// Protocol buffer representing the value for an attr used to configure an Op.
// Comment indicates the corresponding attr type. Only the field matching the
// attr type may be filled.
message AttrValue {
message ListValue {
repeated bytes s = 2; // "list(string)"
repeated int64 i = 3 [packed = true]; // "list(int)"
repeated float f = 4 [packed = true]; // "list(float)"
repeated bool b = 5 [packed = true]; // "list(bool)"
repeated PBType type = 6 [packed = true]; // "list(type)"
repeated TensorShapeProto shape = 7; // "list(shape)"
// repeated TensorProto tensor = 8; // "list(tensor)"
// repeated NameAttrList func = 9; // "list(attr)"
}

oneof value {
bytes s = 2; // "string"
int64 i = 3; // "int"
float f = 4; // "float"
bool b = 5; // "bool"
PBType type = 6; // "type"
TensorShapeProto shape = 7; // "shape"
// TensorProto tensor = 8; // "tensor"
ListValue list = 1; // any "list(...)"

// "func" represents a function. func.name is a function's name or
// a primitive op's name. func.attr.first is the name of an attr
// defined for that function. func.attr.second is the value for
// that attr in the instantiation.
// NameAttrList func = 10;

// This is a placeholder only used in nodes defined inside a
// function. It indicates the attr value will be supplied when
// the function is instantiated. For example, let us suppose a
// node "N" in function "FN". "N" has an attr "A" with value
// placeholder = "foo". When FN is instantiated with attr "foo"
// set to "bar", the instantiated node N's attr A will have been
// given the value "bar".
// string placeholder = 9;
}
}

// A list of attr names and their values. The whole list is attached
// with a string name. E.g., MatMul[T=float].
// message NameAttrList {
// string name = 1;
// map<string, AttrValue> attr = 2;
// }
13 changes: 13 additions & 0 deletions src/nnfusion/common/serialize/graph_def.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
syntax = "proto3";

package nnfusion.serialize;

import "node_def.proto";

option cc_enable_arenas = true;

// Represents the graph of operations
message GraphDef {
repeated NodeDef node = 1;
int32 version = 2;
}
62 changes: 62 additions & 0 deletions src/nnfusion/common/serialize/node_def.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
syntax = "proto3";

package nnfusion.serialize;

import "attr_value.proto";

option cc_enable_arenas = true;

message NodeDef {
// The name given to this operator. Used for naming inputs,
// logging, visualization, etc. Unique within a single GraphDef.
// Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_>./]*".
string name = 1;

// The operation name. There may be custom parameters in attrs.
// Op names starting with an underscore are reserved for internal use.
string op = 2;

// Each input is "node:src_output" with "node" being a string name and
// "src_output" indicating which output tensor to use from "node". If
// "src_output" is 0 the ":0" suffix can be omitted. Regular inputs
// may optionally be followed by control inputs that have the format
// "^node".
repeated string input = 3;

// A (possibly partial) specification for the device on which this
// node should be placed.
// The expected syntax for this string is as follows:
//
// DEVICE_SPEC ::= PARTIAL_SPEC
//
// PARTIAL_SPEC ::= ("/" CONSTRAINT) *
// CONSTRAINT ::= ("job:" JOB_NAME)
// | ("replica:" [1-9][0-9]*)
// | ("task:" [1-9][0-9]*)
// | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") )
//
// Valid values for this string include:
// * "/job:worker/replica:0/task:1/device:GPU:3" (full specification)
// * "/job:worker/device:GPU:3" (partial specification)
// * "" (no specification)
//
// If the constraints do not resolve to a single device (or if this
// field is empty or not present), the runtime will attempt to
// choose a device automatically.
// string device = 4;

// Operation-specific graph-construction-time configuration.
// Note that this should include all attrs defined in the
// corresponding OpDef, including those with a value matching
// the default -- this allows the default to change and makes
// NodeDefs easier to interpret on their own. However, if
// an attr with a default is not specified in this list, the
// default will be used.
// The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and
// one of the names from the corresponding OpDef's attr field).
// The values must have a type matching the corresponding OpDef
// attr's type field.
// TODO(josh11b): Add some examples here showing best practices.
map<string, AttrValue> attr = 5;

}
40 changes: 40 additions & 0 deletions src/nnfusion/common/serialize/pbtypes.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
syntax = "proto3";

//package nnfusion;
package nnfusion.serialize;

option cc_enable_arenas = true;

// (== suppress_warning documentation-presence ==)
// LINT.IfChange
enum PBType {
// Not a legal value for Type. Used to indicate a Type field
// has not been set.
DT_INVALID = 0;

// Data types that all computation devices are expected to be
// capable to support.
DT_BOOL = 1;
DT_CHAR = 2;
DT_FLOAT = 3;
DT_DOUBLE = 4;
DT_INT8 = 5;
DT_INT16 = 6;
DT_INT32 = 7;
DT_INT64 = 8;
DT_UINT8 = 9;
DT_UINT16 = 10;
DT_UINT32 = 11;
DT_UINT64 = 12;
}

// For identifying the underlying type of a variant. For variants, the types
// listed here are a subset of the types in the variant type registry,
// corresponding to commonly used variants which must occasionally be
// special-cased.
// enum SpecializedType {
// // Invalid/unknown specialized type.
// ST_INVALID = 0;
// // "tensorflow::TensorList" in the variant type registry.
// ST_TENSOR_LIST = 1;
// }
42 changes: 42 additions & 0 deletions src/nnfusion/common/serialize/tensor_shape.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Protocol buffer representing the shape of tensors.

syntax = "proto3";
option cc_enable_arenas = true;

package nnfusion.serialize;

// Dimensions of a tensor.
message TensorShapeProto {
// One dimension of the tensor.
message Dim {
// Size of the tensor in that dimension.
// This value must be >= -1, but values of -1 are reserved for "unknown"
// shapes (values of -1 mean "unknown" dimension). Certain wrappers
// that work with TensorShapeProto may fail at runtime when deserializing
// a TensorShapeProto containing a dim value of -1.
int64 size = 1;

// Optional name of the tensor dimension.
string name = 2;
};

// Dimensions of the tensor, such as {"input", 30}, {"output", 40}
// for a 30 x 40 2D tensor. If an entry has size -1, this
// corresponds to a dimension of unknown size. The names are
// optional.
//
// The order of entries in "dim" matters: It indicates the layout of the
// values in the tensor in-memory representation.
//
// The first entry in "dim" is the outermost dimension used to layout the
// values, the last entry is the innermost dimension. This matches the
// in-memory layout of RowMajor Eigen tensors.
//
// If "dim.size()" > 0, "unknown_rank" must be false.
repeated Dim dim = 2;

// If true, the number of dimensions in the shape is unknown.
//
// If true, "dim.size()" must be 0.
bool unknown_rank = 3;
};
33 changes: 33 additions & 0 deletions src/nnfusion/common/type/element_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,39 @@ bool element::Type::nnfusion_element_type_to_dtype_string(const element::Type& n
return true;
}

bool element::Type::nnfusion_element_type_to_pbtype(const element::Type& ng_et,
nnfusion::serialize::PBType& dtype)
{
if (ng_et == element::boolean)
dtype = nnfusion::serialize::PBType::DT_BOOL;
else if (ng_et == element::character)
dtype = nnfusion::serialize::PBType::DT_CHAR;
else if (ng_et == element::f32)
dtype = nnfusion::serialize::PBType::DT_FLOAT;
else if (ng_et == element::f64)
dtype = nnfusion::serialize::PBType::DT_DOUBLE;
else if (ng_et == element::i8)
dtype = nnfusion::serialize::PBType::DT_INT8;
else if (ng_et == element::i16)
dtype = nnfusion::serialize::PBType::DT_INT16;
else if (ng_et == element::i32)
dtype = nnfusion::serialize::PBType::DT_INT32;
else if (ng_et == element::i64)
dtype = nnfusion::serialize::PBType::DT_INT64;
else if (ng_et == element::u8)
dtype = nnfusion::serialize::PBType::DT_UINT8;
else if (ng_et == element::u16)
dtype = nnfusion::serialize::PBType::DT_UINT16;
else if (ng_et == element::u32)
dtype = nnfusion::serialize::PBType::DT_UINT32;
else if (ng_et == element::u64)
dtype = nnfusion::serialize::PBType::DT_UINT64;
else
return false;

return true;
}

element::Type::Type(
size_t bitwidth, bool is_real, bool is_signed, bool is_quantized, const std::string& cname)
: m_bitwidth{bitwidth}
Expand Down
3 changes: 3 additions & 0 deletions src/nnfusion/common/type/element_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <vector>

#include "half/include/half.hpp"
#include "nnfusion/common/serialize/pbtypes.pb.h"
#include "nnfusion/common/type/bfloat16.hpp"
#include "nnfusion/util/errors.hpp"

Expand Down Expand Up @@ -80,6 +81,8 @@ namespace nnfusion
bool operator<(const Type& other) const;
friend std::ostream& operator<<(std::ostream&, const Type&);
static std::vector<const Type*> get_known_types();
static bool nnfusion_element_type_to_pbtype(const Type& ng_et,
nnfusion::serialize::PBType& dtype);
static bool nnfusion_element_type_to_dtype_string(const Type& ng_et,
std::string& dtype);

Expand Down
1 change: 1 addition & 0 deletions src/nnfusion/core/graph/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ add_library(nnfusion_graph STATIC ${SRC})
target_include_directories(nnfusion_graph SYSTEM PUBLIC
${GLOBAL_INCLUDE_PATH}
)
target_link_libraries(nnfusion_graph nnfusion_serialize)
target_compile_options(nnfusion_graph PRIVATE "-fPIC")
Loading

0 comments on commit 7c96540

Please sign in to comment.