Godep is not supposed to used in package.

This commit is contained in:
Chih-Wei Chang 2015-07-02 20:50:03 +08:00
parent 2a172b0a22
commit e404f82893
128 changed files with 0 additions and 38982 deletions

View file

@ -1,26 +0,0 @@
{
"ImportPath": "github.com/lazywei/go-opencv/gocv",
"GoVersion": "go1.4.1",
"Deps": [
{
"ImportPath": "github.com/davecgh/go-spew/spew",
"Rev": "1aaf839fb07e099361e445273993ccd9adc21b07"
},
{
"ImportPath": "github.com/gonum/blas",
"Rev": "22132bfa8c9d291d8c11a6a817e4da1fa1c35c39"
},
{
"ImportPath": "github.com/gonum/internal/asm",
"Rev": "9988c755e4ebb6828adce026d571114b6ee26a6b"
},
{
"ImportPath": "github.com/gonum/matrix/mat64",
"Rev": "7c0d216f456e1c5fe498437367134ccdf6b35ded"
},
{
"ImportPath": "github.com/stretchr/testify/assert",
"Rev": "f0b02af48e5ee31c78b949e9ed67c37e08d1a897"
}
]
}

5
gocv/Godeps/Readme generated
View file

@ -1,5 +0,0 @@
This directory tree is generated automatically by godep.
Please do not edit.
See https://github.com/tools/godep for more information.

2
gocv/Godeps/_workspace/.gitignore generated vendored
View file

@ -1,2 +0,0 @@
/pkg
/bin

View file

@ -1,371 +0,0 @@
/*
* Copyright (c) 2013 Dave Collins <dave@davec.name>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
package spew
import (
"fmt"
"io"
"reflect"
"sort"
"strconv"
"unsafe"
)
const (
// ptrSize is the size of a pointer on the current arch.
ptrSize = unsafe.Sizeof((*byte)(nil))
)
var (
// offsetPtr, offsetScalar, and offsetFlag are the offsets for the
// internal reflect.Value fields. These values are valid before golang
// commit ecccf07e7f9d which changed the format. The are also valid
// after commit 82f48826c6c7 which changed the format again to mirror
// the original format. Code in the init function updates these offsets
// as necessary.
offsetPtr = uintptr(ptrSize)
offsetScalar = uintptr(0)
offsetFlag = uintptr(ptrSize * 2)
// flagKindWidth and flagKindShift indicate various bits that the
// reflect package uses internally to track kind information.
//
// flagRO indicates whether or not the value field of a reflect.Value is
// read-only.
//
// flagIndir indicates whether the value field of a reflect.Value is
// the actual data or a pointer to the data.
//
// These values are valid before golang commit 90a7c3c86944 which
// changed their positions. Code in the init function updates these
// flags as necessary.
flagKindWidth = uintptr(5)
flagKindShift = uintptr(flagKindWidth - 1)
flagRO = uintptr(1 << 0)
flagIndir = uintptr(1 << 1)
)
func init() {
// Older versions of reflect.Value stored small integers directly in the
// ptr field (which is named val in the older versions). Versions
// between commits ecccf07e7f9d and 82f48826c6c7 added a new field named
// scalar for this purpose which unfortunately came before the flag
// field, so the offset of the flag field is different for those
// versions.
//
// This code constructs a new reflect.Value from a known small integer
// and checks if the size of the reflect.Value struct indicates it has
// the scalar field. When it does, the offsets are updated accordingly.
vv := reflect.ValueOf(0xf00)
if unsafe.Sizeof(vv) == (ptrSize * 4) {
offsetScalar = ptrSize * 2
offsetFlag = ptrSize * 3
}
// Commit 90a7c3c86944 changed the flag positions such that the low
// order bits are the kind. This code extracts the kind from the flags
// field and ensures it's the correct type. When it's not, the flag
// order has been changed to the newer format, so the flags are updated
// accordingly.
upf := unsafe.Pointer(uintptr(unsafe.Pointer(&vv)) + offsetFlag)
upfv := *(*uintptr)(upf)
flagKindMask := uintptr((1<<flagKindWidth - 1) << flagKindShift)
if (upfv&flagKindMask)>>flagKindShift != uintptr(reflect.Int) {
flagKindShift = 0
flagRO = 1 << 5
flagIndir = 1 << 6
}
}
// unsafeReflectValue converts the passed reflect.Value into a one that bypasses
// the typical safety restrictions preventing access to unaddressable and
// unexported data. It works by digging the raw pointer to the underlying
// value out of the protected value and generating a new unprotected (unsafe)
// reflect.Value to it.
//
// This allows us to check for implementations of the Stringer and error
// interfaces to be used for pretty printing ordinarily unaddressable and
// inaccessible values such as unexported struct fields.
func unsafeReflectValue(v reflect.Value) (rv reflect.Value) {
indirects := 1
vt := v.Type()
upv := unsafe.Pointer(uintptr(unsafe.Pointer(&v)) + offsetPtr)
rvf := *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&v)) + offsetFlag))
if rvf&flagIndir != 0 {
vt = reflect.PtrTo(v.Type())
indirects++
} else if offsetScalar != 0 {
// The value is in the scalar field when it's not one of the
// reference types.
switch vt.Kind() {
case reflect.Uintptr:
case reflect.Chan:
case reflect.Func:
case reflect.Map:
case reflect.Ptr:
case reflect.UnsafePointer:
default:
upv = unsafe.Pointer(uintptr(unsafe.Pointer(&v)) +
offsetScalar)
}
}
pv := reflect.NewAt(vt, upv)
rv = pv
for i := 0; i < indirects; i++ {
rv = rv.Elem()
}
return rv
}
// Some constants in the form of bytes to avoid string overhead. This mirrors
// the technique used in the fmt package.
var (
panicBytes = []byte("(PANIC=")
plusBytes = []byte("+")
iBytes = []byte("i")
trueBytes = []byte("true")
falseBytes = []byte("false")
interfaceBytes = []byte("(interface {})")
commaNewlineBytes = []byte(",\n")
newlineBytes = []byte("\n")
openBraceBytes = []byte("{")
openBraceNewlineBytes = []byte("{\n")
closeBraceBytes = []byte("}")
asteriskBytes = []byte("*")
colonBytes = []byte(":")
colonSpaceBytes = []byte(": ")
openParenBytes = []byte("(")
closeParenBytes = []byte(")")
spaceBytes = []byte(" ")
pointerChainBytes = []byte("->")
nilAngleBytes = []byte("<nil>")
maxNewlineBytes = []byte("<max depth reached>\n")
maxShortBytes = []byte("<max>")
circularBytes = []byte("<already shown>")
circularShortBytes = []byte("<shown>")
invalidAngleBytes = []byte("<invalid>")
openBracketBytes = []byte("[")
closeBracketBytes = []byte("]")
percentBytes = []byte("%")
precisionBytes = []byte(".")
openAngleBytes = []byte("<")
closeAngleBytes = []byte(">")
openMapBytes = []byte("map[")
closeMapBytes = []byte("]")
lenEqualsBytes = []byte("len=")
capEqualsBytes = []byte("cap=")
)
// hexDigits is used to map a decimal value to a hex digit.
var hexDigits = "0123456789abcdef"
// catchPanic handles any panics that might occur during the handleMethods
// calls.
func catchPanic(w io.Writer, v reflect.Value) {
if err := recover(); err != nil {
w.Write(panicBytes)
fmt.Fprintf(w, "%v", err)
w.Write(closeParenBytes)
}
}
// handleMethods attempts to call the Error and String methods on the underlying
// type the passed reflect.Value represents and outputes the result to Writer w.
//
// It handles panics in any called methods by catching and displaying the error
// as the formatted value.
func handleMethods(cs *ConfigState, w io.Writer, v reflect.Value) (handled bool) {
// We need an interface to check if the type implements the error or
// Stringer interface. However, the reflect package won't give us an
// interface on certain things like unexported struct fields in order
// to enforce visibility rules. We use unsafe to bypass these restrictions
// since this package does not mutate the values.
if !v.CanInterface() {
v = unsafeReflectValue(v)
}
// Choose whether or not to do error and Stringer interface lookups against
// the base type or a pointer to the base type depending on settings.
// Technically calling one of these methods with a pointer receiver can
// mutate the value, however, types which choose to satisify an error or
// Stringer interface with a pointer receiver should not be mutating their
// state inside these interface methods.
var viface interface{}
if !cs.DisablePointerMethods {
if !v.CanAddr() {
v = unsafeReflectValue(v)
}
viface = v.Addr().Interface()
} else {
if v.CanAddr() {
v = v.Addr()
}
viface = v.Interface()
}
// Is it an error or Stringer?
switch iface := viface.(type) {
case error:
defer catchPanic(w, v)
if cs.ContinueOnMethod {
w.Write(openParenBytes)
w.Write([]byte(iface.Error()))
w.Write(closeParenBytes)
w.Write(spaceBytes)
return false
}
w.Write([]byte(iface.Error()))
return true
case fmt.Stringer:
defer catchPanic(w, v)
if cs.ContinueOnMethod {
w.Write(openParenBytes)
w.Write([]byte(iface.String()))
w.Write(closeParenBytes)
w.Write(spaceBytes)
return false
}
w.Write([]byte(iface.String()))
return true
}
return false
}
// printBool outputs a boolean value as true or false to Writer w.
func printBool(w io.Writer, val bool) {
if val {
w.Write(trueBytes)
} else {
w.Write(falseBytes)
}
}
// printInt outputs a signed integer value to Writer w.
func printInt(w io.Writer, val int64, base int) {
w.Write([]byte(strconv.FormatInt(val, base)))
}
// printUint outputs an unsigned integer value to Writer w.
func printUint(w io.Writer, val uint64, base int) {
w.Write([]byte(strconv.FormatUint(val, base)))
}
// printFloat outputs a floating point value using the specified precision,
// which is expected to be 32 or 64bit, to Writer w.
func printFloat(w io.Writer, val float64, precision int) {
w.Write([]byte(strconv.FormatFloat(val, 'g', -1, precision)))
}
// printComplex outputs a complex value using the specified float precision
// for the real and imaginary parts to Writer w.
func printComplex(w io.Writer, c complex128, floatPrecision int) {
r := real(c)
w.Write(openParenBytes)
w.Write([]byte(strconv.FormatFloat(r, 'g', -1, floatPrecision)))
i := imag(c)
if i >= 0 {
w.Write(plusBytes)
}
w.Write([]byte(strconv.FormatFloat(i, 'g', -1, floatPrecision)))
w.Write(iBytes)
w.Write(closeParenBytes)
}
// printHexPtr outputs a uintptr formatted as hexidecimal with a leading '0x'
// prefix to Writer w.
func printHexPtr(w io.Writer, p uintptr) {
// Null pointer.
num := uint64(p)
if num == 0 {
w.Write(nilAngleBytes)
return
}
// Max uint64 is 16 bytes in hex + 2 bytes for '0x' prefix
buf := make([]byte, 18)
// It's simpler to construct the hex string right to left.
base := uint64(16)
i := len(buf) - 1
for num >= base {
buf[i] = hexDigits[num%base]
num /= base
i--
}
buf[i] = hexDigits[num]
// Add '0x' prefix.
i--
buf[i] = 'x'
i--
buf[i] = '0'
// Strip unused leading bytes.
buf = buf[i:]
w.Write(buf)
}
// valuesSorter implements sort.Interface to allow a slice of reflect.Value
// elements to be sorted.
type valuesSorter struct {
values []reflect.Value
}
// Len returns the number of values in the slice. It is part of the
// sort.Interface implementation.
func (s *valuesSorter) Len() int {
return len(s.values)
}
// Swap swaps the values at the passed indices. It is part of the
// sort.Interface implementation.
func (s *valuesSorter) Swap(i, j int) {
s.values[i], s.values[j] = s.values[j], s.values[i]
}
// Less returns whether the value at index i should sort before the
// value at index j. It is part of the sort.Interface implementation.
func (s *valuesSorter) Less(i, j int) bool {
switch s.values[i].Kind() {
case reflect.Bool:
return !s.values[i].Bool() && s.values[j].Bool()
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
return s.values[i].Int() < s.values[j].Int()
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
return s.values[i].Uint() < s.values[j].Uint()
case reflect.Float32, reflect.Float64:
return s.values[i].Float() < s.values[j].Float()
case reflect.String:
return s.values[i].String() < s.values[j].String()
case reflect.Uintptr:
return s.values[i].Uint() < s.values[j].Uint()
}
return s.values[i].String() < s.values[j].String()
}
// sortValues is a generic sort function for native types: int, uint, bool,
// string and uintptr. Other inputs are sorted according to their
// Value.String() value to ensure display stability.
func sortValues(values []reflect.Value) {
if len(values) == 0 {
return
}
sort.Sort(&valuesSorter{values})
}

View file

@ -1,192 +0,0 @@
/*
* Copyright (c) 2013 Dave Collins <dave@davec.name>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
package spew_test
import (
"fmt"
"github.com/davecgh/go-spew/spew"
"reflect"
"testing"
)
// custom type to test Stinger interface on non-pointer receiver.
type stringer string
// String implements the Stringer interface for testing invocation of custom
// stringers on types with non-pointer receivers.
func (s stringer) String() string {
return "stringer " + string(s)
}
// custom type to test Stinger interface on pointer receiver.
type pstringer string
// String implements the Stringer interface for testing invocation of custom
// stringers on types with only pointer receivers.
func (s *pstringer) String() string {
return "stringer " + string(*s)
}
// xref1 and xref2 are cross referencing structs for testing circular reference
// detection.
type xref1 struct {
ps2 *xref2
}
type xref2 struct {
ps1 *xref1
}
// indirCir1, indirCir2, and indirCir3 are used to generate an indirect circular
// reference for testing detection.
type indirCir1 struct {
ps2 *indirCir2
}
type indirCir2 struct {
ps3 *indirCir3
}
type indirCir3 struct {
ps1 *indirCir1
}
// embed is used to test embedded structures.
type embed struct {
a string
}
// embedwrap is used to test embedded structures.
type embedwrap struct {
*embed
e *embed
}
// panicer is used to intentionally cause a panic for testing spew properly
// handles them
type panicer int
func (p panicer) String() string {
panic("test panic")
}
// customError is used to test custom error interface invocation.
type customError int
func (e customError) Error() string {
return fmt.Sprintf("error: %d", int(e))
}
// stringizeWants converts a slice of wanted test output into a format suitable
// for a test error message.
func stringizeWants(wants []string) string {
s := ""
for i, want := range wants {
if i > 0 {
s += fmt.Sprintf("want%d: %s", i+1, want)
} else {
s += "want: " + want
}
}
return s
}
// testFailed returns whether or not a test failed by checking if the result
// of the test is in the slice of wanted strings.
func testFailed(result string, wants []string) bool {
for _, want := range wants {
if result == want {
return false
}
}
return true
}
// TestSortValues ensures the sort functionality for relect.Value based sorting
// works as intended.
func TestSortValues(t *testing.T) {
getInterfaces := func(values []reflect.Value) []interface{} {
interfaces := []interface{}{}
for _, v := range values {
interfaces = append(interfaces, v.Interface())
}
return interfaces
}
v := reflect.ValueOf
a := v("a")
b := v("b")
c := v("c")
embedA := v(embed{"a"})
embedB := v(embed{"b"})
embedC := v(embed{"c"})
tests := []struct {
input []reflect.Value
expected []reflect.Value
}{
// No values.
{
[]reflect.Value{},
[]reflect.Value{},
},
// Bools.
{
[]reflect.Value{v(false), v(true), v(false)},
[]reflect.Value{v(false), v(false), v(true)},
},
// Ints.
{
[]reflect.Value{v(2), v(1), v(3)},
[]reflect.Value{v(1), v(2), v(3)},
},
// Uints.
{
[]reflect.Value{v(uint8(2)), v(uint8(1)), v(uint8(3))},
[]reflect.Value{v(uint8(1)), v(uint8(2)), v(uint8(3))},
},
// Floats.
{
[]reflect.Value{v(2.0), v(1.0), v(3.0)},
[]reflect.Value{v(1.0), v(2.0), v(3.0)},
},
// Strings.
{
[]reflect.Value{b, a, c},
[]reflect.Value{a, b, c},
},
// Uintptrs.
{
[]reflect.Value{v(uintptr(2)), v(uintptr(1)), v(uintptr(3))},
[]reflect.Value{v(uintptr(1)), v(uintptr(2)), v(uintptr(3))},
},
// Invalid.
{
[]reflect.Value{embedB, embedA, embedC},
[]reflect.Value{embedB, embedA, embedC},
},
}
for _, test := range tests {
spew.SortValues(test.input)
// reflect.DeepEqual cannot really make sense of reflect.Value,
// probably because of all the pointer tricks. For instance,
// v(2.0) != v(2.0) on a 32-bits system. Turn them into interface{}
// instead.
input := getInterfaces(test.input)
expected := getInterfaces(test.expected)
if !reflect.DeepEqual(input, expected) {
t.Errorf("Sort mismatch:\n %v != %v", input, expected)
}
}
}

View file

@ -1,288 +0,0 @@
/*
* Copyright (c) 2013 Dave Collins <dave@davec.name>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
package spew
import (
"bytes"
"fmt"
"io"
"os"
)
// ConfigState houses the configuration options used by spew to format and
// display values. There is a global instance, Config, that is used to control
// all top-level Formatter and Dump functionality. Each ConfigState instance
// provides methods equivalent to the top-level functions.
//
// The zero value for ConfigState provides no indentation. You would typically
// want to set it to a space or a tab.
//
// Alternatively, you can use NewDefaultConfig to get a ConfigState instance
// with default settings. See the documentation of NewDefaultConfig for default
// values.
type ConfigState struct {
// Indent specifies the string to use for each indentation level. The
// global config instance that all top-level functions use set this to a
// single space by default. If you would like more indentation, you might
// set this to a tab with "\t" or perhaps two spaces with " ".
Indent string
// MaxDepth controls the maximum number of levels to descend into nested
// data structures. The default, 0, means there is no limit.
//
// NOTE: Circular data structures are properly detected, so it is not
// necessary to set this value unless you specifically want to limit deeply
// nested data structures.
MaxDepth int
// DisableMethods specifies whether or not error and Stringer interfaces are
// invoked for types that implement them.
DisableMethods bool
// DisablePointerMethods specifies whether or not to check for and invoke
// error and Stringer interfaces on types which only accept a pointer
// receiver when the current type is not a pointer.
//
// NOTE: This might be an unsafe action since calling one of these methods
// with a pointer receiver could technically mutate the value, however,
// in practice, types which choose to satisify an error or Stringer
// interface with a pointer receiver should not be mutating their state
// inside these interface methods.
DisablePointerMethods bool
// ContinueOnMethod specifies whether or not recursion should continue once
// a custom error or Stringer interface is invoked. The default, false,
// means it will print the results of invoking the custom error or Stringer
// interface and return immediately instead of continuing to recurse into
// the internals of the data type.
//
// NOTE: This flag does not have any effect if method invocation is disabled
// via the DisableMethods or DisablePointerMethods options.
ContinueOnMethod bool
// SortKeys specifies map keys should be sorted before being printed. Use
// this to have a more deterministic, diffable output. Note that only
// native types (bool, int, uint, floats, uintptr and string) are supported
// with other types sorted according to the reflect.Value.String() output
// which guarantees display stability.
SortKeys bool
}
// Config is the active configuration of the top-level functions.
// The configuration can be changed by modifying the contents of spew.Config.
var Config = ConfigState{Indent: " "}
// Errorf is a wrapper for fmt.Errorf that treats each argument as if it were
// passed with a Formatter interface returned by c.NewFormatter. It returns
// the formatted string as a value that satisfies error. See NewFormatter
// for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Errorf(format, c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Errorf(format string, a ...interface{}) (err error) {
return fmt.Errorf(format, c.convertArgs(a)...)
}
// Fprint is a wrapper for fmt.Fprint that treats each argument as if it were
// passed with a Formatter interface returned by c.NewFormatter. It returns
// the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Fprint(w, c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Fprint(w io.Writer, a ...interface{}) (n int, err error) {
return fmt.Fprint(w, c.convertArgs(a)...)
}
// Fprintf is a wrapper for fmt.Fprintf that treats each argument as if it were
// passed with a Formatter interface returned by c.NewFormatter. It returns
// the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Fprintf(w, format, c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Fprintf(w io.Writer, format string, a ...interface{}) (n int, err error) {
return fmt.Fprintf(w, format, c.convertArgs(a)...)
}
// Fprintln is a wrapper for fmt.Fprintln that treats each argument as if it
// passed with a Formatter interface returned by c.NewFormatter. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Fprintln(w, c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Fprintln(w io.Writer, a ...interface{}) (n int, err error) {
return fmt.Fprintln(w, c.convertArgs(a)...)
}
// Print is a wrapper for fmt.Print that treats each argument as if it were
// passed with a Formatter interface returned by c.NewFormatter. It returns
// the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Print(c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Print(a ...interface{}) (n int, err error) {
return fmt.Print(c.convertArgs(a)...)
}
// Printf is a wrapper for fmt.Printf that treats each argument as if it were
// passed with a Formatter interface returned by c.NewFormatter. It returns
// the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Printf(format, c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Printf(format string, a ...interface{}) (n int, err error) {
return fmt.Printf(format, c.convertArgs(a)...)
}
// Println is a wrapper for fmt.Println that treats each argument as if it were
// passed with a Formatter interface returned by c.NewFormatter. It returns
// the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Println(c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Println(a ...interface{}) (n int, err error) {
return fmt.Println(c.convertArgs(a)...)
}
// Sprint is a wrapper for fmt.Sprint that treats each argument as if it were
// passed with a Formatter interface returned by c.NewFormatter. It returns
// the resulting string. See NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Sprint(c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Sprint(a ...interface{}) string {
return fmt.Sprint(c.convertArgs(a)...)
}
// Sprintf is a wrapper for fmt.Sprintf that treats each argument as if it were
// passed with a Formatter interface returned by c.NewFormatter. It returns
// the resulting string. See NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Sprintf(format, c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Sprintf(format string, a ...interface{}) string {
return fmt.Sprintf(format, c.convertArgs(a)...)
}
// Sprintln is a wrapper for fmt.Sprintln that treats each argument as if it
// were passed with a Formatter interface returned by c.NewFormatter. It
// returns the resulting string. See NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Sprintln(c.NewFormatter(a), c.NewFormatter(b))
func (c *ConfigState) Sprintln(a ...interface{}) string {
return fmt.Sprintln(c.convertArgs(a)...)
}
/*
NewFormatter returns a custom formatter that satisfies the fmt.Formatter
interface. As a result, it integrates cleanly with standard fmt package
printing functions. The formatter is useful for inline printing of smaller data
types similar to the standard %v format specifier.
The custom formatter only responds to the %v (most compact), %+v (adds pointer
addresses), %#v (adds types), and %#+v (adds types and pointer addresses) verb
combinations. Any other verbs such as %x and %q will be sent to the the
standard fmt package for formatting. In addition, the custom formatter ignores
the width and precision arguments (however they will still work on the format
specifiers not handled by the custom formatter).
Typically this function shouldn't be called directly. It is much easier to make
use of the custom formatter by calling one of the convenience functions such as
c.Printf, c.Println, or c.Printf.
*/
func (c *ConfigState) NewFormatter(v interface{}) fmt.Formatter {
return newFormatter(c, v)
}
// Fdump formats and displays the passed arguments to io.Writer w. It formats
// exactly the same as Dump.
func (c *ConfigState) Fdump(w io.Writer, a ...interface{}) {
fdump(c, w, a...)
}
/*
Dump displays the passed parameters to standard out with newlines, customizable
indentation, and additional debug information such as complete types and all
pointer addresses used to indirect to the final value. It provides the
following features over the built-in printing facilities provided by the fmt
package:
* Pointers are dereferenced and followed
* Circular data structures are detected and handled properly
* Custom Stringer/error interfaces are optionally invoked, including
on unexported types
* Custom types which only implement the Stringer/error interfaces via
a pointer receiver are optionally invoked when passing non-pointer
variables
* Byte arrays and slices are dumped like the hexdump -C command which
includes offsets, byte values in hex, and ASCII output
The configuration options are controlled by modifying the public members
of c. See ConfigState for options documentation.
See Fdump if you would prefer dumping to an arbitrary io.Writer or Sdump to
get the formatted result as a string.
*/
func (c *ConfigState) Dump(a ...interface{}) {
fdump(c, os.Stdout, a...)
}
// Sdump returns a string with the passed arguments formatted exactly the same
// as Dump.
func (c *ConfigState) Sdump(a ...interface{}) string {
var buf bytes.Buffer
fdump(c, &buf, a...)
return buf.String()
}
// convertArgs accepts a slice of arguments and returns a slice of the same
// length with each argument converted to a spew Formatter interface using
// the ConfigState associated with s.
func (c *ConfigState) convertArgs(args []interface{}) (formatters []interface{}) {
formatters = make([]interface{}, len(args))
for index, arg := range args {
formatters[index] = newFormatter(c, arg)
}
return formatters
}
// NewDefaultConfig returns a ConfigState with the following default settings.
//
// Indent: " "
// MaxDepth: 0
// DisableMethods: false
// DisablePointerMethods: false
// ContinueOnMethod: false
// SortKeys: false
func NewDefaultConfig() *ConfigState {
return &ConfigState{Indent: " "}
}

View file

@ -1,196 +0,0 @@
/*
* Copyright (c) 2013 Dave Collins <dave@davec.name>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
/*
Package spew implements a deep pretty printer for Go data structures to aid in
debugging.
A quick overview of the additional features spew provides over the built-in
printing facilities for Go data types are as follows:
* Pointers are dereferenced and followed
* Circular data structures are detected and handled properly
* Custom Stringer/error interfaces are optionally invoked, including
on unexported types
* Custom types which only implement the Stringer/error interfaces via
a pointer receiver are optionally invoked when passing non-pointer
variables
* Byte arrays and slices are dumped like the hexdump -C command which
includes offsets, byte values in hex, and ASCII output (only when using
Dump style)
There are two different approaches spew allows for dumping Go data structures:
* Dump style which prints with newlines, customizable indentation,
and additional debug information such as types and all pointer addresses
used to indirect to the final value
* A custom Formatter interface that integrates cleanly with the standard fmt
package and replaces %v, %+v, %#v, and %#+v to provide inline printing
similar to the default %v while providing the additional functionality
outlined above and passing unsupported format verbs such as %x and %q
along to fmt
Quick Start
This section demonstrates how to quickly get started with spew. See the
sections below for further details on formatting and configuration options.
To dump a variable with full newlines, indentation, type, and pointer
information use Dump, Fdump, or Sdump:
spew.Dump(myVar1, myVar2, ...)
spew.Fdump(someWriter, myVar1, myVar2, ...)
str := spew.Sdump(myVar1, myVar2, ...)
Alternatively, if you would prefer to use format strings with a compacted inline
printing style, use the convenience wrappers Printf, Fprintf, etc with
%v (most compact), %+v (adds pointer addresses), %#v (adds types), or
%#+v (adds types and pointer addresses):
spew.Printf("myVar1: %v -- myVar2: %+v", myVar1, myVar2)
spew.Printf("myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
spew.Fprintf(someWriter, "myVar1: %v -- myVar2: %+v", myVar1, myVar2)
spew.Fprintf(someWriter, "myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
Configuration Options
Configuration of spew is handled by fields in the ConfigState type. For
convenience, all of the top-level functions use a global state available
via the spew.Config global.
It is also possible to create a ConfigState instance that provides methods
equivalent to the top-level functions. This allows concurrent configuration
options. See the ConfigState documentation for more details.
The following configuration options are available:
* Indent
String to use for each indentation level for Dump functions.
It is a single space by default. A popular alternative is "\t".
* MaxDepth
Maximum number of levels to descend into nested data structures.
There is no limit by default.
* DisableMethods
Disables invocation of error and Stringer interface methods.
Method invocation is enabled by default.
* DisablePointerMethods
Disables invocation of error and Stringer interface methods on types
which only accept pointer receivers from non-pointer variables.
Pointer method invocation is enabled by default.
* ContinueOnMethod
Enables recursion into types after invoking error and Stringer interface
methods. Recursion after method invocation is disabled by default.
* SortKeys
Specifies map keys should be sorted before being printed. Use
this to have a more deterministic, diffable output. Note that
only native types (bool, int, uint, floats, uintptr and string)
are supported with other types sorted according to the
reflect.Value.String() output which guarantees display stability.
Natural map order is used by default.
Dump Usage
Simply call spew.Dump with a list of variables you want to dump:
spew.Dump(myVar1, myVar2, ...)
You may also call spew.Fdump if you would prefer to output to an arbitrary
io.Writer. For example, to dump to standard error:
spew.Fdump(os.Stderr, myVar1, myVar2, ...)
A third option is to call spew.Sdump to get the formatted output as a string:
str := spew.Sdump(myVar1, myVar2, ...)
Sample Dump Output
See the Dump example for details on the setup of the types and variables being
shown here.
(main.Foo) {
unexportedField: (*main.Bar)(0xf84002e210)({
flag: (main.Flag) flagTwo,
data: (uintptr) <nil>
}),
ExportedField: (map[interface {}]interface {}) (len=1) {
(string) (len=3) "one": (bool) true
}
}
Byte (and uint8) arrays and slices are displayed uniquely like the hexdump -C
command as shown.
([]uint8) (len=32 cap=32) {
00000000 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f 20 |............... |
00000010 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30 |!"#$%&'()*+,-./0|
00000020 31 32 |12|
}
Custom Formatter
Spew provides a custom formatter that implements the fmt.Formatter interface
so that it integrates cleanly with standard fmt package printing functions. The
formatter is useful for inline printing of smaller data types similar to the
standard %v format specifier.
The custom formatter only responds to the %v (most compact), %+v (adds pointer
addresses), %#v (adds types), or %#+v (adds types and pointer addresses) verb
combinations. Any other verbs such as %x and %q will be sent to the the
standard fmt package for formatting. In addition, the custom formatter ignores
the width and precision arguments (however they will still work on the format
specifiers not handled by the custom formatter).
Custom Formatter Usage
The simplest way to make use of the spew custom formatter is to call one of the
convenience functions such as spew.Printf, spew.Println, or spew.Printf. The
functions have syntax you are most likely already familiar with:
spew.Printf("myVar1: %v -- myVar2: %+v", myVar1, myVar2)
spew.Printf("myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
spew.Println(myVar, myVar2)
spew.Fprintf(os.Stderr, "myVar1: %v -- myVar2: %+v", myVar1, myVar2)
spew.Fprintf(os.Stderr, "myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
See the Index for the full list convenience functions.
Sample Formatter Output
Double pointer to a uint8:
%v: <**>5
%+v: <**>(0xf8400420d0->0xf8400420c8)5
%#v: (**uint8)5
%#+v: (**uint8)(0xf8400420d0->0xf8400420c8)5
Pointer to circular struct with a uint8 field and a pointer to itself:
%v: <*>{1 <*><shown>}
%+v: <*>(0xf84003e260){ui8:1 c:<*>(0xf84003e260)<shown>}
%#v: (*main.circular){ui8:(uint8)1 c:(*main.circular)<shown>}
%#+v: (*main.circular)(0xf84003e260){ui8:(uint8)1 c:(*main.circular)(0xf84003e260)<shown>}
See the Printf example for details on the setup of variables being shown
here.
Errors
Since it is possible for custom Stringer/error interfaces to panic, spew
detects them and handles them internally by printing the panic information
inline with the output. Since spew is intended to provide deep pretty printing
capabilities on structures, it intentionally does not return any errors.
*/
package spew

View file

@ -1,506 +0,0 @@
/*
* Copyright (c) 2013 Dave Collins <dave@davec.name>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
package spew
import (
"bytes"
"encoding/hex"
"fmt"
"io"
"os"
"reflect"
"regexp"
"strconv"
"strings"
)
var (
// uint8Type is a reflect.Type representing a uint8. It is used to
// convert cgo types to uint8 slices for hexdumping.
uint8Type = reflect.TypeOf(uint8(0))
// cCharRE is a regular expression that matches a cgo char.
// It is used to detect character arrays to hexdump them.
cCharRE = regexp.MustCompile("^.*\\._Ctype_char$")
// cUnsignedCharRE is a regular expression that matches a cgo unsigned
// char. It is used to detect unsigned character arrays to hexdump
// them.
cUnsignedCharRE = regexp.MustCompile("^.*\\._Ctype_unsignedchar$")
// cUint8tCharRE is a regular expression that matches a cgo uint8_t.
// It is used to detect uint8_t arrays to hexdump them.
cUint8tCharRE = regexp.MustCompile("^.*\\._Ctype_uint8_t$")
)
// dumpState contains information about the state of a dump operation.
type dumpState struct {
w io.Writer
depth int
pointers map[uintptr]int
ignoreNextType bool
ignoreNextIndent bool
cs *ConfigState
}
// indent performs indentation according to the depth level and cs.Indent
// option.
func (d *dumpState) indent() {
if d.ignoreNextIndent {
d.ignoreNextIndent = false
return
}
d.w.Write(bytes.Repeat([]byte(d.cs.Indent), d.depth))
}
// unpackValue returns values inside of non-nil interfaces when possible.
// This is useful for data types like structs, arrays, slices, and maps which
// can contain varying types packed inside an interface.
func (d *dumpState) unpackValue(v reflect.Value) reflect.Value {
if v.Kind() == reflect.Interface && !v.IsNil() {
v = v.Elem()
}
return v
}
// dumpPtr handles formatting of pointers by indirecting them as necessary.
func (d *dumpState) dumpPtr(v reflect.Value) {
// Remove pointers at or below the current depth from map used to detect
// circular refs.
for k, depth := range d.pointers {
if depth >= d.depth {
delete(d.pointers, k)
}
}
// Keep list of all dereferenced pointers to show later.
pointerChain := make([]uintptr, 0)
// Figure out how many levels of indirection there are by dereferencing
// pointers and unpacking interfaces down the chain while detecting circular
// references.
nilFound := false
cycleFound := false
indirects := 0
ve := v
for ve.Kind() == reflect.Ptr {
if ve.IsNil() {
nilFound = true
break
}
indirects++
addr := ve.Pointer()
pointerChain = append(pointerChain, addr)
if pd, ok := d.pointers[addr]; ok && pd < d.depth {
cycleFound = true
indirects--
break
}
d.pointers[addr] = d.depth
ve = ve.Elem()
if ve.Kind() == reflect.Interface {
if ve.IsNil() {
nilFound = true
break
}
ve = ve.Elem()
}
}
// Display type information.
d.w.Write(openParenBytes)
d.w.Write(bytes.Repeat(asteriskBytes, indirects))
d.w.Write([]byte(ve.Type().String()))
d.w.Write(closeParenBytes)
// Display pointer information.
if len(pointerChain) > 0 {
d.w.Write(openParenBytes)
for i, addr := range pointerChain {
if i > 0 {
d.w.Write(pointerChainBytes)
}
printHexPtr(d.w, addr)
}
d.w.Write(closeParenBytes)
}
// Display dereferenced value.
d.w.Write(openParenBytes)
switch {
case nilFound == true:
d.w.Write(nilAngleBytes)
case cycleFound == true:
d.w.Write(circularBytes)
default:
d.ignoreNextType = true
d.dump(ve)
}
d.w.Write(closeParenBytes)
}
// dumpSlice handles formatting of arrays and slices. Byte (uint8 under
// reflection) arrays and slices are dumped in hexdump -C fashion.
func (d *dumpState) dumpSlice(v reflect.Value) {
// Determine whether this type should be hex dumped or not. Also,
// for types which should be hexdumped, try to use the underlying data
// first, then fall back to trying to convert them to a uint8 slice.
var buf []uint8
doConvert := false
doHexDump := false
numEntries := v.Len()
if numEntries > 0 {
vt := v.Index(0).Type()
vts := vt.String()
switch {
// C types that need to be converted.
case cCharRE.MatchString(vts):
fallthrough
case cUnsignedCharRE.MatchString(vts):
fallthrough
case cUint8tCharRE.MatchString(vts):
doConvert = true
// Try to use existing uint8 slices and fall back to converting
// and copying if that fails.
case vt.Kind() == reflect.Uint8:
// We need an addressable interface to convert the type back
// into a byte slice. However, the reflect package won't give
// us an interface on certain things like unexported struct
// fields in order to enforce visibility rules. We use unsafe
// to bypass these restrictions since this package does not
// mutate the values.
vs := v
if !vs.CanInterface() || !vs.CanAddr() {
vs = unsafeReflectValue(vs)
}
vs = vs.Slice(0, numEntries)
// Use the existing uint8 slice if it can be type
// asserted.
iface := vs.Interface()
if slice, ok := iface.([]uint8); ok {
buf = slice
doHexDump = true
break
}
// The underlying data needs to be converted if it can't
// be type asserted to a uint8 slice.
doConvert = true
}
// Copy and convert the underlying type if needed.
if doConvert && vt.ConvertibleTo(uint8Type) {
// Convert and copy each element into a uint8 byte
// slice.
buf = make([]uint8, numEntries)
for i := 0; i < numEntries; i++ {
vv := v.Index(i)
buf[i] = uint8(vv.Convert(uint8Type).Uint())
}
doHexDump = true
}
}
// Hexdump the entire slice as needed.
if doHexDump {
indent := strings.Repeat(d.cs.Indent, d.depth)
str := indent + hex.Dump(buf)
str = strings.Replace(str, "\n", "\n"+indent, -1)
str = strings.TrimRight(str, d.cs.Indent)
d.w.Write([]byte(str))
return
}
// Recursively call dump for each item.
for i := 0; i < numEntries; i++ {
d.dump(d.unpackValue(v.Index(i)))
if i < (numEntries - 1) {
d.w.Write(commaNewlineBytes)
} else {
d.w.Write(newlineBytes)
}
}
}
// dump is the main workhorse for dumping a value. It uses the passed reflect
// value to figure out what kind of object we are dealing with and formats it
// appropriately. It is a recursive function, however circular data structures
// are detected and handled properly.
func (d *dumpState) dump(v reflect.Value) {
// Handle invalid reflect values immediately.
kind := v.Kind()
if kind == reflect.Invalid {
d.w.Write(invalidAngleBytes)
return
}
// Handle pointers specially.
if kind == reflect.Ptr {
d.indent()
d.dumpPtr(v)
return
}
// Print type information unless already handled elsewhere.
if !d.ignoreNextType {
d.indent()
d.w.Write(openParenBytes)
d.w.Write([]byte(v.Type().String()))
d.w.Write(closeParenBytes)
d.w.Write(spaceBytes)
}
d.ignoreNextType = false
// Display length and capacity if the built-in len and cap functions
// work with the value's kind and the len/cap itself is non-zero.
valueLen, valueCap := 0, 0
switch v.Kind() {
case reflect.Array, reflect.Slice, reflect.Chan:
valueLen, valueCap = v.Len(), v.Cap()
case reflect.Map, reflect.String:
valueLen = v.Len()
}
if valueLen != 0 || valueCap != 0 {
d.w.Write(openParenBytes)
if valueLen != 0 {
d.w.Write(lenEqualsBytes)
printInt(d.w, int64(valueLen), 10)
}
if valueCap != 0 {
if valueLen != 0 {
d.w.Write(spaceBytes)
}
d.w.Write(capEqualsBytes)
printInt(d.w, int64(valueCap), 10)
}
d.w.Write(closeParenBytes)
d.w.Write(spaceBytes)
}
// Call Stringer/error interfaces if they exist and the handle methods flag
// is enabled
if !d.cs.DisableMethods {
if (kind != reflect.Invalid) && (kind != reflect.Interface) {
if handled := handleMethods(d.cs, d.w, v); handled {
return
}
}
}
switch kind {
case reflect.Invalid:
// Do nothing. We should never get here since invalid has already
// been handled above.
case reflect.Bool:
printBool(d.w, v.Bool())
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
printInt(d.w, v.Int(), 10)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
printUint(d.w, v.Uint(), 10)
case reflect.Float32:
printFloat(d.w, v.Float(), 32)
case reflect.Float64:
printFloat(d.w, v.Float(), 64)
case reflect.Complex64:
printComplex(d.w, v.Complex(), 32)
case reflect.Complex128:
printComplex(d.w, v.Complex(), 64)
case reflect.Slice:
if v.IsNil() {
d.w.Write(nilAngleBytes)
break
}
fallthrough
case reflect.Array:
d.w.Write(openBraceNewlineBytes)
d.depth++
if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) {
d.indent()
d.w.Write(maxNewlineBytes)
} else {
d.dumpSlice(v)
}
d.depth--
d.indent()
d.w.Write(closeBraceBytes)
case reflect.String:
d.w.Write([]byte(strconv.Quote(v.String())))
case reflect.Interface:
// The only time we should get here is for nil interfaces due to
// unpackValue calls.
if v.IsNil() {
d.w.Write(nilAngleBytes)
}
case reflect.Ptr:
// Do nothing. We should never get here since pointers have already
// been handled above.
case reflect.Map:
// nil maps should be indicated as different than empty maps
if v.IsNil() {
d.w.Write(nilAngleBytes)
break
}
d.w.Write(openBraceNewlineBytes)
d.depth++
if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) {
d.indent()
d.w.Write(maxNewlineBytes)
} else {
numEntries := v.Len()
keys := v.MapKeys()
if d.cs.SortKeys {
sortValues(keys)
}
for i, key := range keys {
d.dump(d.unpackValue(key))
d.w.Write(colonSpaceBytes)
d.ignoreNextIndent = true
d.dump(d.unpackValue(v.MapIndex(key)))
if i < (numEntries - 1) {
d.w.Write(commaNewlineBytes)
} else {
d.w.Write(newlineBytes)
}
}
}
d.depth--
d.indent()
d.w.Write(closeBraceBytes)
case reflect.Struct:
d.w.Write(openBraceNewlineBytes)
d.depth++
if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) {
d.indent()
d.w.Write(maxNewlineBytes)
} else {
vt := v.Type()
numFields := v.NumField()
for i := 0; i < numFields; i++ {
d.indent()
vtf := vt.Field(i)
d.w.Write([]byte(vtf.Name))
d.w.Write(colonSpaceBytes)
d.ignoreNextIndent = true
d.dump(d.unpackValue(v.Field(i)))
if i < (numFields - 1) {
d.w.Write(commaNewlineBytes)
} else {
d.w.Write(newlineBytes)
}
}
}
d.depth--
d.indent()
d.w.Write(closeBraceBytes)
case reflect.Uintptr:
printHexPtr(d.w, uintptr(v.Uint()))
case reflect.UnsafePointer, reflect.Chan, reflect.Func:
printHexPtr(d.w, v.Pointer())
// There were not any other types at the time this code was written, but
// fall back to letting the default fmt package handle it in case any new
// types are added.
default:
if v.CanInterface() {
fmt.Fprintf(d.w, "%v", v.Interface())
} else {
fmt.Fprintf(d.w, "%v", v.String())
}
}
}
// fdump is a helper function to consolidate the logic from the various public
// methods which take varying writers and config states.
func fdump(cs *ConfigState, w io.Writer, a ...interface{}) {
for _, arg := range a {
if arg == nil {
w.Write(interfaceBytes)
w.Write(spaceBytes)
w.Write(nilAngleBytes)
w.Write(newlineBytes)
continue
}
d := dumpState{w: w, cs: cs}
d.pointers = make(map[uintptr]int)
d.dump(reflect.ValueOf(arg))
d.w.Write(newlineBytes)
}
}
// Fdump formats and displays the passed arguments to io.Writer w. It formats
// exactly the same as Dump.
func Fdump(w io.Writer, a ...interface{}) {
fdump(&Config, w, a...)
}
// Sdump returns a string with the passed arguments formatted exactly the same
// as Dump.
func Sdump(a ...interface{}) string {
var buf bytes.Buffer
fdump(&Config, &buf, a...)
return buf.String()
}
/*
Dump displays the passed parameters to standard out with newlines, customizable
indentation, and additional debug information such as complete types and all
pointer addresses used to indirect to the final value. It provides the
following features over the built-in printing facilities provided by the fmt
package:
* Pointers are dereferenced and followed
* Circular data structures are detected and handled properly
* Custom Stringer/error interfaces are optionally invoked, including
on unexported types
* Custom types which only implement the Stringer/error interfaces via
a pointer receiver are optionally invoked when passing non-pointer
variables
* Byte arrays and slices are dumped like the hexdump -C command which
includes offsets, byte values in hex, and ASCII output
The configuration options are controlled by an exported package global,
spew.Config. See ConfigState for options documentation.
See Fdump if you would prefer dumping to an arbitrary io.Writer or Sdump to
get the formatted result as a string.
*/
func Dump(a ...interface{}) {
fdump(&Config, os.Stdout, a...)
}

View file

@ -1,986 +0,0 @@
/*
* Copyright (c) 2013 Dave Collins <dave@davec.name>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
/*
Test Summary:
NOTE: For each test, a nil pointer, a single pointer and double pointer to the
base test element are also tested to ensure proper indirection across all types.
- Max int8, int16, int32, int64, int
- Max uint8, uint16, uint32, uint64, uint
- Boolean true and false
- Standard complex64 and complex128
- Array containing standard ints
- Array containing type with custom formatter on pointer receiver only
- Array containing interfaces
- Array containing bytes
- Slice containing standard float32 values
- Slice containing type with custom formatter on pointer receiver only
- Slice containing interfaces
- Slice containing bytes
- Nil slice
- Standard string
- Nil interface
- Sub-interface
- Map with string keys and int vals
- Map with custom formatter type on pointer receiver only keys and vals
- Map with interface keys and values
- Map with nil interface value
- Struct with primitives
- Struct that contains another struct
- Struct that contains custom type with Stringer pointer interface via both
exported and unexported fields
- Struct that contains embedded struct and field to same struct
- Uintptr to 0 (null pointer)
- Uintptr address of real variable
- Unsafe.Pointer to 0 (null pointer)
- Unsafe.Pointer to address of real variable
- Nil channel
- Standard int channel
- Function with no params and no returns
- Function with param and no returns
- Function with multiple params and multiple returns
- Struct that is circular through self referencing
- Structs that are circular through cross referencing
- Structs that are indirectly circular
- Type that panics in its Stringer interface
*/
package spew_test
import (
"bytes"
"fmt"
"github.com/davecgh/go-spew/spew"
"testing"
"unsafe"
)
// dumpTest is used to describe a test to be perfomed against the Dump method.
type dumpTest struct {
in interface{}
wants []string
}
// dumpTests houses all of the tests to be performed against the Dump method.
var dumpTests = make([]dumpTest, 0)
// addDumpTest is a helper method to append the passed input and desired result
// to dumpTests
func addDumpTest(in interface{}, wants ...string) {
test := dumpTest{in, wants}
dumpTests = append(dumpTests, test)
}
func addIntDumpTests() {
// Max int8.
v := int8(127)
nv := (*int8)(nil)
pv := &v
vAddr := fmt.Sprintf("%p", pv)
pvAddr := fmt.Sprintf("%p", &pv)
vt := "int8"
vs := "127"
addDumpTest(v, "("+vt+") "+vs+"\n")
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs+")\n")
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
addDumpTest(nv, "(*"+vt+")(<nil>)\n")
// Max int16.
v2 := int16(32767)
nv2 := (*int16)(nil)
pv2 := &v2
v2Addr := fmt.Sprintf("%p", pv2)
pv2Addr := fmt.Sprintf("%p", &pv2)
v2t := "int16"
v2s := "32767"
addDumpTest(v2, "("+v2t+") "+v2s+"\n")
addDumpTest(pv2, "(*"+v2t+")("+v2Addr+")("+v2s+")\n")
addDumpTest(&pv2, "(**"+v2t+")("+pv2Addr+"->"+v2Addr+")("+v2s+")\n")
addDumpTest(nv2, "(*"+v2t+")(<nil>)\n")
// Max int32.
v3 := int32(2147483647)
nv3 := (*int32)(nil)
pv3 := &v3
v3Addr := fmt.Sprintf("%p", pv3)
pv3Addr := fmt.Sprintf("%p", &pv3)
v3t := "int32"
v3s := "2147483647"
addDumpTest(v3, "("+v3t+") "+v3s+"\n")
addDumpTest(pv3, "(*"+v3t+")("+v3Addr+")("+v3s+")\n")
addDumpTest(&pv3, "(**"+v3t+")("+pv3Addr+"->"+v3Addr+")("+v3s+")\n")
addDumpTest(nv3, "(*"+v3t+")(<nil>)\n")
// Max int64.
v4 := int64(9223372036854775807)
nv4 := (*int64)(nil)
pv4 := &v4
v4Addr := fmt.Sprintf("%p", pv4)
pv4Addr := fmt.Sprintf("%p", &pv4)
v4t := "int64"
v4s := "9223372036854775807"
addDumpTest(v4, "("+v4t+") "+v4s+"\n")
addDumpTest(pv4, "(*"+v4t+")("+v4Addr+")("+v4s+")\n")
addDumpTest(&pv4, "(**"+v4t+")("+pv4Addr+"->"+v4Addr+")("+v4s+")\n")
addDumpTest(nv4, "(*"+v4t+")(<nil>)\n")
// Max int.
v5 := int(2147483647)
nv5 := (*int)(nil)
pv5 := &v5
v5Addr := fmt.Sprintf("%p", pv5)
pv5Addr := fmt.Sprintf("%p", &pv5)
v5t := "int"
v5s := "2147483647"
addDumpTest(v5, "("+v5t+") "+v5s+"\n")
addDumpTest(pv5, "(*"+v5t+")("+v5Addr+")("+v5s+")\n")
addDumpTest(&pv5, "(**"+v5t+")("+pv5Addr+"->"+v5Addr+")("+v5s+")\n")
addDumpTest(nv5, "(*"+v5t+")(<nil>)\n")
}
func addUintDumpTests() {
// Max uint8.
v := uint8(255)
nv := (*uint8)(nil)
pv := &v
vAddr := fmt.Sprintf("%p", pv)
pvAddr := fmt.Sprintf("%p", &pv)
vt := "uint8"
vs := "255"
addDumpTest(v, "("+vt+") "+vs+"\n")
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs+")\n")
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
addDumpTest(nv, "(*"+vt+")(<nil>)\n")
// Max uint16.
v2 := uint16(65535)
nv2 := (*uint16)(nil)
pv2 := &v2
v2Addr := fmt.Sprintf("%p", pv2)
pv2Addr := fmt.Sprintf("%p", &pv2)
v2t := "uint16"
v2s := "65535"
addDumpTest(v2, "("+v2t+") "+v2s+"\n")
addDumpTest(pv2, "(*"+v2t+")("+v2Addr+")("+v2s+")\n")
addDumpTest(&pv2, "(**"+v2t+")("+pv2Addr+"->"+v2Addr+")("+v2s+")\n")
addDumpTest(nv2, "(*"+v2t+")(<nil>)\n")
// Max uint32.
v3 := uint32(4294967295)
nv3 := (*uint32)(nil)
pv3 := &v3
v3Addr := fmt.Sprintf("%p", pv3)
pv3Addr := fmt.Sprintf("%p", &pv3)
v3t := "uint32"
v3s := "4294967295"
addDumpTest(v3, "("+v3t+") "+v3s+"\n")
addDumpTest(pv3, "(*"+v3t+")("+v3Addr+")("+v3s+")\n")
addDumpTest(&pv3, "(**"+v3t+")("+pv3Addr+"->"+v3Addr+")("+v3s+")\n")
addDumpTest(nv3, "(*"+v3t+")(<nil>)\n")
// Max uint64.
v4 := uint64(18446744073709551615)
nv4 := (*uint64)(nil)
pv4 := &v4
v4Addr := fmt.Sprintf("%p", pv4)
pv4Addr := fmt.Sprintf("%p", &pv4)
v4t := "uint64"
v4s := "18446744073709551615"
addDumpTest(v4, "("+v4t+") "+v4s+"\n")
addDumpTest(pv4, "(*"+v4t+")("+v4Addr+")("+v4s+")\n")
addDumpTest(&pv4, "(**"+v4t+")("+pv4Addr+"->"+v4Addr+")("+v4s+")\n")
addDumpTest(nv4, "(*"+v4t+")(<nil>)\n")
// Max uint.
v5 := uint(4294967295)
nv5 := (*uint)(nil)
pv5 := &v5
v5Addr := fmt.Sprintf("%p", pv5)
pv5Addr := fmt.Sprintf("%p", &pv5)
v5t := "uint"
v5s := "4294967295"
addDumpTest(v5, "("+v5t+") "+v5s+"\n")
addDumpTest(pv5, "(*"+v5t+")("+v5Addr+")("+v5s+")\n")
addDumpTest(&pv5, "(**"+v5t+")("+pv5Addr+"->"+v5Addr+")("+v5s+")\n")
addDumpTest(nv5, "(*"+v5t+")(<nil>)\n")
}
func addBoolDumpTests() {
// Boolean true.
v := bool(true)
nv := (*bool)(nil)
pv := &v
vAddr := fmt.Sprintf("%p", pv)
pvAddr := fmt.Sprintf("%p", &pv)
vt := "bool"
vs := "true"
addDumpTest(v, "("+vt+") "+vs+"\n")
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs+")\n")
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
addDumpTest(nv, "(*"+vt+")(<nil>)\n")
// Boolean false.
v2 := bool(false)
pv2 := &v2
v2Addr := fmt.Sprintf("%p", pv2)
pv2Addr := fmt.Sprintf("%p", &pv2)
v2t := "bool"
v2s := "false"
addDumpTest(v2, "("+v2t+") "+v2s+"\n")
addDumpTest(pv2, "(*"+v2t+")("+v2Addr+")("+v2s+")\n")
addDumpTest(&pv2, "(**"+v2t+")("+pv2Addr+"->"+v2Addr+")("+v2s+")\n")
}
func addFloatDumpTests() {
// Standard float32.
v := float32(3.1415)
nv := (*float32)(nil)
pv := &v
vAddr := fmt.Sprintf("%p", pv)
pvAddr := fmt.Sprintf("%p", &pv)
vt := "float32"
vs := "3.1415"
addDumpTest(v, "("+vt+") "+vs+"\n")
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs+")\n")
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
addDumpTest(nv, "(*"+vt+")(<nil>)\n")
// Standard float64.
v2 := float64(3.1415926)
nv2 := (*float64)(nil)
pv2 := &v2
v2Addr := fmt.Sprintf("%p", pv2)
pv2Addr := fmt.Sprintf("%p", &pv2)
v2t := "float64"
v2s := "3.1415926"
addDumpTest(v2, "("+v2t+") "+v2s+"\n")
addDumpTest(pv2, "(*"+v2t+")("+v2Addr+")("+v2s+")\n")
addDumpTest(&pv2, "(**"+v2t+")("+pv2Addr+"->"+v2Addr+")("+v2s+")\n")
addDumpTest(nv2, "(*"+v2t+")(<nil>)\n")
}
func addComplexDumpTests() {
// Standard complex64.
v := complex(float32(6), -2)
nv := (*complex64)(nil)
pv := &v
vAddr := fmt.Sprintf("%p", pv)
pvAddr := fmt.Sprintf("%p", &pv)
vt := "complex64"
vs := "(6-2i)"
addDumpTest(v, "("+vt+") "+vs+"\n")
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs+")\n")
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
addDumpTest(nv, "(*"+vt+")(<nil>)\n")
// Standard complex128.
v2 := complex(float64(-6), 2)
nv2 := (*complex128)(nil)
pv2 := &v2
v2Addr := fmt.Sprintf("%p", pv2)
pv2Addr := fmt.Sprintf("%p", &pv2)
v2t := "complex128"
v2s := "(-6+2i)"
addDumpTest(v2, "("+v2t+") "+v2s+"\n")
addDumpTest(pv2, "(*"+v2t+")("+v2Addr+")("+v2s+")\n")
addDumpTest(&pv2, "(**"+v2t+")("+pv2Addr+"->"+v2Addr+")("+v2s+")\n")
addDumpTest(nv2, "(*"+v2t+")(<nil>)\n")
}
func addArrayDumpTests() {
// Array containing standard ints.
v := [3]int{1, 2, 3}
vLen := fmt.Sprintf("%d", len(v))
vCap := fmt.Sprintf("%d", cap(v))
nv := (*[3]int)(nil)
pv := &v
vAddr := fmt.Sprintf("%p", pv)
pvAddr := fmt.Sprintf("%p", &pv)
vt := "int"
vs := "(len=" + vLen + " cap=" + vCap + ") {\n (" + vt + ") 1,\n (" +
vt + ") 2,\n (" + vt + ") 3\n}"
addDumpTest(v, "([3]"+vt+") "+vs+"\n")
addDumpTest(pv, "(*[3]"+vt+")("+vAddr+")("+vs+")\n")
addDumpTest(&pv, "(**[3]"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
addDumpTest(nv, "(*[3]"+vt+")(<nil>)\n")
// Array containing type with custom formatter on pointer receiver only.
v2i0 := pstringer("1")
v2i1 := pstringer("2")
v2i2 := pstringer("3")
v2 := [3]pstringer{v2i0, v2i1, v2i2}
v2i0Len := fmt.Sprintf("%d", len(v2i0))
v2i1Len := fmt.Sprintf("%d", len(v2i1))
v2i2Len := fmt.Sprintf("%d", len(v2i2))
v2Len := fmt.Sprintf("%d", len(v2))
v2Cap := fmt.Sprintf("%d", cap(v2))
nv2 := (*[3]pstringer)(nil)
pv2 := &v2
v2Addr := fmt.Sprintf("%p", pv2)
pv2Addr := fmt.Sprintf("%p", &pv2)
v2t := "spew_test.pstringer"
v2s := "(len=" + v2Len + " cap=" + v2Cap + ") {\n (" + v2t + ") (len=" +
v2i0Len + ") stringer 1,\n (" + v2t + ") (len=" + v2i1Len +
") stringer 2,\n (" + v2t + ") (len=" + v2i2Len + ") " +
"stringer 3\n}"
addDumpTest(v2, "([3]"+v2t+") "+v2s+"\n")
addDumpTest(pv2, "(*[3]"+v2t+")("+v2Addr+")("+v2s+")\n")
addDumpTest(&pv2, "(**[3]"+v2t+")("+pv2Addr+"->"+v2Addr+")("+v2s+")\n")
addDumpTest(nv2, "(*[3]"+v2t+")(<nil>)\n")
// Array containing interfaces.
v3i0 := "one"
v3 := [3]interface{}{v3i0, int(2), uint(3)}
v3i0Len := fmt.Sprintf("%d", len(v3i0))
v3Len := fmt.Sprintf("%d", len(v3))
v3Cap := fmt.Sprintf("%d", cap(v3))
nv3 := (*[3]interface{})(nil)
pv3 := &v3
v3Addr := fmt.Sprintf("%p", pv3)
pv3Addr := fmt.Sprintf("%p", &pv3)
v3t := "[3]interface {}"
v3t2 := "string"
v3t3 := "int"
v3t4 := "uint"
v3s := "(len=" + v3Len + " cap=" + v3Cap + ") {\n (" + v3t2 + ") " +
"(len=" + v3i0Len + ") \"one\",\n (" + v3t3 + ") 2,\n (" +
v3t4 + ") 3\n}"
addDumpTest(v3, "("+v3t+") "+v3s+"\n")
addDumpTest(pv3, "(*"+v3t+")("+v3Addr+")("+v3s+")\n")
addDumpTest(&pv3, "(**"+v3t+")("+pv3Addr+"->"+v3Addr+")("+v3s+")\n")
addDumpTest(nv3, "(*"+v3t+")(<nil>)\n")
// Array containing bytes.
v4 := [34]byte{
0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18,
0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28,
0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30,
0x31, 0x32,
}
v4Len := fmt.Sprintf("%d", len(v4))
v4Cap := fmt.Sprintf("%d", cap(v4))
nv4 := (*[34]byte)(nil)
pv4 := &v4
v4Addr := fmt.Sprintf("%p", pv4)
pv4Addr := fmt.Sprintf("%p", &pv4)
v4t := "[34]uint8"
v4s := "(len=" + v4Len + " cap=" + v4Cap + ") " +
"{\n 00000000 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f 20" +
" |............... |\n" +
" 00000010 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30" +
" |!\"#$%&'()*+,-./0|\n" +
" 00000020 31 32 " +
" |12|\n}"
addDumpTest(v4, "("+v4t+") "+v4s+"\n")
addDumpTest(pv4, "(*"+v4t+")("+v4Addr+")("+v4s+")\n")
addDumpTest(&pv4, "(**"+v4t+")("+pv4Addr+"->"+v4Addr+")("+v4s+")\n")
addDumpTest(nv4, "(*"+v4t+")(<nil>)\n")
}
func addSliceDumpTests() {
// Slice containing standard float32 values.
v := []float32{3.14, 6.28, 12.56}
vLen := fmt.Sprintf("%d", len(v))
vCap := fmt.Sprintf("%d", cap(v))
nv := (*[]float32)(nil)
pv := &v
vAddr := fmt.Sprintf("%p", pv)
pvAddr := fmt.Sprintf("%p", &pv)
vt := "float32"
vs := "(len=" + vLen + " cap=" + vCap + ") {\n (" + vt + ") 3.14,\n (" +
vt + ") 6.28,\n (" + vt + ") 12.56\n}"
addDumpTest(v, "([]"+vt+") "+vs+"\n")
addDumpTest(pv, "(*[]"+vt+")("+vAddr+")("+vs+")\n")
addDumpTest(&pv, "(**[]"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
addDumpTest(nv, "(*[]"+vt+")(<nil>)\n")
// Slice containing type with custom formatter on pointer receiver only.
v2i0 := pstringer("1")
v2i1 := pstringer("2")
v2i2 := pstringer("3")
v2 := []pstringer{v2i0, v2i1, v2i2}
v2i0Len := fmt.Sprintf("%d", len(v2i0))
v2i1Len := fmt.Sprintf("%d", len(v2i1))
v2i2Len := fmt.Sprintf("%d", len(v2i2))
v2Len := fmt.Sprintf("%d", len(v2))
v2Cap := fmt.Sprintf("%d", cap(v2))
nv2 := (*[]pstringer)(nil)
pv2 := &v2
v2Addr := fmt.Sprintf("%p", pv2)
pv2Addr := fmt.Sprintf("%p", &pv2)
v2t := "spew_test.pstringer"
v2s := "(len=" + v2Len + " cap=" + v2Cap + ") {\n (" + v2t + ") (len=" +
v2i0Len + ") stringer 1,\n (" + v2t + ") (len=" + v2i1Len +
") stringer 2,\n (" + v2t + ") (len=" + v2i2Len + ") " +
"stringer 3\n}"
addDumpTest(v2, "([]"+v2t+") "+v2s+"\n")
addDumpTest(pv2, "(*[]"+v2t+")("+v2Addr+")("+v2s+")\n")
addDumpTest(&pv2, "(**[]"+v2t+")("+pv2Addr+"->"+v2Addr+")("+v2s+")\n")
addDumpTest(nv2, "(*[]"+v2t+")(<nil>)\n")
// Slice containing interfaces.
v3i0 := "one"
v3 := []interface{}{v3i0, int(2), uint(3), nil}
v3i0Len := fmt.Sprintf("%d", len(v3i0))
v3Len := fmt.Sprintf("%d", len(v3))
v3Cap := fmt.Sprintf("%d", cap(v3))
nv3 := (*[]interface{})(nil)
pv3 := &v3
v3Addr := fmt.Sprintf("%p", pv3)
pv3Addr := fmt.Sprintf("%p", &pv3)
v3t := "[]interface {}"
v3t2 := "string"
v3t3 := "int"
v3t4 := "uint"
v3t5 := "interface {}"
v3s := "(len=" + v3Len + " cap=" + v3Cap + ") {\n (" + v3t2 + ") " +
"(len=" + v3i0Len + ") \"one\",\n (" + v3t3 + ") 2,\n (" +
v3t4 + ") 3,\n (" + v3t5 + ") <nil>\n}"
addDumpTest(v3, "("+v3t+") "+v3s+"\n")
addDumpTest(pv3, "(*"+v3t+")("+v3Addr+")("+v3s+")\n")
addDumpTest(&pv3, "(**"+v3t+")("+pv3Addr+"->"+v3Addr+")("+v3s+")\n")
addDumpTest(nv3, "(*"+v3t+")(<nil>)\n")
// Slice containing bytes.
v4 := []byte{
0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18,
0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28,
0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30,
0x31, 0x32,
}
v4Len := fmt.Sprintf("%d", len(v4))
v4Cap := fmt.Sprintf("%d", cap(v4))
nv4 := (*[]byte)(nil)
pv4 := &v4
v4Addr := fmt.Sprintf("%p", pv4)
pv4Addr := fmt.Sprintf("%p", &pv4)
v4t := "[]uint8"
v4s := "(len=" + v4Len + " cap=" + v4Cap + ") " +
"{\n 00000000 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f 20" +
" |............... |\n" +
" 00000010 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30" +
" |!\"#$%&'()*+,-./0|\n" +
" 00000020 31 32 " +
" |12|\n}"
addDumpTest(v4, "("+v4t+") "+v4s+"\n")
addDumpTest(pv4, "(*"+v4t+")("+v4Addr+")("+v4s+")\n")
addDumpTest(&pv4, "(**"+v4t+")("+pv4Addr+"->"+v4Addr+")("+v4s+")\n")
addDumpTest(nv4, "(*"+v4t+")(<nil>)\n")
// Nil slice.
v5 := []int(nil)
nv5 := (*[]int)(nil)
pv5 := &v5
v5Addr := fmt.Sprintf("%p", pv5)
pv5Addr := fmt.Sprintf("%p", &pv5)
v5t := "[]int"
v5s := "<nil>"
addDumpTest(v5, "("+v5t+") "+v5s+"\n")
addDumpTest(pv5, "(*"+v5t+")("+v5Addr+")("+v5s+")\n")
addDumpTest(&pv5, "(**"+v5t+")("+pv5Addr+"->"+v5Addr+")("+v5s+")\n")
addDumpTest(nv5, "(*"+v5t+")(<nil>)\n")
}
func addStringDumpTests() {
// Standard string.
v := "test"
vLen := fmt.Sprintf("%d", len(v))
nv := (*string)(nil)
pv := &v
vAddr := fmt.Sprintf("%p", pv)
pvAddr := fmt.Sprintf("%p", &pv)
vt := "string"
vs := "(len=" + vLen + ") \"test\""
addDumpTest(v, "("+vt+") "+vs+"\n")
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs+")\n")
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
addDumpTest(nv, "(*"+vt+")(<nil>)\n")
}
func addInterfaceDumpTests() {
// Nil interface.
var v interface{}
nv := (*interface{})(nil)
pv := &v
vAddr := fmt.Sprintf("%p", pv)
pvAddr := fmt.Sprintf("%p", &pv)
vt := "interface {}"
vs := "<nil>"
addDumpTest(v, "("+vt+") "+vs+"\n")
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs+")\n")
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
addDumpTest(nv, "(*"+vt+")(<nil>)\n")
// Sub-interface.
v2 := interface{}(uint16(65535))
pv2 := &v2
v2Addr := fmt.Sprintf("%p", pv2)
pv2Addr := fmt.Sprintf("%p", &pv2)
v2t := "uint16"
v2s := "65535"
addDumpTest(v2, "("+v2t+") "+v2s+"\n")
addDumpTest(pv2, "(*"+v2t+")("+v2Addr+")("+v2s+")\n")
addDumpTest(&pv2, "(**"+v2t+")("+pv2Addr+"->"+v2Addr+")("+v2s+")\n")
}
func addMapDumpTests() {
// Map with string keys and int vals.
k := "one"
kk := "two"
m := map[string]int{k: 1, kk: 2}
klen := fmt.Sprintf("%d", len(k)) // not kLen to shut golint up
kkLen := fmt.Sprintf("%d", len(kk))
mLen := fmt.Sprintf("%d", len(m))
nilMap := map[string]int(nil)
nm := (*map[string]int)(nil)
pm := &m
mAddr := fmt.Sprintf("%p", pm)
pmAddr := fmt.Sprintf("%p", &pm)
mt := "map[string]int"
mt1 := "string"
mt2 := "int"
ms := "(len=" + mLen + ") {\n (" + mt1 + ") (len=" + klen + ") " +
"\"one\": (" + mt2 + ") 1,\n (" + mt1 + ") (len=" + kkLen +
") \"two\": (" + mt2 + ") 2\n}"
ms2 := "(len=" + mLen + ") {\n (" + mt1 + ") (len=" + kkLen + ") " +
"\"two\": (" + mt2 + ") 2,\n (" + mt1 + ") (len=" + klen +
") \"one\": (" + mt2 + ") 1\n}"
addDumpTest(m, "("+mt+") "+ms+"\n", "("+mt+") "+ms2+"\n")
addDumpTest(pm, "(*"+mt+")("+mAddr+")("+ms+")\n",
"(*"+mt+")("+mAddr+")("+ms2+")\n")
addDumpTest(&pm, "(**"+mt+")("+pmAddr+"->"+mAddr+")("+ms+")\n",
"(**"+mt+")("+pmAddr+"->"+mAddr+")("+ms2+")\n")
addDumpTest(nm, "(*"+mt+")(<nil>)\n")
addDumpTest(nilMap, "("+mt+") <nil>\n")
// Map with custom formatter type on pointer receiver only keys and vals.
k2 := pstringer("one")
v2 := pstringer("1")
m2 := map[pstringer]pstringer{k2: v2}
k2Len := fmt.Sprintf("%d", len(k2))
v2Len := fmt.Sprintf("%d", len(v2))
m2Len := fmt.Sprintf("%d", len(m2))
nilMap2 := map[pstringer]pstringer(nil)
nm2 := (*map[pstringer]pstringer)(nil)
pm2 := &m2
m2Addr := fmt.Sprintf("%p", pm2)
pm2Addr := fmt.Sprintf("%p", &pm2)
m2t := "map[spew_test.pstringer]spew_test.pstringer"
m2t1 := "spew_test.pstringer"
m2t2 := "spew_test.pstringer"
m2s := "(len=" + m2Len + ") {\n (" + m2t1 + ") (len=" + k2Len + ") " +
"stringer one: (" + m2t2 + ") (len=" + v2Len + ") stringer 1\n}"
addDumpTest(m2, "("+m2t+") "+m2s+"\n")
addDumpTest(pm2, "(*"+m2t+")("+m2Addr+")("+m2s+")\n")
addDumpTest(&pm2, "(**"+m2t+")("+pm2Addr+"->"+m2Addr+")("+m2s+")\n")
addDumpTest(nm2, "(*"+m2t+")(<nil>)\n")
addDumpTest(nilMap2, "("+m2t+") <nil>\n")
// Map with interface keys and values.
k3 := "one"
k3Len := fmt.Sprintf("%d", len(k3))
m3 := map[interface{}]interface{}{k3: 1}
m3Len := fmt.Sprintf("%d", len(m3))
nilMap3 := map[interface{}]interface{}(nil)
nm3 := (*map[interface{}]interface{})(nil)
pm3 := &m3
m3Addr := fmt.Sprintf("%p", pm3)
pm3Addr := fmt.Sprintf("%p", &pm3)
m3t := "map[interface {}]interface {}"
m3t1 := "string"
m3t2 := "int"
m3s := "(len=" + m3Len + ") {\n (" + m3t1 + ") (len=" + k3Len + ") " +
"\"one\": (" + m3t2 + ") 1\n}"
addDumpTest(m3, "("+m3t+") "+m3s+"\n")
addDumpTest(pm3, "(*"+m3t+")("+m3Addr+")("+m3s+")\n")
addDumpTest(&pm3, "(**"+m3t+")("+pm3Addr+"->"+m3Addr+")("+m3s+")\n")
addDumpTest(nm3, "(*"+m3t+")(<nil>)\n")
addDumpTest(nilMap3, "("+m3t+") <nil>\n")
// Map with nil interface value.
k4 := "nil"
k4Len := fmt.Sprintf("%d", len(k4))
m4 := map[string]interface{}{k4: nil}
m4Len := fmt.Sprintf("%d", len(m4))
nilMap4 := map[string]interface{}(nil)
nm4 := (*map[string]interface{})(nil)
pm4 := &m4
m4Addr := fmt.Sprintf("%p", pm4)
pm4Addr := fmt.Sprintf("%p", &pm4)
m4t := "map[string]interface {}"
m4t1 := "string"
m4t2 := "interface {}"
m4s := "(len=" + m4Len + ") {\n (" + m4t1 + ") (len=" + k4Len + ")" +
" \"nil\": (" + m4t2 + ") <nil>\n}"
addDumpTest(m4, "("+m4t+") "+m4s+"\n")
addDumpTest(pm4, "(*"+m4t+")("+m4Addr+")("+m4s+")\n")
addDumpTest(&pm4, "(**"+m4t+")("+pm4Addr+"->"+m4Addr+")("+m4s+")\n")
addDumpTest(nm4, "(*"+m4t+")(<nil>)\n")
addDumpTest(nilMap4, "("+m4t+") <nil>\n")
}
func addStructDumpTests() {
// Struct with primitives.
type s1 struct {
a int8
b uint8
}
v := s1{127, 255}
nv := (*s1)(nil)
pv := &v
vAddr := fmt.Sprintf("%p", pv)
pvAddr := fmt.Sprintf("%p", &pv)
vt := "spew_test.s1"
vt2 := "int8"
vt3 := "uint8"
vs := "{\n a: (" + vt2 + ") 127,\n b: (" + vt3 + ") 255\n}"
addDumpTest(v, "("+vt+") "+vs+"\n")
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs+")\n")
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
addDumpTest(nv, "(*"+vt+")(<nil>)\n")
// Struct that contains another struct.
type s2 struct {
s1 s1
b bool
}
v2 := s2{s1{127, 255}, true}
nv2 := (*s2)(nil)
pv2 := &v2
v2Addr := fmt.Sprintf("%p", pv2)
pv2Addr := fmt.Sprintf("%p", &pv2)
v2t := "spew_test.s2"
v2t2 := "spew_test.s1"
v2t3 := "int8"
v2t4 := "uint8"
v2t5 := "bool"
v2s := "{\n s1: (" + v2t2 + ") {\n a: (" + v2t3 + ") 127,\n b: (" +
v2t4 + ") 255\n },\n b: (" + v2t5 + ") true\n}"
addDumpTest(v2, "("+v2t+") "+v2s+"\n")
addDumpTest(pv2, "(*"+v2t+")("+v2Addr+")("+v2s+")\n")
addDumpTest(&pv2, "(**"+v2t+")("+pv2Addr+"->"+v2Addr+")("+v2s+")\n")
addDumpTest(nv2, "(*"+v2t+")(<nil>)\n")
// Struct that contains custom type with Stringer pointer interface via both
// exported and unexported fields.
type s3 struct {
s pstringer
S pstringer
}
v3 := s3{"test", "test2"}
nv3 := (*s3)(nil)
pv3 := &v3
v3Addr := fmt.Sprintf("%p", pv3)
pv3Addr := fmt.Sprintf("%p", &pv3)
v3t := "spew_test.s3"
v3t2 := "spew_test.pstringer"
v3s := "{\n s: (" + v3t2 + ") (len=4) stringer test,\n S: (" + v3t2 +
") (len=5) stringer test2\n}"
addDumpTest(v3, "("+v3t+") "+v3s+"\n")
addDumpTest(pv3, "(*"+v3t+")("+v3Addr+")("+v3s+")\n")
addDumpTest(&pv3, "(**"+v3t+")("+pv3Addr+"->"+v3Addr+")("+v3s+")\n")
addDumpTest(nv3, "(*"+v3t+")(<nil>)\n")
// Struct that contains embedded struct and field to same struct.
e := embed{"embedstr"}
eLen := fmt.Sprintf("%d", len("embedstr"))
v4 := embedwrap{embed: &e, e: &e}
nv4 := (*embedwrap)(nil)
pv4 := &v4
eAddr := fmt.Sprintf("%p", &e)
v4Addr := fmt.Sprintf("%p", pv4)
pv4Addr := fmt.Sprintf("%p", &pv4)
v4t := "spew_test.embedwrap"
v4t2 := "spew_test.embed"
v4t3 := "string"
v4s := "{\n embed: (*" + v4t2 + ")(" + eAddr + ")({\n a: (" + v4t3 +
") (len=" + eLen + ") \"embedstr\"\n }),\n e: (*" + v4t2 +
")(" + eAddr + ")({\n a: (" + v4t3 + ") (len=" + eLen + ")" +
" \"embedstr\"\n })\n}"
addDumpTest(v4, "("+v4t+") "+v4s+"\n")
addDumpTest(pv4, "(*"+v4t+")("+v4Addr+")("+v4s+")\n")
addDumpTest(&pv4, "(**"+v4t+")("+pv4Addr+"->"+v4Addr+")("+v4s+")\n")
addDumpTest(nv4, "(*"+v4t+")(<nil>)\n")
}
func addUintptrDumpTests() {
// Null pointer.
v := uintptr(0)
pv := &v
vAddr := fmt.Sprintf("%p", pv)
pvAddr := fmt.Sprintf("%p", &pv)
vt := "uintptr"
vs := "<nil>"
addDumpTest(v, "("+vt+") "+vs+"\n")
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs+")\n")
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
// Address of real variable.
i := 1
v2 := uintptr(unsafe.Pointer(&i))
nv2 := (*uintptr)(nil)
pv2 := &v2
v2Addr := fmt.Sprintf("%p", pv2)
pv2Addr := fmt.Sprintf("%p", &pv2)
v2t := "uintptr"
v2s := fmt.Sprintf("%p", &i)
addDumpTest(v2, "("+v2t+") "+v2s+"\n")
addDumpTest(pv2, "(*"+v2t+")("+v2Addr+")("+v2s+")\n")
addDumpTest(&pv2, "(**"+v2t+")("+pv2Addr+"->"+v2Addr+")("+v2s+")\n")
addDumpTest(nv2, "(*"+v2t+")(<nil>)\n")
}
func addUnsafePointerDumpTests() {
// Null pointer.
v := unsafe.Pointer(uintptr(0))
nv := (*unsafe.Pointer)(nil)
pv := &v
vAddr := fmt.Sprintf("%p", pv)
pvAddr := fmt.Sprintf("%p", &pv)
vt := "unsafe.Pointer"
vs := "<nil>"
addDumpTest(v, "("+vt+") "+vs+"\n")
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs+")\n")
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
addDumpTest(nv, "(*"+vt+")(<nil>)\n")
// Address of real variable.
i := 1
v2 := unsafe.Pointer(&i)
pv2 := &v2
v2Addr := fmt.Sprintf("%p", pv2)
pv2Addr := fmt.Sprintf("%p", &pv2)
v2t := "unsafe.Pointer"
v2s := fmt.Sprintf("%p", &i)
addDumpTest(v2, "("+v2t+") "+v2s+"\n")
addDumpTest(pv2, "(*"+v2t+")("+v2Addr+")("+v2s+")\n")
addDumpTest(&pv2, "(**"+v2t+")("+pv2Addr+"->"+v2Addr+")("+v2s+")\n")
addDumpTest(nv, "(*"+vt+")(<nil>)\n")
}
func addChanDumpTests() {
// Nil channel.
var v chan int
pv := &v
nv := (*chan int)(nil)
vAddr := fmt.Sprintf("%p", pv)
pvAddr := fmt.Sprintf("%p", &pv)
vt := "chan int"
vs := "<nil>"
addDumpTest(v, "("+vt+") "+vs+"\n")
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs+")\n")
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
addDumpTest(nv, "(*"+vt+")(<nil>)\n")
// Real channel.
v2 := make(chan int)
pv2 := &v2
v2Addr := fmt.Sprintf("%p", pv2)
pv2Addr := fmt.Sprintf("%p", &pv2)
v2t := "chan int"
v2s := fmt.Sprintf("%p", v2)
addDumpTest(v2, "("+v2t+") "+v2s+"\n")
addDumpTest(pv2, "(*"+v2t+")("+v2Addr+")("+v2s+")\n")
addDumpTest(&pv2, "(**"+v2t+")("+pv2Addr+"->"+v2Addr+")("+v2s+")\n")
}
func addFuncDumpTests() {
// Function with no params and no returns.
v := addIntDumpTests
nv := (*func())(nil)
pv := &v
vAddr := fmt.Sprintf("%p", pv)
pvAddr := fmt.Sprintf("%p", &pv)
vt := "func()"
vs := fmt.Sprintf("%p", v)
addDumpTest(v, "("+vt+") "+vs+"\n")
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs+")\n")
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
addDumpTest(nv, "(*"+vt+")(<nil>)\n")
// Function with param and no returns.
v2 := TestDump
nv2 := (*func(*testing.T))(nil)
pv2 := &v2
v2Addr := fmt.Sprintf("%p", pv2)
pv2Addr := fmt.Sprintf("%p", &pv2)
v2t := "func(*testing.T)"
v2s := fmt.Sprintf("%p", v2)
addDumpTest(v2, "("+v2t+") "+v2s+"\n")
addDumpTest(pv2, "(*"+v2t+")("+v2Addr+")("+v2s+")\n")
addDumpTest(&pv2, "(**"+v2t+")("+pv2Addr+"->"+v2Addr+")("+v2s+")\n")
addDumpTest(nv2, "(*"+v2t+")(<nil>)\n")
// Function with multiple params and multiple returns.
var v3 = func(i int, s string) (b bool, err error) {
return true, nil
}
nv3 := (*func(int, string) (bool, error))(nil)
pv3 := &v3
v3Addr := fmt.Sprintf("%p", pv3)
pv3Addr := fmt.Sprintf("%p", &pv3)
v3t := "func(int, string) (bool, error)"
v3s := fmt.Sprintf("%p", v3)
addDumpTest(v3, "("+v3t+") "+v3s+"\n")
addDumpTest(pv3, "(*"+v3t+")("+v3Addr+")("+v3s+")\n")
addDumpTest(&pv3, "(**"+v3t+")("+pv3Addr+"->"+v3Addr+")("+v3s+")\n")
addDumpTest(nv3, "(*"+v3t+")(<nil>)\n")
}
func addCircularDumpTests() {
// Struct that is circular through self referencing.
type circular struct {
c *circular
}
v := circular{nil}
v.c = &v
pv := &v
vAddr := fmt.Sprintf("%p", pv)
pvAddr := fmt.Sprintf("%p", &pv)
vt := "spew_test.circular"
vs := "{\n c: (*" + vt + ")(" + vAddr + ")({\n c: (*" + vt + ")(" +
vAddr + ")(<already shown>)\n })\n}"
vs2 := "{\n c: (*" + vt + ")(" + vAddr + ")(<already shown>)\n}"
addDumpTest(v, "("+vt+") "+vs+"\n")
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs2+")\n")
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs2+")\n")
// Structs that are circular through cross referencing.
v2 := xref1{nil}
ts2 := xref2{&v2}
v2.ps2 = &ts2
pv2 := &v2
ts2Addr := fmt.Sprintf("%p", &ts2)
v2Addr := fmt.Sprintf("%p", pv2)
pv2Addr := fmt.Sprintf("%p", &pv2)
v2t := "spew_test.xref1"
v2t2 := "spew_test.xref2"
v2s := "{\n ps2: (*" + v2t2 + ")(" + ts2Addr + ")({\n ps1: (*" + v2t +
")(" + v2Addr + ")({\n ps2: (*" + v2t2 + ")(" + ts2Addr +
")(<already shown>)\n })\n })\n}"
v2s2 := "{\n ps2: (*" + v2t2 + ")(" + ts2Addr + ")({\n ps1: (*" + v2t +
")(" + v2Addr + ")(<already shown>)\n })\n}"
addDumpTest(v2, "("+v2t+") "+v2s+"\n")
addDumpTest(pv2, "(*"+v2t+")("+v2Addr+")("+v2s2+")\n")
addDumpTest(&pv2, "(**"+v2t+")("+pv2Addr+"->"+v2Addr+")("+v2s2+")\n")
// Structs that are indirectly circular.
v3 := indirCir1{nil}
tic2 := indirCir2{nil}
tic3 := indirCir3{&v3}
tic2.ps3 = &tic3
v3.ps2 = &tic2
pv3 := &v3
tic2Addr := fmt.Sprintf("%p", &tic2)
tic3Addr := fmt.Sprintf("%p", &tic3)
v3Addr := fmt.Sprintf("%p", pv3)
pv3Addr := fmt.Sprintf("%p", &pv3)
v3t := "spew_test.indirCir1"
v3t2 := "spew_test.indirCir2"
v3t3 := "spew_test.indirCir3"
v3s := "{\n ps2: (*" + v3t2 + ")(" + tic2Addr + ")({\n ps3: (*" + v3t3 +
")(" + tic3Addr + ")({\n ps1: (*" + v3t + ")(" + v3Addr +
")({\n ps2: (*" + v3t2 + ")(" + tic2Addr +
")(<already shown>)\n })\n })\n })\n}"
v3s2 := "{\n ps2: (*" + v3t2 + ")(" + tic2Addr + ")({\n ps3: (*" + v3t3 +
")(" + tic3Addr + ")({\n ps1: (*" + v3t + ")(" + v3Addr +
")(<already shown>)\n })\n })\n}"
addDumpTest(v3, "("+v3t+") "+v3s+"\n")
addDumpTest(pv3, "(*"+v3t+")("+v3Addr+")("+v3s2+")\n")
addDumpTest(&pv3, "(**"+v3t+")("+pv3Addr+"->"+v3Addr+")("+v3s2+")\n")
}
func addPanicDumpTests() {
// Type that panics in its Stringer interface.
v := panicer(127)
nv := (*panicer)(nil)
pv := &v
vAddr := fmt.Sprintf("%p", pv)
pvAddr := fmt.Sprintf("%p", &pv)
vt := "spew_test.panicer"
vs := "(PANIC=test panic)127"
addDumpTest(v, "("+vt+") "+vs+"\n")
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs+")\n")
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
addDumpTest(nv, "(*"+vt+")(<nil>)\n")
}
func addErrorDumpTests() {
// Type that has a custom Error interface.
v := customError(127)
nv := (*customError)(nil)
pv := &v
vAddr := fmt.Sprintf("%p", pv)
pvAddr := fmt.Sprintf("%p", &pv)
vt := "spew_test.customError"
vs := "error: 127"
addDumpTest(v, "("+vt+") "+vs+"\n")
addDumpTest(pv, "(*"+vt+")("+vAddr+")("+vs+")\n")
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+")("+vs+")\n")
addDumpTest(nv, "(*"+vt+")(<nil>)\n")
}
// TestDump executes all of the tests described by dumpTests.
func TestDump(t *testing.T) {
// Setup tests.
addIntDumpTests()
addUintDumpTests()
addBoolDumpTests()
addFloatDumpTests()
addComplexDumpTests()
addArrayDumpTests()
addSliceDumpTests()
addStringDumpTests()
addInterfaceDumpTests()
addMapDumpTests()
addStructDumpTests()
addUintptrDumpTests()
addUnsafePointerDumpTests()
addChanDumpTests()
addFuncDumpTests()
addCircularDumpTests()
addPanicDumpTests()
addErrorDumpTests()
addCgoDumpTests()
t.Logf("Running %d tests", len(dumpTests))
for i, test := range dumpTests {
buf := new(bytes.Buffer)
spew.Fdump(buf, test.in)
s := buf.String()
if testFailed(s, test.wants) {
t.Errorf("Dump #%d\n got: %s %s", i, s, stringizeWants(test.wants))
continue
}
}
}
func TestDumpSortedKeys(t *testing.T) {
cfg := spew.ConfigState{SortKeys: true}
s := cfg.Sdump(map[int]string{1: "1", 3: "3", 2: "2"})
expected := `(map[int]string) (len=3) {
(int) 1: (string) (len=1) "1",
(int) 2: (string) (len=1) "2",
(int) 3: (string) (len=1) "3"
}
`
if s != expected {
t.Errorf("Sorted keys mismatch:\n %v %v", s, expected)
}
}

View file

@ -1,97 +0,0 @@
// Copyright (c) 2013 Dave Collins <dave@davec.name>
//
// Permission to use, copy, modify, and distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
// NOTE: Due to the following build constraints, this file will only be compiled
// when both cgo is supported and "-tags testcgo" is added to the go test
// command line. This means the cgo tests are only added (and hence run) when
// specifially requested. This configuration is used because spew itself
// does not require cgo to run even though it does handle certain cgo types
// specially. Rather than forcing all clients to require cgo and an external
// C compiler just to run the tests, this scheme makes them optional.
// +build cgo,testcgo
package spew_test
import (
"fmt"
"github.com/davecgh/go-spew/spew/testdata"
)
func addCgoDumpTests() {
// C char pointer.
v := testdata.GetCgoCharPointer()
nv := testdata.GetCgoNullCharPointer()
pv := &v
vcAddr := fmt.Sprintf("%p", v)
vAddr := fmt.Sprintf("%p", pv)
pvAddr := fmt.Sprintf("%p", &pv)
vt := "*testdata._Ctype_char"
vs := "116"
addDumpTest(v, "("+vt+")("+vcAddr+")("+vs+")\n")
addDumpTest(pv, "(*"+vt+")("+vAddr+"->"+vcAddr+")("+vs+")\n")
addDumpTest(&pv, "(**"+vt+")("+pvAddr+"->"+vAddr+"->"+vcAddr+")("+vs+")\n")
addDumpTest(nv, "("+vt+")(<nil>)\n")
// C char array.
v2, v2l, v2c := testdata.GetCgoCharArray()
v2Len := fmt.Sprintf("%d", v2l)
v2Cap := fmt.Sprintf("%d", v2c)
v2t := "[6]testdata._Ctype_char"
v2s := "(len=" + v2Len + " cap=" + v2Cap + ") " +
"{\n 00000000 74 65 73 74 32 00 " +
" |test2.|\n}"
addDumpTest(v2, "("+v2t+") "+v2s+"\n")
// C unsigned char array.
v3, v3l, v3c := testdata.GetCgoUnsignedCharArray()
v3Len := fmt.Sprintf("%d", v3l)
v3Cap := fmt.Sprintf("%d", v3c)
v3t := "[6]testdata._Ctype_unsignedchar"
v3s := "(len=" + v3Len + " cap=" + v3Cap + ") " +
"{\n 00000000 74 65 73 74 33 00 " +
" |test3.|\n}"
addDumpTest(v3, "("+v3t+") "+v3s+"\n")
// C signed char array.
v4, v4l, v4c := testdata.GetCgoSignedCharArray()
v4Len := fmt.Sprintf("%d", v4l)
v4Cap := fmt.Sprintf("%d", v4c)
v4t := "[6]testdata._Ctype_schar"
v4t2 := "testdata._Ctype_schar"
v4s := "(len=" + v4Len + " cap=" + v4Cap + ") " +
"{\n (" + v4t2 + ") 116,\n (" + v4t2 + ") 101,\n (" + v4t2 +
") 115,\n (" + v4t2 + ") 116,\n (" + v4t2 + ") 52,\n (" + v4t2 +
") 0\n}"
addDumpTest(v4, "("+v4t+") "+v4s+"\n")
// C uint8_t array.
v5, v5l, v5c := testdata.GetCgoUint8tArray()
v5Len := fmt.Sprintf("%d", v5l)
v5Cap := fmt.Sprintf("%d", v5c)
v5t := "[6]testdata._Ctype_uint8_t"
v5s := "(len=" + v5Len + " cap=" + v5Cap + ") " +
"{\n 00000000 74 65 73 74 35 00 " +
" |test5.|\n}"
addDumpTest(v5, "("+v5t+") "+v5s+"\n")
// C typedefed unsigned char array.
v6, v6l, v6c := testdata.GetCgoTypdefedUnsignedCharArray()
v6Len := fmt.Sprintf("%d", v6l)
v6Cap := fmt.Sprintf("%d", v6c)
v6t := "[6]testdata._Ctype_custom_uchar_t"
v6s := "(len=" + v6Len + " cap=" + v6Cap + ") " +
"{\n 00000000 74 65 73 74 36 00 " +
" |test6.|\n}"
addDumpTest(v6, "("+v6t+") "+v6s+"\n")
}

View file

@ -1,26 +0,0 @@
// Copyright (c) 2013 Dave Collins <dave@davec.name>
//
// Permission to use, copy, modify, and distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
// NOTE: Due to the following build constraints, this file will only be compiled
// when either cgo is not supported or "-tags testcgo" is not added to the go
// test command line. This file intentionally does not setup any cgo tests in
// this scenario.
// +build !cgo !testcgo
package spew_test
func addCgoDumpTests() {
// Don't add any tests for cgo since this file is only compiled when
// there should not be any cgo tests.
}

View file

@ -1,230 +0,0 @@
/*
* Copyright (c) 2013 Dave Collins <dave@davec.name>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
package spew_test
import (
"fmt"
"github.com/davecgh/go-spew/spew"
)
type Flag int
const (
flagOne Flag = iota
flagTwo
)
var flagStrings = map[Flag]string{
flagOne: "flagOne",
flagTwo: "flagTwo",
}
func (f Flag) String() string {
if s, ok := flagStrings[f]; ok {
return s
}
return fmt.Sprintf("Unknown flag (%d)", int(f))
}
type Bar struct {
flag Flag
data uintptr
}
type Foo struct {
unexportedField Bar
ExportedField map[interface{}]interface{}
}
// This example demonstrates how to use Dump to dump variables to stdout.
func ExampleDump() {
// The following package level declarations are assumed for this example:
/*
type Flag int
const (
flagOne Flag = iota
flagTwo
)
var flagStrings = map[Flag]string{
flagOne: "flagOne",
flagTwo: "flagTwo",
}
func (f Flag) String() string {
if s, ok := flagStrings[f]; ok {
return s
}
return fmt.Sprintf("Unknown flag (%d)", int(f))
}
type Bar struct {
flag Flag
data uintptr
}
type Foo struct {
unexportedField Bar
ExportedField map[interface{}]interface{}
}
*/
// Setup some sample data structures for the example.
bar := Bar{Flag(flagTwo), uintptr(0)}
s1 := Foo{bar, map[interface{}]interface{}{"one": true}}
f := Flag(5)
b := []byte{
0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18,
0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28,
0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30,
0x31, 0x32,
}
// Dump!
spew.Dump(s1, f, b)
// Output:
// (spew_test.Foo) {
// unexportedField: (spew_test.Bar) {
// flag: (spew_test.Flag) flagTwo,
// data: (uintptr) <nil>
// },
// ExportedField: (map[interface {}]interface {}) (len=1) {
// (string) (len=3) "one": (bool) true
// }
// }
// (spew_test.Flag) Unknown flag (5)
// ([]uint8) (len=34 cap=34) {
// 00000000 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f 20 |............... |
// 00000010 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30 |!"#$%&'()*+,-./0|
// 00000020 31 32 |12|
// }
//
}
// This example demonstrates how to use Printf to display a variable with a
// format string and inline formatting.
func ExamplePrintf() {
// Create a double pointer to a uint 8.
ui8 := uint8(5)
pui8 := &ui8
ppui8 := &pui8
// Create a circular data type.
type circular struct {
ui8 uint8
c *circular
}
c := circular{ui8: 1}
c.c = &c
// Print!
spew.Printf("ppui8: %v\n", ppui8)
spew.Printf("circular: %v\n", c)
// Output:
// ppui8: <**>5
// circular: {1 <*>{1 <*><shown>}}
}
// This example demonstrates how to use a ConfigState.
func ExampleConfigState() {
// Modify the indent level of the ConfigState only. The global
// configuration is not modified.
scs := spew.ConfigState{Indent: "\t"}
// Output using the ConfigState instance.
v := map[string]int{"one": 1}
scs.Printf("v: %v\n", v)
scs.Dump(v)
// Output:
// v: map[one:1]
// (map[string]int) (len=1) {
// (string) (len=3) "one": (int) 1
// }
}
// This example demonstrates how to use ConfigState.Dump to dump variables to
// stdout
func ExampleConfigState_Dump() {
// See the top-level Dump example for details on the types used in this
// example.
// Create two ConfigState instances with different indentation.
scs := spew.ConfigState{Indent: "\t"}
scs2 := spew.ConfigState{Indent: " "}
// Setup some sample data structures for the example.
bar := Bar{Flag(flagTwo), uintptr(0)}
s1 := Foo{bar, map[interface{}]interface{}{"one": true}}
// Dump using the ConfigState instances.
scs.Dump(s1)
scs2.Dump(s1)
// Output:
// (spew_test.Foo) {
// unexportedField: (spew_test.Bar) {
// flag: (spew_test.Flag) flagTwo,
// data: (uintptr) <nil>
// },
// ExportedField: (map[interface {}]interface {}) (len=1) {
// (string) (len=3) "one": (bool) true
// }
// }
// (spew_test.Foo) {
// unexportedField: (spew_test.Bar) {
// flag: (spew_test.Flag) flagTwo,
// data: (uintptr) <nil>
// },
// ExportedField: (map[interface {}]interface {}) (len=1) {
// (string) (len=3) "one": (bool) true
// }
// }
//
}
// This example demonstrates how to use ConfigState.Printf to display a variable
// with a format string and inline formatting.
func ExampleConfigState_Printf() {
// See the top-level Dump example for details on the types used in this
// example.
// Create two ConfigState instances and modify the method handling of the
// first ConfigState only.
scs := spew.NewDefaultConfig()
scs2 := spew.NewDefaultConfig()
scs.DisableMethods = true
// Alternatively
// scs := spew.ConfigState{Indent: " ", DisableMethods: true}
// scs2 := spew.ConfigState{Indent: " "}
// This is of type Flag which implements a Stringer and has raw value 1.
f := flagTwo
// Dump using the ConfigState instances.
scs.Printf("f: %v\n", f)
scs2.Printf("f: %v\n", f)
// Output:
// f: 1
// f: flagTwo
}

View file

@ -1,419 +0,0 @@
/*
* Copyright (c) 2013 Dave Collins <dave@davec.name>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
package spew
import (
"bytes"
"fmt"
"reflect"
"strconv"
"strings"
)
// supportedFlags is a list of all the character flags supported by fmt package.
const supportedFlags = "0-+# "
// formatState implements the fmt.Formatter interface and contains information
// about the state of a formatting operation. The NewFormatter function can
// be used to get a new Formatter which can be used directly as arguments
// in standard fmt package printing calls.
type formatState struct {
value interface{}
fs fmt.State
depth int
pointers map[uintptr]int
ignoreNextType bool
cs *ConfigState
}
// buildDefaultFormat recreates the original format string without precision
// and width information to pass in to fmt.Sprintf in the case of an
// unrecognized type. Unless new types are added to the language, this
// function won't ever be called.
func (f *formatState) buildDefaultFormat() (format string) {
buf := bytes.NewBuffer(percentBytes)
for _, flag := range supportedFlags {
if f.fs.Flag(int(flag)) {
buf.WriteRune(flag)
}
}
buf.WriteRune('v')
format = buf.String()
return format
}
// constructOrigFormat recreates the original format string including precision
// and width information to pass along to the standard fmt package. This allows
// automatic deferral of all format strings this package doesn't support.
func (f *formatState) constructOrigFormat(verb rune) (format string) {
buf := bytes.NewBuffer(percentBytes)
for _, flag := range supportedFlags {
if f.fs.Flag(int(flag)) {
buf.WriteRune(flag)
}
}
if width, ok := f.fs.Width(); ok {
buf.WriteString(strconv.Itoa(width))
}
if precision, ok := f.fs.Precision(); ok {
buf.Write(precisionBytes)
buf.WriteString(strconv.Itoa(precision))
}
buf.WriteRune(verb)
format = buf.String()
return format
}
// unpackValue returns values inside of non-nil interfaces when possible and
// ensures that types for values which have been unpacked from an interface
// are displayed when the show types flag is also set.
// This is useful for data types like structs, arrays, slices, and maps which
// can contain varying types packed inside an interface.
func (f *formatState) unpackValue(v reflect.Value) reflect.Value {
if v.Kind() == reflect.Interface {
f.ignoreNextType = false
if !v.IsNil() {
v = v.Elem()
}
}
return v
}
// formatPtr handles formatting of pointers by indirecting them as necessary.
func (f *formatState) formatPtr(v reflect.Value) {
// Display nil if top level pointer is nil.
showTypes := f.fs.Flag('#')
if v.IsNil() && (!showTypes || f.ignoreNextType) {
f.fs.Write(nilAngleBytes)
return
}
// Remove pointers at or below the current depth from map used to detect
// circular refs.
for k, depth := range f.pointers {
if depth >= f.depth {
delete(f.pointers, k)
}
}
// Keep list of all dereferenced pointers to possibly show later.
pointerChain := make([]uintptr, 0)
// Figure out how many levels of indirection there are by derferencing
// pointers and unpacking interfaces down the chain while detecting circular
// references.
nilFound := false
cycleFound := false
indirects := 0
ve := v
for ve.Kind() == reflect.Ptr {
if ve.IsNil() {
nilFound = true
break
}
indirects++
addr := ve.Pointer()
pointerChain = append(pointerChain, addr)
if pd, ok := f.pointers[addr]; ok && pd < f.depth {
cycleFound = true
indirects--
break
}
f.pointers[addr] = f.depth
ve = ve.Elem()
if ve.Kind() == reflect.Interface {
if ve.IsNil() {
nilFound = true
break
}
ve = ve.Elem()
}
}
// Display type or indirection level depending on flags.
if showTypes && !f.ignoreNextType {
f.fs.Write(openParenBytes)
f.fs.Write(bytes.Repeat(asteriskBytes, indirects))
f.fs.Write([]byte(ve.Type().String()))
f.fs.Write(closeParenBytes)
} else {
if nilFound || cycleFound {
indirects += strings.Count(ve.Type().String(), "*")
}
f.fs.Write(openAngleBytes)
f.fs.Write([]byte(strings.Repeat("*", indirects)))
f.fs.Write(closeAngleBytes)
}
// Display pointer information depending on flags.
if f.fs.Flag('+') && (len(pointerChain) > 0) {
f.fs.Write(openParenBytes)
for i, addr := range pointerChain {
if i > 0 {
f.fs.Write(pointerChainBytes)
}
printHexPtr(f.fs, addr)
}
f.fs.Write(closeParenBytes)
}
// Display dereferenced value.
switch {
case nilFound == true:
f.fs.Write(nilAngleBytes)
case cycleFound == true:
f.fs.Write(circularShortBytes)
default:
f.ignoreNextType = true
f.format(ve)
}
}
// format is the main workhorse for providing the Formatter interface. It
// uses the passed reflect value to figure out what kind of object we are
// dealing with and formats it appropriately. It is a recursive function,
// however circular data structures are detected and handled properly.
func (f *formatState) format(v reflect.Value) {
// Handle invalid reflect values immediately.
kind := v.Kind()
if kind == reflect.Invalid {
f.fs.Write(invalidAngleBytes)
return
}
// Handle pointers specially.
if kind == reflect.Ptr {
f.formatPtr(v)
return
}
// Print type information unless already handled elsewhere.
if !f.ignoreNextType && f.fs.Flag('#') {
f.fs.Write(openParenBytes)
f.fs.Write([]byte(v.Type().String()))
f.fs.Write(closeParenBytes)
}
f.ignoreNextType = false
// Call Stringer/error interfaces if they exist and the handle methods
// flag is enabled.
if !f.cs.DisableMethods {
if (kind != reflect.Invalid) && (kind != reflect.Interface) {
if handled := handleMethods(f.cs, f.fs, v); handled {
return
}
}
}
switch kind {
case reflect.Invalid:
// Do nothing. We should never get here since invalid has already
// been handled above.
case reflect.Bool:
printBool(f.fs, v.Bool())
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
printInt(f.fs, v.Int(), 10)
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
printUint(f.fs, v.Uint(), 10)
case reflect.Float32:
printFloat(f.fs, v.Float(), 32)
case reflect.Float64:
printFloat(f.fs, v.Float(), 64)
case reflect.Complex64:
printComplex(f.fs, v.Complex(), 32)
case reflect.Complex128:
printComplex(f.fs, v.Complex(), 64)
case reflect.Slice:
if v.IsNil() {
f.fs.Write(nilAngleBytes)
break
}
fallthrough
case reflect.Array:
f.fs.Write(openBracketBytes)
f.depth++
if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) {
f.fs.Write(maxShortBytes)
} else {
numEntries := v.Len()
for i := 0; i < numEntries; i++ {
if i > 0 {
f.fs.Write(spaceBytes)
}
f.ignoreNextType = true
f.format(f.unpackValue(v.Index(i)))
}
}
f.depth--
f.fs.Write(closeBracketBytes)
case reflect.String:
f.fs.Write([]byte(v.String()))
case reflect.Interface:
// The only time we should get here is for nil interfaces due to
// unpackValue calls.
if v.IsNil() {
f.fs.Write(nilAngleBytes)
}
case reflect.Ptr:
// Do nothing. We should never get here since pointers have already
// been handled above.
case reflect.Map:
// nil maps should be indicated as different than empty maps
if v.IsNil() {
f.fs.Write(nilAngleBytes)
break
}
f.fs.Write(openMapBytes)
f.depth++
if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) {
f.fs.Write(maxShortBytes)
} else {
keys := v.MapKeys()
if f.cs.SortKeys {
sortValues(keys)
}
for i, key := range keys {
if i > 0 {
f.fs.Write(spaceBytes)
}
f.ignoreNextType = true
f.format(f.unpackValue(key))
f.fs.Write(colonBytes)
f.ignoreNextType = true
f.format(f.unpackValue(v.MapIndex(key)))
}
}
f.depth--
f.fs.Write(closeMapBytes)
case reflect.Struct:
numFields := v.NumField()
f.fs.Write(openBraceBytes)
f.depth++
if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) {
f.fs.Write(maxShortBytes)
} else {
vt := v.Type()
for i := 0; i < numFields; i++ {
if i > 0 {
f.fs.Write(spaceBytes)
}
vtf := vt.Field(i)
if f.fs.Flag('+') || f.fs.Flag('#') {
f.fs.Write([]byte(vtf.Name))
f.fs.Write(colonBytes)
}
f.format(f.unpackValue(v.Field(i)))
}
}
f.depth--
f.fs.Write(closeBraceBytes)
case reflect.Uintptr:
printHexPtr(f.fs, uintptr(v.Uint()))
case reflect.UnsafePointer, reflect.Chan, reflect.Func:
printHexPtr(f.fs, v.Pointer())
// There were not any other types at the time this code was written, but
// fall back to letting the default fmt package handle it if any get added.
default:
format := f.buildDefaultFormat()
if v.CanInterface() {
fmt.Fprintf(f.fs, format, v.Interface())
} else {
fmt.Fprintf(f.fs, format, v.String())
}
}
}
// Format satisfies the fmt.Formatter interface. See NewFormatter for usage
// details.
func (f *formatState) Format(fs fmt.State, verb rune) {
f.fs = fs
// Use standard formatting for verbs that are not v.
if verb != 'v' {
format := f.constructOrigFormat(verb)
fmt.Fprintf(fs, format, f.value)
return
}
if f.value == nil {
if fs.Flag('#') {
fs.Write(interfaceBytes)
}
fs.Write(nilAngleBytes)
return
}
f.format(reflect.ValueOf(f.value))
}
// newFormatter is a helper function to consolidate the logic from the various
// public methods which take varying config states.
func newFormatter(cs *ConfigState, v interface{}) fmt.Formatter {
fs := &formatState{value: v, cs: cs}
fs.pointers = make(map[uintptr]int)
return fs
}
/*
NewFormatter returns a custom formatter that satisfies the fmt.Formatter
interface. As a result, it integrates cleanly with standard fmt package
printing functions. The formatter is useful for inline printing of smaller data
types similar to the standard %v format specifier.
The custom formatter only responds to the %v (most compact), %+v (adds pointer
addresses), %#v (adds types), or %#+v (adds types and pointer addresses) verb
combinations. Any other verbs such as %x and %q will be sent to the the
standard fmt package for formatting. In addition, the custom formatter ignores
the width and precision arguments (however they will still work on the format
specifiers not handled by the custom formatter).
Typically this function shouldn't be called directly. It is much easier to make
use of the custom formatter by calling one of the convenience functions such as
Printf, Println, or Fprintf.
*/
func NewFormatter(v interface{}) fmt.Formatter {
return newFormatter(&Config, v)
}

File diff suppressed because it is too large Load diff

View file

@ -1,156 +0,0 @@
/*
* Copyright (c) 2013 Dave Collins <dave@davec.name>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
/*
This test file is part of the spew package rather than than the spew_test
package because it needs access to internals to properly test certain cases
which are not possible via the public interface since they should never happen.
*/
package spew
import (
"bytes"
"reflect"
"testing"
"unsafe"
)
// dummyFmtState implements a fake fmt.State to use for testing invalid
// reflect.Value handling. This is necessary because the fmt package catches
// invalid values before invoking the formatter on them.
type dummyFmtState struct {
bytes.Buffer
}
func (dfs *dummyFmtState) Flag(f int) bool {
if f == int('+') {
return true
}
return false
}
func (dfs *dummyFmtState) Precision() (int, bool) {
return 0, false
}
func (dfs *dummyFmtState) Width() (int, bool) {
return 0, false
}
// TestInvalidReflectValue ensures the dump and formatter code handles an
// invalid reflect value properly. This needs access to internal state since it
// should never happen in real code and therefore can't be tested via the public
// API.
func TestInvalidReflectValue(t *testing.T) {
i := 1
// Dump invalid reflect value.
v := new(reflect.Value)
buf := new(bytes.Buffer)
d := dumpState{w: buf, cs: &Config}
d.dump(*v)
s := buf.String()
want := "<invalid>"
if s != want {
t.Errorf("InvalidReflectValue #%d\n got: %s want: %s", i, s, want)
}
i++
// Formatter invalid reflect value.
buf2 := new(dummyFmtState)
f := formatState{value: *v, cs: &Config, fs: buf2}
f.format(*v)
s = buf2.String()
want = "<invalid>"
if s != want {
t.Errorf("InvalidReflectValue #%d got: %s want: %s", i, s, want)
}
}
// changeKind uses unsafe to intentionally change the kind of a reflect.Value to
// the maximum kind value which does not exist. This is needed to test the
// fallback code which punts to the standard fmt library for new types that
// might get added to the language.
func changeKind(v *reflect.Value, readOnly bool) {
rvf := (*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(v)) + offsetFlag))
*rvf = *rvf | ((1<<flagKindWidth - 1) << flagKindShift)
if readOnly {
*rvf |= flagRO
} else {
*rvf &= ^uintptr(flagRO)
}
}
// TestAddedReflectValue tests functionaly of the dump and formatter code which
// falls back to the standard fmt library for new types that might get added to
// the language.
func TestAddedReflectValue(t *testing.T) {
i := 1
// Dump using a reflect.Value that is exported.
v := reflect.ValueOf(int8(5))
changeKind(&v, false)
buf := new(bytes.Buffer)
d := dumpState{w: buf, cs: &Config}
d.dump(v)
s := buf.String()
want := "(int8) 5"
if s != want {
t.Errorf("TestAddedReflectValue #%d\n got: %s want: %s", i, s, want)
}
i++
// Dump using a reflect.Value that is not exported.
changeKind(&v, true)
buf.Reset()
d.dump(v)
s = buf.String()
want = "(int8) <int8 Value>"
if s != want {
t.Errorf("TestAddedReflectValue #%d\n got: %s want: %s", i, s, want)
}
i++
// Formatter using a reflect.Value that is exported.
changeKind(&v, false)
buf2 := new(dummyFmtState)
f := formatState{value: v, cs: &Config, fs: buf2}
f.format(v)
s = buf2.String()
want = "5"
if s != want {
t.Errorf("TestAddedReflectValue #%d got: %s want: %s", i, s, want)
}
i++
// Formatter using a reflect.Value that is not exported.
changeKind(&v, true)
buf2.Reset()
f = formatState{value: v, cs: &Config, fs: buf2}
f.format(v)
s = buf2.String()
want = "<int8 Value>"
if s != want {
t.Errorf("TestAddedReflectValue #%d got: %s want: %s", i, s, want)
}
}
// SortValues makes the internal sortValues function available to the test
// package.
func SortValues(values []reflect.Value) {
sortValues(values)
}

View file

@ -1,148 +0,0 @@
/*
* Copyright (c) 2013 Dave Collins <dave@davec.name>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
package spew
import (
"fmt"
"io"
)
// Errorf is a wrapper for fmt.Errorf that treats each argument as if it were
// passed with a default Formatter interface returned by NewFormatter. It
// returns the formatted string as a value that satisfies error. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Errorf(format, spew.NewFormatter(a), spew.NewFormatter(b))
func Errorf(format string, a ...interface{}) (err error) {
return fmt.Errorf(format, convertArgs(a)...)
}
// Fprint is a wrapper for fmt.Fprint that treats each argument as if it were
// passed with a default Formatter interface returned by NewFormatter. It
// returns the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Fprint(w, spew.NewFormatter(a), spew.NewFormatter(b))
func Fprint(w io.Writer, a ...interface{}) (n int, err error) {
return fmt.Fprint(w, convertArgs(a)...)
}
// Fprintf is a wrapper for fmt.Fprintf that treats each argument as if it were
// passed with a default Formatter interface returned by NewFormatter. It
// returns the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Fprintf(w, format, spew.NewFormatter(a), spew.NewFormatter(b))
func Fprintf(w io.Writer, format string, a ...interface{}) (n int, err error) {
return fmt.Fprintf(w, format, convertArgs(a)...)
}
// Fprintln is a wrapper for fmt.Fprintln that treats each argument as if it
// passed with a default Formatter interface returned by NewFormatter. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Fprintln(w, spew.NewFormatter(a), spew.NewFormatter(b))
func Fprintln(w io.Writer, a ...interface{}) (n int, err error) {
return fmt.Fprintln(w, convertArgs(a)...)
}
// Print is a wrapper for fmt.Print that treats each argument as if it were
// passed with a default Formatter interface returned by NewFormatter. It
// returns the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Print(spew.NewFormatter(a), spew.NewFormatter(b))
func Print(a ...interface{}) (n int, err error) {
return fmt.Print(convertArgs(a)...)
}
// Printf is a wrapper for fmt.Printf that treats each argument as if it were
// passed with a default Formatter interface returned by NewFormatter. It
// returns the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Printf(format, spew.NewFormatter(a), spew.NewFormatter(b))
func Printf(format string, a ...interface{}) (n int, err error) {
return fmt.Printf(format, convertArgs(a)...)
}
// Println is a wrapper for fmt.Println that treats each argument as if it were
// passed with a default Formatter interface returned by NewFormatter. It
// returns the number of bytes written and any write error encountered. See
// NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Println(spew.NewFormatter(a), spew.NewFormatter(b))
func Println(a ...interface{}) (n int, err error) {
return fmt.Println(convertArgs(a)...)
}
// Sprint is a wrapper for fmt.Sprint that treats each argument as if it were
// passed with a default Formatter interface returned by NewFormatter. It
// returns the resulting string. See NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Sprint(spew.NewFormatter(a), spew.NewFormatter(b))
func Sprint(a ...interface{}) string {
return fmt.Sprint(convertArgs(a)...)
}
// Sprintf is a wrapper for fmt.Sprintf that treats each argument as if it were
// passed with a default Formatter interface returned by NewFormatter. It
// returns the resulting string. See NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Sprintf(format, spew.NewFormatter(a), spew.NewFormatter(b))
func Sprintf(format string, a ...interface{}) string {
return fmt.Sprintf(format, convertArgs(a)...)
}
// Sprintln is a wrapper for fmt.Sprintln that treats each argument as if it
// were passed with a default Formatter interface returned by NewFormatter. It
// returns the resulting string. See NewFormatter for formatting details.
//
// This function is shorthand for the following syntax:
//
// fmt.Sprintln(spew.NewFormatter(a), spew.NewFormatter(b))
func Sprintln(a ...interface{}) string {
return fmt.Sprintln(convertArgs(a)...)
}
// convertArgs accepts a slice of arguments and returns a slice of the same
// length with each argument converted to a default spew Formatter interface.
func convertArgs(args []interface{}) (formatters []interface{}) {
formatters = make([]interface{}, len(args))
for index, arg := range args {
formatters[index] = NewFormatter(arg)
}
return formatters
}

View file

@ -1,308 +0,0 @@
/*
* Copyright (c) 2013 Dave Collins <dave@davec.name>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
package spew_test
import (
"bytes"
"fmt"
"github.com/davecgh/go-spew/spew"
"io/ioutil"
"os"
"testing"
)
// spewFunc is used to identify which public function of the spew package or
// ConfigState a test applies to.
type spewFunc int
const (
fCSFdump spewFunc = iota
fCSFprint
fCSFprintf
fCSFprintln
fCSPrint
fCSPrintln
fCSSdump
fCSSprint
fCSSprintf
fCSSprintln
fCSErrorf
fCSNewFormatter
fErrorf
fFprint
fFprintln
fPrint
fPrintln
fSdump
fSprint
fSprintf
fSprintln
)
// Map of spewFunc values to names for pretty printing.
var spewFuncStrings = map[spewFunc]string{
fCSFdump: "ConfigState.Fdump",
fCSFprint: "ConfigState.Fprint",
fCSFprintf: "ConfigState.Fprintf",
fCSFprintln: "ConfigState.Fprintln",
fCSSdump: "ConfigState.Sdump",
fCSPrint: "ConfigState.Print",
fCSPrintln: "ConfigState.Println",
fCSSprint: "ConfigState.Sprint",
fCSSprintf: "ConfigState.Sprintf",
fCSSprintln: "ConfigState.Sprintln",
fCSErrorf: "ConfigState.Errorf",
fCSNewFormatter: "ConfigState.NewFormatter",
fErrorf: "spew.Errorf",
fFprint: "spew.Fprint",
fFprintln: "spew.Fprintln",
fPrint: "spew.Print",
fPrintln: "spew.Println",
fSdump: "spew.Sdump",
fSprint: "spew.Sprint",
fSprintf: "spew.Sprintf",
fSprintln: "spew.Sprintln",
}
func (f spewFunc) String() string {
if s, ok := spewFuncStrings[f]; ok {
return s
}
return fmt.Sprintf("Unknown spewFunc (%d)", int(f))
}
// spewTest is used to describe a test to be performed against the public
// functions of the spew package or ConfigState.
type spewTest struct {
cs *spew.ConfigState
f spewFunc
format string
in interface{}
want string
}
// spewTests houses the tests to be performed against the public functions of
// the spew package and ConfigState.
//
// These tests are only intended to ensure the public functions are exercised
// and are intentionally not exhaustive of types. The exhaustive type
// tests are handled in the dump and format tests.
var spewTests []spewTest
// redirStdout is a helper function to return the standard output from f as a
// byte slice.
func redirStdout(f func()) ([]byte, error) {
tempFile, err := ioutil.TempFile("", "ss-test")
if err != nil {
return nil, err
}
fileName := tempFile.Name()
defer os.Remove(fileName) // Ignore error
origStdout := os.Stdout
os.Stdout = tempFile
f()
os.Stdout = origStdout
tempFile.Close()
return ioutil.ReadFile(fileName)
}
func initSpewTests() {
// Config states with various settings.
scsDefault := spew.NewDefaultConfig()
scsNoMethods := &spew.ConfigState{Indent: " ", DisableMethods: true}
scsNoPmethods := &spew.ConfigState{Indent: " ", DisablePointerMethods: true}
scsMaxDepth := &spew.ConfigState{Indent: " ", MaxDepth: 1}
scsContinue := &spew.ConfigState{Indent: " ", ContinueOnMethod: true}
// Variables for tests on types which implement Stringer interface with and
// without a pointer receiver.
ts := stringer("test")
tps := pstringer("test")
// depthTester is used to test max depth handling for structs, array, slices
// and maps.
type depthTester struct {
ic indirCir1
arr [1]string
slice []string
m map[string]int
}
dt := depthTester{indirCir1{nil}, [1]string{"arr"}, []string{"slice"},
map[string]int{"one": 1}}
// Variable for tests on types which implement error interface.
te := customError(10)
spewTests = []spewTest{
{scsDefault, fCSFdump, "", int8(127), "(int8) 127\n"},
{scsDefault, fCSFprint, "", int16(32767), "32767"},
{scsDefault, fCSFprintf, "%v", int32(2147483647), "2147483647"},
{scsDefault, fCSFprintln, "", int(2147483647), "2147483647\n"},
{scsDefault, fCSPrint, "", int64(9223372036854775807), "9223372036854775807"},
{scsDefault, fCSPrintln, "", uint8(255), "255\n"},
{scsDefault, fCSSdump, "", uint8(64), "(uint8) 64\n"},
{scsDefault, fCSSprint, "", complex(1, 2), "(1+2i)"},
{scsDefault, fCSSprintf, "%v", complex(float32(3), 4), "(3+4i)"},
{scsDefault, fCSSprintln, "", complex(float64(5), 6), "(5+6i)\n"},
{scsDefault, fCSErrorf, "%#v", uint16(65535), "(uint16)65535"},
{scsDefault, fCSNewFormatter, "%v", uint32(4294967295), "4294967295"},
{scsDefault, fErrorf, "%v", uint64(18446744073709551615), "18446744073709551615"},
{scsDefault, fFprint, "", float32(3.14), "3.14"},
{scsDefault, fFprintln, "", float64(6.28), "6.28\n"},
{scsDefault, fPrint, "", true, "true"},
{scsDefault, fPrintln, "", false, "false\n"},
{scsDefault, fSdump, "", complex(-10, -20), "(complex128) (-10-20i)\n"},
{scsDefault, fSprint, "", complex(-1, -2), "(-1-2i)"},
{scsDefault, fSprintf, "%v", complex(float32(-3), -4), "(-3-4i)"},
{scsDefault, fSprintln, "", complex(float64(-5), -6), "(-5-6i)\n"},
{scsNoMethods, fCSFprint, "", ts, "test"},
{scsNoMethods, fCSFprint, "", &ts, "<*>test"},
{scsNoMethods, fCSFprint, "", tps, "test"},
{scsNoMethods, fCSFprint, "", &tps, "<*>test"},
{scsNoPmethods, fCSFprint, "", ts, "stringer test"},
{scsNoPmethods, fCSFprint, "", &ts, "<*>stringer test"},
{scsNoPmethods, fCSFprint, "", tps, "test"},
{scsNoPmethods, fCSFprint, "", &tps, "<*>stringer test"},
{scsMaxDepth, fCSFprint, "", dt, "{{<max>} [<max>] [<max>] map[<max>]}"},
{scsMaxDepth, fCSFdump, "", dt, "(spew_test.depthTester) {\n" +
" ic: (spew_test.indirCir1) {\n <max depth reached>\n },\n" +
" arr: ([1]string) (len=1 cap=1) {\n <max depth reached>\n },\n" +
" slice: ([]string) (len=1 cap=1) {\n <max depth reached>\n },\n" +
" m: (map[string]int) (len=1) {\n <max depth reached>\n }\n}\n"},
{scsContinue, fCSFprint, "", ts, "(stringer test) test"},
{scsContinue, fCSFdump, "", ts, "(spew_test.stringer) " +
"(len=4) (stringer test) \"test\"\n"},
{scsContinue, fCSFprint, "", te, "(error: 10) 10"},
{scsContinue, fCSFdump, "", te, "(spew_test.customError) " +
"(error: 10) 10\n"},
}
}
// TestSpew executes all of the tests described by spewTests.
func TestSpew(t *testing.T) {
initSpewTests()
t.Logf("Running %d tests", len(spewTests))
for i, test := range spewTests {
buf := new(bytes.Buffer)
switch test.f {
case fCSFdump:
test.cs.Fdump(buf, test.in)
case fCSFprint:
test.cs.Fprint(buf, test.in)
case fCSFprintf:
test.cs.Fprintf(buf, test.format, test.in)
case fCSFprintln:
test.cs.Fprintln(buf, test.in)
case fCSPrint:
b, err := redirStdout(func() { test.cs.Print(test.in) })
if err != nil {
t.Errorf("%v #%d %v", test.f, i, err)
continue
}
buf.Write(b)
case fCSPrintln:
b, err := redirStdout(func() { test.cs.Println(test.in) })
if err != nil {
t.Errorf("%v #%d %v", test.f, i, err)
continue
}
buf.Write(b)
case fCSSdump:
str := test.cs.Sdump(test.in)
buf.WriteString(str)
case fCSSprint:
str := test.cs.Sprint(test.in)
buf.WriteString(str)
case fCSSprintf:
str := test.cs.Sprintf(test.format, test.in)
buf.WriteString(str)
case fCSSprintln:
str := test.cs.Sprintln(test.in)
buf.WriteString(str)
case fCSErrorf:
err := test.cs.Errorf(test.format, test.in)
buf.WriteString(err.Error())
case fCSNewFormatter:
fmt.Fprintf(buf, test.format, test.cs.NewFormatter(test.in))
case fErrorf:
err := spew.Errorf(test.format, test.in)
buf.WriteString(err.Error())
case fFprint:
spew.Fprint(buf, test.in)
case fFprintln:
spew.Fprintln(buf, test.in)
case fPrint:
b, err := redirStdout(func() { spew.Print(test.in) })
if err != nil {
t.Errorf("%v #%d %v", test.f, i, err)
continue
}
buf.Write(b)
case fPrintln:
b, err := redirStdout(func() { spew.Println(test.in) })
if err != nil {
t.Errorf("%v #%d %v", test.f, i, err)
continue
}
buf.Write(b)
case fSdump:
str := spew.Sdump(test.in)
buf.WriteString(str)
case fSprint:
str := spew.Sprint(test.in)
buf.WriteString(str)
case fSprintf:
str := spew.Sprintf(test.format, test.in)
buf.WriteString(str)
case fSprintln:
str := spew.Sprintln(test.in)
buf.WriteString(str)
default:
t.Errorf("%v #%d unrecognized function", test.f, i)
continue
}
s := buf.String()
if test.want != s {
t.Errorf("ConfigState #%d\n got: %s want: %s", i, s, test.want)
continue
}
}
}

View file

@ -1,82 +0,0 @@
// Copyright (c) 2013 Dave Collins <dave@davec.name>
//
// Permission to use, copy, modify, and distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
// NOTE: Due to the following build constraints, this file will only be compiled
// when both cgo is supported and "-tags testcgo" is added to the go test
// command line. This code should really only be in the dumpcgo_test.go file,
// but unfortunately Go will not allow cgo in test files, so this is a
// workaround to allow cgo types to be tested. This configuration is used
// because spew itself does not require cgo to run even though it does handle
// certain cgo types specially. Rather than forcing all clients to require cgo
// and an external C compiler just to run the tests, this scheme makes them
// optional.
// +build cgo,testcgo
package testdata
/*
#include <stdint.h>
typedef unsigned char custom_uchar_t;
char *ncp = 0;
char *cp = "test";
char ca[6] = {'t', 'e', 's', 't', '2', '\0'};
unsigned char uca[6] = {'t', 'e', 's', 't', '3', '\0'};
signed char sca[6] = {'t', 'e', 's', 't', '4', '\0'};
uint8_t ui8ta[6] = {'t', 'e', 's', 't', '5', '\0'};
custom_uchar_t tuca[6] = {'t', 'e', 's', 't', '6', '\0'};
*/
import "C"
// GetCgoNullCharPointer returns a null char pointer via cgo. This is only
// used for tests.
func GetCgoNullCharPointer() interface{} {
return C.ncp
}
// GetCgoCharPointer returns a char pointer via cgo. This is only used for
// tests.
func GetCgoCharPointer() interface{} {
return C.cp
}
// GetCgoCharArray returns a char array via cgo and the array's len and cap.
// This is only used for tests.
func GetCgoCharArray() (interface{}, int, int) {
return C.ca, len(C.ca), cap(C.ca)
}
// GetCgoUnsignedCharArray returns an unsigned char array via cgo and the
// array's len and cap. This is only used for tests.
func GetCgoUnsignedCharArray() (interface{}, int, int) {
return C.uca, len(C.uca), cap(C.uca)
}
// GetCgoSignedCharArray returns a signed char array via cgo and the array's len
// and cap. This is only used for tests.
func GetCgoSignedCharArray() (interface{}, int, int) {
return C.sca, len(C.sca), cap(C.sca)
}
// GetCgoUint8tArray returns a uint8_t array via cgo and the array's len and
// cap. This is only used for tests.
func GetCgoUint8tArray() (interface{}, int, int) {
return C.ui8ta, len(C.ui8ta), cap(C.ui8ta)
}
// GetCgoTypdefedUnsignedCharArray returns a typedefed unsigned char array via
// cgo and the array's len and cap. This is only used for tests.
func GetCgoTypdefedUnsignedCharArray() (interface{}, int, int) {
return C.tuca, len(C.tuca), cap(C.tuca)
}

View file

@ -1,43 +0,0 @@
#!/bin/bash
# based on http://stackoverflow.com/questions/21126011/is-it-possible-to-post-coverage-for-multiple-packages-to-coveralls
# with script found at https://github.com/gopns/gopns/blob/master/test-coverage.sh
echo "mode: set" > acc.out
returnval=`go test -v -coverprofile=profile.out`
echo ${returnval}
if [[ ${returnval} != *FAIL* ]]
then
if [ -f profile.out ]
then
cat profile.out | grep -v "mode: set" >> acc.out
fi
else
exit 1
fi
for Dir in $(find ./* -maxdepth 10 -type d );
do
if ls $Dir/*.go &> /dev/null;
then
echo $Dir
returnval=`go test -v -coverprofile=profile.out $Dir`
echo ${returnval}
if [[ ${returnval} != *FAIL* ]]
then
if [ -f profile.out ]
then
cat profile.out | grep -v "mode: set" >> acc.out
fi
else
exit 1
fi
fi
done
if [ -n "$COVERALLS_TOKEN" ]
then
$HOME/gopath/bin/goveralls -coverprofile=acc.out -service=travis-ci -repotoken $COVERALLS_TOKEN
fi
rm -rf ./profile.out
rm -rf ./acc.out

View file

@ -1,68 +0,0 @@
language: go
env:
global:
- secure: "gewG9b13l2/JJkag584f/e7vbH+CN5sE/v5IxJLI24vVBsta0L/rUiRN5e/NRXiyNDT4X2XV6R6BLED8VaUo3vDSWHBFtRAuwbMswxRcjDuIGph53zTNukhEwbFThEhZO5vO9T1tECXK1D8ktgQjmqwQ171InUy2loLFWloUTF4=" # at some point, when testing on osx
matrix:
- BLAS_LIB=OpenBLAS && GOARCH=amd64
#- BLAS_LIB=native && GOCROSS=386 && GO386=387 # GOCROSS will be renamed GOARCH to avoid gvm (?) from fiddling with it
# at some point, when travis allows builds on darwin
#- BLAS_LIB=Accellerate
# at some point, when the issue with drotgm is resolved
#- BLAS_LIB=ATLAS
go:
- 1.3.3
- 1.4.1
- tip
before_install:
- sudo apt-get update -qq
- if ! go get code.google.com/p/go.tools/cmd/cover; then go get golang.org/x/tools/cmd/cover; fi
- go get github.com/mattn/goveralls
install:
- if [[ "$BLAS_LIB" == "ATLAS" ]]; then sudo apt-get install -qq libatlas-base-dev; fi
- if [[ "$BLAS_LIB" == "OpenBLAS" ]]; then sudo apt-get install -qq gfortran; fi
- if [[ "$BLAS_LIB" == "OpenBLAS" ]]; then pushd ~; fi
- if [[ "$BLAS_LIB" == "OpenBLAS" ]]; then sudo git clone --depth=1 git://github.com/xianyi/OpenBLAS; fi
- if [[ "$BLAS_LIB" == "OpenBLAS" ]]; then pushd OpenBLAS; fi
- if [[ "$BLAS_LIB" == "OpenBLAS" ]]; then sudo make FC=gfortran &> /dev/null; fi
- if [[ "$BLAS_LIB" == "OpenBLAS" ]]; then sudo make PREFIX=/usr install; fi
- if [[ "$BLAS_LIB" == "OpenBLAS" ]]; then popd; fi
- if [[ "$BLAS_LIB" == "OpenBLAS" ]]; then curl http://www.netlib.org/blas/blast-forum/cblas.tgz | tar -zx; fi
- if [[ "$BLAS_LIB" == "OpenBLAS" ]]; then pushd CBLAS; fi
- if [[ "$BLAS_LIB" == "OpenBLAS" ]]; then sudo mv Makefile.LINUX Makefile.in; fi
- if [[ "$BLAS_LIB" == "OpenBLAS" ]]; then sudo BLLIB=/usr/lib/libopenblas.a make alllib; fi
- if [[ "$BLAS_LIB" == "OpenBLAS" ]]; then sudo mv lib/cblas_LINUX.a /usr/lib/libcblas.a; fi
- if [[ "$BLAS_LIB" == "OpenBLAS" ]]; then popd; fi
- if [[ "$BLAS_LIB" == "OpenBLAS" ]]; then popd; fi
- if [[ "$BLAS_LIB" == "OpenBLAS" ]]; then export CGO_LDFLAGS="-L/usr/lib -lopenblas"; fi
- if [[ "$BLAS_LIB" == "OpenBLAS" ]]; then go get github.com/gonum/blas; fi # get rid of this when the fork is merged
- if [[ "$BLAS_LIB" == "OpenBLAS" ]]; then pushd cgo; fi
- if [[ "$BLAS_LIB" == "OpenBLAS" ]]; then go install -v -x; fi
- if [[ "$BLAS_LIB" == "OpenBLAS" ]]; then popd; fi
- if [[ "$GOCROSS" == "386" && "$TRAVIS_GO_VERSION" != "tip" ]]; then export GOARCH=386; fi
- if [[ "$GOARCH" == "386" ]]; then gvm cross linux 386; fi
script:
- go version
- go env
- env
- if [[ "$BLAS_LIB" == "native" ]]; then pushd native; fi
- go get -d -t -v ./...
- go test -x -v ./...
- diff <(gofmt -d .) <("")
- if [[ $TRAVIS_SECURE_ENV_VARS = "true" ]]; then bash -c "${TRAVIS_BUILD_DIR}/.test-coverage.sh"; fi
after_failure: failure
notifications:
email:
recipients:
- jonathan.lawlor@gmail.com
on_success: change
on_failure: always

View file

@ -1,95 +0,0 @@
# Gonum BLAS [![Build Status](https://travis-ci.org/gonum/blas.svg?branch=master)](https://travis-ci.org/gonum/blas) [![Coverage Status](https://img.shields.io/coveralls/gonum/blas.svg)](https://coveralls.io/r/gonum/blas)
A collection of packages to provide BLAS functionality for the [Go programming
language](http://golang.org)
## Installation
```sh
go get github.com/gonum/blas
```
### BLAS C-bindings
If you want to use OpenBLAS, install it in any directory:
```sh
git clone https://github.com/xianyi/OpenBLAS
cd OpenBLAS
make
```
The blas/cgo package provides bindings to C-backed BLAS packages. blas/cgo needs the `CGO_LDFLAGS`
environment variable to point to the blas installation. More information can be found in the
[cgo command documentation](http://golang.org/cmd/cgo/).
Then install the blas/cgo package:
```sh
CGO_LDFLAGS="-L/path/to/OpenBLAS -lopenblas" go install github.com/gonum/blas/cgo
```
For Windows you can download binary packages for OpenBLAS at
[SourceForge](http://sourceforge.net/projects/openblas/files/).
If you want to use a different BLAS package such as the Intel MKL you can
adjust the `CGO_LDFLAGS` variable:
```sh
CGO_LDFLAGS="-lmkl_rt" go install github.com/gonum/blas/cgo
```
On OS X the easiest solution is to use the libraries provided by the system:
```sh
CGO_LDFLAGS="-framework Accelerate" go install github.com/gonum/blas/cgo
```
## Packages
### blas
Defines [BLAS API](http://www.netlib.org/blas/blast-forum/cinterface.pdf) split in several interfaces
### blas/native
Go implementation of the BLAS API (incomplete, implements the float64 API)
### blas/cgo
Binding to a C implementation of the cblas interface (e.g. ATLAS, OpenBLAS, Intel MKL)
The recommended (free) option for good performance on both Linux and Darwin is OpenBLAS.
### blas/blas64
Wrapper for an implementation of the double precision real (i.e., `float64`) part
of the blas API
```Go
package main
import (
"fmt"
"github.com/gonum/blas/blas64"
)
func main() {
v := blas64.Vector{Inc: 1, Data: []float64{1, 1, 1}}
fmt.Println("v has length:", blas64.Nrm2(len(v.Data), v))
}
```
### blas/cblas128
Wrapper for an implementation of the double precision complex (i.e., `complex128`)
part of the blas API
Currently blas/cblas128 requires blas/cgo.
## Issues
If you find any bugs, feel free to file an issue on the github issue tracker.
Discussions on API changes, added features, code review, or similar requests
are preferred on the [gonum-dev Google Group](https://groups.google.com/forum/#!forum/gonum-dev).
## License
Please see [github.com/gonum/license](https://github.com/gonum/license) for general
license information, contributors, authors, etc on the Gonum suite of packages.

View file

@ -1,388 +0,0 @@
// Copyright ©2013 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package blas provides interfaces for the BLAS linear algebra standard.
All methods must perform appropriate parameter checking and panic if
provided parameters that do not conform to the requirements specified
by the BLAS standard.
Quick Reference Guide to the BLAS from http://www.netlib.org/lapack/lug/node145.html
This version is modified to remove the "order" option. All matrix operations are
on row-order matrices.
Level 1 BLAS
dim scalar vector vector scalars 5-element prefixes
struct
_rotg ( a, b ) S, D
_rotmg( d1, d2, a, b ) S, D
_rot ( n, x, incX, y, incY, c, s ) S, D
_rotm ( n, x, incX, y, incY, param ) S, D
_swap ( n, x, incX, y, incY ) S, D, C, Z
_scal ( n, alpha, x, incX ) S, D, C, Z, Cs, Zd
_copy ( n, x, incX, y, incY ) S, D, C, Z
_axpy ( n, alpha, x, incX, y, incY ) S, D, C, Z
_dot ( n, x, incX, y, incY ) S, D, Ds
_dotu ( n, x, incX, y, incY ) C, Z
_dotc ( n, x, incX, y, incY ) C, Z
__dot ( n, alpha, x, incX, y, incY ) Sds
_nrm2 ( n, x, incX ) S, D, Sc, Dz
_asum ( n, x, incX ) S, D, Sc, Dz
I_amax( n, x, incX ) s, d, c, z
Level 2 BLAS
options dim b-width scalar matrix vector scalar vector prefixes
_gemv ( trans, m, n, alpha, a, lda, x, incX, beta, y, incY ) S, D, C, Z
_gbmv ( trans, m, n, kL, kU, alpha, a, lda, x, incX, beta, y, incY ) S, D, C, Z
_hemv ( uplo, n, alpha, a, lda, x, incX, beta, y, incY ) C, Z
_hbmv ( uplo, n, k, alpha, a, lda, x, incX, beta, y, incY ) C, Z
_hpmv ( uplo, n, alpha, ap, x, incX, beta, y, incY ) C, Z
_symv ( uplo, n, alpha, a, lda, x, incX, beta, y, incY ) S, D
_sbmv ( uplo, n, k, alpha, a, lda, x, incX, beta, y, incY ) S, D
_spmv ( uplo, n, alpha, ap, x, incX, beta, y, incY ) S, D
_trmv ( uplo, trans, diag, n, a, lda, x, incX ) S, D, C, Z
_tbmv ( uplo, trans, diag, n, k, a, lda, x, incX ) S, D, C, Z
_tpmv ( uplo, trans, diag, n, ap, x, incX ) S, D, C, Z
_trsv ( uplo, trans, diag, n, a, lda, x, incX ) S, D, C, Z
_tbsv ( uplo, trans, diag, n, k, a, lda, x, incX ) S, D, C, Z
_tpsv ( uplo, trans, diag, n, ap, x, incX ) S, D, C, Z
options dim scalar vector vector matrix prefixes
_ger ( m, n, alpha, x, incX, y, incY, a, lda ) S, D
_geru ( m, n, alpha, x, incX, y, incY, a, lda ) C, Z
_gerc ( m, n, alpha, x, incX, y, incY, a, lda ) C, Z
_her ( uplo, n, alpha, x, incX, a, lda ) C, Z
_hpr ( uplo, n, alpha, x, incX, ap ) C, Z
_her2 ( uplo, n, alpha, x, incX, y, incY, a, lda ) C, Z
_hpr2 ( uplo, n, alpha, x, incX, y, incY, ap ) C, Z
_syr ( uplo, n, alpha, x, incX, a, lda ) S, D
_spr ( uplo, n, alpha, x, incX, ap ) S, D
_syr2 ( uplo, n, alpha, x, incX, y, incY, a, lda ) S, D
_spr2 ( uplo, n, alpha, x, incX, y, incY, ap ) S, D
Level 3 BLAS
options dim scalar matrix matrix scalar matrix prefixes
_gemm ( transA, transB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc ) S, D, C, Z
_symm ( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc ) S, D, C, Z
_hemm ( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc ) C, Z
_syrk ( uplo, trans, n, k, alpha, a, lda, beta, c, ldc ) S, D, C, Z
_herk ( uplo, trans, n, k, alpha, a, lda, beta, c, ldc ) C, Z
_syr2k( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc ) S, D, C, Z
_her2k( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc ) C, Z
_trmm ( side, uplo, transA, diag, m, n, alpha, a, lda, b, ldb ) S, D, C, Z
_trsm ( side, uplo, transA, diag, m, n, alpha, a, lda, b, ldb ) S, D, C, Z
Meaning of prefixes
S - float32 C - complex64
D - float64 Z - complex128
Matrix types
GE - GEneral GB - General Band
SY - SYmmetric SB - Symmetric Band SP - Symmetric Packed
HE - HErmitian HB - Hermitian Band HP - Hermitian Packed
TR - TRiangular TB - Triangular Band TP - Triangular Packed
Options
trans = NoTrans, Trans, ConjTrans
uplo = Upper, Lower
diag = Nonunit, Unit
side = Left, Right (A or op(A) on the left, or A or op(A) on the right)
For real matrices, Trans and ConjTrans have the same meaning.
For Hermitian matrices, trans = Trans is not allowed.
For complex symmetric matrices, trans = ConjTrans is not allowed.
*/
package blas
// Flag constants indicate Givens transformation H matrix state.
type Flag int
const (
Identity Flag = iota - 2 // H is the identity matrix; no rotation is needed.
Rescaling // H specifies rescaling.
OffDiagonal // Off-diagonal elements of H are units.
Diagonal // Diagonal elements of H are units.
)
// SrotmParams contains Givens transformation parameters returned
// by the Float32 Srotm method.
type SrotmParams struct {
Flag
H [4]float32 // Column-major 2 by 2 matrix.
}
// DrotmParams contains Givens transformation parameters returned
// by the Float64 Drotm method.
type DrotmParams struct {
Flag
H [4]float64 // Column-major 2 by 2 matrix.
}
// Transpose is used to specify the transposition operation for a
// routine.
type Transpose int
const (
NoTrans Transpose = 111 + iota
Trans
ConjTrans
)
// Uplo is used to specify whether the matrix is an upper or lower
// triangular matrix.
type Uplo int
const (
All Uplo = 120 + iota
Upper
Lower
)
// Diag is used to specify whether the matrix is a unit or non-unit
// triangular matrix.
type Diag int
const (
NonUnit Diag = 131 + iota
Unit
)
// Side is used to specify from which side a multiplication operation
// is performed.
type Side int
const (
Left Side = 141 + iota
Right
)
// Float32 implements the single precision real BLAS routines.
type Float32 interface {
Float32Level1
Float32Level2
Float32Level3
}
// Float32Level1 implements the single precision real BLAS Level 1 routines.
type Float32Level1 interface {
Sdsdot(n int, alpha float32, x []float32, incX int, y []float32, incY int) float32
Dsdot(n int, x []float32, incX int, y []float32, incY int) float64
Sdot(n int, x []float32, incX int, y []float32, incY int) float32
Snrm2(n int, x []float32, incX int) float32
Sasum(n int, x []float32, incX int) float32
Isamax(n int, x []float32, incX int) int
Sswap(n int, x []float32, incX int, y []float32, incY int)
Scopy(n int, x []float32, incX int, y []float32, incY int)
Saxpy(n int, alpha float32, x []float32, incX int, y []float32, incY int)
Srotg(a, b float32) (c, s, r, z float32)
Srotmg(d1, d2, b1, b2 float32) (p SrotmParams, rd1, rd2, rb1 float32)
Srot(n int, x []float32, incX int, y []float32, incY int, c, s float32)
Srotm(n int, x []float32, incX int, y []float32, incY int, p SrotmParams)
Sscal(n int, alpha float32, x []float32, incX int)
}
// Float32Level2 implements the single precision real BLAS Level 2 routines.
type Float32Level2 interface {
Sgemv(tA Transpose, m, n int, alpha float32, a []float32, lda int, x []float32, incX int, beta float32, y []float32, incY int)
Sgbmv(tA Transpose, m, n, kL, kU int, alpha float32, a []float32, lda int, x []float32, incX int, beta float32, y []float32, incY int)
Strmv(ul Uplo, tA Transpose, d Diag, n int, a []float32, lda int, x []float32, incX int)
Stbmv(ul Uplo, tA Transpose, d Diag, n, k int, a []float32, lda int, x []float32, incX int)
Stpmv(ul Uplo, tA Transpose, d Diag, n int, ap []float32, x []float32, incX int)
Strsv(ul Uplo, tA Transpose, d Diag, n int, a []float32, lda int, x []float32, incX int)
Stbsv(ul Uplo, tA Transpose, d Diag, n, k int, a []float32, lda int, x []float32, incX int)
Stpsv(ul Uplo, tA Transpose, d Diag, n int, ap []float32, x []float32, incX int)
Ssymv(ul Uplo, n int, alpha float32, a []float32, lda int, x []float32, incX int, beta float32, y []float32, incY int)
Ssbmv(ul Uplo, n, k int, alpha float32, a []float32, lda int, x []float32, incX int, beta float32, y []float32, incY int)
Sspmv(ul Uplo, n int, alpha float32, ap []float32, x []float32, incX int, beta float32, y []float32, incY int)
Sger(m, n int, alpha float32, x []float32, incX int, y []float32, incY int, a []float32, lda int)
Ssyr(ul Uplo, n int, alpha float32, x []float32, incX int, a []float32, lda int)
Sspr(ul Uplo, n int, alpha float32, x []float32, incX int, ap []float32)
Ssyr2(ul Uplo, n int, alpha float32, x []float32, incX int, y []float32, incY int, a []float32, lda int)
Sspr2(ul Uplo, n int, alpha float32, x []float32, incX int, y []float32, incY int, a []float32)
}
// Float32Level3 implements the single precision real BLAS Level 3 routines.
type Float32Level3 interface {
Sgemm(tA, tB Transpose, m, n, k int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int)
Ssymm(s Side, ul Uplo, m, n int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int)
Ssyrk(ul Uplo, t Transpose, n, k int, alpha float32, a []float32, lda int, beta float32, c []float32, ldc int)
Ssyr2k(ul Uplo, t Transpose, n, k int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int)
Strmm(s Side, ul Uplo, tA Transpose, d Diag, m, n int, alpha float32, a []float32, lda int, b []float32, ldb int)
Strsm(s Side, ul Uplo, tA Transpose, d Diag, m, n int, alpha float32, a []float32, lda int, b []float32, ldb int)
}
// Float64 implements the single precision real BLAS routines.
type Float64 interface {
Float64Level1
Float64Level2
Float64Level3
}
// Float64Level1 implements the double precision real BLAS Level 1 routines.
type Float64Level1 interface {
Ddot(n int, x []float64, incX int, y []float64, incY int) float64
Dnrm2(n int, x []float64, incX int) float64
Dasum(n int, x []float64, incX int) float64
Idamax(n int, x []float64, incX int) int
Dswap(n int, x []float64, incX int, y []float64, incY int)
Dcopy(n int, x []float64, incX int, y []float64, incY int)
Daxpy(n int, alpha float64, x []float64, incX int, y []float64, incY int)
Drotg(a, b float64) (c, s, r, z float64)
Drotmg(d1, d2, b1, b2 float64) (p DrotmParams, rd1, rd2, rb1 float64)
Drot(n int, x []float64, incX int, y []float64, incY int, c float64, s float64)
Drotm(n int, x []float64, incX int, y []float64, incY int, p DrotmParams)
Dscal(n int, alpha float64, x []float64, incX int)
}
// Float64Level2 implements the double precision real BLAS Level 2 routines.
type Float64Level2 interface {
Dgemv(tA Transpose, m, n int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int)
Dgbmv(tA Transpose, m, n, kL, kU int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int)
Dtrmv(ul Uplo, tA Transpose, d Diag, n int, a []float64, lda int, x []float64, incX int)
Dtbmv(ul Uplo, tA Transpose, d Diag, n, k int, a []float64, lda int, x []float64, incX int)
Dtpmv(ul Uplo, tA Transpose, d Diag, n int, ap []float64, x []float64, incX int)
Dtrsv(ul Uplo, tA Transpose, d Diag, n int, a []float64, lda int, x []float64, incX int)
Dtbsv(ul Uplo, tA Transpose, d Diag, n, k int, a []float64, lda int, x []float64, incX int)
Dtpsv(ul Uplo, tA Transpose, d Diag, n int, ap []float64, x []float64, incX int)
Dsymv(ul Uplo, n int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int)
Dsbmv(ul Uplo, n, k int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int)
Dspmv(ul Uplo, n int, alpha float64, ap []float64, x []float64, incX int, beta float64, y []float64, incY int)
Dger(m, n int, alpha float64, x []float64, incX int, y []float64, incY int, a []float64, lda int)
Dsyr(ul Uplo, n int, alpha float64, x []float64, incX int, a []float64, lda int)
Dspr(ul Uplo, n int, alpha float64, x []float64, incX int, ap []float64)
Dsyr2(ul Uplo, n int, alpha float64, x []float64, incX int, y []float64, incY int, a []float64, lda int)
Dspr2(ul Uplo, n int, alpha float64, x []float64, incX int, y []float64, incY int, a []float64)
}
// Float64Level3 implements the double precision real BLAS Level 3 routines.
type Float64Level3 interface {
Dgemm(tA, tB Transpose, m, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int)
Dsymm(s Side, ul Uplo, m, n int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int)
Dsyrk(ul Uplo, t Transpose, n, k int, alpha float64, a []float64, lda int, beta float64, c []float64, ldc int)
Dsyr2k(ul Uplo, t Transpose, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int)
Dtrmm(s Side, ul Uplo, tA Transpose, d Diag, m, n int, alpha float64, a []float64, lda int, b []float64, ldb int)
Dtrsm(s Side, ul Uplo, tA Transpose, d Diag, m, n int, alpha float64, a []float64, lda int, b []float64, ldb int)
}
// Complex64 implements the single precision complex BLAS routines.
type Complex64 interface {
Complex64Level1
Complex64Level2
Complex64Level3
}
// Complex64Level1 implements the single precision complex BLAS Level 1 routines.
type Complex64Level1 interface {
Cdotu(n int, x []complex64, incX int, y []complex64, incY int) (dotu complex64)
Cdotc(n int, x []complex64, incX int, y []complex64, incY int) (dotc complex64)
Scnrm2(n int, x []complex64, incX int) float32
Scasum(n int, x []complex64, incX int) float32
Icamax(n int, x []complex64, incX int) int
Cswap(n int, x []complex64, incX int, y []complex64, incY int)
Ccopy(n int, x []complex64, incX int, y []complex64, incY int)
Caxpy(n int, alpha complex64, x []complex64, incX int, y []complex64, incY int)
Cscal(n int, alpha complex64, x []complex64, incX int)
Csscal(n int, alpha float32, x []complex64, incX int)
}
// Complex64Level2 implements the single precision complex BLAS routines Level 2 routines.
type Complex64Level2 interface {
Cgemv(tA Transpose, m, n int, alpha complex64, a []complex64, lda int, x []complex64, incX int, beta complex64, y []complex64, incY int)
Cgbmv(tA Transpose, m, n, kL, kU int, alpha complex64, a []complex64, lda int, x []complex64, incX int, beta complex64, y []complex64, incY int)
Ctrmv(ul Uplo, tA Transpose, d Diag, n int, a []complex64, lda int, x []complex64, incX int)
Ctbmv(ul Uplo, tA Transpose, d Diag, n, k int, a []complex64, lda int, x []complex64, incX int)
Ctpmv(ul Uplo, tA Transpose, d Diag, n int, ap []complex64, x []complex64, incX int)
Ctrsv(ul Uplo, tA Transpose, d Diag, n int, a []complex64, lda int, x []complex64, incX int)
Ctbsv(ul Uplo, tA Transpose, d Diag, n, k int, a []complex64, lda int, x []complex64, incX int)
Ctpsv(ul Uplo, tA Transpose, d Diag, n int, ap []complex64, x []complex64, incX int)
Chemv(ul Uplo, n int, alpha complex64, a []complex64, lda int, x []complex64, incX int, beta complex64, y []complex64, incY int)
Chbmv(ul Uplo, n, k int, alpha complex64, a []complex64, lda int, x []complex64, incX int, beta complex64, y []complex64, incY int)
Chpmv(ul Uplo, n int, alpha complex64, ap []complex64, x []complex64, incX int, beta complex64, y []complex64, incY int)
Cgeru(m, n int, alpha complex64, x []complex64, incX int, y []complex64, incY int, a []complex64, lda int)
Cgerc(m, n int, alpha complex64, x []complex64, incX int, y []complex64, incY int, a []complex64, lda int)
Cher(ul Uplo, n int, alpha float32, x []complex64, incX int, a []complex64, lda int)
Chpr(ul Uplo, n int, alpha float32, x []complex64, incX int, a []complex64)
Cher2(ul Uplo, n int, alpha complex64, x []complex64, incX int, y []complex64, incY int, a []complex64, lda int)
Chpr2(ul Uplo, n int, alpha complex64, x []complex64, incX int, y []complex64, incY int, ap []complex64)
}
// Complex64Level3 implements the single precision complex BLAS Level 3 routines.
type Complex64Level3 interface {
Cgemm(tA, tB Transpose, m, n, k int, alpha complex64, a []complex64, lda int, b []complex64, ldb int, beta complex64, c []complex64, ldc int)
Csymm(s Side, ul Uplo, m, n int, alpha complex64, a []complex64, lda int, b []complex64, ldb int, beta complex64, c []complex64, ldc int)
Csyrk(ul Uplo, t Transpose, n, k int, alpha complex64, a []complex64, lda int, beta complex64, c []complex64, ldc int)
Csyr2k(ul Uplo, t Transpose, n, k int, alpha complex64, a []complex64, lda int, b []complex64, ldb int, beta complex64, c []complex64, ldc int)
Ctrmm(s Side, ul Uplo, tA Transpose, d Diag, m, n int, alpha complex64, a []complex64, lda int, b []complex64, ldb int)
Ctrsm(s Side, ul Uplo, tA Transpose, d Diag, m, n int, alpha complex64, a []complex64, lda int, b []complex64, ldb int)
Chemm(s Side, ul Uplo, m, n int, alpha complex64, a []complex64, lda int, b []complex64, ldb int, beta complex64, c []complex64, ldc int)
Cherk(ul Uplo, t Transpose, n, k int, alpha float32, a []complex64, lda int, beta float32, c []complex64, ldc int)
Cher2k(ul Uplo, t Transpose, n, k int, alpha complex64, a []complex64, lda int, b []complex64, ldb int, beta float32, c []complex64, ldc int)
}
// Complex128 implements the double precision complex BLAS routines.
type Complex128 interface {
Complex128Level1
Complex128Level2
Complex128Level3
}
// Complex128Level1 implements the double precision complex BLAS Level 1 routines.
type Complex128Level1 interface {
Zdotu(n int, x []complex128, incX int, y []complex128, incY int) (dotu complex128)
Zdotc(n int, x []complex128, incX int, y []complex128, incY int) (dotc complex128)
Dznrm2(n int, x []complex128, incX int) float64
Dzasum(n int, x []complex128, incX int) float64
Izamax(n int, x []complex128, incX int) int
Zswap(n int, x []complex128, incX int, y []complex128, incY int)
Zcopy(n int, x []complex128, incX int, y []complex128, incY int)
Zaxpy(n int, alpha complex128, x []complex128, incX int, y []complex128, incY int)
Zscal(n int, alpha complex128, x []complex128, incX int)
Zdscal(n int, alpha float64, x []complex128, incX int)
}
// Complex128Level2 implements the double precision complex BLAS Level 2 routines.
type Complex128Level2 interface {
Zgemv(tA Transpose, m, n int, alpha complex128, a []complex128, lda int, x []complex128, incX int, beta complex128, y []complex128, incY int)
Zgbmv(tA Transpose, m, n int, kL int, kU int, alpha complex128, a []complex128, lda int, x []complex128, incX int, beta complex128, y []complex128, incY int)
Ztrmv(ul Uplo, tA Transpose, d Diag, n int, a []complex128, lda int, x []complex128, incX int)
Ztbmv(ul Uplo, tA Transpose, d Diag, n, k int, a []complex128, lda int, x []complex128, incX int)
Ztpmv(ul Uplo, tA Transpose, d Diag, n int, ap []complex128, x []complex128, incX int)
Ztrsv(ul Uplo, tA Transpose, d Diag, n int, a []complex128, lda int, x []complex128, incX int)
Ztbsv(ul Uplo, tA Transpose, d Diag, n, k int, a []complex128, lda int, x []complex128, incX int)
Ztpsv(ul Uplo, tA Transpose, d Diag, n int, ap []complex128, x []complex128, incX int)
Zhemv(ul Uplo, n int, alpha complex128, a []complex128, lda int, x []complex128, incX int, beta complex128, y []complex128, incY int)
Zhbmv(ul Uplo, n, k int, alpha complex128, a []complex128, lda int, x []complex128, incX int, beta complex128, y []complex128, incY int)
Zhpmv(ul Uplo, n int, alpha complex128, ap []complex128, x []complex128, incX int, beta complex128, y []complex128, incY int)
Zgeru(m, n int, alpha complex128, x []complex128, incX int, y []complex128, incY int, a []complex128, lda int)
Zgerc(m, n int, alpha complex128, x []complex128, incX int, y []complex128, incY int, a []complex128, lda int)
Zher(ul Uplo, n int, alpha float64, x []complex128, incX int, a []complex128, lda int)
Zhpr(ul Uplo, n int, alpha float64, x []complex128, incX int, a []complex128)
Zher2(ul Uplo, n int, alpha complex128, x []complex128, incX int, y []complex128, incY int, a []complex128, lda int)
Zhpr2(ul Uplo, n int, alpha complex128, x []complex128, incX int, y []complex128, incY int, ap []complex128)
}
// Complex128Level3 implements the double precision complex BLAS Level 3 routines.
type Complex128Level3 interface {
Zgemm(tA, tB Transpose, m, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int)
Zsymm(s Side, ul Uplo, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int)
Zsyrk(ul Uplo, t Transpose, n, k int, alpha complex128, a []complex128, lda int, beta complex128, c []complex128, ldc int)
Zsyr2k(ul Uplo, t Transpose, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int)
Ztrmm(s Side, ul Uplo, tA Transpose, d Diag, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int)
Ztrsm(s Side, ul Uplo, tA Transpose, d Diag, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int)
Zhemm(s Side, ul Uplo, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int)
Zherk(ul Uplo, t Transpose, n, k int, alpha float64, a []complex128, lda int, beta float64, c []complex128, ldc int)
Zher2k(ul Uplo, t Transpose, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta float64, c []complex128, ldc int)
}

View file

@ -1,286 +0,0 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package blas64 provides a simple interface to the float64 BLAS API.
package blas64
import (
"github.com/gonum/blas"
"github.com/gonum/blas/native"
)
var blas64 blas.Float64 = native.Implementation{}
// Use sets the BLAS float64 implementation to be used by subsequent BLAS calls.
// The default implementation is native.Implementation.
func Use(b blas.Float64) {
blas64 = b
}
// Implementation returns the current BLAS float64 implementation.
//
// Implementation allows direct calls to the current the BLAS float64 implementation
// giving finer control of parameters.
func Implementation() blas.Float64 {
return blas64
}
// Vector represents a vector with an associated element increment.
type Vector struct {
Inc int
Data []float64
}
// General represents a matrix using the conventional storage scheme.
type General struct {
Rows, Cols int
Stride int
Data []float64
}
// Band represents a band matrix using the band storage scheme.
type Band struct {
Rows, Cols int
KL, KU int
Stride int
Data []float64
}
// Triangular represents a triangular matrix using the conventional storage scheme.
type Triangular struct {
N int
Stride int
Data []float64
Uplo blas.Uplo
Diag blas.Diag
}
// TriangularBand represents a triangular matrix using the band storage scheme.
type TriangularBand struct {
N, K int
Stride int
Data []float64
Uplo blas.Uplo
Diag blas.Diag
}
// TriangularPacked represents a triangular matrix using the packed storage scheme.
type TriangularPacked struct {
N int
Data []float64
Uplo blas.Uplo
Diag blas.Diag
}
// Symmetric represents a symmetric matrix using the conventional storage scheme.
type Symmetric struct {
N int
Stride int
Data []float64
Uplo blas.Uplo
}
// SymmetricBand represents a symmetric matrix using the band storage scheme.
type SymmetricBand struct {
N, K int
Stride int
Data []float64
Uplo blas.Uplo
}
// SymmetricPacked represents a symmetric matrix using the packed storage scheme.
type SymmetricPacked struct {
N int
Data []float64
Uplo blas.Uplo
}
// Level 1
const negInc = "blas64: negative vector increment"
func Dot(n int, x, y Vector) float64 {
return blas64.Ddot(n, x.Data, x.Inc, y.Data, y.Inc)
}
// Nrm2 will panic if the vector increment is negative.
func Nrm2(n int, x Vector) float64 {
if x.Inc < 0 {
panic(negInc)
}
return blas64.Dnrm2(n, x.Data, x.Inc)
}
// Asum will panic if the vector increment is negative.
func Asum(n int, x Vector) float64 {
if x.Inc < 0 {
panic(negInc)
}
return blas64.Dasum(n, x.Data, x.Inc)
}
// Iamax will panic if the vector increment is negative.
func Iamax(n int, x Vector) int {
if x.Inc < 0 {
panic(negInc)
}
return blas64.Idamax(n, x.Data, x.Inc)
}
func Swap(n int, x, y Vector) {
blas64.Dswap(n, x.Data, x.Inc, y.Data, y.Inc)
}
func Copy(n int, x, y Vector) {
blas64.Dcopy(n, x.Data, x.Inc, y.Data, y.Inc)
}
func Axpy(n int, alpha float64, x, y Vector) {
blas64.Daxpy(n, alpha, x.Data, x.Inc, y.Data, y.Inc)
}
func Rotg(a, b float64) (c, s, r, z float64) {
return blas64.Drotg(a, b)
}
func Rotmg(d1, d2, b1, b2 float64) (p blas.DrotmParams, rd1, rd2, rb1 float64) {
return blas64.Drotmg(d1, d2, b1, b2)
}
func Rot(n int, x, y Vector, c, s float64) {
blas64.Drot(n, x.Data, x.Inc, y.Data, y.Inc, c, s)
}
func Rotm(n int, x, y Vector, p blas.DrotmParams) {
blas64.Drotm(n, x.Data, x.Inc, y.Data, y.Inc, p)
}
// Scal will panic if the vector increment is negative
func Scal(n int, alpha float64, x Vector) {
if x.Inc < 0 {
panic(negInc)
}
blas64.Dscal(n, alpha, x.Data, x.Inc)
}
// Level 2
func Gemv(tA blas.Transpose, alpha float64, a General, x Vector, beta float64, y Vector) {
blas64.Dgemv(tA, a.Rows, a.Cols, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
}
func Gbmv(tA blas.Transpose, alpha float64, a Band, x Vector, beta float64, y Vector) {
blas64.Dgbmv(tA, a.Rows, a.Cols, a.KL, a.KU, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
}
func Trmv(tA blas.Transpose, a Triangular, x Vector) {
blas64.Dtrmv(a.Uplo, tA, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc)
}
func Tbmv(tA blas.Transpose, a TriangularBand, x Vector) {
blas64.Dtbmv(a.Uplo, tA, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc)
}
func Tpmv(tA blas.Transpose, a TriangularPacked, x Vector) {
blas64.Dtpmv(a.Uplo, tA, a.Diag, a.N, a.Data, x.Data, x.Inc)
}
func Trsv(tA blas.Transpose, a Triangular, x Vector) {
blas64.Dtrsv(a.Uplo, tA, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc)
}
func Tbsv(tA blas.Transpose, a TriangularBand, x Vector) {
blas64.Dtbsv(a.Uplo, tA, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc)
}
func Tpsv(tA blas.Transpose, a TriangularPacked, x Vector) {
blas64.Dtpsv(a.Uplo, tA, a.Diag, a.N, a.Data, x.Data, x.Inc)
}
func Symv(alpha float64, a Symmetric, x Vector, beta float64, y Vector) {
blas64.Dsymv(a.Uplo, a.N, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
}
func Sbmv(alpha float64, a SymmetricBand, x Vector, beta float64, y Vector) {
blas64.Dsbmv(a.Uplo, a.N, a.K, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
}
func Spmv(alpha float64, a SymmetricPacked, x Vector, beta float64, y Vector) {
blas64.Dspmv(a.Uplo, a.N, alpha, a.Data, x.Data, x.Inc, beta, y.Data, y.Inc)
}
func Ger(alpha float64, x, y Vector, a General) {
blas64.Dger(a.Rows, a.Cols, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
}
func Syr(alpha float64, x Vector, a Symmetric) {
blas64.Dsyr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data, a.Stride)
}
func Spr(alpha float64, x Vector, a SymmetricPacked) {
blas64.Dspr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data)
}
func Syr2(alpha float64, x, y Vector, a Symmetric) {
blas64.Dsyr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
}
func Spr2(alpha float64, x, y Vector, a SymmetricPacked) {
blas64.Dspr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data)
}
// Level 3
func Gemm(tA, tB blas.Transpose, alpha float64, a, b General, beta float64, c General) {
var m, n, k int
if tA == blas.NoTrans {
m, k = a.Rows, a.Cols
} else {
m, k = a.Cols, a.Rows
}
if tB == blas.NoTrans {
n = b.Cols
} else {
n = b.Rows
}
blas64.Dgemm(tA, tB, m, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
}
func Symm(s blas.Side, alpha float64, a Symmetric, b General, beta float64, c General) {
var m, n int
if s == blas.Left {
m, n = a.N, b.Cols
} else {
m, n = b.Rows, a.N
}
blas64.Dsymm(s, a.Uplo, m, n, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
}
func Syrk(t blas.Transpose, alpha float64, a General, beta float64, c Symmetric) {
var n, k int
if t == blas.NoTrans {
n, k = a.Rows, a.Cols
} else {
n, k = a.Cols, a.Rows
}
blas64.Dsyrk(c.Uplo, t, n, k, alpha, a.Data, a.Stride, beta, c.Data, c.Stride)
}
func Syr2k(t blas.Transpose, alpha float64, a, b General, beta float64, c Symmetric) {
var n, k int
if t == blas.NoTrans {
n, k = a.Rows, a.Cols
} else {
n, k = a.Cols, a.Rows
}
blas64.Dsyr2k(c.Uplo, t, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
}
func Trmm(s blas.Side, tA blas.Transpose, alpha float64, a Triangular, b General) {
blas64.Dtrmm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride)
}
func Trsm(s blas.Side, tA blas.Transpose, alpha float64, a Triangular, b General) {
blas64.Dtrsm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride)
}

View file

@ -1,327 +0,0 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package cblas128 provides a simple interface to the complex128 BLAS API.
package cblas128
import (
"github.com/gonum/blas"
"github.com/gonum/blas/cgo"
)
// TODO(kortschak): Change this and the comment below to native.Implementation
// when blas/native covers the complex BLAS API.
var cblas128 blas.Complex128 = cgo.Implementation{}
// Use sets the BLAS complex128 implementation to be used by subsequent BLAS calls.
// The default implementation is cgo.Implementation.
func Use(b blas.Complex128) {
cblas128 = b
}
// Implementation returns the current BLAS complex128 implementation.
//
// Implementation allows direct calls to the current the BLAS complex128 implementation
// giving finer control of parameters.
func Implementation() blas.Complex128 {
return cblas128
}
// Vector represents a vector with an associated element increment.
type Vector struct {
Inc int
Data []complex128
}
// General represents a matrix using the conventional storage scheme.
type General struct {
Rows, Cols int
Stride int
Data []complex128
}
// Band represents a band matrix using the band storage scheme.
type Band struct {
Rows, Cols int
KL, KU int
Stride int
Data []complex128
}
// Triangular represents a triangular matrix using the conventional storage scheme.
type Triangular struct {
N int
Stride int
Data []complex128
Uplo blas.Uplo
Diag blas.Diag
}
// TriangularBand represents a triangular matrix using the band storage scheme.
type TriangularBand struct {
N, K int
Stride int
Data []complex128
Uplo blas.Uplo
Diag blas.Diag
}
// TriangularPacked represents a triangular matrix using the packed storage scheme.
type TriangularPacked struct {
N int
Data []complex128
Uplo blas.Uplo
Diag blas.Diag
}
// Symmetric represents a symmetric matrix using the conventional storage scheme.
type Symmetric struct {
N int
Stride int
Data []complex128
Uplo blas.Uplo
}
// SymmetricBand represents a symmetric matrix using the band storage scheme.
type SymmetricBand struct {
N, K int
Stride int
Data []complex128
Uplo blas.Uplo
}
// SymmetricPacked represents a symmetric matrix using the packed storage scheme.
type SymmetricPacked struct {
N int
Data []complex128
Uplo blas.Uplo
}
// Hermitian represents an Hermitian matrix using the conventional storage scheme.
type Hermitian Symmetric
// HermitianBand represents an Hermitian matrix using the band storage scheme.
type HermitianBand SymmetricBand
// HermitianPacked represents an Hermitian matrix using the packed storage scheme.
type HermitianPacked SymmetricPacked
// Level 1
const negInc = "cblas128: negative vector increment"
func Dotu(n int, x, y Vector) complex128 {
return cblas128.Zdotu(n, x.Data, x.Inc, y.Data, y.Inc)
}
func Dotc(n int, x, y Vector) complex128 {
return cblas128.Zdotc(n, x.Data, x.Inc, y.Data, y.Inc)
}
// Nrm2 will panic if the vector increment is negative.
func Nrm2(n int, x Vector) float64 {
if x.Inc < 0 {
panic(negInc)
}
return cblas128.Dznrm2(n, x.Data, x.Inc)
}
// Asum will panic if the vector increment is negative.
func Asum(n int, x Vector) float64 {
if x.Inc < 0 {
panic(negInc)
}
return cblas128.Dzasum(n, x.Data, x.Inc)
}
// Iamax will panic if the vector increment is negative.
func Iamax(n int, x Vector) int {
if x.Inc < 0 {
panic(negInc)
}
return cblas128.Izamax(n, x.Data, x.Inc)
}
func Swap(n int, x, y Vector) {
cblas128.Zswap(n, x.Data, x.Inc, y.Data, y.Inc)
}
func Copy(n int, x, y Vector) {
cblas128.Zcopy(n, x.Data, x.Inc, y.Data, y.Inc)
}
func Axpy(n int, alpha complex128, x, y Vector) {
cblas128.Zaxpy(n, alpha, x.Data, x.Inc, y.Data, y.Inc)
}
// Scal will panic if the vector increment is negative
func Scal(n int, alpha complex128, x Vector) {
if x.Inc < 0 {
panic(negInc)
}
cblas128.Zscal(n, alpha, x.Data, x.Inc)
}
// Dscal will panic if the vector increment is negative
func Dscal(n int, alpha float64, x Vector) {
if x.Inc < 0 {
panic(negInc)
}
cblas128.Zdscal(n, alpha, x.Data, x.Inc)
}
// Level 2
func Gemv(tA blas.Transpose, alpha complex128, a General, x Vector, beta complex128, y Vector) {
cblas128.Zgemv(tA, a.Rows, a.Cols, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
}
func Gbmv(tA blas.Transpose, alpha complex128, a Band, x Vector, beta complex128, y Vector) {
cblas128.Zgbmv(tA, a.Rows, a.Cols, a.KL, a.KU, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
}
func Trmv(tA blas.Transpose, a Triangular, x Vector) {
cblas128.Ztrmv(a.Uplo, tA, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc)
}
func Tbmv(tA blas.Transpose, a TriangularBand, x Vector) {
cblas128.Ztbmv(a.Uplo, tA, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc)
}
func Tpmv(tA blas.Transpose, a TriangularPacked, x Vector) {
cblas128.Ztpmv(a.Uplo, tA, a.Diag, a.N, a.Data, x.Data, x.Inc)
}
func Trsv(tA blas.Transpose, a Triangular, x Vector) {
cblas128.Ztrsv(a.Uplo, tA, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc)
}
func Tbsv(tA blas.Transpose, a TriangularBand, x Vector) {
cblas128.Ztbsv(a.Uplo, tA, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc)
}
func Tpsv(tA blas.Transpose, a TriangularPacked, x Vector) {
cblas128.Ztpsv(a.Uplo, tA, a.Diag, a.N, a.Data, x.Data, x.Inc)
}
func Hemv(alpha complex128, a Hermitian, x Vector, beta complex128, y Vector) {
cblas128.Zhemv(a.Uplo, a.N, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
}
func Hbmv(alpha complex128, a HermitianBand, x Vector, beta complex128, y Vector) {
cblas128.Zhbmv(a.Uplo, a.N, a.K, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
}
func Hpmv(alpha complex128, a HermitianPacked, x Vector, beta complex128, y Vector) {
cblas128.Zhpmv(a.Uplo, a.N, alpha, a.Data, x.Data, x.Inc, beta, y.Data, y.Inc)
}
func Geru(alpha complex128, x, y Vector, a General) {
cblas128.Zgeru(a.Rows, a.Cols, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
}
func Gerc(alpha complex128, x, y Vector, a General) {
cblas128.Zgerc(a.Rows, a.Cols, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
}
func Her(alpha float64, x Vector, a Hermitian) {
cblas128.Zher(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data, a.Stride)
}
func Hpr(alpha float64, x Vector, a HermitianPacked) {
cblas128.Zhpr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data)
}
func Her2(alpha complex128, x, y Vector, a Hermitian) {
cblas128.Zher2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
}
func Hpr2(alpha complex128, x, y Vector, a HermitianPacked) {
cblas128.Zhpr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data)
}
// Level 3
func Gemm(tA, tB blas.Transpose, alpha complex128, a, b General, beta complex128, c General) {
var m, n, k int
if tA == blas.NoTrans {
m, k = a.Rows, a.Cols
} else {
m, k = a.Cols, a.Rows
}
if tB == blas.NoTrans {
n = b.Cols
} else {
n = b.Rows
}
cblas128.Zgemm(tA, tB, m, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
}
func Symm(s blas.Side, alpha complex128, a Symmetric, b General, beta complex128, c General) {
var m, n int
if s == blas.Left {
m, n = a.N, b.Cols
} else {
m, n = b.Rows, a.N
}
cblas128.Zsymm(s, a.Uplo, m, n, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
}
func Syrk(t blas.Transpose, alpha complex128, a General, beta complex128, c Symmetric) {
var n, k int
if t == blas.NoTrans {
n, k = a.Rows, a.Cols
} else {
n, k = a.Cols, a.Rows
}
cblas128.Zsyrk(c.Uplo, t, n, k, alpha, a.Data, a.Stride, beta, c.Data, c.Stride)
}
func Syr2k(t blas.Transpose, alpha complex128, a, b General, beta complex128, c Symmetric) {
var n, k int
if t == blas.NoTrans {
n, k = a.Rows, a.Cols
} else {
n, k = a.Cols, a.Rows
}
cblas128.Zsyr2k(c.Uplo, t, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
}
func Trmm(s blas.Side, tA blas.Transpose, alpha complex128, a Triangular, b General) {
cblas128.Ztrmm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride)
}
func Trsm(s blas.Side, tA blas.Transpose, alpha complex128, a Triangular, b General) {
cblas128.Ztrsm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride)
}
func Hemm(s blas.Side, alpha complex128, a Hermitian, b General, beta complex128, c General) {
var m, n int
if s == blas.Left {
m, n = a.N, b.Cols
} else {
m, n = b.Rows, a.N
}
cblas128.Zhemm(s, a.Uplo, m, n, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
}
func Herk(t blas.Transpose, alpha float64, a General, beta float64, c Hermitian) {
var n, k int
if t == blas.NoTrans {
n, k = a.Rows, a.Cols
} else {
n, k = a.Cols, a.Rows
}
cblas128.Zherk(c.Uplo, t, n, k, alpha, a.Data, a.Stride, beta, c.Data, c.Stride)
}
func Her2k(t blas.Transpose, alpha complex128, a, b General, beta float64, c Hermitian) {
var n, k int
if t == blas.NoTrans {
n, k = a.Rows, a.Cols
} else {
n, k = a.Cols, a.Rows
}
cblas128.Zher2k(c.Uplo, t, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
}

View file

@ -1,18 +0,0 @@
package cgo
import (
"github.com/gonum/blas"
"github.com/gonum/blas/testblas"
)
const (
Sm = testblas.SmallMat
Med = testblas.MediumMat
Lg = testblas.LargeMat
Hg = testblas.HugeMat
)
const (
T = blas.Trans
NT = blas.NoTrans
)

File diff suppressed because it is too large Load diff

View file

@ -1,596 +0,0 @@
#ifndef CBLAS_H
#ifndef CBLAS_ENUM_DEFINED_H
#define CBLAS_ENUM_DEFINED_H
enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102 };
enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113,
AtlasConj=114};
enum CBLAS_UPLO {CblasUpper=121, CblasLower=122};
enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132};
enum CBLAS_SIDE {CblasLeft=141, CblasRight=142};
#endif
#ifndef CBLAS_ENUM_ONLY
#define CBLAS_H
#define CBLAS_INDEX int
int cblas_errprn(int ierr, int info, char *form, ...);
/*
* ===========================================================================
* Prototypes for level 1 BLAS functions (complex are recast as routines)
* ===========================================================================
*/
float cblas_sdsdot(const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY);
double cblas_dsdot(const int N, const float *X, const int incX, const float *Y,
const int incY);
float cblas_sdot(const int N, const float *X, const int incX,
const float *Y, const int incY);
double cblas_ddot(const int N, const double *X, const int incX,
const double *Y, const int incY);
/*
* Functions having prefixes Z and C only
*/
void cblas_cdotu_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotu);
void cblas_cdotc_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotc);
void cblas_zdotu_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotu);
void cblas_zdotc_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotc);
/*
* Functions having prefixes S D SC DZ
*/
float cblas_snrm2(const int N, const float *X, const int incX);
float cblas_sasum(const int N, const float *X, const int incX);
double cblas_dnrm2(const int N, const double *X, const int incX);
double cblas_dasum(const int N, const double *X, const int incX);
float cblas_scnrm2(const int N, const void *X, const int incX);
float cblas_scasum(const int N, const void *X, const int incX);
double cblas_dznrm2(const int N, const void *X, const int incX);
double cblas_dzasum(const int N, const void *X, const int incX);
/*
* Functions having standard 4 prefixes (S D C Z)
*/
CBLAS_INDEX cblas_isamax(const int N, const float *X, const int incX);
CBLAS_INDEX cblas_idamax(const int N, const double *X, const int incX);
CBLAS_INDEX cblas_icamax(const int N, const void *X, const int incX);
CBLAS_INDEX cblas_izamax(const int N, const void *X, const int incX);
/*
* ===========================================================================
* Prototypes for level 1 BLAS routines
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (s, d, c, z)
*/
void cblas_sswap(const int N, float *X, const int incX,
float *Y, const int incY);
void cblas_scopy(const int N, const float *X, const int incX,
float *Y, const int incY);
void cblas_saxpy(const int N, const float alpha, const float *X,
const int incX, float *Y, const int incY);
void catlas_saxpby(const int N, const float alpha, const float *X,
const int incX, const float beta, float *Y, const int incY);
void catlas_sset
(const int N, const float alpha, float *X, const int incX);
void cblas_dswap(const int N, double *X, const int incX,
double *Y, const int incY);
void cblas_dcopy(const int N, const double *X, const int incX,
double *Y, const int incY);
void cblas_daxpy(const int N, const double alpha, const double *X,
const int incX, double *Y, const int incY);
void catlas_daxpby(const int N, const double alpha, const double *X,
const int incX, const double beta, double *Y, const int incY);
void catlas_dset
(const int N, const double alpha, double *X, const int incX);
void cblas_cswap(const int N, void *X, const int incX,
void *Y, const int incY);
void cblas_ccopy(const int N, const void *X, const int incX,
void *Y, const int incY);
void cblas_caxpy(const int N, const void *alpha, const void *X,
const int incX, void *Y, const int incY);
void catlas_caxpby(const int N, const void *alpha, const void *X,
const int incX, const void *beta, void *Y, const int incY);
void catlas_cset
(const int N, const void *alpha, void *X, const int incX);
void cblas_zswap(const int N, void *X, const int incX,
void *Y, const int incY);
void cblas_zcopy(const int N, const void *X, const int incX,
void *Y, const int incY);
void cblas_zaxpy(const int N, const void *alpha, const void *X,
const int incX, void *Y, const int incY);
void catlas_zaxpby(const int N, const void *alpha, const void *X,
const int incX, const void *beta, void *Y, const int incY);
void catlas_zset
(const int N, const void *alpha, void *X, const int incX);
/*
* Routines with S and D prefix only
*/
void cblas_srotg(float *a, float *b, float *c, float *s);
void cblas_srotmg(float *d1, float *d2, float *b1, const float b2, float *P);
void cblas_srot(const int N, float *X, const int incX,
float *Y, const int incY, const float c, const float s);
void cblas_srotm(const int N, float *X, const int incX,
float *Y, const int incY, const float *P);
void cblas_drotg(double *a, double *b, double *c, double *s);
void cblas_drotmg(double *d1, double *d2, double *b1, const double b2, double *P);
void cblas_drot(const int N, double *X, const int incX,
double *Y, const int incY, const double c, const double s);
void cblas_drotm(const int N, double *X, const int incX,
double *Y, const int incY, const double *P);
/*
* Routines with S D C Z CS and ZD prefixes
*/
void cblas_sscal(const int N, const float alpha, float *X, const int incX);
void cblas_dscal(const int N, const double alpha, double *X, const int incX);
void cblas_cscal(const int N, const void *alpha, void *X, const int incX);
void cblas_zscal(const int N, const void *alpha, void *X, const int incX);
void cblas_csscal(const int N, const float alpha, void *X, const int incX);
void cblas_zdscal(const int N, const double alpha, void *X, const int incX);
/*
* Extra reference routines provided by ATLAS, but not mandated by the standard
*/
void cblas_crotg(void *a, void *b, void *c, void *s);
void cblas_zrotg(void *a, void *b, void *c, void *s);
void cblas_csrot(const int N, void *X, const int incX, void *Y, const int incY,
const float c, const float s);
void cblas_zdrot(const int N, void *X, const int incX, void *Y, const int incY,
const double c, const double s);
/*
* ===========================================================================
* Prototypes for level 2 BLAS
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (S, D, C, Z)
*/
void cblas_sgemv(const enum CBLAS_ORDER Order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const float alpha, const float *A, const int lda,
const float *X, const int incX, const float beta,
float *Y, const int incY);
void cblas_sgbmv(const enum CBLAS_ORDER Order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const float alpha,
const float *A, const int lda, const float *X,
const int incX, const float beta, float *Y, const int incY);
void cblas_strmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *A, const int lda,
float *X, const int incX);
void cblas_stbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const float *A, const int lda,
float *X, const int incX);
void cblas_stpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *Ap, float *X, const int incX);
void cblas_strsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *A, const int lda, float *X,
const int incX);
void cblas_stbsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const float *A, const int lda,
float *X, const int incX);
void cblas_stpsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *Ap, float *X, const int incX);
void cblas_dgemv(const enum CBLAS_ORDER Order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const double alpha, const double *A, const int lda,
const double *X, const int incX, const double beta,
double *Y, const int incY);
void cblas_dgbmv(const enum CBLAS_ORDER Order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const double alpha,
const double *A, const int lda, const double *X,
const int incX, const double beta, double *Y, const int incY);
void cblas_dtrmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *A, const int lda,
double *X, const int incX);
void cblas_dtbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const double *A, const int lda,
double *X, const int incX);
void cblas_dtpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *Ap, double *X, const int incX);
void cblas_dtrsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *A, const int lda, double *X,
const int incX);
void cblas_dtbsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const double *A, const int lda,
double *X, const int incX);
void cblas_dtpsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *Ap, double *X, const int incX);
void cblas_cgemv(const enum CBLAS_ORDER Order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *X, const int incX, const void *beta,
void *Y, const int incY);
void cblas_cgbmv(const enum CBLAS_ORDER Order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const void *alpha,
const void *A, const int lda, const void *X,
const int incX, const void *beta, void *Y, const int incY);
void cblas_ctrmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda,
void *X, const int incX);
void cblas_ctbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ctpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_ctrsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda, void *X,
const int incX);
void cblas_ctbsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ctpsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_zgemv(const enum CBLAS_ORDER Order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *X, const int incX, const void *beta,
void *Y, const int incY);
void cblas_zgbmv(const enum CBLAS_ORDER Order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const void *alpha,
const void *A, const int lda, const void *X,
const int incX, const void *beta, void *Y, const int incY);
void cblas_ztrmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda,
void *X, const int incX);
void cblas_ztbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ztpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_ztrsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda, void *X,
const int incX);
void cblas_ztbsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ztpsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
/*
* Routines with S and D prefixes only
*/
void cblas_ssymv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *A,
const int lda, const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_ssbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const float alpha, const float *A,
const int lda, const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_sspmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *Ap,
const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_sger(const enum CBLAS_ORDER Order, const int M, const int N,
const float alpha, const float *X, const int incX,
const float *Y, const int incY, float *A, const int lda);
void cblas_ssyr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, float *A, const int lda);
void cblas_sspr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, float *Ap);
void cblas_ssyr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY, float *A,
const int lda);
void cblas_sspr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY, float *Ap);
void cblas_dsymv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *A,
const int lda, const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dsbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const double alpha, const double *A,
const int lda, const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dspmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *Ap,
const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dger(const enum CBLAS_ORDER Order, const int M, const int N,
const double alpha, const double *X, const int incX,
const double *Y, const int incY, double *A, const int lda);
void cblas_dsyr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, double *A, const int lda);
void cblas_dspr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, double *Ap);
void cblas_dsyr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, const double *Y, const int incY, double *A,
const int lda);
void cblas_dspr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, const double *Y, const int incY, double *Ap);
/*
* Routines with C and Z prefixes only
*/
void cblas_chemv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_chbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_chpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *Ap,
const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_cgeru(const enum CBLAS_ORDER Order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_cgerc(const enum CBLAS_ORDER Order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_cher(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const void *X, const int incX,
void *A, const int lda);
void cblas_chpr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const void *X,
const int incX, void *Ap);
void cblas_cher2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_chpr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *Ap);
void cblas_zhemv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zhbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zhpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *Ap,
const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zgeru(const enum CBLAS_ORDER Order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zgerc(const enum CBLAS_ORDER Order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zher(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const void *X, const int incX,
void *A, const int lda);
void cblas_zhpr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const void *X,
const int incX, void *Ap);
void cblas_zher2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zhpr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *Ap);
/*
* ===========================================================================
* Prototypes for level 3 BLAS
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (S, D, C, Z)
*/
void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const float alpha, const float *A,
const int lda, const float *B, const int ldb,
const float beta, float *C, const int ldc);
void cblas_ssymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const float alpha, const float *A, const int lda,
const float *B, const int ldb, const float beta,
float *C, const int ldc);
void cblas_ssyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const float *A, const int lda,
const float beta, float *C, const int ldc);
void cblas_ssyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const float *A, const int lda,
const float *B, const int ldb, const float beta,
float *C, const int ldc);
void cblas_strmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const float alpha, const float *A, const int lda,
float *B, const int ldb);
void cblas_strsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const float alpha, const float *A, const int lda,
float *B, const int ldb);
void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const double alpha, const double *A,
const int lda, const double *B, const int ldb,
const double beta, double *C, const int ldc);
void cblas_dsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const double alpha, const double *A, const int lda,
const double *B, const int ldb, const double beta,
double *C, const int ldc);
void cblas_dsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const double *A, const int lda,
const double beta, double *C, const int ldc);
void cblas_dsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const double *A, const int lda,
const double *B, const int ldb, const double beta,
double *C, const int ldc);
void cblas_dtrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const double alpha, const double *A, const int lda,
double *B, const int ldb);
void cblas_dtrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const double alpha, const double *A, const int lda,
double *B, const int ldb);
void cblas_cgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const void *alpha, const void *A,
const int lda, const void *B, const int ldb,
const void *beta, void *C, const int ldc);
void cblas_csymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_csyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *beta, void *C, const int ldc);
void cblas_csyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_ctrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_ctrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_zgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const void *alpha, const void *A,
const int lda, const void *B, const int ldb,
const void *beta, void *C, const int ldc);
void cblas_zsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_zsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *beta, void *C, const int ldc);
void cblas_zsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_ztrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_ztrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
/*
* Routines with prefixes C and Z only
*/
void cblas_chemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_cherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const void *A, const int lda,
const float beta, void *C, const int ldc);
void cblas_cher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const float beta,
void *C, const int ldc);
void cblas_zhemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_zherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const void *A, const int lda,
const double beta, void *C, const int ldc);
void cblas_zher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const double beta,
void *C, const int ldc);
int cblas_errprn(int ierr, int info, char *form, ...);
#endif /* end #ifdef CBLAS_ENUM_ONLY */
#endif

View file

@ -1,47 +0,0 @@
package cgo
import (
"testing"
"github.com/gonum/blas/testblas"
)
func BenchmarkDgemmSmSmSm(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Sm, Sm, Sm, NT, NT)
}
func BenchmarkDgemmMedMedMed(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Med, Med, Med, NT, NT)
}
func BenchmarkDgemmMedLgMed(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Med, Lg, Med, NT, NT)
}
func BenchmarkDgemmLgLgLg(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Lg, Lg, Lg, NT, NT)
}
func BenchmarkDgemmLgSmLg(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Lg, Sm, Lg, NT, NT)
}
func BenchmarkDgemmLgLgSm(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Lg, Lg, Sm, NT, NT)
}
func BenchmarkDgemmHgHgSm(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Hg, Hg, Sm, NT, NT)
}
func BenchmarkDgemmMedMedMedTNT(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Med, Med, Med, T, NT)
}
func BenchmarkDgemmMedMedMedNTT(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Med, Med, Med, NT, T)
}
func BenchmarkDgemmMedMedMedTT(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Med, Med, Med, T, T)
}

View file

@ -1,87 +0,0 @@
package cgo
import (
"testing"
"github.com/gonum/blas/testblas"
)
func BenchmarkDgemvSmSmNoTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Sm, Sm, 1, 1)
}
func BenchmarkDgemvSmSmNoTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Sm, Sm, 2, 3)
}
func BenchmarkDgemvSmSmTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Sm, Sm, 1, 1)
}
func BenchmarkDgemvSmSmTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Sm, Sm, 2, 3)
}
func BenchmarkDgemvMedMedNoTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Med, Med, 1, 1)
}
func BenchmarkDgemvMedMedNoTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Med, Med, 2, 3)
}
func BenchmarkDgemvMedMedTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Med, Med, 1, 1)
}
func BenchmarkDgemvMedMedTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Med, Med, 2, 3)
}
func BenchmarkDgemvLgLgNoTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Lg, Lg, 1, 1)
}
func BenchmarkDgemvLgLgNoTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Lg, Lg, 2, 3)
}
func BenchmarkDgemvLgLgTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Lg, Lg, 1, 1)
}
func BenchmarkDgemvLgLgTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Lg, Lg, 2, 3)
}
func BenchmarkDgemvLgSmNoTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Lg, Sm, 1, 1)
}
func BenchmarkDgemvLgSmNoTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Lg, Sm, 2, 3)
}
func BenchmarkDgemvLgSmTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Lg, Sm, 1, 1)
}
func BenchmarkDgemvLgSmTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Lg, Sm, 2, 3)
}
func BenchmarkDgemvSmLgNoTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Sm, Lg, 1, 1)
}
func BenchmarkDgemvSmLgNoTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Sm, Lg, 2, 3)
}
func BenchmarkDgemvSmLgTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Sm, Lg, 1, 1)
}
func BenchmarkDgemvSmLgTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Sm, Lg, 2, 3)
}

View file

@ -1,47 +0,0 @@
package cgo
import (
"testing"
"github.com/gonum/blas/testblas"
)
func BenchmarkDgerSmSmInc1(b *testing.B) {
testblas.DgerBenchmark(b, impl, Sm, Sm, 1, 1)
}
func BenchmarkDgerSmSmIncN(b *testing.B) {
testblas.DgerBenchmark(b, impl, Sm, Sm, 2, 3)
}
func BenchmarkDgerMedMedInc1(b *testing.B) {
testblas.DgerBenchmark(b, impl, Med, Med, 1, 1)
}
func BenchmarkDgerMedMedIncN(b *testing.B) {
testblas.DgerBenchmark(b, impl, Med, Med, 2, 3)
}
func BenchmarkDgerLgLgInc1(b *testing.B) {
testblas.DgerBenchmark(b, impl, Lg, Lg, 1, 1)
}
func BenchmarkDgerLgLgIncN(b *testing.B) {
testblas.DgerBenchmark(b, impl, Lg, Lg, 2, 3)
}
func BenchmarkDgerLgSmInc1(b *testing.B) {
testblas.DgerBenchmark(b, impl, Lg, Sm, 1, 1)
}
func BenchmarkDgerLgSmIncN(b *testing.B) {
testblas.DgerBenchmark(b, impl, Lg, Sm, 2, 3)
}
func BenchmarkDgerSmLgInc1(b *testing.B) {
testblas.DgerBenchmark(b, impl, Sm, Lg, 1, 1)
}
func BenchmarkDgerSmLgIncN(b *testing.B) {
testblas.DgerBenchmark(b, impl, Sm, Lg, 2, 3)
}

View file

@ -1,607 +0,0 @@
#!/usr/bin/env perl
# Copyright ©2014 The Gonum Authors. All rights reserved.
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.
use strict;
use warnings;
my $excludeComplex = 0;
my $excludeAtlas = 1;
my $cblasHeader = "cblas.h";
open(my $cblas, "<", $cblasHeader) or die;
open(my $goblas, ">", "blas.go") or die;
my %done = ("cblas_errprn" => 1,
"cblas_srotg" => 1,
"cblas_srotmg" => 1,
"cblas_srotm" => 1,
"cblas_drotg" => 1,
"cblas_drotmg" => 1,
"cblas_drotm" => 1,
"cblas_crotg" => 1,
"cblas_zrotg" => 1,
"cblas_cdotu_sub" => 1,
"cblas_cdotc_sub" => 1,
"cblas_zdotu_sub" => 1,
"cblas_zdotc_sub" => 1,
);
if ($excludeAtlas) {
$done{'cblas_csrot'} = 1;
$done{'cblas_zdrot'} = 1;
}
printf $goblas <<EOH;
// Do not manually edit this file. It was created by the genBlas.pl script from ${cblasHeader}.
// Copyright ©2014 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package cgo provides bindings to a C BLAS library.
package cgo
/*
#cgo CFLAGS: -g -O2
#include "${cblasHeader}"
*/
import "C"
import (
"unsafe"
"github.com/gonum/blas"
)
// Type check assertions:
var (
_ blas.Float32 = Implementation{}
_ blas.Float64 = Implementation{}
_ blas.Complex64 = Implementation{}
_ blas.Complex128 = Implementation{}
)
// Type order is used to specify the matrix storage format. We still interact with
// an API that allows client calls to specify order, so this is here to document that fact.
type order int
const (
rowMajor order = 101 + iota
)
func max(a, b int) int {
if a > b {
return a
}
return b
}
type Implementation struct{}
// Special cases...
type srotmParams struct {
flag float32
h [4]float32
}
type drotmParams struct {
flag float64
h [4]float64
}
func (Implementation) Srotg(a float32, b float32) (c float32, s float32, r float32, z float32) {
C.cblas_srotg((*C.float)(&a), (*C.float)(&b), (*C.float)(&c), (*C.float)(&s))
return c, s, a, b
}
func (Implementation) Srotmg(d1 float32, d2 float32, b1 float32, b2 float32) (p blas.SrotmParams, rd1 float32, rd2 float32, rb1 float32) {
var pi srotmParams
C.cblas_srotmg((*C.float)(&d1), (*C.float)(&d2), (*C.float)(&b1), C.float(b2), (*C.float)(unsafe.Pointer(&pi)))
return blas.SrotmParams{Flag: blas.Flag(pi.flag), H: pi.h}, d1, d2, b1
}
func (Implementation) Srotm(n int, x []float32, incX int, y []float32, incY int, p blas.SrotmParams) {
if n < 0 {
panic("blas: n < 0")
}
if incX == 0 {
panic("blas: zero x index increment")
}
if incY == 0 {
panic("blas: zero y index increment")
}
if (n-1)*incX >= len(x) {
panic("blas: x index out of range")
}
if (n-1)*incY >= len(y) {
panic("blas: y index out of range")
}
if p.Flag < blas.Identity || p.Flag > blas.Diagonal {
panic("blas: illegal blas.Flag value")
}
pi := srotmParams{
flag: float32(p.Flag),
h: p.H,
}
C.cblas_srotm(C.int(n), (*C.float)(&x[0]), C.int(incX), (*C.float)(&y[0]), C.int(incY), (*C.float)(unsafe.Pointer(&pi)))
}
func (Implementation) Drotg(a float64, b float64) (c float64, s float64, r float64, z float64) {
C.cblas_drotg((*C.double)(&a), (*C.double)(&b), (*C.double)(&c), (*C.double)(&s))
return c, s, a, b
}
func (Implementation) Drotmg(d1 float64, d2 float64, b1 float64, b2 float64) (p blas.DrotmParams, rd1 float64, rd2 float64, rb1 float64) {
var pi drotmParams
C.cblas_drotmg((*C.double)(&d1), (*C.double)(&d2), (*C.double)(&b1), C.double(b2), (*C.double)(unsafe.Pointer(&pi)))
return blas.DrotmParams{Flag: blas.Flag(pi.flag), H: pi.h}, d1, d2, b1
}
func (Implementation) Drotm(n int, x []float64, incX int, y []float64, incY int, p blas.DrotmParams) {
if n < 0 {
panic("blas: n < 0")
}
if incX == 0 {
panic("blas: zero x index increment")
}
if incY == 0 {
panic("blas: zero y index increment")
}
if (n-1)*incX >= len(x) {
panic("blas: x index out of range")
}
if (n-1)*incY >= len(y) {
panic("blas: y index out of range")
}
if p.Flag < blas.Identity || p.Flag > blas.Diagonal {
panic("blas: illegal blas.Flag value")
}
pi := drotmParams{
flag: float64(p.Flag),
h: p.H,
}
C.cblas_drotm(C.int(n), (*C.double)(&x[0]), C.int(incX), (*C.double)(&y[0]), C.int(incY), (*C.double)(unsafe.Pointer(&pi)))
}
func (Implementation) Cdotu(n int, x []complex64, incX int, y []complex64, incY int) (dotu complex64) {
if n < 0 {
panic("blas: n < 0")
}
if incX == 0 {
panic("blas: zero x index increment")
}
if incY == 0 {
panic("blas: zero y index increment")
}
if (n-1)*incX >= len(x) {
panic("blas: x index out of range")
}
if (n-1)*incY >= len(y) {
panic("blas: y index out of range")
}
C.cblas_cdotu_sub(C.int(n), unsafe.Pointer(&x[0]), C.int(incX), unsafe.Pointer(&y[0]), C.int(incY), unsafe.Pointer(&dotu))
return dotu
}
func (Implementation) Cdotc(n int, x []complex64, incX int, y []complex64, incY int) (dotc complex64) {
if n < 0 {
panic("blas: n < 0")
}
if incX == 0 {
panic("blas: zero x index increment")
}
if incY == 0 {
panic("blas: zero y index increment")
}
if (n-1)*incX >= len(x) {
panic("blas: x index out of range")
}
if (n-1)*incY >= len(y) {
panic("blas: y index out of range")
}
C.cblas_cdotc_sub(C.int(n), unsafe.Pointer(&x[0]), C.int(incX), unsafe.Pointer(&y[0]), C.int(incY), unsafe.Pointer(&dotc))
return dotc
}
func (Implementation) Zdotu(n int, x []complex128, incX int, y []complex128, incY int) (dotu complex128) {
if n < 0 {
panic("blas: n < 0")
}
if incX == 0 {
panic("blas: zero x index increment")
}
if incY == 0 {
panic("blas: zero y index increment")
}
if (n-1)*incX >= len(x) {
panic("blas: x index out of range")
}
if (n-1)*incY >= len(y) {
panic("blas: y index out of range")
}
C.cblas_zdotu_sub(C.int(n), unsafe.Pointer(&x[0]), C.int(incX), unsafe.Pointer(&y[0]), C.int(incY), unsafe.Pointer(&dotu))
return dotu
}
func (Implementation) Zdotc(n int, x []complex128, incX int, y []complex128, incY int) (dotc complex128) {
if n < 0 {
panic("blas: n < 0")
}
if incX == 0 {
panic("blas: zero x index increment")
}
if incY == 0 {
panic("blas: zero y index increment")
}
if (n-1)*incX >= len(x) {
panic("blas: x index out of range")
}
if (n-1)*incY >= len(y) {
panic("blas: y index out of range")
}
C.cblas_zdotc_sub(C.int(n), unsafe.Pointer(&x[0]), C.int(incX), unsafe.Pointer(&y[0]), C.int(incY), unsafe.Pointer(&dotc))
return dotc
}
EOH
printf $goblas <<EOH unless $excludeAtlas;
func (Implementation) Crotg(a complex64, b complex64) (c complex64, s complex64, r complex64, z complex64) {
C.cblas_srotg(unsafe.Pointer(&a), unsafe.Pointer(&b), unsafe.Pointer(&c), unsafe.Pointer(&s))
return c, s, a, b
}
func (Implementation) Zrotg(a complex128, b complex128) (c complex128, s complex128, r complex128, z complex128) {
C.cblas_drotg(unsafe.Pointer(&a), unsafe.Pointer(&b), unsafe.Pointer(&c), unsafe.Pointer(&s))
return c, s, a, b
}
EOH
print $goblas "\n";
$/ = undef;
my $header = <$cblas>;
# horrible munging of text...
$header =~ s/#[^\n\r]*//g; # delete cpp lines
$header =~ s/\n +([^\n\r]*)/\n$1/g; # remove starting space
$header =~ s/(?:\n ?\n)+/\n/g; # delete empty lines
$header =~ s! ((['"]) (?: \\. | .)*? \2) | # skip quoted strings
/\* .*? \*/ | # delete C comments
// [^\n\r]* # delete C++ comments just in case
! $1 || ' ' # change comments to a single space
!xseg; # ignore white space, treat as single line
# evaluate result, repeat globally
$header =~ s/([^;])\n/$1/g; # join prototypes into single lines
$header =~ s/, +/,/g;
$header =~ s/ +/ /g;
$header =~ s/ +}/}/g;
$header =~ s/\n+//;
$/ = "\n";
my @lines = split ";\n", $header;
our %retConv = (
"int" => "int ",
"float" => "float32 ",
"double" => "float64 ",
"CBLAS_INDEX" => "int ",
"void" => ""
);
foreach my $line (@lines) {
process($line);
}
close($goblas);
`go fmt .`;
sub process {
my $line = shift;
chomp $line;
if (not $line =~ m/^enum/) {
processProto($line);
}
}
sub processProto {
my $proto = shift;
my ($func, $paramList) = split /[()]/, $proto;
(my $ret, $func) = split ' ', $func;
if ($done{$func} or $excludeComplex && $func =~ m/_[isd]?[zc]/ or $excludeAtlas && $func =~ m/^catlas_/) {
return
}
$done{$func} = 1;
my $GoRet = $retConv{$ret};
my $complexType = $func;
$complexType =~ s/.*_[isd]?([zc]).*/$1/;
print $goblas "func (Implementation) ".Gofunc($func)."(".processParamToGo($func, $paramList, $complexType).") ".$GoRet."{\n";
print $goblas processParamToChecks($func, $paramList);
print $goblas "\t";
if ($ret ne 'void') {
chop($GoRet);
print $goblas "return ".$GoRet."(";
}
print $goblas "C.$func(".processParamToC($func, $paramList).")";
if ($ret ne 'void') {
print $goblas ")";
}
print $goblas "\n}\n";
}
sub Gofunc {
my $fnName = shift;
$fnName =~ s/_sub//;
my ($pack, $func, $tail) = split '_', $fnName;
if ($pack eq 'cblas') {
$pack = "";
} else {
$pack = substr $pack, 1;
}
return ucfirst $pack . ucfirst $func . ucfirst $tail if $tail;
return ucfirst $pack . ucfirst $func;
}
sub processParamToGo {
my $func = shift;
my $paramList = shift;
my $complexType = shift;
my @processed;
my @params = split ',', $paramList;
my $skip = 0;
foreach my $param (@params) {
my @parts = split /[ *]/, $param;
my $var = lcfirst $parts[scalar @parts - 1];
$param =~ m/^(?:const )?int/ && do {
push @processed, $var." int"; next;
};
$param =~ m/^(?:const )?void/ && do {
my $type;
if ($var eq "alpha" || $var eq "beta") {
$type = " ";
} else {
$type = " []";
}
if ($complexType eq 'c') {
push @processed, $var.$type."complex64"; next;
} elsif ($complexType eq 'z') {
push @processed, $var.$type."complex128"; next;
} else {
die "unexpected complex type for '$func' - '$complexType'";
}
};
$param =~ m/^(?:const )?char \*/ && do {
push @processed, $var." *byte"; next;
};
$param =~ m/^(?:const )?float \*/ && do {
push @processed, $var." []float32"; next;
};
$param =~ m/^(?:const )?double \*/ && do {
push @processed, $var." []float64"; next;
};
$param =~ m/^(?:const )?float/ && do {
push @processed, $var." float32"; next;
};
$param =~ m/^(?:const )?double/ && do {
push @processed, $var." float64"; next;
};
$param =~ m/^const enum/ && do {
$var eq "order" && $skip++;
$var =~ /trans/ && do {
$var =~ s/trans([AB]?)/t$1/;
push @processed, $var." blas.Transpose"; next;
};
$var eq "uplo" && do {
$var = "ul";
push @processed, $var." blas.Uplo"; next;
};
$var eq "diag" && do {
$var = "d";
push @processed, $var." blas.Diag"; next;
};
$var eq "side" && do {
$var = "s";
push @processed, $var." blas.Side"; next;
};
};
}
die "missed Go parameters from '$func', '$paramList'" if scalar @processed+$skip != scalar @params;
return join ", ", @processed;
}
sub processParamToChecks {
my $func = shift;
my $paramList = shift;
my @processed;
my @params = split ',', $paramList;
my %arrayArgs;
my %scalarArgs;
foreach my $param (@params) {
my @parts = split /[ *]/, $param;
my $var = lcfirst $parts[scalar @parts - 1];
$param =~ m/^(?:const )?int \*[a-zA-Z]/ && do {
$scalarArgs{$var} = 1; next;
};
$param =~ m/^(?:const )?void \*[a-zA-Z]/ && do {
if ($var ne "alpha" && $var ne "beta") {
$arrayArgs{$var} = 1;
}
next;
};
$param =~ m/^(?:const )?(?:float|double) \*[a-zA-Z]/ && do {
$arrayArgs{$var} = 1; next;
};
$param =~ m/^(?:const )?(?:int|float|double) [a-zA-Z]/ && do {
$scalarArgs{$var} = 1; next;
};
$param =~ m/^const enum [a-zA-Z]/ && do {
$var eq "order" && do {
$scalarArgs{'o'} = 1;
};
$var =~ /trans/ && do {
$var =~ s/trans([AB]?)/t$1/;
$scalarArgs{$var} = 1;
if ($func =~ m/cblas_[cz]h/) {
push @processed, "if $var != blas.NoTrans && $var != blas.ConjTrans { panic(\"blas: illegal transpose\") }"; next;
} elsif ($func =~ m/cblas_[cz]s/) {
push @processed, "if $var != blas.NoTrans && $var != blas.Trans { panic(\"blas: illegal transpose\") }"; next;
} else {
push @processed, "if $var != blas.NoTrans && $var != blas.Trans && $var != blas.ConjTrans { panic(\"blas: illegal transpose\") }"; next;
}
};
$var eq "uplo" && do {
push @processed, "if ul != blas.Upper && ul != blas.Lower { panic(\"blas: illegal triangle\") }"; next;
};
$var eq "diag" && do {
push @processed, "if d != blas.NonUnit && d != blas.Unit { panic(\"blas: illegal diagonal\") }"; next;
};
$var eq "side" && do {
$scalarArgs{'s'} = 1;
push @processed, "if s != blas.Left && s != blas.Right { panic(\"blas: illegal side\") }"; next;
};
};
}
# shape checks
foreach my $ref ('m', 'n', 'k', 'kL', 'kU') {
push @processed, "if $ref < 0 { panic(\"blas: $ref < 0\") }" if $scalarArgs{$ref};
}
if ($arrayArgs{'ap'}) {
push @processed, "if n*(n + 1)/2 > len(ap) { panic(\"blas: index of ap out of range\") }"
}
push @processed, "if incX == 0 { panic(\"blas: zero x index increment\") }" if $scalarArgs{'incX'};
push @processed, "if incY == 0 { panic(\"blas: zero y index increment\") }" if $scalarArgs{'incY'};
if ($func =~ m/cblas_[sdcz]g[eb]mv/) {
push @processed, "var lenX, lenY int";
push @processed, "if tA == blas.NoTrans { lenX, lenY = n, m } else { lenX, lenY = m, n }";
push @processed, "if (incX > 0 && (lenX-1)*incX >= len(x)) || (incX < 0 && (1-lenX)*incX >= len(x)) { panic(\"blas: x index out of range\") }";
push @processed, "if (incY > 0 && (lenY-1)*incY >= len(y)) || (incY < 0 && (1-lenY)*incY >= len(y)) { panic(\"blas: y index out of range\") }";
} elsif ($scalarArgs{'m'}) {
push @processed, "if (incX > 0 && (m-1)*incX >= len(x)) || (incX < 0 && (1-m)*incX >= len(x)) { panic(\"blas: x index out of range\") }" if $scalarArgs{'incX'};
push @processed, "if (incY > 0 && (n-1)*incY >= len(y)) || (incY < 0 && (1-n)*incY >= len(y)) { panic(\"blas: y index out of range\") }" if $scalarArgs{'incY'};
} elsif ($func =~ m/cblas_[sdcz]s?scal/) {
push @processed, "if incX < 0 { return }";
push @processed, "if incX > 0 && (n-1)*incX >= len(x) { panic(\"blas: x index out of range\") }";
} elsif ($func =~ m/cblas_i[sdcz]amax/) {
push @processed, "if n == 0 || incX < 0 { return -1 }";
push @processed, "if incX > 0 && (n-1)*incX >= len(x) { panic(\"blas: x index out of range\") }";
} elsif ($func =~ m/cblas_[sdz][cz]?(?:asum|nrm2)/) {
push @processed, "if incX < 0 { return 0 }";
push @processed, "if incX > 0 && (n-1)*incX >= len(x) { panic(\"blas: x index out of range\") }";
} else {
push @processed, "if (incX > 0 && (n-1)*incX >= len(x)) || (incX < 0 && (1-n)*incX >= len(x)) { panic(\"blas: x index out of range\") }" if $scalarArgs{'incX'};
push @processed, "if (incY > 0 && (n-1)*incY >= len(y)) || (incY < 0 && (1-n)*incY >= len(y)) { panic(\"blas: y index out of range\") }" if $scalarArgs{'incY'};
}
if ($arrayArgs{'a'} && $scalarArgs{'s'}) {
push @processed, "var k int";
push @processed, "if s == blas.Left { k = m } else { k = n }";
push @processed, "if lda*(k-1)+k > len(a) || lda < max(1, k) { panic(\"blas: index of a out of range\") }";
push @processed, "if ldb*(m-1)+n > len(b) || ldb < max(1, n) { panic(\"blas: index of b out of range\") }";
if ($arrayArgs{'c'}) {
push @processed, "if ldc*(m-1)+n > len(c) || ldc < max(1, n) { panic(\"blas: index of c out of range\") }";
}
}
if (not $func =~ m/(?:mm|r2?k)$/) {
if ($arrayArgs{'a'} && !$scalarArgs{'s'}) {
if (($scalarArgs{'kL'} && $scalarArgs{'kU'}) || $scalarArgs{'m'}) {
if ($scalarArgs{'kL'} && $scalarArgs{'kU'}) {
push @processed, "if lda*(m-1)+kL+kU+1 > len(a) || lda < kL+kU+1 { panic(\"blas: index of a out of range\") }";
} else {
push @processed, "if lda*(m-1)+n > len(a) || lda < max(1, n) { panic(\"blas: index of a out of range\") }";
}
} else {
if ($scalarArgs{'k'}) {
push @processed, "if lda*(n-1)+k+1 > len(a) || lda < k+1 { panic(\"blas: index of a out of range\") }";
} else {
push @processed, "if lda*(n-1)+n > len(a) || lda < max(1, n) { panic(\"blas: index of a out of range\") }";
}
}
}
} else {
if ($scalarArgs{'t'}) {
push @processed, "var row, col int";
push @processed, "if t == blas.NoTrans { row, col = n, k } else { row, col = k, n }";
foreach my $ref ('a', 'b') {
if ($arrayArgs{$ref}) {
push @processed, "if ld${ref}*(row-1)+col > len(${ref}) || ld${ref} < max(1, col) { panic(\"blas: index of ${ref} out of range\") }";
}
}
if ($arrayArgs{'c'}) {
push @processed, "if ldc*(n-1)+n > len(c) || ldc < max(1, n) { panic(\"blas: index of c out of range\") }";
}
}
if ($scalarArgs{'tA'} && $scalarArgs{'tB'}) {
push @processed, "var rowA, colA, rowB, colB int";
push @processed, "if tA == blas.NoTrans { rowA, colA = m, k } else { rowA, colA = k, m }";
push @processed, "if tB == blas.NoTrans { rowB, colB = k, n } else { rowB, colB = n, k }";
push @processed, "if lda*(rowA-1)+colA > len(a) || lda < max(1, colA) { panic(\"blas: index of a out of range\") }";
push @processed, "if ldb*(rowB-1)+colB > len(b) || ldb < max(1, colB) { panic(\"blas: index of b out of range\") }";
push @processed, "if ldc*(m-1)+n > len(c) || ldc < max(1, n) { panic(\"blas: index of c out of range\") }";
}
}
my $checks = join "\n", @processed;
$checks .= "\n" if scalar @processed > 0;
return $checks
}
sub processParamToC {
my $func = shift;
my $paramList = shift;
my @processed;
my @params = split ',', $paramList;
foreach my $param (@params) {
my @parts = split /[ *]/, $param;
my $var = lcfirst $parts[scalar @parts - 1];
$param =~ m/^(?:const )?int \*[a-zA-Z]/ && do {
push @processed, "(*C.int)(&".$var.")"; next;
};
$param =~ m/^(?:const )?void \*[a-zA-Z]/ && do {
my $type;
if ($var eq "alpha" || $var eq "beta") {
$type = "";
} else {
$type = "[0]";
}
push @processed, "unsafe.Pointer(&".$var.$type.")"; next;
};
$param =~ m/^(?:const )?char \*[a-zA-Z]/ && do {
push @processed, "(*C.char)(&".$var.")"; next;
};
$param =~ m/^(?:const )?float \*[a-zA-Z]/ && do {
push @processed, "(*C.float)(&".$var."[0])"; next;
};
$param =~ m/^(?:const )?double \*[a-zA-Z]/ && do {
push @processed, "(*C.double)(&".$var."[0])"; next;
};
$param =~ m/^(?:const )?int [a-zA-Z]/ && do {
push @processed, "C.int(".$var.")"; next;
};
$param =~ m/^(?:const )float [a-zA-Z]/ && do {
push @processed, "C.float(".$var.")"; next;
};
$param =~ m/^(?:const )double [a-zA-Z]/ && do {
push @processed, "C.double(".$var.")"; next;
};
$param =~ m/^const enum [a-zA-Z]/ && do {
$var eq "order" && do {
push @processed, "C.enum_$parts[scalar @parts - 2](rowMajor)"; next;
};
$var =~ /trans/ && do {
$var =~ s/trans([AB]?)/t$1/;
push @processed, "C.enum_$parts[scalar @parts - 2](".$var.")"; next;
};
$var eq "uplo" && do {
$var = "ul";
push @processed, "C.enum_$parts[scalar @parts - 2](".$var.")"; next;
};
$var eq "diag" && do {
$var = "d";
push @processed, "C.enum_$parts[scalar @parts - 2](".$var.")"; next;
};
$var eq "side" && do {
$var = "s";
push @processed, "C.enum_$parts[scalar @parts - 2](".$var.")"; next;
};
};
}
die "missed C parameters from '$func', '$paramList'" if scalar @processed != scalar @params;
return join ", ", @processed;
}

File diff suppressed because it is too large Load diff

View file

@ -1,57 +0,0 @@
package cgo
import (
"testing"
"github.com/gonum/blas/testblas"
)
var impl Implementation
func TestDasum(t *testing.T) {
testblas.DasumTest(t, impl)
}
func TestDaxpy(t *testing.T) {
testblas.DaxpyTest(t, impl)
}
func TestDdot(t *testing.T) {
testblas.DdotTest(t, impl)
}
func TestDnrm2(t *testing.T) {
testblas.Dnrm2Test(t, impl)
}
func TestIdamax(t *testing.T) {
testblas.IdamaxTest(t, impl)
}
func TestDswap(t *testing.T) {
testblas.DswapTest(t, impl)
}
func TestDcopy(t *testing.T) {
testblas.DcopyTest(t, impl)
}
func TestDrotg(t *testing.T) {
testblas.DrotgTest(t, impl)
}
func TestDrotmg(t *testing.T) {
testblas.DrotmgTest(t, impl)
}
func TestDrot(t *testing.T) {
testblas.DrotTest(t, impl)
}
func TestDrotm(t *testing.T) {
testblas.DrotmTest(t, impl)
}
func TestDscal(t *testing.T) {
testblas.DscalTest(t, impl)
}

View file

@ -1,75 +0,0 @@
package cgo
import (
"testing"
"github.com/gonum/blas/testblas"
)
func TestDgemv(t *testing.T) {
testblas.DgemvTest(t, impl)
}
func TestDger(t *testing.T) {
testblas.DgerTest(t, impl)
}
func TestDtbmv(t *testing.T) {
testblas.DtbmvTest(t, impl)
}
func TestDtxmv(t *testing.T) {
testblas.DtxmvTest(t, impl)
}
func TestDgbmv(t *testing.T) {
testblas.DgbmvTest(t, impl)
}
func TestDtbsv(t *testing.T) {
testblas.DtbsvTest(t, impl)
}
func TestDsbmv(t *testing.T) {
testblas.DsbmvTest(t, impl)
}
func TestDtrsv(t *testing.T) {
testblas.DtrsvTest(t, impl)
}
func TestDsyr(t *testing.T) {
testblas.DsyrTest(t, impl)
}
func TestDsymv(t *testing.T) {
testblas.DsymvTest(t, impl)
}
func TestDtrmv(t *testing.T) {
testblas.DtrmvTest(t, impl)
}
func TestDsyr2(t *testing.T) {
testblas.Dsyr2Test(t, impl)
}
func TestDspr2(t *testing.T) {
testblas.Dspr2Test(t, impl)
}
func TestDspr(t *testing.T) {
testblas.DsprTest(t, impl)
}
func TestDspmv(t *testing.T) {
testblas.DspmvTest(t, impl)
}
func TestDtpsv(t *testing.T) {
testblas.DtpsvTest(t, impl)
}
func TestDtmpv(t *testing.T) {
testblas.DtpmvTest(t, impl)
}

View file

@ -1,31 +0,0 @@
package cgo
import (
"testing"
"github.com/gonum/blas/testblas"
)
func TestDgemm(t *testing.T) {
testblas.TestDgemm(t, impl)
}
func TestDsymm(t *testing.T) {
testblas.DsymmTest(t, impl)
}
func TestDtrsm(t *testing.T) {
testblas.DtrsmTest(t, impl)
}
func TestDsyrk(t *testing.T) {
testblas.DsyrkTest(t, impl)
}
func TestDsyr2k(t *testing.T) {
testblas.Dsyr2kTest(t, impl)
}
func TestDtrmm(t *testing.T) {
testblas.DtrmmTest(t, impl)
}

View file

@ -1,18 +0,0 @@
package native
import (
"github.com/gonum/blas"
"github.com/gonum/blas/testblas"
)
const (
Sm = testblas.SmallMat
Med = testblas.MediumMat
Lg = testblas.LargeMat
Hg = testblas.HugeMat
)
const (
T = blas.Trans
NT = blas.NoTrans
)

View file

@ -1,400 +0,0 @@
// Copyright ©2014 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package native
import (
"fmt"
"runtime"
"sync"
"github.com/gonum/blas"
"github.com/gonum/internal/asm"
)
const (
blockSize = 64 // b x b matrix
minParBlock = 4 // minimum number of blocks needed to go parallel
buffMul = 4 // how big is the buffer relative to the number of workers
)
// Dgemm computes
// C = beta * C + alpha * A * B.
// tA and tB specify whether A or B are transposed. A, B, and C are n×n dense
// matrices.
func (Implementation) Dgemm(tA, tB blas.Transpose, m, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) {
var amat, bmat, cmat general
if tA == blas.Trans {
amat = general{
data: a,
rows: k,
cols: m,
stride: lda,
}
} else {
amat = general{
data: a,
rows: m,
cols: k,
stride: lda,
}
}
err := amat.check('a')
if err != nil {
panic(err.Error())
}
if tB == blas.Trans {
bmat = general{
data: b,
rows: n,
cols: k,
stride: ldb,
}
} else {
bmat = general{
data: b,
rows: k,
cols: n,
stride: ldb,
}
}
err = bmat.check('b')
if err != nil {
panic(err.Error())
}
cmat = general{
data: c,
rows: m,
cols: n,
stride: ldc,
}
err = cmat.check('c')
if err != nil {
panic(err.Error())
}
if tA != blas.Trans && tA != blas.NoTrans {
panic(badTranspose)
}
if tB != blas.Trans && tB != blas.NoTrans {
panic(badTranspose)
}
// scale c
if beta != 1 {
if beta == 0 {
for i := 0; i < m; i++ {
ctmp := cmat.data[i*cmat.stride : i*cmat.stride+cmat.cols]
for j := range ctmp {
ctmp[j] = 0
}
}
} else {
for i := 0; i < m; i++ {
ctmp := cmat.data[i*cmat.stride : i*cmat.stride+cmat.cols]
for j := range ctmp {
ctmp[j] *= beta
}
}
}
}
dgemmParallel(tA, tB, amat, bmat, cmat, alpha)
}
func dgemmParallel(tA, tB blas.Transpose, a, b, c general, alpha float64) {
// dgemmParallel computes a parallel matrix multiplication by partitioning
// a and b into sub-blocks, and updating c with the multiplication of the sub-block
// In all cases,
// A = [ A_11 A_12 ... A_1j
// A_21 A_22 ... A_2j
// ...
// A_i1 A_i2 ... A_ij]
//
// and same for B. All of the submatrix sizes are blockSize*blockSize except
// at the edges.
// In all cases, there is one dimension for each matrix along which
// C must be updated sequentially.
// Cij = \sum_k Aik Bki, (A * B)
// Cij = \sum_k Aki Bkj, (A^T * B)
// Cij = \sum_k Aik Bjk, (A * B^T)
// Cij = \sum_k Aki Bjk, (A^T * B^T)
//
// This code computes one {i, j} block sequentially along the k dimension,
// and computes all of the {i, j} blocks concurrently. This
// partitioning allows Cij to be updated in-place without race-conditions.
// Instead of launching a goroutine for each possible concurrent computation,
// a number of worker goroutines are created and channels are used to pass
// available and completed cases.
//
// http://alexkr.com/docs/matrixmult.pdf is a good reference on matrix-matrix
// multiplies, though this code does not copy matrices to attempt to eliminate
// cache misses.
aTrans := tA == blas.Trans
bTrans := tB == blas.Trans
maxKLen, parBlocks := computeNumBlocks(a, b, aTrans, bTrans)
if parBlocks < minParBlock {
// The matrix multiplication is small in the dimensions where it can be
// computed concurrently. Just do it in serial.
dgemmSerial(tA, tB, a, b, c, alpha)
return
}
nWorkers := runtime.GOMAXPROCS(0)
if parBlocks < nWorkers {
nWorkers = parBlocks
}
// There is a tradeoff between the workers having to wait for work
// and a large buffer making operations slow.
buf := buffMul * nWorkers
if buf > parBlocks {
buf = parBlocks
}
sendChan := make(chan subMul, buf)
// Launch workers. A worker receives an {i, j} submatrix of c, and computes
// A_ik B_ki (or the transposed version) storing the result in c_ij. When the
// channel is finally closed, it signals to the waitgroup that it has finished
// computing.
var wg sync.WaitGroup
for i := 0; i < nWorkers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
// Make local copies of otherwise global variables to reduce shared memory.
// This has a noticable effect on benchmarks in some cases.
alpha := alpha
aTrans := aTrans
bTrans := bTrans
crows := c.rows
ccols := c.cols
for sub := range sendChan {
i := sub.i
j := sub.j
leni := blockSize
if i+leni > crows {
leni = crows - i
}
lenj := blockSize
if j+lenj > ccols {
lenj = ccols - j
}
cSub := c.view(i, j, leni, lenj)
// Compute A_ik B_kj for all k
for k := 0; k < maxKLen; k += blockSize {
lenk := blockSize
if k+lenk > maxKLen {
lenk = maxKLen - k
}
var aSub, bSub general
if aTrans {
aSub = a.view(k, i, lenk, leni)
} else {
aSub = a.view(i, k, leni, lenk)
}
if bTrans {
bSub = b.view(j, k, lenj, lenk)
} else {
bSub = b.view(k, j, lenk, lenj)
}
dgemmSerial(tA, tB, aSub, bSub, cSub, alpha)
}
}
}()
}
// Send out all of the {i, j} subblocks for computation.
for i := 0; i < c.rows; i += blockSize {
for j := 0; j < c.cols; j += blockSize {
sendChan <- subMul{
i: i,
j: j,
}
}
}
close(sendChan)
wg.Wait()
}
type subMul struct {
i, j int // index of block
}
// computeNumBlocks says how many blocks there are to compute. maxKLen says the length of the
// k dimension, parBlocks is the number of blocks that could be computed in parallel
// (the submatrices in i and j). expect is the full number of blocks that will be computed.
func computeNumBlocks(a, b general, aTrans, bTrans bool) (maxKLen, parBlocks int) {
aRowBlocks := a.rows / blockSize
if a.rows%blockSize != 0 {
aRowBlocks++
}
aColBlocks := a.cols / blockSize
if a.cols%blockSize != 0 {
aColBlocks++
}
bRowBlocks := b.rows / blockSize
if b.rows%blockSize != 0 {
bRowBlocks++
}
bColBlocks := b.cols / blockSize
if b.cols%blockSize != 0 {
bColBlocks++
}
switch {
case !aTrans && !bTrans:
// Cij = \sum_k Aik Bki
maxKLen = a.cols
parBlocks = aRowBlocks * bColBlocks
case aTrans && !bTrans:
// Cij = \sum_k Aki Bkj
maxKLen = a.rows
parBlocks = aColBlocks * bColBlocks
case !aTrans && bTrans:
// Cij = \sum_k Aik Bjk
maxKLen = a.cols
parBlocks = aRowBlocks * bRowBlocks
case aTrans && bTrans:
// Cij = \sum_k Aki Bjk
maxKLen = a.rows
parBlocks = aColBlocks * bRowBlocks
}
return
}
// dgemmSerial is serial matrix multiply
func dgemmSerial(tA, tB blas.Transpose, a, b, c general, alpha float64) {
switch {
case tA == blas.NoTrans && tB == blas.NoTrans:
dgemmSerialNotNot(a, b, c, alpha)
return
case tA == blas.Trans && tB == blas.NoTrans:
dgemmSerialTransNot(a, b, c, alpha)
return
case tA == blas.NoTrans && tB == blas.Trans:
dgemmSerialNotTrans(a, b, c, alpha)
return
case tA == blas.Trans && tB == blas.Trans:
dgemmSerialTransTrans(a, b, c, alpha)
return
default:
panic("unreachable")
}
}
// dgemmSerial where neither a nor b are transposed
func dgemmSerialNotNot(a, b, c general, alpha float64) {
if debug {
if a.cols != b.rows {
panic("inner dimension mismatch")
}
if a.rows != c.rows {
panic("outer dimension mismatch")
}
if b.cols != c.cols {
panic("outer dimension mismatch")
}
}
// This style is used instead of the literal [i*stride +j]) is used because
// approximately 5 times faster as of go 1.3.
for i := 0; i < a.rows; i++ {
ctmp := c.data[i*c.stride : i*c.stride+c.cols]
for l, v := range a.data[i*a.stride : i*a.stride+a.cols] {
tmp := alpha * v
if tmp != 0 {
asm.DaxpyUnitary(tmp, b.data[l*b.stride:l*b.stride+b.cols], ctmp, ctmp)
}
}
}
}
// dgemmSerial where neither a is transposed and b is not
func dgemmSerialTransNot(a, b, c general, alpha float64) {
if debug {
if a.rows != b.rows {
fmt.Println(a.rows, b.rows)
panic("inner dimension mismatch")
}
if a.cols != c.rows {
panic("outer dimension mismatch")
}
if b.cols != c.cols {
panic("outer dimension mismatch")
}
}
// This style is used instead of the literal [i*stride +j]) is used because
// approximately 5 times faster as of go 1.3.
for l := 0; l < a.rows; l++ {
btmp := b.data[l*b.stride : l*b.stride+b.cols]
for i, v := range a.data[l*a.stride : l*a.stride+a.cols] {
tmp := alpha * v
ctmp := c.data[i*c.stride : i*c.stride+c.cols]
if tmp != 0 {
asm.DaxpyUnitary(tmp, btmp, ctmp, ctmp)
}
}
}
}
// dgemmSerial where neither a is not transposed and b is
func dgemmSerialNotTrans(a, b, c general, alpha float64) {
if debug {
if a.cols != b.cols {
panic("inner dimension mismatch")
}
if a.rows != c.rows {
panic("outer dimension mismatch")
}
if b.rows != c.cols {
panic("outer dimension mismatch")
}
}
// This style is used instead of the literal [i*stride +j]) is used because
// approximately 5 times faster as of go 1.3.
for i := 0; i < a.rows; i++ {
atmp := a.data[i*a.stride : i*a.stride+a.cols]
ctmp := c.data[i*c.stride : i*c.stride+c.cols]
for j := 0; j < b.rows; j++ {
ctmp[j] += alpha * asm.DdotUnitary(atmp, b.data[j*b.stride:j*b.stride+b.cols])
}
}
}
// dgemmSerial where both are transposed
func dgemmSerialTransTrans(a, b, c general, alpha float64) {
if debug {
if a.rows != b.cols {
panic("inner dimension mismatch")
}
if a.cols != c.rows {
panic("outer dimension mismatch")
}
if b.rows != c.cols {
panic("outer dimension mismatch")
}
}
// This style is used instead of the literal [i*stride +j]) is used because
// approximately 5 times faster as of go 1.3.
for l := 0; l < a.rows; l++ {
for i, v := range a.data[l*a.stride : l*a.stride+a.cols] {
ctmp := c.data[i*c.stride : i*c.stride+c.cols]
if v != 0 {
tmp := alpha * v
if tmp != 0 {
asm.DaxpyInc(tmp, b.data[l:], ctmp, uintptr(b.rows), uintptr(b.stride), 1, 0, 0)
}
}
}
}
}

View file

@ -1,47 +0,0 @@
package native
import (
"testing"
"github.com/gonum/blas/testblas"
)
func BenchmarkDgemmSmSmSm(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Sm, Sm, Sm, NT, NT)
}
func BenchmarkDgemmMedMedMed(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Med, Med, Med, NT, NT)
}
func BenchmarkDgemmMedLgMed(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Med, Lg, Med, NT, NT)
}
func BenchmarkDgemmLgLgLg(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Lg, Lg, Lg, NT, NT)
}
func BenchmarkDgemmLgSmLg(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Lg, Sm, Lg, NT, NT)
}
func BenchmarkDgemmLgLgSm(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Lg, Lg, Sm, NT, NT)
}
func BenchmarkDgemmHgHgSm(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Hg, Hg, Sm, NT, NT)
}
func BenchmarkDgemmMedMedMedTNT(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Med, Med, Med, T, NT)
}
func BenchmarkDgemmMedMedMedNTT(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Med, Med, Med, NT, T)
}
func BenchmarkDgemmMedMedMedTT(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Med, Med, Med, T, T)
}

View file

@ -1,87 +0,0 @@
package native
import (
"testing"
"github.com/gonum/blas/testblas"
)
func BenchmarkDgemvSmSmNoTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Sm, Sm, 1, 1)
}
func BenchmarkDgemvSmSmNoTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Sm, Sm, 2, 3)
}
func BenchmarkDgemvSmSmTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Sm, Sm, 1, 1)
}
func BenchmarkDgemvSmSmTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Sm, Sm, 2, 3)
}
func BenchmarkDgemvMedMedNoTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Med, Med, 1, 1)
}
func BenchmarkDgemvMedMedNoTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Med, Med, 2, 3)
}
func BenchmarkDgemvMedMedTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Med, Med, 1, 1)
}
func BenchmarkDgemvMedMedTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Med, Med, 2, 3)
}
func BenchmarkDgemvLgLgNoTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Lg, Lg, 1, 1)
}
func BenchmarkDgemvLgLgNoTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Lg, Lg, 2, 3)
}
func BenchmarkDgemvLgLgTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Lg, Lg, 1, 1)
}
func BenchmarkDgemvLgLgTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Lg, Lg, 2, 3)
}
func BenchmarkDgemvLgSmNoTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Lg, Sm, 1, 1)
}
func BenchmarkDgemvLgSmNoTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Lg, Sm, 2, 3)
}
func BenchmarkDgemvLgSmTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Lg, Sm, 1, 1)
}
func BenchmarkDgemvLgSmTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Lg, Sm, 2, 3)
}
func BenchmarkDgemvSmLgNoTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Sm, Lg, 1, 1)
}
func BenchmarkDgemvSmLgNoTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Sm, Lg, 2, 3)
}
func BenchmarkDgemvSmLgTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Sm, Lg, 1, 1)
}
func BenchmarkDgemvSmLgTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Sm, Lg, 2, 3)
}

View file

@ -1,47 +0,0 @@
package native
import (
"testing"
"github.com/gonum/blas/testblas"
)
func BenchmarkDgerSmSmInc1(b *testing.B) {
testblas.DgerBenchmark(b, impl, Sm, Sm, 1, 1)
}
func BenchmarkDgerSmSmIncN(b *testing.B) {
testblas.DgerBenchmark(b, impl, Sm, Sm, 2, 3)
}
func BenchmarkDgerMedMedInc1(b *testing.B) {
testblas.DgerBenchmark(b, impl, Med, Med, 1, 1)
}
func BenchmarkDgerMedMedIncN(b *testing.B) {
testblas.DgerBenchmark(b, impl, Med, Med, 2, 3)
}
func BenchmarkDgerLgLgInc1(b *testing.B) {
testblas.DgerBenchmark(b, impl, Lg, Lg, 1, 1)
}
func BenchmarkDgerLgLgIncN(b *testing.B) {
testblas.DgerBenchmark(b, impl, Lg, Lg, 2, 3)
}
func BenchmarkDgerLgSmInc1(b *testing.B) {
testblas.DgerBenchmark(b, impl, Lg, Sm, 1, 1)
}
func BenchmarkDgerLgSmIncN(b *testing.B) {
testblas.DgerBenchmark(b, impl, Lg, Sm, 2, 3)
}
func BenchmarkDgerSmLgInc1(b *testing.B) {
testblas.DgerBenchmark(b, impl, Sm, Lg, 1, 1)
}
func BenchmarkDgerSmLgIncN(b *testing.B) {
testblas.DgerBenchmark(b, impl, Sm, Lg, 2, 3)
}

View file

@ -1,82 +0,0 @@
/*
Package native is a Go implementation of the BLAS API. This implementation
panics when the input arguments are invalid as per the standard, for example
if a vector increment is zero. Please note that the treatment of NaN values
is not specified, and differs among the BLAS implementations.
github.com/gonum/blas/blas64 provides helpful wrapper functions to the BLAS
interface. The rest of this text describes the layout of the data for the input types.
Please note that in the function documentation, x[i] refers to the i^th element
of the vector, which will be different from the i^th element of the slice if
incX != 1.
See http://www.netlib.org/lapack/explore-html/d4/de1/_l_i_c_e_n_s_e_source.html
for more license information.
Vector arguments are effectively strided slices. They have two input arguments,
a number of elements, n, and an increment, incX. The increment specifies the
distance between elements of the vector. The actual Go slice may be longer
than necessary.
The increment may be positive or negative, except in functions with only
a single vector argument where the increment may only be positive. If the increment
is negative, s[0] is the last element in the slice. Note that this is not the same
as counting backward from the end of the slice, as len(s) may be longer than
necessary. So, for example, if n = 5 and incX = 3, the elements of s are
[0 * * 1 * * 2 * * 3 * * 4 * * * ...]
where elements are never accessed. If incX = -3, the same elements are
accessed, just in reverse order (4, 3, 2, 1, 0).
Dense matrices are specified by a number of rows, a number of columns, and a stride.
The stride specifies the number of entries in the slice between the first element
of successive rows. The stride must be at least as large as the number of columns
but may be longer.
[a00 ... a0n a0* ... a1stride-1 a21 ... amn am* ... amstride-1]
Thus, dense[i*ld + j] refers to the {i, j}th element of the matrix.
Symmetric and triangular matrices (non-packed) are stored identically to Dense,
except that only elements in one triangle of the matrix are accessed.
Packed symmetric and packed triangular matrices are laid out with the entries
condensed such that all of the unreferenced elements are removed. So, the upper triangular
matrix
[
1 2 3
0 4 5
0 0 6
]
and the lower-triangular matrix
[
1 0 0
2 3 0
4 5 6
]
will both be compacted as [1 2 3 4 5 6]. The (i, j) element of the original
dense matrix can be found at element i*n - (i-1)*i/2 + j for upper triangular,
and at element i * (i+1) /2 + j for lower triangular.
Banded matrices are laid out in a compact format, constructed by removing the
zeros in the rows and aligning the diagonals. For example, the matrix
[
1 2 3 0 0 0
4 5 6 7 0 0
0 8 9 10 11 0
0 0 12 13 14 15
0 0 0 16 17 18
0 0 0 0 19 20
]
implicitly becomes ( entries are never accessed)
[
* 1 2 3
4 5 6 7
8 9 10 11
12 13 14 15
16 17 18 *
19 20 * *
]
which is given to the BLAS routine is [ 1 2 3 4 ...].
See http://www.crest.iu.edu/research/mtl/reference/html/banded.html
for more information
*/
package native

View file

@ -1,159 +0,0 @@
// Copyright ©2014 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package native
import (
"errors"
"fmt"
"math"
)
const (
debug = false
)
func newGeneral(r, c int) general {
return general{
data: make([]float64, r*c),
rows: r,
cols: c,
stride: c,
}
}
type general struct {
data []float64
rows, cols int
stride int
}
// adds element-wise into receiver. rows and columns must match
func (g general) add(h general) {
if debug {
if g.rows != h.rows {
panic("blas: row size mismatch")
}
if g.cols != h.cols {
panic("blas: col size mismatch")
}
}
for i := 0; i < g.rows; i++ {
gtmp := g.data[i*g.stride : i*g.stride+g.cols]
for j, v := range h.data[i*h.stride : i*h.stride+h.cols] {
gtmp[j] += v
}
}
}
// at returns the value at the ith row and jth column. For speed reasons, the
// rows and columns are not bounds checked.
func (g general) at(i, j int) float64 {
if debug {
if i < 0 || i >= g.rows {
panic("blas: row out of bounds")
}
if j < 0 || j >= g.cols {
panic("blas: col out of bounds")
}
}
return g.data[i*g.stride+j]
}
func (g general) check(c byte) error {
if g.rows < 0 {
return errors.New("blas: rows < 0")
}
if g.cols < 0 {
return errors.New("blas: cols < 0")
}
if g.stride < 1 {
return errors.New("blas: stride < 1")
}
if g.stride < g.cols {
return errors.New("blas: illegal stride")
}
if (g.rows-1)*g.stride+g.cols > len(g.data) {
return fmt.Errorf("blas: index of %c out of range", c)
}
return nil
}
func (g general) clone() general {
data := make([]float64, len(g.data))
copy(data, g.data)
return general{
data: data,
rows: g.rows,
cols: g.cols,
stride: g.stride,
}
}
// assumes they are the same size
func (g general) copy(h general) {
if debug {
if g.rows != h.rows {
panic("blas: row mismatch")
}
if g.cols != h.cols {
panic("blas: col mismatch")
}
}
for k := 0; k < g.rows; k++ {
copy(g.data[k*g.stride:(k+1)*g.stride], h.data[k*h.stride:(k+1)*h.stride])
}
}
func (g general) equal(a general) bool {
if g.rows != a.rows || g.cols != a.cols || g.stride != a.stride {
return false
}
for i, v := range g.data {
if a.data[i] != v {
return false
}
}
return true
}
/*
// print is to aid debugging. Commented out to avoid fmt import
func (g general) print() {
fmt.Println("r = ", g.rows, "c = ", g.cols, "stride: ", g.stride)
for i := 0; i < g.rows; i++ {
fmt.Println(g.data[i*g.stride : (i+1)*g.stride])
}
}
*/
func (g general) view(i, j, r, c int) general {
if debug {
if i < 0 || i+r > g.rows {
panic("blas: row out of bounds")
}
if j < 0 || j+c > g.cols {
panic("blas: col out of bounds")
}
}
return general{
data: g.data[i*g.stride+j : (i+r-1)*g.stride+j+c],
rows: r,
cols: c,
stride: g.stride,
}
}
func (g general) equalWithinAbs(a general, tol float64) bool {
if g.rows != a.rows || g.cols != a.cols || g.stride != a.stride {
return false
}
for i, v := range g.data {
if math.Abs(a.data[i]-v) > tol {
return false
}
}
return true
}

View file

@ -1,600 +0,0 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package native
import (
"math"
"github.com/gonum/blas"
"github.com/gonum/internal/asm"
)
type Implementation struct{}
var _ blas.Float64Level1 = Implementation{}
const (
negativeN = "blas: n < 0"
zeroIncX = "blas: zero x index increment"
zeroIncY = "blas: zero y index increment"
badLenX = "blas: x index out of range"
badLenY = "blas: y index out of range"
)
// Ddot computes the dot product of the two vectors
// \sum_i x[i]*y[i]
func (Implementation) Ddot(n int, x []float64, incX int, y []float64, incY int) float64 {
if n < 0 {
panic(negativeN)
}
if incX == 0 {
panic(zeroIncX)
}
if incY == 0 {
panic(zeroIncY)
}
if incX == 1 && incY == 1 {
if len(x) < n {
panic(badLenX)
}
if len(y) < n {
panic(badLenY)
}
return asm.DdotUnitary(x[:n], y)
}
var ix, iy int
if incX < 0 {
ix = (-n + 1) * incX
}
if incY < 0 {
iy = (-n + 1) * incY
}
if ix >= len(x) || ix+(n-1)*incX >= len(x) {
panic(badLenX)
}
if iy >= len(y) || iy+(n-1)*incY >= len(y) {
panic(badLenY)
}
return asm.DdotInc(x, y, uintptr(n), uintptr(incX), uintptr(incY), uintptr(ix), uintptr(iy))
}
// Dnrm2 computes the Euclidean norm of a vector,
// sqrt(\sum_i x[i] * x[i]).
// This function returns 0 if incX is negative.
func (Implementation) Dnrm2(n int, x []float64, incX int) float64 {
if incX < 1 {
if incX == 0 {
panic(zeroIncX)
}
return 0
}
if n < 2 {
if n == 1 {
return math.Abs(x[0])
}
if n == 0 {
return 0
}
if n < 1 {
panic(negativeN)
}
}
scale := 0.0
sumSquares := 1.0
if incX == 1 {
x = x[:n]
for _, v := range x {
absxi := math.Abs(v)
if scale < absxi {
sumSquares = 1 + sumSquares*(scale/absxi)*(scale/absxi)
scale = absxi
} else {
sumSquares = sumSquares + (absxi/scale)*(absxi/scale)
}
}
return scale * math.Sqrt(sumSquares)
}
for ix := 0; ix < n*incX; ix += incX {
val := x[ix]
if val == 0 {
continue
}
absxi := math.Abs(val)
if scale < absxi {
sumSquares = 1 + sumSquares*(scale/absxi)*(scale/absxi)
scale = absxi
} else {
sumSquares = sumSquares + (absxi/scale)*(absxi/scale)
}
}
return scale * math.Sqrt(sumSquares)
}
// Dasum computes the sum of the absolute values of the elements of x.
// \sum_i |x[i]|
// Dasum returns 0 if incX is negative.
func (Implementation) Dasum(n int, x []float64, incX int) float64 {
var sum float64
if n < 0 {
panic(negativeN)
}
if incX < 1 {
if incX == 0 {
panic(zeroIncX)
}
return 0
}
if incX == 1 {
x = x[:n]
for _, v := range x {
sum += math.Abs(v)
}
return sum
}
for i := 0; i < n; i++ {
sum += math.Abs(x[i*incX])
}
return sum
}
// Idamax returns the index of the largest element of x. If there are multiple
// such indices the earliest is returned. Idamax returns -1 if incX is negative or if
// n == 0.
func (Implementation) Idamax(n int, x []float64, incX int) int {
if incX < 1 {
if incX == 0 {
panic(zeroIncX)
}
return -1
}
if n < 2 {
if n == 1 {
return 0
}
if n == 0 {
return -1 // Netlib returns invalid index when n == 0
}
if n < 1 {
panic(negativeN)
}
}
idx := 0
max := math.Abs(x[0])
if incX == 1 {
for i, v := range x {
absV := math.Abs(v)
if absV > max {
max = absV
idx = i
}
}
}
ix := incX
for i := 1; i < n; i++ {
v := x[ix]
absV := math.Abs(v)
if absV > max {
max = absV
idx = i
}
ix += incX
}
return idx
}
// Dswap exchanges the elements of two vectors.
// x[i], y[i] = y[i], x[i] for all i
func (Implementation) Dswap(n int, x []float64, incX int, y []float64, incY int) {
if n < 1 {
if n == 0 {
return
}
panic(negativeN)
}
if incX == 0 {
panic(zeroIncX)
}
if incY == 0 {
panic(zeroIncY)
}
if incX == 1 && incY == 1 {
x = x[:n]
for i, v := range x {
x[i], y[i] = y[i], v
}
return
}
var ix, iy int
if incX < 0 {
ix = (-n + 1) * incX
}
if incY < 0 {
iy = (-n + 1) * incY
}
for i := 0; i < n; i++ {
x[ix], y[iy] = y[iy], x[ix]
ix += incX
iy += incY
}
}
// Dcopy copies the elements of x into the elements of y.
// y[i] = x[i] for all i
func (Implementation) Dcopy(n int, x []float64, incX int, y []float64, incY int) {
if n < 1 {
if n == 0 {
return
}
panic(negativeN)
}
if incX == 0 {
panic(zeroIncX)
}
if incY == 0 {
panic(zeroIncY)
}
if incX == 1 && incY == 1 {
copy(y[:n], x[:n])
return
}
var ix, iy int
if incX < 0 {
ix = (-n + 1) * incX
}
if incY < 0 {
iy = (-n + 1) * incY
}
for i := 0; i < n; i++ {
y[iy] = x[ix]
ix += incX
iy += incY
}
}
// Daxpy adds alpha times x to y
// y[i] += alpha * x[i] for all i
func (Implementation) Daxpy(n int, alpha float64, x []float64, incX int, y []float64, incY int) {
if n < 1 {
if n == 0 {
return
}
panic(negativeN)
}
if incX == 0 {
panic(zeroIncX)
}
if incY == 0 {
panic(zeroIncY)
}
if alpha == 0 {
return
}
if incX == 1 && incY == 1 {
if len(x) < n {
panic(badLenX)
}
if len(y) < n {
panic(badLenY)
}
asm.DaxpyUnitary(alpha, x[:n], y, y)
return
}
var ix, iy int
if incX < 0 {
ix = (-n + 1) * incX
}
if incY < 0 {
iy = (-n + 1) * incY
}
if ix >= len(x) || ix+(n-1)*incX >= len(x) {
panic(badLenX)
}
if iy >= len(y) || iy+(n-1)*incY >= len(y) {
panic(badLenY)
}
asm.DaxpyInc(alpha, x, y, uintptr(n), uintptr(incX), uintptr(incY), uintptr(ix), uintptr(iy))
}
// Drotg computes the plane rotation
// _ _ _ _ _ _
// | c s | | a | | r |
// | -s c | * | b | = | 0 |
// ‾ ‾ ‾ ‾ ‾ ‾
// where
// r = ±(a^2 + b^2)
// c = a/r, the cosine of the plane rotation
// s = b/r, the sine of the plane rotation
//
// NOTE: There is a discrepancy between the refence implementation and the BLAS
// technical manual regarding the sign for r when a or b are zero.
// Drotg agrees with the definition in the manual and other
// common BLAS implementations.
func (Implementation) Drotg(a, b float64) (c, s, r, z float64) {
if b == 0 && a == 0 {
return 1, 0, a, 0
}
absA := math.Abs(a)
absB := math.Abs(b)
aGTb := absA > absB
r = math.Hypot(a, b)
if aGTb {
r = math.Copysign(r, a)
} else {
r = math.Copysign(r, b)
}
c = a / r
s = b / r
if aGTb {
z = s
} else if c != 0 { // r == 0 case handled above
z = 1 / c
} else {
z = 1
}
return
}
// Drotmg computes the modified Givens rotation. See
// http://www.netlib.org/lapack/explore-html/df/deb/drotmg_8f.html
// for more details.
func (Implementation) Drotmg(d1, d2, x1, y1 float64) (p blas.DrotmParams, rd1, rd2, rx1 float64) {
var p1, p2, q1, q2, u float64
gam := 4096.0
gamsq := 16777216.0
rgamsq := 5.9604645e-8
if d1 < 0 {
p.Flag = blas.Rescaling
return
}
p2 = d2 * y1
if p2 == 0 {
p.Flag = blas.Identity
rd1 = d1
rd2 = d2
rx1 = x1
return
}
p1 = d1 * x1
q2 = p2 * y1
q1 = p1 * x1
absQ1 := math.Abs(q1)
absQ2 := math.Abs(q2)
if absQ1 < absQ2 && q2 < 0 {
p.Flag = blas.Rescaling
return
}
if d1 == 0 {
p.Flag = blas.Diagonal
p.H[0] = p1 / p2
p.H[3] = x1 / y1
u = 1 + p.H[0]*p.H[3]
rd1, rd2 = d2/u, d1/u
rx1 = y1 / u
return
}
// Now we know that d1 != 0, and d2 != 0. If d2 == 0, it would be caught
// when p2 == 0, and if d1 == 0, then it is caught above
if absQ1 > absQ2 {
p.H[1] = -y1 / x1
p.H[2] = p2 / p1
u = 1 - p.H[2]*p.H[1]
rd1 = d1
rd2 = d2
rx1 = x1
p.Flag = blas.OffDiagonal
// u must be greater than zero because |q1| > |q2|, so check from netlib
// is unnecessary
// This is left in for ease of comparison with complex routines
//if u > 0 {
rd1 /= u
rd2 /= u
rx1 *= u
//}
} else {
p.Flag = blas.Diagonal
p.H[0] = p1 / p2
p.H[3] = x1 / y1
u = 1 + p.H[0]*p.H[3]
rd1 = d2 / u
rd2 = d1 / u
rx1 = y1 * u
}
for rd1 <= rgamsq || rd1 >= gamsq {
if p.Flag == blas.OffDiagonal {
p.H[0] = 1
p.H[3] = 1
p.Flag = blas.Rescaling
} else if p.Flag == blas.Diagonal {
p.H[1] = -1
p.H[2] = 1
p.Flag = blas.Rescaling
}
if rd1 <= rgamsq {
rd1 *= gam * gam
rx1 /= gam
p.H[0] /= gam
p.H[2] /= gam
} else {
rd1 /= gam * gam
rx1 *= gam
p.H[0] *= gam
p.H[2] *= gam
}
}
for math.Abs(rd2) <= rgamsq || math.Abs(rd2) >= gamsq {
if p.Flag == blas.OffDiagonal {
p.H[0] = 1
p.H[3] = 1
p.Flag = blas.Rescaling
} else if p.Flag == blas.Diagonal {
p.H[1] = -1
p.H[2] = 1
p.Flag = blas.Rescaling
}
if math.Abs(rd2) <= rgamsq {
rd2 *= gam * gam
p.H[1] /= gam
p.H[3] /= gam
} else {
rd2 /= gam * gam
p.H[1] *= gam
p.H[3] *= gam
}
}
return
}
// Drot applies a plane transformation.
// x[i] = c * x[i] + s * y[i]
// y[i] = c * y[i] - s * x[i]
func (Implementation) Drot(n int, x []float64, incX int, y []float64, incY int, c float64, s float64) {
if n < 1 {
if n == 0 {
return
}
panic(negativeN)
}
if incX == 0 {
panic(zeroIncX)
}
if incY == 0 {
panic(zeroIncY)
}
if incX == 1 && incY == 1 {
x = x[:n]
for i, vx := range x {
vy := y[i]
x[i], y[i] = c*vx+s*vy, c*vy-s*vx
}
return
}
var ix, iy int
if incX < 0 {
ix = (-n + 1) * incX
}
if incY < 0 {
iy = (-n + 1) * incY
}
for i := 0; i < n; i++ {
vx := x[ix]
vy := y[iy]
x[ix], y[iy] = c*vx+s*vy, c*vy-s*vx
ix += incX
iy += incY
}
}
// Drotm applies the modified Givens rotation to the 2⨉n matrix.
func (Implementation) Drotm(n int, x []float64, incX int, y []float64, incY int, p blas.DrotmParams) {
if n <= 0 {
if n == 0 {
return
}
panic(negativeN)
}
if incX == 0 {
panic(zeroIncX)
}
if incY == 0 {
panic(zeroIncY)
}
var h11, h12, h21, h22 float64
var ix, iy int
switch p.Flag {
case blas.Identity:
return
case blas.Rescaling:
h11 = p.H[0]
h12 = p.H[2]
h21 = p.H[1]
h22 = p.H[3]
case blas.OffDiagonal:
h11 = 1
h12 = p.H[2]
h21 = p.H[1]
h22 = 1
case blas.Diagonal:
h11 = p.H[0]
h12 = 1
h21 = -1
h22 = p.H[3]
}
if incX < 0 {
ix = (-n + 1) * incX
}
if incY < 0 {
iy = (-n + 1) * incY
}
if incX == 1 && incY == 1 {
x = x[:n]
for i, vx := range x {
vy := y[i]
x[i], y[i] = vx*h11+vy*h12, vx*h21+vy*h22
}
return
}
for i := 0; i < n; i++ {
vx := x[ix]
vy := y[iy]
x[ix], y[iy] = vx*h11+vy*h12, vx*h21+vy*h22
ix += incX
iy += incY
}
return
}
// Dscal scales x by alpha.
// x[i] *= alpha
// Dscal has no effect if incX < 0.
func (Implementation) Dscal(n int, alpha float64, x []float64, incX int) {
if incX < 1 {
if incX == 0 {
panic(zeroIncX)
}
return
}
if n < 1 {
if n == 0 {
return
}
if n < 1 {
panic(negativeN)
}
}
if alpha == 0 {
if incX == 1 {
x = x[:n]
for i := range x {
x[i] = 0
}
}
for ix := 0; ix < n*incX; ix += incX {
x[ix] = 0
}
}
if incX == 1 {
x = x[:n]
for i := range x {
x[i] *= alpha
}
return
}
for ix := 0; ix < n*incX; ix += incX {
x[ix] *= alpha
}
return
}

File diff suppressed because it is too large Load diff

View file

@ -1,57 +0,0 @@
package native
import (
"testing"
"github.com/gonum/blas/testblas"
)
var impl Implementation
func TestDasum(t *testing.T) {
testblas.DasumTest(t, impl)
}
func TestDaxpy(t *testing.T) {
testblas.DaxpyTest(t, impl)
}
func TestDdot(t *testing.T) {
testblas.DdotTest(t, impl)
}
func TestDnrm2(t *testing.T) {
testblas.Dnrm2Test(t, impl)
}
func TestIdamax(t *testing.T) {
testblas.IdamaxTest(t, impl)
}
func TestDswap(t *testing.T) {
testblas.DswapTest(t, impl)
}
func TestDcopy(t *testing.T) {
testblas.DcopyTest(t, impl)
}
func TestDrotg(t *testing.T) {
testblas.DrotgTest(t, impl)
}
func TestDrotmg(t *testing.T) {
testblas.DrotmgTest(t, impl)
}
func TestDrot(t *testing.T) {
testblas.DrotTest(t, impl)
}
func TestDrotm(t *testing.T) {
testblas.DrotmTest(t, impl)
}
func TestDscal(t *testing.T) {
testblas.DscalTest(t, impl)
}

File diff suppressed because it is too large Load diff

View file

@ -1,75 +0,0 @@
package native
import (
"testing"
"github.com/gonum/blas/testblas"
)
func TestDgemv(t *testing.T) {
testblas.DgemvTest(t, impl)
}
func TestDger(t *testing.T) {
testblas.DgerTest(t, impl)
}
func TestDtxmv(t *testing.T) {
testblas.DtxmvTest(t, impl)
}
func TestDgbmv(t *testing.T) {
testblas.DgbmvTest(t, impl)
}
func TestDtbsv(t *testing.T) {
testblas.DtbsvTest(t, impl)
}
func TestDsbmv(t *testing.T) {
testblas.DsbmvTest(t, impl)
}
func TestDtbmv(t *testing.T) {
testblas.DtbmvTest(t, impl)
}
func TestDtrsv(t *testing.T) {
testblas.DtrsvTest(t, impl)
}
func TestDtrmv(t *testing.T) {
testblas.DtrmvTest(t, impl)
}
func TestDsymv(t *testing.T) {
testblas.DsymvTest(t, impl)
}
func TestDsyr(t *testing.T) {
testblas.DsyrTest(t, impl)
}
func TestDsyr2(t *testing.T) {
testblas.Dsyr2Test(t, impl)
}
func TestDspr2(t *testing.T) {
testblas.Dspr2Test(t, impl)
}
func TestDspr(t *testing.T) {
testblas.DsprTest(t, impl)
}
func TestDspmv(t *testing.T) {
testblas.DspmvTest(t, impl)
}
func TestDtpsv(t *testing.T) {
testblas.DtpsvTest(t, impl)
}
func TestDtpmv(t *testing.T) {
testblas.DtpmvTest(t, impl)
}

View file

@ -1,815 +0,0 @@
package native
import (
"github.com/gonum/blas"
"github.com/gonum/internal/asm"
)
var _ blas.Float64Level3 = Implementation{}
// Dtrsm solves
// A * X = alpha * B if tA == blas.NoTrans, side == blas.Left
// A^T * X = alpha * B if tA == blas.Trans, side == blas.Left
// X * A = alpha * B if tA == blas.NoTrans, side == blas.Right
// X * A^T = alpha * B if tA == blas.Trans, side == blas.Right
// where A is an n×n triangular matrix, x is an m×n matrix, and alpha is a
// scalar.
//
// At entry to the function, X contains the values of B, and the result is
// stored in place into X.
//
// No check is made that A is invertible.
func (Implementation) Dtrsm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m, n int, alpha float64, a []float64, lda int, b []float64, ldb int) {
if s != blas.Left && s != blas.Right {
panic(badSide)
}
if ul != blas.Lower && ul != blas.Upper {
panic(badUplo)
}
if tA != blas.NoTrans && tA != blas.Trans && tA != blas.ConjTrans {
panic(badTranspose)
}
if d != blas.NonUnit && d != blas.Unit {
panic(badDiag)
}
if m < 0 {
panic(mLT0)
}
if n < 0 {
panic(nLT0)
}
if ldb < n {
panic(badLdB)
}
if s == blas.Left {
if lda < m {
panic(badLdA)
}
} else {
if lda < n {
panic(badLdA)
}
}
if m == 0 || n == 0 {
return
}
if alpha == 0 {
for i := 0; i < m; i++ {
btmp := b[i*ldb : i*ldb+n]
for j := range btmp {
btmp[j] = 0
}
}
return
}
nonUnit := d == blas.NonUnit
if s == blas.Left {
if tA == blas.NoTrans {
if ul == blas.Upper {
for i := m - 1; i >= 0; i-- {
btmp := b[i*ldb : i*ldb+n]
if alpha != 1 {
for j := range btmp {
btmp[j] *= alpha
}
}
for ka, va := range a[i*lda+i+1 : i*lda+m] {
k := ka + i + 1
if va != 0 {
asm.DaxpyUnitary(-va, b[k*ldb:k*ldb+n], btmp, btmp)
}
}
if nonUnit {
tmp := 1 / a[i*lda+i]
for j := 0; j < n; j++ {
btmp[j] *= tmp
}
}
}
return
}
for i := 0; i < m; i++ {
btmp := b[i*ldb : i*ldb+n]
if alpha != 1 {
for j := 0; j < n; j++ {
btmp[j] *= alpha
}
}
for k, va := range a[i*lda : i*lda+i] {
if va != 0 {
asm.DaxpyUnitary(-va, b[k*ldb:k*ldb+n], btmp, btmp)
}
}
if nonUnit {
tmp := 1 / a[i*lda+i]
for j := 0; j < n; j++ {
btmp[j] *= tmp
}
}
}
return
}
// Cases where a is transposed
if ul == blas.Upper {
for k := 0; k < m; k++ {
btmpk := b[k*ldb : k*ldb+n]
if nonUnit {
tmp := 1 / a[k*lda+k]
for j := 0; j < n; j++ {
btmpk[j] *= tmp
}
}
for ia, va := range a[k*lda+k+1 : k*lda+m] {
i := ia + k + 1
if va != 0 {
btmp := b[i*ldb : i*ldb+n]
asm.DaxpyUnitary(-va, btmpk, btmp, btmp)
}
}
if alpha != 1 {
for j := 0; j < n; j++ {
btmpk[j] *= alpha
}
}
}
return
}
for k := m - 1; k >= 0; k-- {
btmpk := b[k*ldb : k*ldb+n]
if nonUnit {
tmp := 1 / a[k*lda+k]
for j := 0; j < n; j++ {
btmpk[j] *= tmp
}
}
for i, va := range a[k*lda : k*lda+k] {
if va != 0 {
btmp := b[i*ldb : i*ldb+n]
asm.DaxpyUnitary(-va, btmpk, btmp, btmp)
}
}
if alpha != 1 {
for j := 0; j < n; j++ {
btmpk[j] *= alpha
}
}
}
return
}
// Cases where a is to the right of X.
if tA == blas.NoTrans {
if ul == blas.Upper {
for i := 0; i < m; i++ {
btmp := b[i*ldb : i*ldb+n]
if alpha != 1 {
for j := 0; j < n; j++ {
btmp[j] *= alpha
}
}
for k, vb := range btmp {
if vb != 0 {
if btmp[k] != 0 {
if nonUnit {
btmp[k] /= a[k*lda+k]
}
btmpk := btmp[k+1 : n]
asm.DaxpyUnitary(-btmp[k], a[k*lda+k+1:k*lda+n], btmpk, btmpk)
}
}
}
}
return
}
for i := 0; i < m; i++ {
btmp := b[i*lda : i*lda+n]
if alpha != 1 {
for j := 0; j < n; j++ {
btmp[j] *= alpha
}
}
for k := n - 1; k >= 0; k-- {
if btmp[k] != 0 {
if nonUnit {
btmp[k] /= a[k*lda+k]
}
asm.DaxpyUnitary(-btmp[k], a[k*lda:k*lda+k], btmp, btmp)
}
}
}
return
}
// Cases where a is transposed.
if ul == blas.Upper {
for i := 0; i < m; i++ {
btmp := b[i*lda : i*lda+n]
for j := n - 1; j >= 0; j-- {
tmp := alpha*btmp[j] - asm.DdotUnitary(a[j*lda+j+1:j*lda+n], btmp[j+1:])
if nonUnit {
tmp /= a[j*lda+j]
}
btmp[j] = tmp
}
}
return
}
for i := 0; i < m; i++ {
btmp := b[i*lda : i*lda+n]
for j := 0; j < n; j++ {
tmp := alpha*btmp[j] - asm.DdotUnitary(a[j*lda:j*lda+j], btmp)
if nonUnit {
tmp /= a[j*lda+j]
}
btmp[j] = tmp
}
}
}
// Dsymm performs one of
// C = alpha * A * B + beta * C if side == blas.Left
// C = alpha * B * A + beta * C if side == blas.Right
// where A is an n×n symmetric matrix, B and C are m×n matrices, and alpha
// is a scalar.
func (Implementation) Dsymm(s blas.Side, ul blas.Uplo, m, n int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) {
if s != blas.Right && s != blas.Left {
panic("goblas: bad side")
}
if ul != blas.Lower && ul != blas.Upper {
panic(badUplo)
}
if m < 0 {
panic(mLT0)
}
if n < 0 {
panic(nLT0)
}
if (lda < m && s == blas.Left) || (lda < n && s == blas.Right) {
panic(badLdA)
}
if ldb < n {
panic(badLdB)
}
if ldc < n {
panic(badLdC)
}
if m == 0 || n == 0 {
return
}
if alpha == 0 && beta == 1 {
return
}
if alpha == 0 {
if beta == 0 {
for i := 0; i < m; i++ {
ctmp := c[i*ldc : i*ldc+n]
for j := range ctmp {
ctmp[j] = 0
}
}
return
}
for i := 0; i < m; i++ {
ctmp := c[i*ldc : i*ldc+n]
for j := 0; j < n; j++ {
ctmp[j] *= beta
}
}
return
}
isUpper := ul == blas.Upper
if s == blas.Left {
for i := 0; i < m; i++ {
atmp := alpha * a[i*lda+i]
btmp := b[i*ldb : i*ldb+n]
ctmp := c[i*ldc : i*ldc+n]
for j, v := range btmp {
ctmp[j] *= beta
ctmp[j] += atmp * v
}
for k := 0; k < i; k++ {
var atmp float64
if isUpper {
atmp = a[k*lda+i]
} else {
atmp = a[i*lda+k]
}
atmp *= alpha
ctmp := c[i*ldc : i*ldc+n]
asm.DaxpyUnitary(atmp, b[k*ldb:k*ldb+n], ctmp, ctmp)
}
for k := i + 1; k < m; k++ {
var atmp float64
if isUpper {
atmp = a[i*lda+k]
} else {
atmp = a[k*lda+i]
}
atmp *= alpha
ctmp := c[i*ldc : i*ldc+n]
asm.DaxpyUnitary(atmp, b[k*ldb:k*ldb+n], ctmp, ctmp)
}
}
return
}
if isUpper {
for i := 0; i < m; i++ {
for j := n - 1; j >= 0; j-- {
tmp := alpha * b[i*ldb+j]
var tmp2 float64
atmp := a[j*lda+j+1 : j*lda+n]
btmp := b[i*ldb+j+1 : i*ldb+n]
ctmp := c[i*ldc+j+1 : i*ldc+n]
for k, v := range atmp {
ctmp[k] += tmp * v
tmp2 += btmp[k] * v
}
c[i*ldc+j] *= beta
c[i*ldc+j] += tmp*a[j*lda+j] + alpha*tmp2
}
}
return
}
for i := 0; i < m; i++ {
for j := 0; j < n; j++ {
tmp := alpha * b[i*ldb+j]
var tmp2 float64
atmp := a[j*lda : j*lda+j]
btmp := b[i*ldb : i*ldb+j]
ctmp := c[i*ldc : i*ldc+j]
for k, v := range atmp {
ctmp[k] += tmp * v
tmp2 += btmp[k] * v
}
c[i*ldc+j] *= beta
c[i*ldc+j] += tmp*a[j*lda+j] + alpha*tmp2
}
}
}
// Dsyrk performs the symmetric rank-k operation
// C = alpha * A * A^T + beta*C
// C is an n×n symmetric matrix. A is an n×k matrix if tA == blas.NoTrans, and
// a k×n matrix otherwise. alpha and beta are scalars.
func (Implementation) Dsyrk(ul blas.Uplo, tA blas.Transpose, n, k int, alpha float64, a []float64, lda int, beta float64, c []float64, ldc int) {
if ul != blas.Lower && ul != blas.Upper {
panic(badUplo)
}
if tA != blas.Trans && tA != blas.NoTrans && tA != blas.ConjTrans {
panic(badTranspose)
}
if n < 0 {
panic(nLT0)
}
if k < 0 {
panic(kLT0)
}
if ldc < n {
panic(badLdC)
}
if tA == blas.Trans {
if lda < n {
panic(badLdA)
}
} else {
if lda < k {
panic(badLdA)
}
}
if alpha == 0 {
if beta == 0 {
if ul == blas.Upper {
for i := 0; i < n; i++ {
ctmp := c[i*ldc+i : i*ldc+n]
for j := range ctmp {
ctmp[j] = 0
}
}
return
}
for i := 0; i < n; i++ {
ctmp := c[i*ldc : i*ldc+i+1]
for j := range ctmp {
ctmp[j] = 0
}
}
return
}
if ul == blas.Upper {
for i := 0; i < n; i++ {
ctmp := c[i*ldc+i : i*ldc+n]
for j := range ctmp {
ctmp[j] *= beta
}
}
return
}
for i := 0; i < n; i++ {
ctmp := c[i*ldc : i*ldc+i+1]
for j := range ctmp {
ctmp[j] *= beta
}
}
return
}
if tA == blas.NoTrans {
if ul == blas.Upper {
for i := 0; i < n; i++ {
ctmp := c[i*ldc+i : i*ldc+n]
atmp := a[i*lda : i*lda+k]
for jc, vc := range ctmp {
j := jc + i
ctmp[jc] = vc*beta + alpha*asm.DdotUnitary(atmp, a[j*lda:j*lda+k])
}
}
return
}
for i := 0; i < n; i++ {
atmp := a[i*lda : i*lda+k]
for j, vc := range c[i*ldc : i*ldc+i+1] {
c[i*ldc+j] = vc*beta + alpha*asm.DdotUnitary(a[j*lda:j*lda+k], atmp)
}
}
return
}
// Cases where a is transposed.
if ul == blas.Upper {
for i := 0; i < n; i++ {
ctmp := c[i*ldc+i : i*ldc+n]
if beta != 1 {
for j := range ctmp {
ctmp[j] *= beta
}
}
for l := 0; l < k; l++ {
tmp := alpha * a[l*lda+i]
if tmp != 0 {
asm.DaxpyUnitary(tmp, a[l*lda+i:l*lda+n], ctmp, ctmp)
}
}
}
return
}
for i := 0; i < n; i++ {
ctmp := c[i*ldc : i*ldc+i+1]
if beta != 0 {
for j := range ctmp {
ctmp[j] *= beta
}
}
for l := 0; l < k; l++ {
tmp := alpha * a[l*lda+i]
if tmp != 0 {
asm.DaxpyUnitary(tmp, a[l*lda:l*lda+i+1], ctmp, ctmp)
}
}
}
}
// Dsyr2k performs the symmetric rank 2k operation
// C = alpha * A * B^T + alpha * B * A^T + beta * C
// where C is an n×n symmetric matrix. A and B are n×k matrices if
// tA == NoTrans and k×n otherwise. alpha and beta are scalars.
func (Implementation) Dsyr2k(ul blas.Uplo, tA blas.Transpose, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) {
if ul != blas.Lower && ul != blas.Upper {
panic(badUplo)
}
if tA != blas.Trans && tA != blas.NoTrans && tA != blas.ConjTrans {
panic(badTranspose)
}
if n < 0 {
panic(nLT0)
}
if k < 0 {
panic(kLT0)
}
if ldc < n {
panic(badLdC)
}
if tA == blas.Trans {
if lda < n {
panic(badLdA)
}
if ldb < n {
panic(badLdB)
}
} else {
if lda < k {
panic(badLdA)
}
if ldb < k {
panic(badLdB)
}
}
if alpha == 0 {
if beta == 0 {
if ul == blas.Upper {
for i := 0; i < n; i++ {
ctmp := c[i*ldc+i : i*ldc+n]
for j := range ctmp {
ctmp[j] = 0
}
}
return
}
for i := 0; i < n; i++ {
ctmp := c[i*ldc : i*ldc+i+1]
for j := range ctmp {
ctmp[j] = 0
}
}
return
}
if ul == blas.Upper {
for i := 0; i < n; i++ {
ctmp := c[i*ldc+i : i*ldc+n]
for j := range ctmp {
ctmp[j] *= beta
}
}
return
}
for i := 0; i < n; i++ {
ctmp := c[i*ldc : i*ldc+i+1]
for j := range ctmp {
ctmp[j] *= beta
}
}
return
}
if tA == blas.NoTrans {
if ul == blas.Upper {
for i := 0; i < n; i++ {
atmp := a[i*lda : i*lda+k]
btmp := b[i*lda : i*lda+k]
ctmp := c[i*ldc+i : i*ldc+n]
for jc := range ctmp {
j := i + jc
var tmp1, tmp2 float64
binner := b[j*ldb : j*ldb+k]
for l, v := range a[j*lda : j*lda+k] {
tmp1 += v * btmp[l]
tmp2 += atmp[l] * binner[l]
}
ctmp[jc] *= beta
ctmp[jc] += alpha * (tmp1 + tmp2)
}
}
return
}
for i := 0; i < n; i++ {
atmp := a[i*lda : i*lda+k]
btmp := b[i*lda : i*lda+k]
ctmp := c[i*ldc : i*ldc+i+1]
for j := 0; j <= i; j++ {
var tmp1, tmp2 float64
binner := b[j*ldb : j*ldb+k]
for l, v := range a[j*lda : j*lda+k] {
tmp1 += v * btmp[l]
tmp2 += atmp[l] * binner[l]
}
ctmp[j] *= beta
ctmp[j] += alpha * (tmp1 + tmp2)
}
}
return
}
if ul == blas.Upper {
for i := 0; i < n; i++ {
ctmp := c[i*ldc+i : i*ldc+n]
if beta != 1 {
for j := range ctmp {
ctmp[j] *= beta
}
}
for l := 0; l < k; l++ {
tmp1 := alpha * b[l*lda+i]
tmp2 := alpha * a[l*lda+i]
btmp := b[l*ldb+i : l*ldb+n]
if tmp1 != 0 || tmp2 != 0 {
for j, v := range a[l*lda+i : l*lda+n] {
ctmp[j] += v*tmp1 + btmp[j]*tmp2
}
}
}
}
return
}
for i := 0; i < n; i++ {
ctmp := c[i*ldc : i*ldc+i+1]
if beta != 1 {
for j := range ctmp {
ctmp[j] *= beta
}
}
for l := 0; l < k; l++ {
tmp1 := alpha * b[l*lda+i]
tmp2 := alpha * a[l*lda+i]
btmp := b[l*ldb : l*ldb+i+1]
if tmp1 != 0 || tmp2 != 0 {
for j, v := range a[l*lda : l*lda+i+1] {
ctmp[j] += v*tmp1 + btmp[j]*tmp2
}
}
}
}
}
// Dtrmm performs
// B = alpha * A * B if tA == blas.NoTrans and side == blas.Left
// B = alpha * A^T * B if tA == blas.Trans and side == blas.Left
// B = alpha * B * A if tA == blas.NoTrans and side == blas.Right
// B = alpha * B * A^T if tA == blas.Trans and side == blas.Right
// where A is an n×n triangular matrix, and B is an m×n matrix.
func (Implementation) Dtrmm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m, n int, alpha float64, a []float64, lda int, b []float64, ldb int) {
if s != blas.Left && s != blas.Right {
panic(badSide)
}
if ul != blas.Lower && ul != blas.Upper {
panic(badUplo)
}
if tA != blas.NoTrans && tA != blas.Trans && tA != blas.ConjTrans {
panic(badTranspose)
}
if d != blas.NonUnit && d != blas.Unit {
panic(badDiag)
}
if m < 0 {
panic(mLT0)
}
if n < 0 {
panic(nLT0)
}
if ldb < n {
panic(badLdB)
}
if s == blas.Left {
if lda < m {
panic(badLdA)
}
} else {
if lda < n {
panic(badLdA)
}
}
if alpha == 0 {
for i := 0; i < m; i++ {
btmp := b[i*ldb : i*ldb+n]
for j := range btmp {
btmp[j] = 0
}
}
return
}
nonUnit := d == blas.NonUnit
if s == blas.Left {
if tA == blas.NoTrans {
if ul == blas.Upper {
for i := 0; i < m; i++ {
tmp := alpha
if nonUnit {
tmp *= a[i*lda+i]
}
btmp := b[i*ldb : i*ldb+n]
for j := range btmp {
btmp[j] *= tmp
}
for ka, va := range a[i*lda+i+1 : i*lda+m] {
k := ka + i + 1
tmp := alpha * va
if tmp != 0 {
asm.DaxpyUnitary(tmp, b[k*ldb:k*ldb+n], btmp, btmp)
}
}
}
return
}
for i := m - 1; i >= 0; i-- {
tmp := alpha
if nonUnit {
tmp *= a[i*lda+i]
}
btmp := b[i*ldb : i*ldb+n]
for j := range btmp {
btmp[j] *= tmp
}
for k, va := range a[i*lda : i*lda+i] {
tmp := alpha * va
if tmp != 0 {
asm.DaxpyUnitary(tmp, b[k*ldb:k*ldb+n], btmp, btmp)
}
}
}
return
}
// Cases where a is transposed.
if ul == blas.Upper {
for k := m - 1; k >= 0; k-- {
btmpk := b[k*ldb : k*ldb+n]
for ia, va := range a[k*lda+k+1 : k*lda+m] {
i := ia + k + 1
btmp := b[i*ldb : i*ldb+n]
tmp := alpha * va
if tmp != 0 {
asm.DaxpyUnitary(tmp, btmpk, btmp, btmp)
}
}
tmp := alpha
if nonUnit {
tmp *= a[k*lda+k]
}
if tmp != 1 {
for j := 0; j < n; j++ {
btmpk[j] *= tmp
}
}
}
return
}
for k := 0; k < m; k++ {
btmpk := b[k*ldb : k*ldb+n]
for i, va := range a[k*lda : k*lda+k] {
btmp := b[i*ldb : i*ldb+n]
tmp := alpha * va
if tmp != 0 {
asm.DaxpyUnitary(tmp, btmpk, btmp, btmp)
}
}
tmp := alpha
if nonUnit {
tmp *= a[k*lda+k]
}
if tmp != 1 {
for j := 0; j < n; j++ {
btmpk[j] *= tmp
}
}
}
return
}
// Cases where a is on the right
if tA == blas.NoTrans {
if ul == blas.Upper {
for i := 0; i < m; i++ {
btmp := b[i*ldb : i*ldb+n]
for k := n - 1; k >= 0; k-- {
tmp := alpha * btmp[k]
if tmp != 0 {
btmp[k] = tmp
if nonUnit {
btmp[k] *= a[k*lda+k]
}
for ja, v := range a[k*lda+k+1 : k*lda+n] {
j := ja + k + 1
btmp[j] += tmp * v
}
}
}
}
return
}
for i := 0; i < m; i++ {
btmp := b[i*ldb : i*ldb+n]
for k := 0; k < n; k++ {
tmp := alpha * btmp[k]
if tmp != 0 {
btmp[k] = tmp
if nonUnit {
btmp[k] *= a[k*lda+k]
}
asm.DaxpyUnitary(tmp, a[k*lda:k*lda+k], btmp, btmp)
}
}
}
return
}
// Cases where a is transposed.
if ul == blas.Upper {
for i := 0; i < m; i++ {
btmp := b[i*lda : i*lda+n]
for j, vb := range btmp {
tmp := vb
if nonUnit {
tmp *= a[j*lda+j]
}
tmp += asm.DdotUnitary(a[j*lda+j+1:j*lda+n], btmp[j+1:n])
btmp[j] = alpha * tmp
}
}
return
}
for i := 0; i < m; i++ {
btmp := b[i*lda : i*lda+n]
for j := n - 1; j >= 0; j-- {
tmp := btmp[j]
if nonUnit {
tmp *= a[j*lda+j]
}
tmp += asm.DdotUnitary(a[j*lda:j*lda+j], btmp[:j])
btmp[j] = alpha * tmp
}
}
}

View file

@ -1,31 +0,0 @@
package native
import (
"testing"
"github.com/gonum/blas/testblas"
)
func TestDgemm(t *testing.T) {
testblas.TestDgemm(t, impl)
}
func TestDsymm(t *testing.T) {
testblas.DsymmTest(t, impl)
}
func TestDtrsm(t *testing.T) {
testblas.DtrsmTest(t, impl)
}
func TestDsyrk(t *testing.T) {
testblas.DsyrkTest(t, impl)
}
func TestDsyr2k(t *testing.T) {
testblas.Dsyr2kTest(t, impl)
}
func TestDtrmm(t *testing.T) {
testblas.DtrmmTest(t, impl)
}

View file

@ -1,178 +0,0 @@
// Copyright ©2014 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package native
import (
"math/rand"
"testing"
"github.com/gonum/blas"
)
func TestDgemmParallel(t *testing.T) {
for i, test := range []struct {
m int
n int
k int
alpha float64
tA blas.Transpose
tB blas.Transpose
}{
{
m: 3,
n: 4,
k: 2,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: blockSize*2 + 5,
n: 3,
k: 2,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: 3,
n: blockSize * 2,
k: 2,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: 2,
n: 3,
k: blockSize*3 - 2,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: blockSize * minParBlock,
n: 3,
k: 2,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: 3,
n: blockSize * minParBlock,
k: 2,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: 2,
n: 3,
k: blockSize * minParBlock,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: blockSize*minParBlock + 1,
n: blockSize * minParBlock,
k: 3,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: 3,
n: blockSize*minParBlock + 2,
k: blockSize * 3,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: blockSize * minParBlock,
n: 3,
k: blockSize * minParBlock,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: blockSize * minParBlock,
n: blockSize * minParBlock,
k: blockSize * 3,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: blockSize + blockSize/2,
n: blockSize + blockSize/2,
k: blockSize + blockSize/2,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
} {
testMatchParallelSerial(t, i, blas.NoTrans, blas.NoTrans, test.m, test.n, test.k, test.alpha)
testMatchParallelSerial(t, i, blas.Trans, blas.NoTrans, test.m, test.n, test.k, test.alpha)
testMatchParallelSerial(t, i, blas.NoTrans, blas.Trans, test.m, test.n, test.k, test.alpha)
testMatchParallelSerial(t, i, blas.Trans, blas.Trans, test.m, test.n, test.k, test.alpha)
}
}
func testMatchParallelSerial(t *testing.T, i int, tA, tB blas.Transpose, m, n, k int, alpha float64) {
var (
rowA, colA int
rowB, colB int
)
if tA == blas.NoTrans {
rowA = m
colA = k
} else {
rowA = k
colA = m
}
if tB == blas.NoTrans {
rowB = k
colB = n
} else {
rowB = n
colB = k
}
a := randmat(rowA, colA, colA)
b := randmat(rowB, colB, colB)
c := randmat(m, n, n)
aClone := a.clone()
bClone := b.clone()
cClone := c.clone()
dgemmSerial(tA, tB, a, b, cClone, alpha)
dgemmParallel(tA, tB, a, b, c, alpha)
if !a.equal(aClone) {
t.Errorf("Case %v: a changed during call to dgemmParallel", i)
}
if !b.equal(bClone) {
t.Errorf("Case %v: b changed during call to dgemmParallel", i)
}
if !c.equalWithinAbs(cClone, 1e-12) {
t.Errorf("Case %v: answer not equal parallel and serial", i)
}
}
func randmat(r, c, stride int) general {
data := make([]float64, r*stride+c)
for i := range data {
data[i] = rand.Float64()
}
return general{
data: data,
rows: r,
cols: c,
stride: stride,
}
}

View file

@ -1,284 +0,0 @@
// Copyright 2014 The Gonum Authors. All rights reserved.
// Use of this code is governed by a BSD-style
// license that can be found in the LICENSE file
// Script for automatic code generation of the benchmark routines
package main
import (
"fmt"
"os"
"path/filepath"
"strconv"
)
var gopath string
var copyrightnotice = []byte(`// Copyright 2014 The Gonum Authors. All rights reserved.
// Use of this code is governed by a BSD-style
// license that can be found in the LICENSE file`)
var autogen = []byte(`// This file is autogenerated by github.com/gonum/blas/testblas/benchautogen/autogen_bench_level1double.go`)
var imports = []byte(`import(
"math/rand"
"testing"
"github.com/gonum/blas"
)`)
var randomSliceFunction = []byte(`func randomSlice(l, idx int) ([]float64) {
if idx < 0{
idx = -idx
}
s := make([]float64, l * idx)
for i := range s {
s[i] = rand.Float64()
}
return s
}`)
const (
posInc1 = 5
posInc2 = 3
negInc1 = -3
negInc2 = -4
)
var level1Sizes = []struct {
lower string
upper string
camel string
size int
}{
{
lower: "small",
upper: "SMALL_SLICE",
camel: "Small",
size: 10,
},
{
lower: "medium",
upper: "MEDIUM_SLICE",
camel: "Medium",
size: 1000,
},
{
lower: "large",
upper: "LARGE_SLICE",
camel: "Large",
size: 100000,
},
{
lower: "huge",
upper: "HUGE_SLICE",
camel: "Huge",
size: 10000000,
},
}
type level1functionStruct struct {
camel string
sig string
call string
extraSetup string
oneInput bool
extraName string // if have a couple different cases for the same function
}
var level1Functions = []level1functionStruct{
{
camel: "Ddot",
sig: "n int, x []float64, incX int, y []float64, incY int",
call: "n, x, incX, y, incY",
oneInput: false,
},
{
camel: "Dnrm2",
sig: "n int, x []float64, incX int",
call: "n, x, incX",
oneInput: true,
},
{
camel: "Dasum",
sig: "n int, x []float64, incX int",
call: "n, x, incX",
oneInput: true,
},
{
camel: "Idamax",
sig: "n int, x []float64, incX int",
call: "n, x, incX",
oneInput: true,
},
{
camel: "Dswap",
sig: "n int, x []float64, incX int, y []float64, incY int",
call: "n, x, incX, y, incY",
oneInput: false,
},
{
camel: "Dcopy",
sig: "n int, x []float64, incX int, y []float64, incY int",
call: "n, x, incX, y, incY",
oneInput: false,
},
{
camel: "Daxpy",
sig: "n int, alpha float64, x []float64, incX int, y []float64, incY int",
call: "n, alpha, x, incX, y, incY",
extraSetup: "alpha := 2.4",
oneInput: false,
},
{
camel: "Drot",
sig: "n int, x []float64, incX int, y []float64, incY int, c, s float64",
call: "n, x, incX, y, incY, c, s",
extraSetup: "c := 0.89725836967\ns:= 0.44150585279",
oneInput: false,
},
{
camel: "Drotm",
sig: "n int, x []float64, incX int, y []float64, incY int, p blas.DrotmParams",
call: "n, x, incX, y, incY, p",
extraSetup: "p := blas.DrotmParams{Flag: blas.OffDiagonal, H: [4]float64{0, -0.625, 0.9375,0}}",
oneInput: false,
extraName: "OffDia",
},
{
camel: "Drotm",
sig: "n int, x []float64, incX int, y []float64, incY int, p blas.DrotmParams",
call: "n, x, incX, y, incY, p",
extraSetup: "p := blas.DrotmParams{Flag: blas.OffDiagonal, H: [4]float64{5.0 / 12, 0, 0, 0.625}}",
oneInput: false,
extraName: "Dia",
},
{
camel: "Drotm",
sig: "n int, x []float64, incX int, y []float64, incY int, p blas.DrotmParams",
call: "n, x, incX, y, incY, p",
extraSetup: "p := blas.DrotmParams{Flag: blas.OffDiagonal, H: [4]float64{4096, -3584, 1792, 4096}}",
oneInput: false,
extraName: "Resc",
},
{
camel: "Dscal",
sig: "n int, alpha float64, x []float64, incX int",
call: "n, alpha, x, incX",
extraSetup: "alpha := 2.4",
oneInput: true,
},
}
func init() {
gopath = os.Getenv("GOPATH")
if gopath == "" {
panic("gopath not set")
}
}
func main() {
blasPath := filepath.Join(gopath, "src", "github.com", "gonum", "blas")
pkgs := []struct{ name string }{{name: "native"}, {name: "cgo"}}
for _, pkg := range pkgs {
err := level1(filepath.Join(blasPath, pkg.name), pkg.name)
if err != nil {
fmt.Println(err)
os.Exit(1)
}
}
}
func printHeader(f *os.File, name string) error {
if _, err := f.Write([]byte(copyrightnotice)); err != nil {
return err
}
f.WriteString("\n\n")
f.Write(autogen)
f.WriteString("\n\n")
f.WriteString("package " + name)
f.WriteString("\n\n")
f.Write(imports)
f.WriteString("\n\n")
return nil
}
// Generate the benchmark scripts for level1
func level1(benchPath string, pkgname string) error {
// Generate level 1 benchmarks
level1Filepath := filepath.Join(benchPath, "level1doubleBench_auto_test.go")
f, err := os.Create(level1Filepath)
if err != nil {
fmt.Println(err)
os.Exit(1)
}
defer f.Close()
printHeader(f, pkgname)
// Print all of the constants
f.WriteString("const (\n")
f.WriteString("\tposInc1 = " + strconv.Itoa(posInc1) + "\n")
f.WriteString("\tposInc2 = " + strconv.Itoa(posInc2) + "\n")
f.WriteString("\tnegInc1 = " + strconv.Itoa(negInc1) + "\n")
f.WriteString("\tnegInc2 = " + strconv.Itoa(negInc2) + "\n")
for _, con := range level1Sizes {
f.WriteString("\t" + con.upper + " = " + strconv.Itoa(con.size) + "\n")
}
f.WriteString(")\n")
f.WriteString("\n")
// Write the randomSlice function
f.Write(randomSliceFunction)
f.WriteString("\n\n")
// Start writing the benchmarks
for _, fun := range level1Functions {
writeLevel1Benchmark(fun, f)
f.WriteString("\n/* ------------------ */ \n")
}
return nil
}
func writeLevel1Benchmark(fun level1functionStruct, f *os.File) {
// First, write the base benchmark file
f.WriteString("func benchmark" + fun.camel + fun.extraName + "(b *testing.B, ")
f.WriteString(fun.sig)
f.WriteString(") {\n")
f.WriteString("b.ResetTimer()\n")
f.WriteString("for i := 0; i < b.N; i++{\n")
f.WriteString("\timpl." + fun.camel + "(")
f.WriteString(fun.call)
f.WriteString(")\n}\n}\n")
f.WriteString("\n")
// Write all of the benchmarks to call it
for _, sz := range level1Sizes {
lambda := func(incX, incY, name string, twoInput bool) {
f.WriteString("func Benchmark" + fun.camel + fun.extraName + sz.camel + name + "(b *testing.B){\n")
f.WriteString("n := " + sz.upper + "\n")
f.WriteString("incX := " + incX + "\n")
f.WriteString("x := randomSlice(n, incX)\n")
if twoInput {
f.WriteString("incY := " + incY + "\n")
f.WriteString("y := randomSlice(n, incY)\n")
}
f.WriteString(fun.extraSetup + "\n")
f.WriteString("benchmark" + fun.camel + fun.extraName + "(b, " + fun.call + ")\n")
f.WriteString("}\n\n")
}
if fun.oneInput {
lambda("1", "", "UnitaryInc", false)
lambda("posInc1", "", "PosInc", false)
} else {
lambda("1", "1", "BothUnitary", true)
lambda("posInc1", "1", "IncUni", true)
lambda("1", "negInc1", "UniInc", true)
lambda("posInc1", "negInc1", "BothInc", true)
}
}
}

View file

@ -1,8 +0,0 @@
package testblas
const (
SmallMat = 10
MediumMat = 100
LargeMat = 1000
HugeMat = 10000
)

View file

@ -1,234 +0,0 @@
package testblas
import (
"math"
"testing"
"github.com/gonum/blas"
)
// throwPanic will throw unexpected panics if true, or will just report them as errors if false
const throwPanic = true
func dTolEqual(a, b float64) bool {
if math.IsNaN(a) && math.IsNaN(b) {
return true
}
m := math.Max(math.Abs(a), math.Abs(b))
if m > 1 {
a /= m
b /= m
}
if math.Abs(a-b) < 1e-14 {
return true
}
return false
}
func dSliceTolEqual(a, b []float64) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if !dTolEqual(a[i], b[i]) {
return false
}
}
return true
}
func dStridedSliceTolEqual(n int, a []float64, inca int, b []float64, incb int) bool {
ia := 0
ib := 0
if inca <= 0 {
ia = -(n - 1) * inca
}
if incb <= 0 {
ib = -(n - 1) * incb
}
for i := 0; i < n; i++ {
if !dTolEqual(a[ia], b[ib]) {
return false
}
ia += inca
ib += incb
}
return true
}
func dSliceEqual(a, b []float64) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if !(a[i] == b[i]) {
return false
}
}
return true
}
func dCopyTwoTmp(x, xTmp, y, yTmp []float64) {
if len(x) != len(xTmp) {
panic("x size mismatch")
}
if len(y) != len(yTmp) {
panic("y size mismatch")
}
for i, val := range x {
xTmp[i] = val
}
for i, val := range y {
yTmp[i] = val
}
}
// returns true if the function panics
func panics(f func()) (b bool) {
defer func() {
err := recover()
if err != nil {
b = true
}
}()
f()
return
}
func testpanics(f func(), name string, t *testing.T) {
b := panics(f)
if !b {
t.Errorf("%v should panic and does not", name)
}
}
func sliceOfSliceCopy(a [][]float64) [][]float64 {
n := make([][]float64, len(a))
for i := range a {
n[i] = make([]float64, len(a[i]))
copy(n[i], a[i])
}
return n
}
func sliceCopy(a []float64) []float64 {
n := make([]float64, len(a))
copy(n, a)
return n
}
func flatten(a [][]float64) []float64 {
if len(a) == 0 {
return nil
}
m := len(a)
n := len(a[0])
s := make([]float64, m*n)
for i := 0; i < m; i++ {
for j := 0; j < n; j++ {
s[i*n+j] = a[i][j]
}
}
return s
}
func unflatten(a []float64, m, n int) [][]float64 {
s := make([][]float64, m)
for i := 0; i < m; i++ {
s[i] = make([]float64, n)
for j := 0; j < n; j++ {
s[i][j] = a[i*n+j]
}
}
return s
}
// flattenTriangular turns the upper or lower triangle of a dense slice of slice
// into a single slice with packed storage. a must be a square matrix.
func flattenTriangular(a [][]float64, ul blas.Uplo) []float64 {
m := len(a)
aFlat := make([]float64, m*(m+1)/2)
var k int
if ul == blas.Upper {
for i := 0; i < m; i++ {
k += copy(aFlat[k:], a[i][i:])
}
return aFlat
}
for i := 0; i < m; i++ {
k += copy(aFlat[k:], a[i][:i+1])
}
return aFlat
}
// flattenBanded turns a dense banded slice of slice into the compact banded matrix format
func flattenBanded(a [][]float64, ku, kl int) []float64 {
m := len(a)
n := len(a[0])
if ku < 0 || kl < 0 {
panic("testblas: negative band length")
}
nRows := m
nCols := (ku + kl + 1)
aflat := make([]float64, nRows*nCols)
for i := range aflat {
aflat[i] = math.NaN()
}
// loop over the rows, and then the bands
// elements in the ith row stay in the ith row
// order in bands is kept
for i := 0; i < nRows; i++ {
min := -kl
if i-kl < 0 {
min = -i
}
max := ku
if i+ku >= n {
max = n - i - 1
}
for j := min; j <= max; j++ {
col := kl + j
aflat[i*nCols+col] = a[i][i+j]
}
}
return aflat
}
// makeIncremented takes a slice with inc == 1 and makes an incremented version
// and adds extra values on the end
func makeIncremented(x []float64, inc int, extra int) []float64 {
if inc == 0 {
panic("zero inc")
}
absinc := inc
if absinc < 0 {
absinc = -inc
}
xcopy := make([]float64, len(x))
if inc > 0 {
copy(xcopy, x)
} else {
for i := 0; i < len(x); i++ {
xcopy[i] = x[len(x)-i-1]
}
}
// don't use NaN because it makes comparison hard
// Do use a weird unique value for easier debugging
counter := 100.0
var xnew []float64
for i, v := range xcopy {
xnew = append(xnew, v)
if i != len(x)-1 {
for j := 0; j < absinc-1; j++ {
xnew = append(xnew, counter)
counter++
}
}
}
for i := 0; i < extra; i++ {
xnew = append(xnew, counter)
counter++
}
return xnew
}

View file

@ -1,187 +0,0 @@
package testblas
import (
"math"
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
)
func TestFlattenBanded(t *testing.T) {
for i, test := range []struct {
dense [][]float64
ku int
kl int
condensed [][]float64
}{
{
dense: [][]float64{{3}},
ku: 0,
kl: 0,
condensed: [][]float64{{3}},
},
{
dense: [][]float64{
{3, 4, 0},
},
ku: 1,
kl: 0,
condensed: [][]float64{
{3, 4},
},
},
{
dense: [][]float64{
{3, 4, 0, 0, 0},
},
ku: 1,
kl: 0,
condensed: [][]float64{
{3, 4},
},
},
{
dense: [][]float64{
{3, 4, 0},
{0, 5, 8},
{0, 0, 2},
{0, 0, 0},
{0, 0, 0},
},
ku: 1,
kl: 0,
condensed: [][]float64{
{3, 4},
{5, 8},
{2, math.NaN()},
{math.NaN(), math.NaN()},
{math.NaN(), math.NaN()},
},
},
{
dense: [][]float64{
{3, 4, 6},
{0, 5, 8},
{0, 0, 2},
{0, 0, 0},
{0, 0, 0},
},
ku: 2,
kl: 0,
condensed: [][]float64{
{3, 4, 6},
{5, 8, math.NaN()},
{2, math.NaN(), math.NaN()},
{math.NaN(), math.NaN(), math.NaN()},
{math.NaN(), math.NaN(), math.NaN()},
},
},
{
dense: [][]float64{
{3, 4, 6},
{1, 5, 8},
{0, 6, 2},
{0, 0, 7},
{0, 0, 0},
},
ku: 2,
kl: 1,
condensed: [][]float64{
{math.NaN(), 3, 4, 6},
{1, 5, 8, math.NaN()},
{6, 2, math.NaN(), math.NaN()},
{7, math.NaN(), math.NaN(), math.NaN()},
{math.NaN(), math.NaN(), math.NaN(), math.NaN()},
},
},
{
dense: [][]float64{
{1, 2, 0},
{3, 4, 5},
{6, 7, 8},
{0, 9, 10},
{0, 0, 11},
},
ku: 1,
kl: 2,
condensed: [][]float64{
{math.NaN(), math.NaN(), 1, 2},
{math.NaN(), 3, 4, 5},
{6, 7, 8, math.NaN()},
{9, 10, math.NaN(), math.NaN()},
{11, math.NaN(), math.NaN(), math.NaN()},
},
},
{
dense: [][]float64{
{1, 0, 0},
{3, 4, 0},
{6, 7, 8},
{0, 9, 10},
{0, 0, 11},
},
ku: 0,
kl: 2,
condensed: [][]float64{
{math.NaN(), math.NaN(), 1},
{math.NaN(), 3, 4},
{6, 7, 8},
{9, 10, math.NaN()},
{11, math.NaN(), math.NaN()},
},
},
{
dense: [][]float64{
{1, 0, 0, 0, 0},
{3, 4, 0, 0, 0},
{1, 3, 5, 0, 0},
},
ku: 0,
kl: 2,
condensed: [][]float64{
{math.NaN(), math.NaN(), 1},
{math.NaN(), 3, 4},
{1, 3, 5},
},
},
} {
condensed := flattenBanded(test.dense, test.ku, test.kl)
correct := flatten(test.condensed)
if !floats.Same(condensed, correct) {
t.Errorf("Case %v mismatch. Want %v, got %v.", i, correct, condensed)
}
}
}
func TestFlattenTriangular(t *testing.T) {
for i, test := range []struct {
a [][]float64
ans []float64
ul blas.Uplo
}{
{
a: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 6},
},
ul: blas.Upper,
ans: []float64{1, 2, 3, 4, 5, 6},
},
{
a: [][]float64{
{1, 0, 0},
{2, 3, 0},
{4, 5, 6},
},
ul: blas.Lower,
ans: []float64{1, 2, 3, 4, 5, 6},
},
} {
a := flattenTriangular(test.a, test.ul)
if !floats.Equal(a, test.ans) {
t.Errorf("Case %v. Want %v, got %v.", i, test.ans, a)
}
}
}

View file

@ -1,94 +0,0 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
)
type Dgbmver interface {
Dgbmv(tA blas.Transpose, m, n, kL, kU int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int)
}
func DgbmvTest(t *testing.T, blasser Dgbmver) {
for i, test := range []struct {
tA blas.Transpose
m, n int
kL, kU int
alpha float64
a [][]float64
lda int
x []float64
beta float64
y []float64
ans []float64
}{
{
tA: blas.NoTrans,
m: 9,
n: 6,
lda: 4,
kL: 2,
kU: 1,
alpha: 3.0,
beta: 2.0,
a: [][]float64{
{5, 3, 0, 0, 0, 0},
{-1, 2, 9, 0, 0, 0},
{4, 8, 3, 6, 0, 0},
{0, -1, 8, 2, 1, 0},
{0, 0, 9, 9, 9, 5},
{0, 0, 0, 2, -3, 2},
{0, 0, 0, 0, 1, 5},
{0, 0, 0, 0, 0, 6},
{0, 0, 0, 0, 0, 0},
},
x: []float64{1, 2, 3, 4, 5, 6},
y: []float64{-1, -2, -3, -4, -5, -6, -7, -8, -9},
ans: []float64{31, 86, 153, 97, 404, 3, 91, 92, -18},
},
{
tA: blas.Trans,
m: 9,
n: 6,
lda: 4,
kL: 2,
kU: 1,
alpha: 3.0,
beta: 2.0,
a: [][]float64{
{5, 3, 0, 0, 0, 0},
{-1, 2, 9, 0, 0, 0},
{4, 8, 3, 6, 0, 0},
{0, -1, 8, 2, 1, 0},
{0, 0, 9, 9, 9, 5},
{0, 0, 0, 2, -3, 2},
{0, 0, 0, 0, 1, 5},
{0, 0, 0, 0, 0, 6},
{0, 0, 0, 0, 0, 0},
},
x: []float64{1, 2, 3, 4, 5, 6, 7, 8, 9},
y: []float64{-1, -2, -3, -4, -5, -6},
ans: []float64{43, 77, 306, 241, 104, 348},
},
} {
extra := 3
aFlat := flattenBanded(test.a, test.kU, test.kL)
incTest := func(incX, incY, extra int) {
xnew := makeIncremented(test.x, incX, extra)
ynew := makeIncremented(test.y, incY, extra)
ans := makeIncremented(test.ans, incY, extra)
blasser.Dgbmv(test.tA, test.m, test.n, test.kL, test.kU, test.alpha, aFlat, test.lda, xnew, incX, test.beta, ynew, incY)
if !dSliceTolEqual(ans, ynew) {
t.Errorf("Case %v: Want %v, got %v", i, ans, ynew)
}
}
incTest(1, 1, extra)
incTest(1, 3, extra)
incTest(1, -3, extra)
incTest(2, 3, extra)
incTest(2, -3, extra)
incTest(3, 2, extra)
incTest(-3, 2, extra)
}
}

View file

@ -1,252 +0,0 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
)
type Dgemmer interface {
Dgemm(tA, tB blas.Transpose, m, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int)
}
type DgemmCase struct {
isATrans bool
m, n, k int
alpha, beta float64
a [][]float64
aTrans [][]float64 // transpose of a
b [][]float64
c [][]float64
ans [][]float64
}
var DgemmCases = []DgemmCase{
{
m: 4,
n: 3,
k: 2,
isATrans: false,
alpha: 2,
beta: 0.5,
a: [][]float64{
{1, 2},
{4, 5},
{7, 8},
{10, 11},
},
b: [][]float64{
{1, 5, 6},
{5, -8, 8},
},
c: [][]float64{
{4, 8, -9},
{12, 16, -8},
{1, 5, 15},
{-3, -4, 7},
},
ans: [][]float64{
{24, -18, 39.5},
{64, -32, 124},
{94.5, -55.5, 219.5},
{128.5, -78, 299.5},
},
},
{
m: 4,
n: 2,
k: 3,
isATrans: false,
alpha: 2,
beta: 0.5,
a: [][]float64{
{1, 2, 3},
{4, 5, 6},
{7, 8, 9},
{10, 11, 12},
},
b: [][]float64{
{1, 5},
{5, -8},
{6, 2},
},
c: [][]float64{
{4, 8},
{12, 16},
{1, 5},
{-3, -4},
},
ans: [][]float64{
{60, -6},
{136, -8},
{202.5, -19.5},
{272.5, -30},
},
},
{
m: 3,
n: 2,
k: 4,
isATrans: false,
alpha: 2,
beta: 0.5,
a: [][]float64{
{1, 2, 3, 4},
{4, 5, 6, 7},
{8, 9, 10, 11},
},
b: [][]float64{
{1, 5},
{5, -8},
{6, 2},
{8, 10},
},
c: [][]float64{
{4, 8},
{12, 16},
{9, -10},
},
ans: [][]float64{
{124, 74},
{248, 132},
{406.5, 191},
},
},
{
m: 3,
n: 4,
k: 2,
isATrans: false,
alpha: 2,
beta: 0.5,
a: [][]float64{
{1, 2},
{4, 5},
{8, 9},
},
b: [][]float64{
{1, 5, 2, 1},
{5, -8, 2, 1},
},
c: [][]float64{
{4, 8, 2, 2},
{12, 16, 8, 9},
{9, -10, 10, 10},
},
ans: [][]float64{
{24, -18, 13, 7},
{64, -32, 40, 22.5},
{110.5, -69, 73, 39},
},
},
{
m: 2,
n: 4,
k: 3,
isATrans: false,
alpha: 2,
beta: 0.5,
a: [][]float64{
{1, 2, 3},
{4, 5, 6},
},
b: [][]float64{
{1, 5, 8, 8},
{5, -8, 9, 10},
{6, 2, -3, 2},
},
c: [][]float64{
{4, 8, 7, 8},
{12, 16, -2, 6},
},
ans: [][]float64{
{60, -6, 37.5, 72},
{136, -8, 117, 191},
},
},
{
m: 2,
n: 3,
k: 4,
isATrans: false,
alpha: 2,
beta: 0.5,
a: [][]float64{
{1, 2, 3, 4},
{4, 5, 6, 7},
},
b: [][]float64{
{1, 5, 8},
{5, -8, 9},
{6, 2, -3},
{8, 10, 2},
},
c: [][]float64{
{4, 8, 1},
{12, 16, 6},
},
ans: [][]float64{
{124, 74, 50.5},
{248, 132, 149},
},
},
}
// assumes [][]float64 is actually a matrix
func transpose(a [][]float64) [][]float64 {
b := make([][]float64, len(a[0]))
for i := range b {
b[i] = make([]float64, len(a))
for j := range b[i] {
b[i][j] = a[j][i]
}
}
return b
}
func TestDgemm(t *testing.T, blasser Dgemmer) {
for i, test := range DgemmCases {
// Test that it passes row major
dgemmcomp(i, "RowMajorNoTrans", t, blasser, blas.NoTrans, blas.NoTrans,
test.m, test.n, test.k, test.alpha, test.beta, test.a, test.b, test.c, test.ans)
// Try with A transposed
dgemmcomp(i, "RowMajorTransA", t, blasser, blas.Trans, blas.NoTrans,
test.m, test.n, test.k, test.alpha, test.beta, transpose(test.a), test.b, test.c, test.ans)
// Try with B transposed
dgemmcomp(i, "RowMajorTransB", t, blasser, blas.NoTrans, blas.Trans,
test.m, test.n, test.k, test.alpha, test.beta, test.a, transpose(test.b), test.c, test.ans)
// Try with both transposed
dgemmcomp(i, "RowMajorTransBoth", t, blasser, blas.Trans, blas.Trans,
test.m, test.n, test.k, test.alpha, test.beta, transpose(test.a), transpose(test.b), test.c, test.ans)
}
}
func dgemmcomp(i int, name string, t *testing.T, blasser Dgemmer, tA, tB blas.Transpose, m, n, k int,
alpha, beta float64, a [][]float64, b [][]float64, c [][]float64, ans [][]float64) {
aFlat := flatten(a)
aCopy := flatten(a)
bFlat := flatten(b)
bCopy := flatten(b)
cFlat := flatten(c)
ansFlat := flatten(ans)
lda := len(a[0])
ldb := len(b[0])
ldc := len(c[0])
// Compute the matrix multiplication
blasser.Dgemm(tA, tB, m, n, k, alpha, aFlat, lda, bFlat, ldb, beta, cFlat, ldc)
if !dSliceEqual(aFlat, aCopy) {
t.Errorf("Test %v case %v: a changed during call to Dgemm", i, name)
}
if !dSliceEqual(bFlat, bCopy) {
t.Errorf("Test %v case %v: b changed during call to Dgemm", i, name)
}
if !dSliceTolEqual(ansFlat, cFlat) {
t.Errorf("Test %v case %v: answer mismatch. Expected %v, Found %v", i, name, ansFlat, cFlat)
}
// TODO: Need to add a sub-slice test where don't use up full matrix
}

View file

@ -1,39 +0,0 @@
package testblas
import (
"math/rand"
"testing"
"github.com/gonum/blas"
)
func DgemmBenchmark(b *testing.B, dgemm Dgemmer, m, n, k int, tA, tB blas.Transpose) {
a := make([]float64, m*k)
for i := range a {
a[i] = rand.Float64()
}
bv := make([]float64, k*n)
for i := range bv {
bv[i] = rand.Float64()
}
c := make([]float64, m*n)
for i := range c {
c[i] = rand.Float64()
}
var lda, ldb int
if tA == blas.Trans {
lda = m
} else {
lda = k
}
if tB == blas.Trans {
ldb = k
} else {
ldb = n
}
ldc := n
b.ResetTimer()
for i := 0; i < b.N; i++ {
dgemm.Dgemm(tA, tB, m, n, k, 3.0, a, lda, bv, ldb, 1.0, c, ldc)
}
}

View file

@ -1,600 +0,0 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
)
type DgemvCase struct {
Name string
m int
n int
A [][]float64
tA blas.Transpose
x []float64
incX int
y []float64
incY int
xCopy []float64
yCopy []float64
Subcases []DgemvSubcase
}
type DgemvSubcase struct {
mulXNeg1 bool
mulYNeg1 bool
alpha float64
beta float64
ans []float64
}
var DgemvCases = []DgemvCase{
{
Name: "M_gt_N_Inc1_NoTrans",
tA: blas.NoTrans,
m: 5,
n: 3,
A: [][]float64{
{4.1, 6.2, 8.1},
{9.6, 3.5, 9.1},
{10, 7, 3},
{1, 1, 2},
{9, 2, 5},
},
incX: 1,
incY: 1,
x: []float64{1, 2, 3},
y: []float64{7, 8, 9, 10, 11},
Subcases: []DgemvSubcase{
{
alpha: 0,
beta: 0,
ans: []float64{0, 0, 0, 0, 0},
},
{
alpha: 0,
beta: 1,
ans: []float64{7, 8, 9, 10, 11},
},
{
alpha: 1,
beta: 0,
ans: []float64{40.8, 43.9, 33, 9, 28},
},
{
alpha: 8,
beta: -6,
ans: []float64{284.4, 303.2, 210, 12, 158},
},
},
},
{
Name: "M_gt_N_Inc1_Trans",
tA: blas.Trans,
m: 5,
n: 3,
A: [][]float64{
{4.1, 6.2, 8.1},
{9.6, 3.5, 9.1},
{10, 7, 3},
{1, 1, 2},
{9, 2, 5},
},
incX: 1,
incY: 1,
x: []float64{1, 2, 3, -4, 5},
y: []float64{7, 8, 9},
Subcases: []DgemvSubcase{
{
alpha: 0,
beta: 0,
ans: []float64{0, 0, 0},
},
{
alpha: 0,
beta: 1,
ans: []float64{7, 8, 9},
},
{
alpha: 1,
beta: 0,
ans: []float64{94.3, 40.2, 52.3},
},
{
alpha: 8,
beta: -6,
ans: []float64{712.4, 273.6, 364.4},
},
},
},
{
Name: "M_eq_N_Inc1_NoTrans",
tA: blas.NoTrans,
m: 3,
n: 3,
A: [][]float64{
{4.1, 6.2, 8.1},
{9.6, 3.5, 9.1},
{10, 7, 3},
},
incX: 1,
incY: 1,
x: []float64{1, 2, 3},
y: []float64{7, 2, 2},
Subcases: []DgemvSubcase{
{
alpha: 0,
beta: 0,
ans: []float64{0, 0, 0},
},
{
alpha: 0,
beta: 1,
ans: []float64{7, 2, 2},
},
{
alpha: 1,
beta: 0,
ans: []float64{40.8, 43.9, 33},
},
{
alpha: 8,
beta: -6,
ans: []float64{40.8*8 - 6*7, 43.9*8 - 6*2, 33*8 - 6*2},
},
},
},
{
Name: "M_eq_N_Inc1_Trans",
tA: blas.Trans,
m: 3,
n: 3,
A: [][]float64{
{4.1, 6.2, 8.1},
{9.6, 3.5, 9.1},
{10, 7, 3},
},
incX: 1,
incY: 1,
x: []float64{1, 2, 3},
y: []float64{7, 2, 2},
Subcases: []DgemvSubcase{
{
alpha: 8,
beta: -6,
ans: []float64{384.4, 261.6, 270.4},
},
},
},
{
Name: "M_lt_N_Inc1_NoTrans",
tA: blas.NoTrans,
m: 3,
n: 5,
A: [][]float64{
{4.1, 6.2, 8.1, 10, 7},
{9.6, 3.5, 9.1, -2, 9},
{10, 7, 3, 1, -5},
},
incX: 1,
incY: 1,
x: []float64{1, 2, 3, -7.6, 8.1},
y: []float64{7, 2, 2},
Subcases: []DgemvSubcase{
{
alpha: 0,
beta: 0,
ans: []float64{0, 0, 0},
},
{
alpha: 0,
beta: 1,
ans: []float64{7, 2, 2},
},
{
alpha: 1,
beta: 0,
ans: []float64{21.5, 132, -15.1},
},
{
alpha: 8,
beta: -6,
ans: []float64{21.5*8 - 6*7, 132*8 - 6*2, -15.1*8 - 6*2},
},
},
},
{
Name: "M_lt_N_Inc1_Trans",
tA: blas.Trans,
m: 3,
n: 5,
A: [][]float64{
{4.1, 6.2, 8.1, 10, 7},
{9.6, 3.5, 9.1, -2, 9},
{10, 7, 3, 1, -5},
},
incX: 1,
incY: 1,
x: []float64{1, 2, 3},
y: []float64{7, 2, 2, -3, 5},
Subcases: []DgemvSubcase{
{
alpha: 8,
beta: -6,
ans: []float64{384.4, 261.6, 270.4, 90, 50},
},
},
},
{
Name: "M_gt_N_IncNot1_NoTrans",
tA: blas.NoTrans,
m: 5,
n: 3,
A: [][]float64{
{4.1, 6.2, 8.1},
{9.6, 3.5, 9.1},
{10, 7, 3},
{1, 1, 2},
{9, 2, 5},
},
incX: 2,
incY: 3,
x: []float64{1, 15, 2, 150, 3},
y: []float64{7, 2, 6, 8, -4, -5, 9, 1, 1, 10, 19, 22, 11},
Subcases: []DgemvSubcase{
{
alpha: 8,
beta: -6,
ans: []float64{284.4, 2, 6, 303.2, -4, -5, 210, 1, 1, 12, 19, 22, 158},
},
{
mulXNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{220.4, 2, 6, 311.2, -4, -5, 322, 1, 1, -4, 19, 22, 222},
},
{
mulYNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{182, 2, 6, 24, -4, -5, 210, 1, 1, 291.2, 19, 22, 260.4},
},
{
mulXNeg1: true,
mulYNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{246, 2, 6, 8, -4, -5, 322, 1, 1, 299.2, 19, 22, 196.4},
},
},
},
{
Name: "M_gt_N_IncNot1_Trans",
tA: blas.Trans,
m: 5,
n: 3,
A: [][]float64{
{4.1, 6.2, 8.1},
{9.6, 3.5, 9.1},
{10, 7, 3},
{1, 1, 2},
{9, 2, 5},
},
incX: 2,
incY: 3,
x: []float64{1, 15, 2, 150, 3, 8, -3, 6, 5},
y: []float64{7, 2, 6, 8, -4, -5, 9},
Subcases: []DgemvSubcase{
{
alpha: 8,
beta: -6,
ans: []float64{720.4, 2, 6, 281.6, -4, -5, 380.4},
},
{
mulXNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{219.6, 2, 6, 316, -4, -5, 195.6},
},
{
mulYNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{392.4, 2, 6, 281.6, -4, -5, 708.4},
},
{
mulXNeg1: true,
mulYNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{207.6, 2, 6, 316, -4, -5, 207.6},
},
},
},
{
Name: "M_eq_N_IncNot1_NoTrans",
tA: blas.NoTrans,
m: 3,
n: 3,
A: [][]float64{
{4.1, 6.2, 8.1},
{9.6, 3.5, 9.1},
{10, 7, 3},
},
incX: 2,
incY: 3,
x: []float64{1, 15, 2, 150, 3},
y: []float64{7, 2, 6, 8, -4, -5, 9},
Subcases: []DgemvSubcase{
{
alpha: 8,
beta: -6,
ans: []float64{284.4, 2, 6, 303.2, -4, -5, 210},
},
{
mulXNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{220.4, 2, 6, 311.2, -4, -5, 322},
},
{
mulYNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{222, 2, 6, 303.2, -4, -5, 272.4},
},
{
mulXNeg1: true,
mulYNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{334, 2, 6, 311.2, -4, -5, 208.4},
},
},
},
{
Name: "M_eq_N_IncNot1_Trans",
tA: blas.Trans,
m: 3,
n: 3,
A: [][]float64{
{4.1, 6.2, 8.1},
{9.6, 3.5, 9.1},
{10, 7, 3},
},
incX: 2,
incY: 3,
x: []float64{1, 15, 2, 150, 3},
y: []float64{7, 2, 6, 8, -4, -5, 9},
Subcases: []DgemvSubcase{
{
alpha: 8,
beta: -6,
ans: []float64{384.4, 2, 6, 225.6, -4, -5, 228.4},
},
{
mulXNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{290, 2, 6, 212.8, -4, -5, 310},
},
{
mulYNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{240.4, 2, 6, 225.6, -4, -5, 372.4},
},
{
mulXNeg1: true,
mulYNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{322, 2, 6, 212.8, -4, -5, 278},
},
},
},
{
Name: "M_lt_N_IncNot1_NoTrans",
tA: blas.NoTrans,
m: 3,
n: 5,
A: [][]float64{
{4.1, 6.2, 8.1, 10, 11},
{9.6, 3.5, 9.1, -3, -2},
{10, 7, 3, -7, -4},
},
incX: 2,
incY: 3,
x: []float64{1, 15, 2, 150, 3, -2, -4, 8, -9},
y: []float64{7, 2, 6, 8, -4, -5, 9},
Subcases: []DgemvSubcase{
{
alpha: 8,
beta: -6,
ans: []float64{-827.6, 2, 6, 543.2, -4, -5, 722},
},
{
mulXNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{-93.2, 2, 6, -696.8, -4, -5, -1070},
},
{
mulYNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{734, 2, 6, 543.2, -4, -5, -839.6},
},
{
mulXNeg1: true,
mulYNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{-1058, 2, 6, -696.8, -4, -5, -105.2},
},
},
},
{
Name: "M_lt_N_IncNot1_Trans",
tA: blas.Trans,
m: 3,
n: 5,
A: [][]float64{
{4.1, 6.2, 8.1, 10, 11},
{9.6, 3.5, 9.1, -3, -2},
{10, 7, 3, -7, -4},
},
incX: 2,
incY: 3,
x: []float64{1, 15, 2, 150, 3},
y: []float64{7, 2, 6, 8, -4, -5, 9, -4, -1, -9, 1, 1, 2},
Subcases: []DgemvSubcase{
{
alpha: 8,
beta: -6,
ans: []float64{384.4, 2, 6, 225.6, -4, -5, 228.4, -4, -1, -82, 1, 1, -52},
},
{
mulXNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{290, 2, 6, 212.8, -4, -5, 310, -4, -1, 190, 1, 1, 188},
},
{
mulYNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{-82, 2, 6, -184, -4, -5, 228.4, -4, -1, 327.6, 1, 1, 414.4},
},
{
mulXNeg1: true,
mulYNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{158, 2, 6, 88, -4, -5, 310, -4, -1, 314.8, 1, 1, 320},
},
},
},
// TODO: A can be longer than mxn. Add cases where it is longer
// TODO: x and y can also be longer. Add tests for these
// TODO: Add tests for dimension mismatch
// TODO: Add places with a "submatrix view", where lda != m
}
type Dgemver interface {
Dgemv(tA blas.Transpose, m, n int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int)
}
func DgemvTest(t *testing.T, blasser Dgemver) {
for _, test := range DgemvCases {
for i, cas := range test.Subcases {
// Test that it passes with row-major
dgemvcomp(t, test, cas, i, blasser)
// Test the bad inputs
dgemvbad(t, test, cas, i, blasser)
}
}
}
func dgemvcomp(t *testing.T, test DgemvCase, cas DgemvSubcase, i int, blasser Dgemver) {
x := sliceCopy(test.x)
y := sliceCopy(test.y)
a := sliceOfSliceCopy(test.A)
aFlat := flatten(a)
lda := test.n
incX := test.incX
if cas.mulXNeg1 {
incX *= -1
}
incY := test.incY
if cas.mulYNeg1 {
incY *= -1
}
f := func() {
blasser.Dgemv(test.tA, test.m, test.n, cas.alpha, aFlat, lda, x, incX, cas.beta, y, incY)
}
if panics(f) {
t.Errorf("Test %v case %v: unexpected panic", test.Name, i)
if throwPanic {
blasser.Dgemv(test.tA, test.m, test.n, cas.alpha, aFlat, lda, x, incX, cas.beta, y, incY)
}
return
}
// Check that x and a are unchanged
if !dSliceEqual(x, test.x) {
t.Errorf("Test %v, case %v: x modified during call", test.Name, i)
}
aFlat2 := flatten(sliceOfSliceCopy(test.A))
if !dSliceEqual(aFlat2, aFlat) {
t.Errorf("Test %v, case %v: a modified during call", test.Name, i)
}
// Check that the answer matches
if !dSliceTolEqual(cas.ans, y) {
t.Errorf("Test %v, case %v: answer mismatch: Expected %v, Found %v", test.Name, i, cas.ans, y)
}
}
func dgemvbad(t *testing.T, test DgemvCase, cas DgemvSubcase, i int, blasser Dgemver) {
x := sliceCopy(test.x)
y := sliceCopy(test.y)
a := sliceOfSliceCopy(test.A)
aFlatRow := flatten(a)
ldaRow := test.n
f := func() {
blasser.Dgemv(312, test.m, test.n, cas.alpha, aFlatRow, ldaRow, x, test.incX, cas.beta, y, test.incY)
}
if !panics(f) {
t.Errorf("Test %v case %v: no panic for bad transpose", test.Name, i)
}
f = func() {
blasser.Dgemv(test.tA, -2, test.n, cas.alpha, aFlatRow, ldaRow, x, test.incX, cas.beta, y, test.incY)
}
if !panics(f) {
t.Errorf("Test %v case %v: no panic for m negative", test.Name, i)
}
f = func() {
blasser.Dgemv(test.tA, test.m, -4, cas.alpha, aFlatRow, ldaRow, x, test.incX, cas.beta, y, test.incY)
}
if !panics(f) {
t.Errorf("Test %v case %v: no panic for n negative", test.Name, i)
}
f = func() {
blasser.Dgemv(test.tA, test.m, test.n, cas.alpha, aFlatRow, ldaRow, x, 0, cas.beta, y, test.incY)
}
if !panics(f) {
t.Errorf("Test %v case %v: no panic for incX zero", test.Name, i)
}
f = func() {
blasser.Dgemv(test.tA, test.m, test.n, cas.alpha, aFlatRow, ldaRow, x, test.incX, cas.beta, y, 0)
}
if !panics(f) {
t.Errorf("Test %v case %v: no panic for incY zero", test.Name, i)
}
f = func() {
blasser.Dgemv(test.tA, test.m, test.n, cas.alpha, aFlatRow, ldaRow-1, x, test.incX, cas.beta, y, test.incY)
}
if !panics(f) {
t.Errorf("Test %v case %v: no panic for lda too small row major", test.Name, i)
}
}

View file

@ -1,164 +0,0 @@
package testblas
import "testing"
type Dgerer interface {
Dger(m, n int, alpha float64, x []float64, incX int, y []float64, incY int, a []float64, lda int)
}
func DgerTest(t *testing.T, blasser Dgerer) {
for _, test := range []struct {
name string
a [][]float64
m int
n int
x []float64
y []float64
incX int
incY int
ansAlphaEq1 []float64
trueAns [][]float64
}{
{
name: "M gt N inc 1",
m: 5,
n: 3,
a: [][]float64{
{1.3, 2.4, 3.5},
{2.6, 2.8, 3.3},
{-1.3, -4.3, -9.7},
{8, 9, -10},
{-12, -14, -6},
},
x: []float64{-2, -3, 0, 1, 2},
y: []float64{-1.1, 5, 0},
incX: 1,
incY: 1,
trueAns: [][]float64{{3.5, -7.6, 3.5}, {5.9, -12.2, 3.3}, {-1.3, -4.3, -9.7}, {6.9, 14, -10}, {-14.2, -4, -6}},
},
{
name: "M eq N inc 1",
m: 3,
n: 3,
a: [][]float64{
{1.3, 2.4, 3.5},
{2.6, 2.8, 3.3},
{-1.3, -4.3, -9.7},
},
x: []float64{-2, -3, 0},
y: []float64{-1.1, 5, 0},
incX: 1,
incY: 1,
trueAns: [][]float64{{3.5, -7.6, 3.5}, {5.9, -12.2, 3.3}, {-1.3, -4.3, -9.7}},
},
{
name: "M lt N inc 1",
m: 3,
n: 6,
a: [][]float64{
{1.3, 2.4, 3.5, 4.8, 1.11, -9},
{2.6, 2.8, 3.3, -3.4, 6.2, -8.7},
{-1.3, -4.3, -9.7, -3.1, 8.9, 8.9},
},
x: []float64{-2, -3, 0},
y: []float64{-1.1, 5, 0, 9, 19, 22},
incX: 1,
incY: 1,
trueAns: [][]float64{{3.5, -7.6, 3.5, -13.2, -36.89, -53}, {5.9, -12.2, 3.3, -30.4, -50.8, -74.7}, {-1.3, -4.3, -9.7, -3.1, 8.9, 8.9}},
},
{
name: "M gt N inc not 1",
m: 5,
n: 3,
a: [][]float64{
{1.3, 2.4, 3.5},
{2.6, 2.8, 3.3},
{-1.3, -4.3, -9.7},
{8, 9, -10},
{-12, -14, -6},
},
x: []float64{-2, -3, 0, 1, 2, 6, 0, 9, 7},
y: []float64{-1.1, 5, 0, 8, 7, -5, 7},
incX: 2,
incY: 3,
trueAns: [][]float64{{3.5, -13.6, -10.5}, {2.6, 2.8, 3.3}, {-3.5, 11.7, 4.3}, {8, 9, -10}, {-19.700000000000003, 42, 43}},
},
{
name: "M eq N inc not 1",
m: 3,
n: 3,
a: [][]float64{
{1.3, 2.4, 3.5},
{2.6, 2.8, 3.3},
{-1.3, -4.3, -9.7},
},
x: []float64{-2, -3, 0, 8, 7, -9, 7, -6, 12, 6, 6, 6, -11},
y: []float64{-1.1, 5, 0, 0, 9, 8, 6},
incX: 4,
incY: 3,
trueAns: [][]float64{{3.5, 2.4, -8.5}, {-5.1, 2.8, 45.3}, {-14.5, -4.3, 62.3}},
},
{
name: "M lt N inc not 1",
m: 3,
n: 6,
a: [][]float64{
{1.3, 2.4, 3.5, 4.8, 1.11, -9},
{2.6, 2.8, 3.3, -3.4, 6.2, -8.7},
{-1.3, -4.3, -9.7, -3.1, 8.9, 8.9},
},
x: []float64{-2, -3, 0, 0, 8, 0, 9, -3},
y: []float64{-1.1, 5, 0, 9, 19, 22, 11, -8.11, -9.22, 9.87, 7},
incX: 3,
incY: 2,
trueAns: [][]float64{{3.5, 2.4, -34.5, -17.2, 19.55, -23}, {2.6, 2.8, 3.3, -3.4, 6.2, -8.7}, {-11.2, -4.3, 161.3, 95.9, -74.08, 71.9}},
},
} {
// TODO: Add tests where a is longer
// TODO: Add panic tests
// TODO: Add negative increment tests
x := sliceCopy(test.x)
y := sliceCopy(test.y)
a := sliceOfSliceCopy(test.a)
// Test with row major
alpha := 1.0
aFlat := flatten(a)
blasser.Dger(test.m, test.n, alpha, x, test.incX, y, test.incY, aFlat, test.n)
ans := unflatten(aFlat, test.m, test.n)
dgercomp(t, x, test.x, y, test.y, ans, test.trueAns, test.name+" row maj")
// Test with different alpha
alpha = 4.0
aFlat = flatten(a)
blasser.Dger(test.m, test.n, alpha, x, test.incX, y, test.incY, aFlat, test.n)
ans = unflatten(aFlat, test.m, test.n)
trueCopy := sliceOfSliceCopy(test.trueAns)
for i := range trueCopy {
for j := range trueCopy[i] {
trueCopy[i][j] = alpha*(trueCopy[i][j]-a[i][j]) + a[i][j]
}
}
dgercomp(t, x, test.x, y, test.y, ans, trueCopy, test.name+" row maj alpha")
}
}
func dgercomp(t *testing.T, x, xCopy, y, yCopy []float64, ans [][]float64, trueAns [][]float64, name string) {
if !dSliceEqual(x, xCopy) {
t.Errorf("case %v: x modified during call to dger", name)
}
if !dSliceEqual(y, yCopy) {
t.Errorf("case %v: x modified during call to dger", name)
}
for i := range ans {
if !dSliceTolEqual(ans[i], trueAns[i]) {
t.Errorf("case %v: answer mismatch. Expected %v, Found %v", name, trueAns, ans)
break
}
}
}

View file

@ -1,83 +0,0 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
)
type Dsbmver interface {
Dsbmv(ul blas.Uplo, n, k int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int)
}
func DsbmvTest(t *testing.T, blasser Dsbmver) {
for i, test := range []struct {
ul blas.Uplo
n int
k int
alpha float64
beta float64
a [][]float64
x []float64
y []float64
ans []float64
}{
{
ul: blas.Upper,
n: 4,
k: 2,
alpha: 2,
beta: 3,
a: [][]float64{
{7, 8, 2, 0},
{0, 8, 2, -3},
{0, 0, 3, 6},
{0, 0, 0, 9},
},
x: []float64{1, 2, 3, 4},
y: []float64{-1, -2, -3, -4},
ans: []float64{55, 30, 69, 84},
},
{
ul: blas.Lower,
n: 4,
k: 2,
alpha: 2,
beta: 3,
a: [][]float64{
{7, 0, 0, 0},
{8, 8, 0, 0},
{2, 2, 3, 0},
{0, -3, 6, 9},
},
x: []float64{1, 2, 3, 4},
y: []float64{-1, -2, -3, -4},
ans: []float64{55, 30, 69, 84},
},
} {
extra := 0
var aFlat []float64
if test.ul == blas.Upper {
aFlat = flattenBanded(test.a, test.k, 0)
} else {
aFlat = flattenBanded(test.a, 0, test.k)
}
incTest := func(incX, incY, extra int) {
xnew := makeIncremented(test.x, incX, extra)
ynew := makeIncremented(test.y, incY, extra)
ans := makeIncremented(test.ans, incY, extra)
blasser.Dsbmv(test.ul, test.n, test.k, test.alpha, aFlat, test.k+1, xnew, incX, test.beta, ynew, incY)
if !dSliceTolEqual(ans, ynew) {
t.Errorf("Case %v: Want %v, got %v", i, ans, ynew)
}
}
incTest(1, 1, extra)
incTest(1, 3, extra)
incTest(1, -3, extra)
incTest(2, 3, extra)
incTest(2, -3, extra)
incTest(3, 2, extra)
incTest(-3, 2, extra)
}
}

View file

@ -1,73 +0,0 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
)
type Dspmver interface {
Dspmv(ul blas.Uplo, n int, alpha float64, ap []float64, x []float64, incX int, beta float64, y []float64, incY int)
}
func DspmvTest(t *testing.T, blasser Dspmver) {
for i, test := range []struct {
ul blas.Uplo
n int
a [][]float64
x []float64
y []float64
alpha float64
beta float64
ans []float64
}{
{
ul: blas.Upper,
n: 3,
a: [][]float64{
{5, 6, 7},
{0, 8, 10},
{0, 0, 13},
},
x: []float64{3, 4, 5},
y: []float64{6, 7, 8},
alpha: 2.1,
beta: -3,
ans: []float64{137.4, 189, 240.6},
},
{
ul: blas.Lower,
n: 3,
a: [][]float64{
{5, 0, 0},
{6, 8, 0},
{7, 10, 13},
},
x: []float64{3, 4, 5},
y: []float64{6, 7, 8},
alpha: 2.1,
beta: -3,
ans: []float64{137.4, 189, 240.6},
},
} {
incTest := func(incX, incY, extra int) {
x := makeIncremented(test.x, incX, extra)
y := makeIncremented(test.y, incY, extra)
aFlat := flattenTriangular(test.a, test.ul)
ans := makeIncremented(test.ans, incY, extra)
blasser.Dspmv(test.ul, test.n, test.alpha, aFlat, x, incX, test.beta, y, incY)
if !floats.EqualApprox(ans, y, 1e-14) {
t.Errorf("Case %v, incX=%v, incY=%v: Want %v, got %v.", i, incX, incY, ans, y)
}
}
incTest(1, 1, 0)
incTest(2, 3, 0)
incTest(3, 2, 0)
incTest(-3, 2, 0)
incTest(-2, 4, 0)
incTest(2, -1, 0)
incTest(-3, -4, 3)
}
}

View file

@ -1,71 +0,0 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
)
type Dsprer interface {
Dspr(ul blas.Uplo, n int, alpha float64, x []float64, incX int, a []float64)
}
func DsprTest(t *testing.T, blasser Dsprer) {
for i, test := range []struct {
ul blas.Uplo
n int
a [][]float64
x []float64
alpha float64
ans [][]float64
}{
{
ul: blas.Upper,
n: 4,
a: [][]float64{
{10, 2, 0, 1},
{0, 1, 2, 3},
{0, 0, 9, 15},
{0, 0, 0, -6},
},
x: []float64{1, 2, 0, 5},
alpha: 8,
ans: [][]float64{
{18, 18, 0, 41},
{0, 33, 2, 83},
{0, 0, 9, 15},
{0, 0, 0, 194},
},
},
{
ul: blas.Lower,
n: 3,
a: [][]float64{
{10, 2, 0},
{4, 1, 2},
{2, 7, 9},
},
x: []float64{3, 0, 5},
alpha: 8,
ans: [][]float64{
{82, 2, 0},
{4, 1, 2},
{122, 7, 209},
},
},
} {
incTest := func(incX, extra int) {
xnew := makeIncremented(test.x, incX, extra)
aFlat := flattenTriangular(test.a, test.ul)
ans := flattenTriangular(test.ans, test.ul)
blasser.Dspr(test.ul, test.n, test.alpha, xnew, incX, aFlat)
if !dSliceTolEqual(aFlat, ans) {
t.Errorf("Case %v, idx %v: Want %v, got %v.", i, incX, ans, aFlat)
}
}
incTest(1, 3)
incTest(1, 0)
incTest(3, 2)
incTest(-2, 2)
}
}

View file

@ -1,76 +0,0 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
)
type Dspr2er interface {
Dspr2(ul blas.Uplo, n int, alpha float64, x []float64, incX int, y []float64, incY int, a []float64)
}
func Dspr2Test(t *testing.T, blasser Dspr2er) {
for i, test := range []struct {
n int
a [][]float64
ul blas.Uplo
x []float64
y []float64
alpha float64
ans [][]float64
}{
{
n: 3,
a: [][]float64{
{7, 2, 4},
{0, 3, 5},
{0, 0, 6},
},
x: []float64{2, 3, 4},
y: []float64{5, 6, 7},
alpha: 2,
ul: blas.Upper,
ans: [][]float64{
{47, 56, 72},
{0, 75, 95},
{0, 0, 118},
},
},
{
n: 3,
a: [][]float64{
{7, 0, 0},
{2, 3, 0},
{4, 5, 6},
},
x: []float64{2, 3, 4},
y: []float64{5, 6, 7},
alpha: 2,
ul: blas.Lower,
ans: [][]float64{
{47, 0, 0},
{56, 75, 0},
{72, 95, 118},
},
},
} {
incTest := func(incX, incY, extra int) {
aFlat := flattenTriangular(test.a, test.ul)
x := makeIncremented(test.x, incX, extra)
y := makeIncremented(test.y, incY, extra)
blasser.Dspr2(test.ul, test.n, test.alpha, x, incX, y, incY, aFlat)
ansFlat := flattenTriangular(test.ans, test.ul)
if !floats.EqualApprox(aFlat, ansFlat, 1e-14) {
t.Errorf("Case %v, incX = %v, incY = %v. Want %v, got %v.", i, incX, incY, ansFlat, aFlat)
}
}
incTest(1, 1, 0)
incTest(-2, 1, 0)
incTest(-2, 3, 0)
incTest(2, -3, 0)
incTest(3, -2, 0)
incTest(-3, -4, 0)
}
}

View file

@ -1,277 +0,0 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
)
type Dsymmer interface {
Dsymm(s blas.Side, ul blas.Uplo, m, n int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int)
}
func DsymmTest(t *testing.T, blasser Dsymmer) {
for i, test := range []struct {
m int
n int
side blas.Side
ul blas.Uplo
a [][]float64
b [][]float64
c [][]float64
alpha float64
beta float64
ans [][]float64
}{
{
side: blas.Left,
ul: blas.Upper,
m: 3,
n: 4,
a: [][]float64{
{2, 3, 4},
{0, 6, 7},
{0, 0, 10},
},
b: [][]float64{
{2, 3, 4, 8},
{5, 6, 7, 15},
{8, 9, 10, 20},
},
c: [][]float64{
{8, 12, 2, 1},
{9, 12, 9, 9},
{12, 1, -1, 5},
},
alpha: 2,
beta: 3,
ans: [][]float64{
{126, 156, 144, 285},
{211, 252, 275, 535},
{282, 291, 327, 689},
},
},
{
side: blas.Left,
ul: blas.Upper,
m: 4,
n: 3,
a: [][]float64{
{2, 3, 4, 8},
{0, 6, 7, 9},
{0, 0, 10, 10},
{0, 0, 0, 11},
},
b: [][]float64{
{2, 3, 4},
{5, 6, 7},
{8, 9, 10},
{2, 1, 1},
},
c: [][]float64{
{8, 12, 2},
{9, 12, 9},
{12, 1, -1},
{1, 9, 5},
},
alpha: 2,
beta: 3,
ans: [][]float64{
{158, 172, 160},
{247, 270, 293},
{322, 311, 347},
{329, 385, 427},
},
},
{
side: blas.Left,
ul: blas.Lower,
m: 3,
n: 4,
a: [][]float64{
{2, 0, 0},
{3, 6, 0},
{4, 7, 10},
},
b: [][]float64{
{2, 3, 4, 8},
{5, 6, 7, 15},
{8, 9, 10, 20},
},
c: [][]float64{
{8, 12, 2, 1},
{9, 12, 9, 9},
{12, 1, -1, 5},
},
alpha: 2,
beta: 3,
ans: [][]float64{
{126, 156, 144, 285},
{211, 252, 275, 535},
{282, 291, 327, 689},
},
},
{
side: blas.Left,
ul: blas.Lower,
m: 4,
n: 3,
a: [][]float64{
{2, 0, 0, 0},
{3, 6, 0, 0},
{4, 7, 10, 0},
{8, 9, 10, 11},
},
b: [][]float64{
{2, 3, 4},
{5, 6, 7},
{8, 9, 10},
{2, 1, 1},
},
c: [][]float64{
{8, 12, 2},
{9, 12, 9},
{12, 1, -1},
{1, 9, 5},
},
alpha: 2,
beta: 3,
ans: [][]float64{
{158, 172, 160},
{247, 270, 293},
{322, 311, 347},
{329, 385, 427},
},
},
{
side: blas.Right,
ul: blas.Upper,
m: 3,
n: 4,
a: [][]float64{
{2, 0, 0, 0},
{3, 6, 0, 0},
{4, 7, 10, 0},
{3, 4, 5, 6},
},
b: [][]float64{
{2, 3, 4, 9},
{5, 6, 7, -3},
{8, 9, 10, -2},
},
c: [][]float64{
{8, 12, 2, 10},
{9, 12, 9, 10},
{12, 1, -1, 10},
},
alpha: 2,
beta: 3,
ans: [][]float64{
{32, 72, 86, 138},
{47, 108, 167, -6},
{68, 111, 197, 6},
},
},
{
side: blas.Right,
ul: blas.Upper,
m: 4,
n: 3,
a: [][]float64{
{2, 0, 0},
{3, 6, 0},
{4, 7, 10},
},
b: [][]float64{
{2, 3, 4},
{5, 6, 7},
{8, 9, 10},
{2, 1, 1},
},
c: [][]float64{
{8, 12, 2},
{9, 12, 9},
{12, 1, -1},
{1, 9, 5},
},
alpha: 2,
beta: 3,
ans: [][]float64{
{32, 72, 86},
{47, 108, 167},
{68, 111, 197},
{11, 39, 35},
},
},
{
side: blas.Right,
ul: blas.Lower,
m: 3,
n: 4,
a: [][]float64{
{2, 0, 0, 0},
{3, 6, 0, 0},
{4, 7, 10, 0},
{3, 4, 5, 6},
},
b: [][]float64{
{2, 3, 4, 2},
{5, 6, 7, 1},
{8, 9, 10, 1},
},
c: [][]float64{
{8, 12, 2, 1},
{9, 12, 9, 9},
{12, 1, -1, 5},
},
alpha: 2,
beta: 3,
ans: [][]float64{
{94, 156, 164, 103},
{145, 244, 301, 187},
{208, 307, 397, 247},
},
},
{
side: blas.Right,
ul: blas.Lower,
m: 4,
n: 3,
a: [][]float64{
{2, 0, 0},
{3, 6, 0},
{4, 7, 10},
},
b: [][]float64{
{2, 3, 4},
{5, 6, 7},
{8, 9, 10},
{2, 1, 1},
},
c: [][]float64{
{8, 12, 2},
{9, 12, 9},
{12, 1, -1},
{1, 9, 5},
},
alpha: 2,
beta: 3,
ans: [][]float64{
{82, 140, 144},
{139, 236, 291},
{202, 299, 387},
{25, 65, 65},
},
},
} {
aFlat := flatten(test.a)
bFlat := flatten(test.b)
cFlat := flatten(test.c)
ansFlat := flatten(test.ans)
blasser.Dsymm(test.side, test.ul, test.m, test.n, test.alpha, aFlat, len(test.a[0]), bFlat, test.n, test.beta, cFlat, test.n)
if !floats.EqualApprox(cFlat, ansFlat, 1e-14) {
t.Errorf("Case %v: Want %v, got %v.", i, ansFlat, cFlat)
}
}
}

View file

@ -1,73 +0,0 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
)
type Dsymver interface {
Dsymv(ul blas.Uplo, n int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int)
}
func DsymvTest(t *testing.T, blasser Dsymver) {
for i, test := range []struct {
ul blas.Uplo
n int
a [][]float64
x []float64
y []float64
alpha float64
beta float64
ans []float64
}{
{
ul: blas.Upper,
n: 3,
a: [][]float64{
{5, 6, 7},
{0, 8, 10},
{0, 0, 13},
},
x: []float64{3, 4, 5},
y: []float64{6, 7, 8},
alpha: 2.1,
beta: -3,
ans: []float64{137.4, 189, 240.6},
},
{
ul: blas.Lower,
n: 3,
a: [][]float64{
{5, 0, 0},
{6, 8, 0},
{7, 10, 13},
},
x: []float64{3, 4, 5},
y: []float64{6, 7, 8},
alpha: 2.1,
beta: -3,
ans: []float64{137.4, 189, 240.6},
},
} {
incTest := func(incX, incY, extra int) {
x := makeIncremented(test.x, incX, extra)
y := makeIncremented(test.y, incY, extra)
aFlat := flatten(test.a)
ans := makeIncremented(test.ans, incY, extra)
blasser.Dsymv(test.ul, test.n, test.alpha, aFlat, test.n, x, incX, test.beta, y, incY)
if !floats.EqualApprox(ans, y, 1e-14) {
t.Errorf("Case %v, incX=%v, incY=%v: Want %v, got %v.", i, incX, incY, ans, y)
}
}
incTest(1, 1, 0)
incTest(2, 3, 0)
incTest(3, 2, 0)
incTest(-3, 2, 0)
incTest(-2, 4, 0)
incTest(2, -1, 0)
incTest(-3, -4, 3)
}
}

View file

@ -1,72 +0,0 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
)
type Dsyrer interface {
Dsyr(ul blas.Uplo, n int, alpha float64, x []float64, incX int, a []float64, lda int)
}
func DsyrTest(t *testing.T, blasser Dsyrer) {
for i, test := range []struct {
ul blas.Uplo
n int
a [][]float64
x []float64
alpha float64
ans [][]float64
}{
{
ul: blas.Upper,
n: 4,
a: [][]float64{
{10, 2, 0, 1},
{0, 1, 2, 3},
{0, 0, 9, 15},
{0, 0, 0, -6},
},
x: []float64{1, 2, 0, 5},
alpha: 8,
ans: [][]float64{
{18, 18, 0, 41},
{0, 33, 2, 83},
{0, 0, 9, 15},
{0, 0, 0, 194},
},
},
{
ul: blas.Lower,
n: 3,
a: [][]float64{
{10, 2, 0},
{4, 1, 2},
{2, 7, 9},
},
x: []float64{3, 0, 5},
alpha: 8,
ans: [][]float64{
{82, 2, 0},
{4, 1, 2},
{122, 7, 209},
},
},
} {
incTest := func(incX, extra int) {
xnew := makeIncremented(test.x, incX, extra)
aFlat := flatten(test.a)
ans := flatten(test.ans)
lda := test.n
blasser.Dsyr(test.ul, test.n, test.alpha, xnew, incX, aFlat, lda)
if !dSliceTolEqual(aFlat, ans) {
t.Errorf("Case %v, idx %v: Want %v, got %v.", i, incX, ans, aFlat)
}
}
incTest(1, 3)
incTest(1, 0)
incTest(3, 2)
incTest(-2, 2)
}
}

View file

@ -1,76 +0,0 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
)
type Dsyr2er interface {
Dsyr2(ul blas.Uplo, n int, alpha float64, x []float64, incX int, y []float64, incY int, a []float64, lda int)
}
func Dsyr2Test(t *testing.T, blasser Dsyr2er) {
for i, test := range []struct {
n int
a [][]float64
ul blas.Uplo
x []float64
y []float64
alpha float64
ans [][]float64
}{
{
n: 3,
a: [][]float64{
{7, 2, 4},
{0, 3, 5},
{0, 0, 6},
},
x: []float64{2, 3, 4},
y: []float64{5, 6, 7},
alpha: 2,
ul: blas.Upper,
ans: [][]float64{
{47, 56, 72},
{0, 75, 95},
{0, 0, 118},
},
},
{
n: 3,
a: [][]float64{
{7, 0, 0},
{2, 3, 0},
{4, 5, 6},
},
x: []float64{2, 3, 4},
y: []float64{5, 6, 7},
alpha: 2,
ul: blas.Lower,
ans: [][]float64{
{47, 0, 0},
{56, 75, 0},
{72, 95, 118},
},
},
} {
incTest := func(incX, incY, extra int) {
aFlat := flatten(test.a)
x := makeIncremented(test.x, incX, extra)
y := makeIncremented(test.y, incY, extra)
blasser.Dsyr2(test.ul, test.n, test.alpha, x, incX, y, incY, aFlat, test.n)
ansFlat := flatten(test.ans)
if !floats.EqualApprox(aFlat, ansFlat, 1e-14) {
t.Errorf("Case %v, incX = %v, incY = %v. Want %v, got %v.", i, incX, incY, ansFlat, aFlat)
}
}
incTest(1, 1, 0)
incTest(-2, 1, 0)
incTest(-2, 3, 0)
incTest(2, -3, 0)
incTest(3, -2, 0)
incTest(-3, -4, 0)
}
}

View file

@ -1,201 +0,0 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
)
type Dsyr2ker interface {
Dsyr2k(ul blas.Uplo, tA blas.Transpose, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int)
}
func Dsyr2kTest(t *testing.T, blasser Dsyr2ker) {
for i, test := range []struct {
ul blas.Uplo
tA blas.Transpose
n int
k int
alpha float64
a [][]float64
b [][]float64
c [][]float64
beta float64
ans [][]float64
}{
{
ul: blas.Upper,
tA: blas.NoTrans,
n: 3,
k: 2,
alpha: 0,
a: [][]float64{
{1, 2},
{3, 4},
{5, 6},
},
b: [][]float64{
{7, 8},
{9, 10},
{11, 12},
},
c: [][]float64{
{1, 2, 3},
{0, 5, 6},
{0, 0, 9},
},
beta: 2,
ans: [][]float64{
{2, 4, 6},
{0, 10, 12},
{0, 0, 18},
},
},
{
ul: blas.Lower,
tA: blas.NoTrans,
n: 3,
k: 2,
alpha: 0,
a: [][]float64{
{1, 2},
{3, 4},
{5, 6},
},
b: [][]float64{
{7, 8},
{9, 10},
{11, 12},
},
c: [][]float64{
{1, 0, 0},
{2, 3, 0},
{4, 5, 6},
},
beta: 2,
ans: [][]float64{
{2, 0, 0},
{4, 6, 0},
{8, 10, 12},
},
},
{
ul: blas.Upper,
tA: blas.NoTrans,
n: 3,
k: 2,
alpha: 3,
a: [][]float64{
{1, 2},
{3, 4},
{5, 6},
},
b: [][]float64{
{7, 8},
{9, 10},
{11, 12},
},
c: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 6},
},
beta: 2,
ans: [][]float64{
{140, 250, 360},
{0, 410, 568},
{0, 0, 774},
},
},
{
ul: blas.Lower,
tA: blas.NoTrans,
n: 3,
k: 2,
alpha: 3,
a: [][]float64{
{1, 2},
{3, 4},
{5, 6},
},
b: [][]float64{
{7, 8},
{9, 10},
{11, 12},
},
c: [][]float64{
{1, 0, 0},
{2, 4, 0},
{3, 5, 6},
},
beta: 2,
ans: [][]float64{
{140, 0, 0},
{250, 410, 0},
{360, 568, 774},
},
},
{
ul: blas.Upper,
tA: blas.Trans,
n: 3,
k: 2,
alpha: 3,
a: [][]float64{
{1, 3, 5},
{2, 4, 6},
},
b: [][]float64{
{7, 9, 11},
{8, 10, 12},
},
c: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 6},
},
beta: 2,
ans: [][]float64{
{140, 250, 360},
{0, 410, 568},
{0, 0, 774},
},
},
{
ul: blas.Lower,
tA: blas.Trans,
n: 3,
k: 2,
alpha: 3,
a: [][]float64{
{1, 3, 5},
{2, 4, 6},
},
b: [][]float64{
{7, 9, 11},
{8, 10, 12},
},
c: [][]float64{
{1, 0, 0},
{2, 4, 0},
{3, 5, 6},
},
beta: 2,
ans: [][]float64{
{140, 0, 0},
{250, 410, 0},
{360, 568, 774},
},
},
} {
aFlat := flatten(test.a)
bFlat := flatten(test.b)
cFlat := flatten(test.c)
ansFlat := flatten(test.ans)
blasser.Dsyr2k(test.ul, test.tA, test.n, test.k, test.alpha, aFlat, len(test.a[0]), bFlat, len(test.b[0]), test.beta, cFlat, len(test.c[0]))
if !floats.EqualApprox(ansFlat, cFlat, 1e-14) {
t.Errorf("Case %v. Want %v, got %v.", i, ansFlat, cFlat)
}
}
}

View file

@ -1,171 +0,0 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
)
type Dsyker interface {
Dsyrk(ul blas.Uplo, tA blas.Transpose, n, k int, alpha float64, a []float64, lda int, beta float64, c []float64, ldc int)
}
func DsyrkTest(t *testing.T, blasser Dsyker) {
for i, test := range []struct {
ul blas.Uplo
tA blas.Transpose
n int
k int
alpha float64
a [][]float64
c [][]float64
beta float64
ans [][]float64
}{
{
ul: blas.Upper,
tA: blas.NoTrans,
n: 3,
k: 2,
alpha: 0,
a: [][]float64{
{1, 2},
{3, 4},
{5, 6},
},
c: [][]float64{
{1, 2, 3},
{0, 5, 6},
{0, 0, 9},
},
beta: 2,
ans: [][]float64{
{2, 4, 6},
{0, 10, 12},
{0, 0, 18},
},
},
{
ul: blas.Lower,
tA: blas.NoTrans,
n: 3,
k: 2,
alpha: 0,
a: [][]float64{
{1, 2},
{3, 4},
{5, 6},
},
c: [][]float64{
{1, 0, 0},
{2, 3, 0},
{4, 5, 6},
},
beta: 2,
ans: [][]float64{
{2, 0, 0},
{4, 6, 0},
{8, 10, 12},
},
},
{
ul: blas.Upper,
tA: blas.NoTrans,
n: 3,
k: 2,
alpha: 3,
a: [][]float64{
{1, 2},
{3, 4},
{5, 6},
},
c: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 6},
},
beta: 2,
ans: [][]float64{
{17, 37, 57},
{0, 83, 127},
{0, 0, 195},
},
},
{
ul: blas.Lower,
tA: blas.NoTrans,
n: 3,
k: 2,
alpha: 3,
a: [][]float64{
{1, 2},
{3, 4},
{5, 6},
},
c: [][]float64{
{1, 0, 0},
{2, 4, 0},
{3, 5, 6},
},
beta: 2,
ans: [][]float64{
{17, 0, 0},
{37, 83, 0},
{57, 127, 195},
},
},
{
ul: blas.Upper,
tA: blas.Trans,
n: 3,
k: 2,
alpha: 3,
a: [][]float64{
{1, 3, 5},
{2, 4, 6},
},
c: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 6},
},
beta: 2,
ans: [][]float64{
{17, 37, 57},
{0, 83, 127},
{0, 0, 195},
},
},
{
ul: blas.Lower,
tA: blas.Trans,
n: 3,
k: 2,
alpha: 3,
a: [][]float64{
{1, 3, 5},
{2, 4, 6},
},
c: [][]float64{
{1, 0, 0},
{2, 4, 0},
{3, 5, 6},
},
beta: 2,
ans: [][]float64{
{17, 0, 0},
{37, 83, 0},
{57, 127, 195},
},
},
} {
aFlat := flatten(test.a)
cFlat := flatten(test.c)
ansFlat := flatten(test.ans)
blasser.Dsyrk(test.ul, test.tA, test.n, test.k, test.alpha, aFlat, len(test.a[0]), test.beta, cFlat, len(test.c[0]))
if !floats.EqualApprox(ansFlat, cFlat, 1e-14) {
t.Errorf("Case %v. Want %v, got %v.", i, ansFlat, cFlat)
}
}
}

View file

@ -1,123 +0,0 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
)
type Dtbmver interface {
Dtbmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n, k int, a []float64, lda int, x []float64, incX int)
}
func DtbmvTest(t *testing.T, blasser Dtbmver) {
for i, test := range []struct {
ul blas.Uplo
tA blas.Transpose
d blas.Diag
n int
k int
a [][]float64
x []float64
ans []float64
}{
{
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.Unit,
n: 3,
k: 1,
a: [][]float64{
{1, 2, 0},
{0, 1, 4},
{0, 0, 1},
},
x: []float64{2, 3, 4},
ans: []float64{8, 19, 4},
},
{
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
n: 5,
k: 1,
a: [][]float64{
{1, 3, 0, 0, 0},
{0, 6, 7, 0, 0},
{0, 0, 2, 1, 0},
{0, 0, 0, 12, 3},
{0, 0, 0, 0, -1},
},
x: []float64{1, 2, 3, 4, 5},
ans: []float64{7, 33, 10, 63, -5},
},
{
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.NonUnit,
n: 5,
k: 1,
a: [][]float64{
{7, 0, 0, 0, 0},
{3, 6, 0, 0, 0},
{0, 7, 2, 0, 0},
{0, 0, 1, 12, 0},
{0, 0, 0, 3, -1},
},
x: []float64{1, 2, 3, 4, 5},
ans: []float64{7, 15, 20, 51, 7},
},
{
ul: blas.Upper,
tA: blas.Trans,
d: blas.NonUnit,
n: 5,
k: 2,
a: [][]float64{
{7, 3, 9, 0, 0},
{0, 6, 7, 10, 0},
{0, 0, 2, 1, 11},
{0, 0, 0, 12, 3},
{0, 0, 0, 0, -1},
},
x: []float64{1, 2, 3, 4, 5},
ans: []float64{7, 15, 29, 71, 40},
},
{
ul: blas.Lower,
tA: blas.Trans,
d: blas.NonUnit,
n: 5,
k: 2,
a: [][]float64{
{7, 0, 0, 0, 0},
{3, 6, 0, 0, 0},
{9, 7, 2, 0, 0},
{0, 10, 1, 12, 0},
{0, 0, 11, 3, -1},
},
x: []float64{1, 2, 3, 4, 5},
ans: []float64{40, 73, 65, 63, -5},
},
} {
extra := 0
var aFlat []float64
if test.ul == blas.Upper {
aFlat = flattenBanded(test.a, test.k, 0)
} else {
aFlat = flattenBanded(test.a, 0, test.k)
}
incTest := func(incX, extra int) {
xnew := makeIncremented(test.x, incX, extra)
ans := makeIncremented(test.ans, incX, extra)
lda := test.k + 1
blasser.Dtbmv(test.ul, test.tA, test.d, test.n, test.k, aFlat, lda, xnew, incX)
if !dSliceTolEqual(ans, xnew) {
t.Errorf("Case %v, Inc %v: Want %v, got %v", i, incX, ans, xnew)
}
}
incTest(1, extra)
incTest(3, extra)
incTest(-2, extra)
}
}

View file

@ -1,256 +0,0 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
)
type Dtbsver interface {
Dtbsv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n, k int, a []float64, lda int, x []float64, incX int)
Dtrsv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a []float64, lda int, x []float64, incX int)
}
func DtbsvTest(t *testing.T, blasser Dtbsver) {
for i, test := range []struct {
ul blas.Uplo
tA blas.Transpose
d blas.Diag
n, k int
a [][]float64
lda int
x []float64
incX int
ans []float64
}{
{
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
n: 5,
k: 1,
a: [][]float64{
{1, 3, 0, 0, 0},
{0, 6, 7, 0, 0},
{0, 0, 2, 1, 0},
{0, 0, 0, 12, 3},
{0, 0, 0, 0, -1},
},
x: []float64{1, 2, 3, 4, 5},
incX: 1,
ans: []float64{2.479166666666667, -0.493055555555556, 0.708333333333333, 1.583333333333333, -5.000000000000000},
},
{
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
n: 5,
k: 2,
a: [][]float64{
{1, 3, 5, 0, 0},
{0, 6, 7, 5, 0},
{0, 0, 2, 1, 5},
{0, 0, 0, 12, 3},
{0, 0, 0, 0, -1},
},
x: []float64{1, 2, 3, 4, 5},
incX: 1,
ans: []float64{-15.854166666666664, -16.395833333333336, 13.208333333333334, 1.583333333333333, -5.000000000000000},
},
{
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
n: 5,
k: 1,
a: [][]float64{
{1, 3, 0, 0, 0},
{0, 6, 7, 0, 0},
{0, 0, 2, 1, 0},
{0, 0, 0, 12, 3},
{0, 0, 0, 0, -1},
},
x: []float64{1, -101, 2, -201, 3, -301, 4, -401, 5, -501, -601, -701},
incX: 2,
ans: []float64{2.479166666666667, -101, -0.493055555555556, -201, 0.708333333333333, -301, 1.583333333333333, -401, -5.000000000000000, -501, -601, -701},
},
{
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
n: 5,
k: 2,
a: [][]float64{
{1, 3, 5, 0, 0},
{0, 6, 7, 5, 0},
{0, 0, 2, 1, 5},
{0, 0, 0, 12, 3},
{0, 0, 0, 0, -1},
},
x: []float64{1, -101, 2, -201, 3, -301, 4, -401, 5, -501, -601, -701},
incX: 2,
ans: []float64{-15.854166666666664, -101, -16.395833333333336, -201, 13.208333333333334, -301, 1.583333333333333, -401, -5.000000000000000, -501, -601, -701},
},
{
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.NonUnit,
n: 5,
k: 2,
a: [][]float64{
{1, 0, 0, 0, 0},
{3, 6, 0, 0, 0},
{5, 7, 2, 0, 0},
{0, 5, 1, 12, 0},
{0, 0, 5, 3, -1},
},
x: []float64{1, 2, 3, 4, 5},
incX: 1,
ans: []float64{1, -0.166666666666667, -0.416666666666667, 0.437500000000000, -5.770833333333334},
},
{
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.NonUnit,
n: 5,
k: 2,
a: [][]float64{
{1, 0, 0, 0, 0},
{3, 6, 0, 0, 0},
{5, 7, 2, 0, 0},
{0, 5, 1, 12, 0},
{0, 0, 5, 3, -1},
},
x: []float64{1, -101, 2, -201, 3, -301, 4, -401, 5, -501, -601, -701},
incX: 2,
ans: []float64{1, -101, -0.166666666666667, -201, -0.416666666666667, -301, 0.437500000000000, -401, -5.770833333333334, -501, -601, -701},
},
{
ul: blas.Upper,
tA: blas.Trans,
d: blas.NonUnit,
n: 5,
k: 2,
a: [][]float64{
{1, 3, 5, 0, 0},
{0, 6, 7, 5, 0},
{0, 0, 2, 1, 5},
{0, 0, 0, 12, 3},
{0, 0, 0, 0, -1},
},
x: []float64{1, 2, 3, 4, 5},
incX: 1,
ans: []float64{1, -0.166666666666667, -0.416666666666667, 0.437500000000000, -5.770833333333334},
},
{
ul: blas.Upper,
tA: blas.Trans,
d: blas.NonUnit,
n: 5,
k: 2,
a: [][]float64{
{1, 3, 5, 0, 0},
{0, 6, 7, 5, 0},
{0, 0, 2, 1, 5},
{0, 0, 0, 12, 3},
{0, 0, 0, 0, -1},
},
x: []float64{1, -101, 2, -201, 3, -301, 4, -401, 5, -501, -601, -701},
incX: 2,
ans: []float64{1, -101, -0.166666666666667, -201, -0.416666666666667, -301, 0.437500000000000, -401, -5.770833333333334, -501, -601, -701},
},
{
ul: blas.Lower,
tA: blas.Trans,
d: blas.NonUnit,
n: 5,
k: 2,
a: [][]float64{
{1, 0, 0, 0, 0},
{3, 6, 0, 0, 0},
{5, 7, 2, 0, 0},
{0, 5, 1, 12, 0},
{0, 0, 5, 3, -1},
},
x: []float64{1, 2, 3, 4, 5},
incX: 1,
ans: []float64{-15.854166666666664, -16.395833333333336, 13.208333333333334, 1.583333333333333, -5.000000000000000},
},
{
ul: blas.Lower,
tA: blas.Trans,
d: blas.NonUnit,
n: 5,
k: 2,
a: [][]float64{
{1, 0, 0, 0, 0},
{3, 6, 0, 0, 0},
{5, 7, 2, 0, 0},
{0, 5, 1, 12, 0},
{0, 0, 5, 3, -1},
},
x: []float64{1, -101, 2, -201, 3, -301, 4, -401, 5, -501, -601, -701},
incX: 2,
ans: []float64{-15.854166666666664, -101, -16.395833333333336, -201, 13.208333333333334, -301, 1.583333333333333, -401, -5.000000000000000, -501, -601, -701},
},
} {
var aFlat []float64
if test.ul == blas.Upper {
aFlat = flattenBanded(test.a, test.k, 0)
} else {
aFlat = flattenBanded(test.a, 0, test.k)
}
xCopy := sliceCopy(test.x)
// TODO: Have tests where the banded matrix is constructed explicitly
// to allow testing for lda =! k+1
blasser.Dtbsv(test.ul, test.tA, test.d, test.n, test.k, aFlat, test.k+1, xCopy, test.incX)
if !dSliceTolEqual(test.ans, xCopy) {
t.Errorf("Case %v: Want %v, got %v", i, test.ans, xCopy)
}
}
/*
// TODO: Uncomment when Dtrsv is fixed
// Compare with dense for larger matrices
for _, ul := range [...]blas.Uplo{blas.Upper, blas.Lower} {
for _, tA := range [...]blas.Transpose{blas.NoTrans, blas.Trans} {
for _, n := range [...]int{7, 8, 11} {
for _, d := range [...]blas.Diag{blas.NonUnit, blas.Unit} {
for _, k := range [...]int{0, 1, 3} {
for _, incX := range [...]int{1, 3} {
a := make([][]float64, n)
for i := range a {
a[i] = make([]float64, n)
for j := range a[i] {
a[i][j] = rand.Float64()
}
}
x := make([]float64, n)
for i := range x {
x[i] = rand.Float64()
}
extra := 3
xinc := makeIncremented(x, incX, extra)
bandX := sliceCopy(xinc)
var aFlatBand []float64
if ul == blas.Upper {
aFlatBand = flattenBanded(a, k, 0)
} else {
aFlatBand = flattenBanded(a, 0, k)
}
blasser.Dtbsv(ul, tA, d, n, k, aFlatBand, k+1, bandX, incX)
aFlatDense := flatten(a)
denseX := sliceCopy(xinc)
blasser.Dtrsv(ul, tA, d, n, aFlatDense, n, denseX, incX)
if !dSliceTolEqual(denseX, bandX) {
t.Errorf("Case %v: dense banded mismatch")
}
}
}
}
}
}
}
*/
}

View file

@ -1,129 +0,0 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
)
type Dtpmver interface {
Dtpmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, ap []float64, x []float64, incX int)
}
func DtpmvTest(t *testing.T, blasser Dtpmver) {
for i, test := range []struct {
n int
a [][]float64
x []float64
d blas.Diag
ul blas.Uplo
tA blas.Transpose
ans []float64
}{
{
n: 3,
a: [][]float64{
{5, 6, 7},
{0, 9, 10},
{0, 0, 13},
},
x: []float64{3, 4, 5},
d: blas.NonUnit,
ul: blas.Upper,
tA: blas.NoTrans,
ans: []float64{74, 86, 65},
},
{
n: 3,
a: [][]float64{
{5, 6, 7},
{0, 9, 10},
{0, 0, 13},
},
x: []float64{3, 4, 5},
d: blas.Unit,
ul: blas.Upper,
tA: blas.NoTrans,
ans: []float64{62, 54, 5},
},
{
n: 3,
a: [][]float64{
{5, 0, 0},
{6, 9, 0},
{7, 10, 13},
},
x: []float64{3, 4, 5},
d: blas.NonUnit,
ul: blas.Lower,
tA: blas.NoTrans,
ans: []float64{15, 54, 126},
},
{
n: 3,
a: [][]float64{
{1, 0, 0},
{6, 1, 0},
{7, 10, 1},
},
x: []float64{3, 4, 5},
d: blas.Unit,
ul: blas.Lower,
tA: blas.NoTrans,
ans: []float64{3, 22, 66},
},
{
n: 3,
a: [][]float64{
{5, 6, 7},
{0, 9, 10},
{0, 0, 13},
},
x: []float64{3, 4, 5},
d: blas.NonUnit,
ul: blas.Upper,
tA: blas.Trans,
ans: []float64{15, 54, 126},
},
{
n: 3,
a: [][]float64{
{1, 6, 7},
{0, 1, 10},
{0, 0, 1},
},
x: []float64{3, 4, 5},
d: blas.Unit,
ul: blas.Upper,
tA: blas.Trans,
ans: []float64{3, 22, 66},
},
{
n: 3,
a: [][]float64{
{5, 0, 0},
{6, 9, 0},
{7, 10, 13},
},
x: []float64{3, 4, 5},
d: blas.NonUnit,
ul: blas.Lower,
tA: blas.Trans,
ans: []float64{74, 86, 65},
},
} {
incTest := func(incX, extra int) {
aFlat := flattenTriangular(test.a, test.ul)
x := makeIncremented(test.x, incX, extra)
blasser.Dtpmv(test.ul, test.tA, test.d, test.n, aFlat, x, incX)
ans := makeIncremented(test.ans, incX, extra)
if !floats.EqualApprox(x, ans, 1e-14) {
t.Errorf("Case %v, idx %v: Want %v, got %v.", i, incX, ans, x)
}
}
incTest(1, 0)
incTest(-3, 3)
incTest(4, 3)
}
}

View file

@ -1,144 +0,0 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
)
type Dtpsver interface {
Dtpsv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, ap []float64, x []float64, incX int)
}
func DtpsvTest(t *testing.T, blasser Dtpsver) {
for i, test := range []struct {
n int
a [][]float64
ul blas.Uplo
tA blas.Transpose
d blas.Diag
x []float64
ans []float64
}{
{
n: 3,
a: [][]float64{
{1, 2, 3},
{0, 8, 15},
{0, 0, 8},
},
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
x: []float64{5, 6, 7},
ans: []float64{4.15625, -0.890625, 0.875},
},
{
n: 3,
a: [][]float64{
{1, 2, 3},
{0, 1, 15},
{0, 0, 1},
},
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.Unit,
x: []float64{5, 6, 7},
ans: []float64{182, -99, 7},
},
{
n: 3,
a: [][]float64{
{1, 0, 0},
{2, 8, 0},
{3, 15, 8},
},
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.NonUnit,
x: []float64{5, 6, 7},
ans: []float64{5, -0.5, -0.0625},
},
{
n: 3,
a: [][]float64{
{1, 0, 0},
{2, 8, 0},
{3, 15, 8},
},
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.Unit,
x: []float64{5, 6, 7},
ans: []float64{5, -4, 52},
},
{
n: 3,
a: [][]float64{
{1, 2, 3},
{0, 8, 15},
{0, 0, 8},
},
ul: blas.Upper,
tA: blas.Trans,
d: blas.NonUnit,
x: []float64{5, 6, 7},
ans: []float64{5, -0.5, -0.0625},
},
{
n: 3,
a: [][]float64{
{1, 2, 3},
{0, 8, 15},
{0, 0, 8},
},
ul: blas.Upper,
tA: blas.Trans,
d: blas.Unit,
x: []float64{5, 6, 7},
ans: []float64{5, -4, 52},
},
{
n: 3,
a: [][]float64{
{1, 0, 0},
{2, 8, 0},
{3, 15, 8},
},
ul: blas.Lower,
tA: blas.Trans,
d: blas.NonUnit,
x: []float64{5, 6, 7},
ans: []float64{4.15625, -0.890625, 0.875},
},
{
n: 3,
a: [][]float64{
{1, 0, 0},
{2, 1, 0},
{3, 15, 1},
},
ul: blas.Lower,
tA: blas.Trans,
d: blas.Unit,
x: []float64{5, 6, 7},
ans: []float64{182, -99, 7},
},
} {
incTest := func(incX, extra int) {
aFlat := flattenTriangular(test.a, test.ul)
x := makeIncremented(test.x, incX, extra)
blasser.Dtpsv(test.ul, test.tA, test.d, test.n, aFlat, x, incX)
ans := makeIncremented(test.ans, incX, extra)
if !floats.EqualApprox(x, ans, 1e-14) {
t.Errorf("Case %v, incX = %v: Want %v, got %v.", i, incX, ans, x)
}
}
incTest(1, 0)
incTest(-2, 0)
incTest(3, 0)
incTest(-3, 8)
incTest(4, 2)
}
}

View file

@ -1,806 +0,0 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
)
type Dtrmmer interface {
Dtrmm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m, n int, alpha float64, a []float64, lda int, b []float64, ldb int)
}
func DtrmmTest(t *testing.T, blasser Dtrmmer) {
for i, test := range []struct {
s blas.Side
ul blas.Uplo
tA blas.Transpose
d blas.Diag
m int
n int
alpha float64
a [][]float64
b [][]float64
ans [][]float64
}{
{
s: blas.Left,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2, 3, 4},
{0, 5, 6, 7},
{0, 0, 8, 9},
{0, 0, 0, 10},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{320, 340, 360},
{588, 624, 660},
{598, 632, 666},
{380, 400, 420},
},
},
{
s: blas.Left,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2},
{0, 5},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{72, 78, 84},
{130, 140, 150},
},
},
{
s: blas.Left,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.Unit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2, 3, 4},
{0, 5, 6, 7},
{0, 0, 8, 9},
{0, 0, 0, 10},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{320, 340, 360},
{484, 512, 540},
{374, 394, 414},
{38, 40, 42},
},
},
{
s: blas.Left,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.Unit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2},
{0, 5},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{72, 78, 84},
{26, 28, 30},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0, 0, 0},
{2, 5, 0, 0},
{3, 6, 8, 0},
{4, 7, 9, 10},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{20, 22, 24},
{170, 184, 198},
{472, 506, 540},
{930, 990, 1050},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0},
{2, 5},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{20, 22, 24},
{170, 184, 198},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.Unit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0, 0, 0},
{2, 5, 0, 0},
{3, 6, 8, 0},
{4, 7, 9, 10},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{20, 22, 24},
{66, 72, 78},
{248, 268, 288},
{588, 630, 672},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.Unit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0},
{2, 5},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{20, 22, 24},
{66, 72, 78},
},
},
{
s: blas.Left,
ul: blas.Upper,
tA: blas.Trans,
d: blas.NonUnit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2, 3, 4},
{0, 5, 6, 7},
{0, 0, 8, 9},
{0, 0, 0, 10},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{20, 22, 24},
{170, 184, 198},
{472, 506, 540},
{930, 990, 1050},
},
},
{
s: blas.Left,
ul: blas.Upper,
tA: blas.Trans,
d: blas.NonUnit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2},
{0, 5},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{20, 22, 24},
{170, 184, 198},
},
},
{
s: blas.Left,
ul: blas.Upper,
tA: blas.Trans,
d: blas.Unit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2, 3, 4},
{0, 5, 6, 7},
{0, 0, 8, 9},
{0, 0, 0, 10},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{20, 22, 24},
{66, 72, 78},
{248, 268, 288},
{588, 630, 672},
},
},
{
s: blas.Left,
ul: blas.Upper,
tA: blas.Trans,
d: blas.Unit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2},
{0, 5},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{20, 22, 24},
{66, 72, 78},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.Trans,
d: blas.NonUnit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0, 0, 0},
{2, 5, 0, 0},
{3, 6, 8, 0},
{4, 7, 9, 10},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{320, 340, 360},
{588, 624, 660},
{598, 632, 666},
{380, 400, 420},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.Trans,
d: blas.NonUnit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0},
{2, 5},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{72, 78, 84},
{130, 140, 150},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.Trans,
d: blas.Unit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0, 0, 0},
{2, 5, 0, 0},
{3, 6, 8, 0},
{4, 7, 9, 10},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{320, 340, 360},
{484, 512, 540},
{374, 394, 414},
{38, 40, 42},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.Trans,
d: blas.Unit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0},
{2, 5},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{72, 78, 84},
{26, 28, 30},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{20, 128, 314},
{26, 164, 398},
{32, 200, 482},
{38, 236, 566},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{20, 128, 314},
{26, 164, 398},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.Unit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{20, 62, 194},
{26, 80, 248},
{32, 98, 302},
{38, 116, 356},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.Unit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{20, 62, 194},
{26, 80, 248},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0, 0},
{2, 4, 0},
{3, 5, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{136, 208, 144},
{172, 262, 180},
{208, 316, 216},
{244, 370, 252},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0, 0},
{2, 4, 0},
{3, 5, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{136, 208, 144},
{172, 262, 180},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.Unit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0, 0},
{2, 4, 0},
{3, 5, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{136, 142, 24},
{172, 178, 30},
{208, 214, 36},
{244, 250, 42},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.Unit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0, 0},
{2, 4, 0},
{3, 5, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{136, 142, 24},
{172, 178, 30},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.Trans,
d: blas.NonUnit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{136, 208, 144},
{172, 262, 180},
{208, 316, 216},
{244, 370, 252},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.Trans,
d: blas.NonUnit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{136, 208, 144},
{172, 262, 180},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.Trans,
d: blas.Unit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{136, 142, 24},
{172, 178, 30},
{208, 214, 36},
{244, 250, 42},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.Trans,
d: blas.Unit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{136, 142, 24},
{172, 178, 30},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.Trans,
d: blas.NonUnit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0, 0},
{2, 4, 0},
{3, 5, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{20, 128, 314},
{26, 164, 398},
{32, 200, 482},
{38, 236, 566},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.Trans,
d: blas.NonUnit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0, 0},
{2, 4, 0},
{3, 5, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{20, 128, 314},
{26, 164, 398},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.Trans,
d: blas.Unit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0, 0},
{2, 4, 0},
{3, 5, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{20, 62, 194},
{26, 80, 248},
{32, 98, 302},
{38, 116, 356},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.Trans,
d: blas.Unit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0, 0},
{2, 4, 0},
{3, 5, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{20, 62, 194},
{26, 80, 248},
},
},
} {
aFlat := flatten(test.a)
bFlat := flatten(test.b)
ansFlat := flatten(test.ans)
blasser.Dtrmm(test.s, test.ul, test.tA, test.d, test.m, test.n, test.alpha, aFlat, len(test.a[0]), bFlat, len(test.b[0]))
if !floats.EqualApprox(ansFlat, bFlat, 1e-14) {
t.Errorf("Case %v. Want %v, got %v.", i, ansFlat, bFlat)
}
}
}

View file

@ -1,129 +0,0 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
)
type Dtrmver interface {
Dtrmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a []float64, lda int, x []float64, incX int)
}
func DtrmvTest(t *testing.T, blasser Dtrmver) {
for i, test := range []struct {
n int
a [][]float64
x []float64
d blas.Diag
ul blas.Uplo
tA blas.Transpose
ans []float64
}{
{
n: 3,
a: [][]float64{
{5, 6, 7},
{0, 9, 10},
{0, 0, 13},
},
x: []float64{3, 4, 5},
d: blas.NonUnit,
ul: blas.Upper,
tA: blas.NoTrans,
ans: []float64{74, 86, 65},
},
{
n: 3,
a: [][]float64{
{5, 6, 7},
{0, 9, 10},
{0, 0, 13},
},
x: []float64{3, 4, 5},
d: blas.Unit,
ul: blas.Upper,
tA: blas.NoTrans,
ans: []float64{62, 54, 5},
},
{
n: 3,
a: [][]float64{
{5, 0, 0},
{6, 9, 0},
{7, 10, 13},
},
x: []float64{3, 4, 5},
d: blas.NonUnit,
ul: blas.Lower,
tA: blas.NoTrans,
ans: []float64{15, 54, 126},
},
{
n: 3,
a: [][]float64{
{1, 0, 0},
{6, 1, 0},
{7, 10, 1},
},
x: []float64{3, 4, 5},
d: blas.Unit,
ul: blas.Lower,
tA: blas.NoTrans,
ans: []float64{3, 22, 66},
},
{
n: 3,
a: [][]float64{
{5, 6, 7},
{0, 9, 10},
{0, 0, 13},
},
x: []float64{3, 4, 5},
d: blas.NonUnit,
ul: blas.Upper,
tA: blas.Trans,
ans: []float64{15, 54, 126},
},
{
n: 3,
a: [][]float64{
{1, 6, 7},
{0, 1, 10},
{0, 0, 1},
},
x: []float64{3, 4, 5},
d: blas.Unit,
ul: blas.Upper,
tA: blas.Trans,
ans: []float64{3, 22, 66},
},
{
n: 3,
a: [][]float64{
{5, 0, 0},
{6, 9, 0},
{7, 10, 13},
},
x: []float64{3, 4, 5},
d: blas.NonUnit,
ul: blas.Lower,
tA: blas.Trans,
ans: []float64{74, 86, 65},
},
} {
incTest := func(incX, extra int) {
aFlat := flatten(test.a)
x := makeIncremented(test.x, incX, extra)
blasser.Dtrmv(test.ul, test.tA, test.d, test.n, aFlat, test.n, x, incX)
ans := makeIncremented(test.ans, incX, extra)
if !floats.EqualApprox(x, ans, 1e-14) {
t.Errorf("Case %v, idx %v: Want %v, got %v.", i, incX, ans, x)
}
}
incTest(1, 3)
incTest(-3, 3)
incTest(4, 3)
}
}

View file

@ -1,811 +0,0 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
)
type Dtrsmer interface {
Dtrsm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m, n int,
alpha float64, a []float64, lda int, b []float64, ldb int)
}
func DtrsmTest(t *testing.T, blasser Dtrsmer) {
for i, test := range []struct {
s blas.Side
ul blas.Uplo
tA blas.Transpose
d blas.Diag
m int
n int
alpha float64
a [][]float64
b [][]float64
ans [][]float64
}{
{
s: blas.Left,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 3,
n: 2,
alpha: 2,
a: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 5},
},
b: [][]float64{
{3, 6},
{4, 7},
{5, 8},
},
ans: [][]float64{
{1, 3.4},
{-0.5, -0.5},
{2, 3.2},
},
},
{
s: blas.Left,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.Unit,
m: 3,
n: 2,
alpha: 2,
a: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 5},
},
b: [][]float64{
{3, 6},
{4, 7},
{5, 8},
},
ans: [][]float64{
{60, 96},
{-42, -66},
{10, 16},
},
},
{
s: blas.Left,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 3,
n: 4,
alpha: 2,
a: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 5},
},
b: [][]float64{
{3, 6, 2, 9},
{4, 7, 1, 3},
{5, 8, 9, 10},
},
ans: [][]float64{
{1, 3.4, 1.2, 13},
{-0.5, -0.5, -4, -3.5},
{2, 3.2, 3.6, 4},
},
},
{
s: blas.Left,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.Unit,
m: 3,
n: 4,
alpha: 2,
a: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 5},
},
b: [][]float64{
{3, 6, 2, 9},
{4, 7, 1, 3},
{5, 8, 9, 10},
},
ans: [][]float64{
{60, 96, 126, 146},
{-42, -66, -88, -94},
{10, 16, 18, 20},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 3,
n: 2,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 4, 0},
{5, 6, 7},
},
b: [][]float64{
{3, 6},
{4, 7},
{5, 8},
},
ans: [][]float64{
{4.5, 9},
{-0.375, -1.5},
{-0.75, -12.0 / 7},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.Unit,
m: 3,
n: 2,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 4, 0},
{5, 6, 7},
},
b: [][]float64{
{3, 6},
{4, 7},
{5, 8},
},
ans: [][]float64{
{9, 18},
{-15, -33},
{60, 132},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 3,
n: 4,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 4, 0},
{5, 6, 7},
},
b: [][]float64{
{3, 6, 2, 9},
{4, 7, 1, 3},
{5, 8, 9, 10},
},
ans: [][]float64{
{4.5, 9, 3, 13.5},
{-0.375, -1.5, -1.5, -63.0 / 8},
{-0.75, -12.0 / 7, 3, 39.0 / 28},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.Unit,
m: 3,
n: 4,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 4, 0},
{5, 6, 7},
},
b: [][]float64{
{3, 6, 2, 9},
{4, 7, 1, 3},
{5, 8, 9, 10},
},
ans: [][]float64{
{9, 18, 6, 27},
{-15, -33, -15, -72},
{60, 132, 87, 327},
},
},
{
s: blas.Left,
ul: blas.Upper,
tA: blas.Trans,
d: blas.NonUnit,
m: 3,
n: 2,
alpha: 3,
a: [][]float64{
{2, 3, 4},
{0, 5, 6},
{0, 0, 7},
},
b: [][]float64{
{3, 6},
{4, 7},
{5, 8},
},
ans: [][]float64{
{4.5, 9},
{-0.30, -1.2},
{-6.0 / 35, -24.0 / 35},
},
},
{
s: blas.Left,
ul: blas.Upper,
tA: blas.Trans,
d: blas.Unit,
m: 3,
n: 2,
alpha: 3,
a: [][]float64{
{2, 3, 4},
{0, 5, 6},
{0, 0, 7},
},
b: [][]float64{
{3, 6},
{4, 7},
{5, 8},
},
ans: [][]float64{
{9, 18},
{-15, -33},
{69, 150},
},
},
{
s: blas.Left,
ul: blas.Upper,
tA: blas.Trans,
d: blas.NonUnit,
m: 3,
n: 4,
alpha: 3,
a: [][]float64{
{2, 3, 4},
{0, 5, 6},
{0, 0, 7},
},
b: [][]float64{
{3, 6, 6, 7},
{4, 7, 8, 9},
{5, 8, 10, 11},
},
ans: [][]float64{
{4.5, 9, 9, 10.5},
{-0.3, -1.2, -0.6, -0.9},
{-6.0 / 35, -24.0 / 35, -12.0 / 35, -18.0 / 35},
},
},
{
s: blas.Left,
ul: blas.Upper,
tA: blas.Trans,
d: blas.Unit,
m: 3,
n: 4,
alpha: 3,
a: [][]float64{
{2, 3, 4},
{0, 5, 6},
{0, 0, 7},
},
b: [][]float64{
{3, 6, 6, 7},
{4, 7, 8, 9},
{5, 8, 10, 11},
},
ans: [][]float64{
{9, 18, 18, 21},
{-15, -33, -30, -36},
{69, 150, 138, 165},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.Trans,
d: blas.NonUnit,
m: 3,
n: 2,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 4, 0},
{5, 6, 8},
},
b: [][]float64{
{3, 6},
{4, 7},
{5, 8},
},
ans: [][]float64{
{-0.46875, 0.375},
{0.1875, 0.75},
{1.875, 3},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.Trans,
d: blas.Unit,
m: 3,
n: 2,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 4, 0},
{5, 6, 8},
},
b: [][]float64{
{3, 6},
{4, 7},
{5, 8},
},
ans: [][]float64{
{168, 267},
{-78, -123},
{15, 24},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.Trans,
d: blas.NonUnit,
m: 3,
n: 4,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 4, 0},
{5, 6, 8},
},
b: [][]float64{
{3, 6, 2, 3},
{4, 7, 4, 5},
{5, 8, 6, 7},
},
ans: [][]float64{
{-0.46875, 0.375, -2.0625, -1.78125},
{0.1875, 0.75, -0.375, -0.1875},
{1.875, 3, 2.25, 2.625},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.Trans,
d: blas.Unit,
m: 3,
n: 4,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 4, 0},
{5, 6, 8},
},
b: [][]float64{
{3, 6, 2, 3},
{4, 7, 4, 5},
{5, 8, 6, 7},
},
ans: [][]float64{
{168, 267, 204, 237},
{-78, -123, -96, -111},
{15, 24, 18, 21},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 4,
n: 3,
alpha: 3,
a: [][]float64{
{2, 3, 4},
{0, 5, 6},
{0, 0, 7},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{15, -2.4, -48.0 / 35},
{19.5, -3.3, -66.0 / 35},
{24, -4.2, -2.4},
{28.5, -5.1, -102.0 / 35},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.Unit,
m: 4,
n: 3,
alpha: 3,
a: [][]float64{
{2, 3, 4},
{0, 5, 6},
{0, 0, 8},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{30, -57, 258},
{39, -75, 339},
{48, -93, 420},
{57, -111, 501},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 2,
n: 3,
alpha: 3,
a: [][]float64{
{2, 3, 4},
{0, 5, 6},
{0, 0, 7},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{15, -2.4, -48.0 / 35},
{19.5, -3.3, -66.0 / 35},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.Unit,
m: 2,
n: 3,
alpha: 3,
a: [][]float64{
{2, 3, 4},
{0, 5, 6},
{0, 0, 8},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{30, -57, 258},
{39, -75, 339},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 4,
n: 3,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 5, 0},
{4, 6, 8},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{4.2, 1.2, 4.5},
{5.775, 1.65, 5.625},
{7.35, 2.1, 6.75},
{8.925, 2.55, 7.875},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.Unit,
m: 4,
n: 3,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 5, 0},
{4, 6, 8},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{435, -183, 36},
{543, -228, 45},
{651, -273, 54},
{759, -318, 63},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 2,
n: 3,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 5, 0},
{4, 6, 8},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{4.2, 1.2, 4.5},
{5.775, 1.65, 5.625},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.Unit,
m: 2,
n: 3,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 5, 0},
{4, 6, 8},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{435, -183, 36},
{543, -228, 45},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.Trans,
d: blas.NonUnit,
m: 4,
n: 3,
alpha: 3,
a: [][]float64{
{2, 3, 4},
{0, 5, 6},
{0, 0, 8},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{4.2, 1.2, 4.5},
{5.775, 1.65, 5.625},
{7.35, 2.1, 6.75},
{8.925, 2.55, 7.875},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.Trans,
d: blas.Unit,
m: 4,
n: 3,
alpha: 3,
a: [][]float64{
{2, 3, 4},
{0, 5, 6},
{0, 0, 8},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{435, -183, 36},
{543, -228, 45},
{651, -273, 54},
{759, -318, 63},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.Trans,
d: blas.NonUnit,
m: 2,
n: 3,
alpha: 3,
a: [][]float64{
{2, 3, 4},
{0, 5, 6},
{0, 0, 8},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{4.2, 1.2, 4.5},
{5.775, 1.65, 5.625},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.Trans,
d: blas.Unit,
m: 2,
n: 3,
alpha: 3,
a: [][]float64{
{2, 3, 4},
{0, 5, 6},
{0, 0, 8},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{435, -183, 36},
{543, -228, 45},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.Trans,
d: blas.NonUnit,
m: 4,
n: 3,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 5, 0},
{4, 6, 8},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{15, -2.4, -1.2},
{19.5, -3.3, -1.65},
{24, -4.2, -2.1},
{28.5, -5.1, -2.55},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.Trans,
d: blas.Unit,
m: 4,
n: 3,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 5, 0},
{4, 6, 8},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{30, -57, 258},
{39, -75, 339},
{48, -93, 420},
{57, -111, 501},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.Trans,
d: blas.NonUnit,
m: 2,
n: 3,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 5, 0},
{4, 6, 8},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{15, -2.4, -1.2},
{19.5, -3.3, -1.65},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.Trans,
d: blas.Unit,
m: 2,
n: 3,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 5, 0},
{4, 6, 8},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{30, -57, 258},
{39, -75, 339},
},
},
} {
aFlat := flatten(test.a)
bFlat := flatten(test.b)
ansFlat := flatten(test.ans)
var lda int
if test.s == blas.Left {
lda = test.m
} else {
lda = test.n
}
blasser.Dtrsm(test.s, test.ul, test.tA, test.d, test.m, test.n, test.alpha, aFlat, lda, bFlat, test.n)
if !floats.EqualApprox(ansFlat, bFlat, 1e-13) {
t.Errorf("Case %v: Want %v, got %v.", i, ansFlat, bFlat)
}
}
}

View file

@ -1,144 +0,0 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
)
type Dtrsver interface {
Dtrsv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a []float64, lda int, x []float64, incX int)
}
func DtrsvTest(t *testing.T, blasser Dtrsver) {
for i, test := range []struct {
n int
a [][]float64
ul blas.Uplo
tA blas.Transpose
d blas.Diag
x []float64
ans []float64
}{
{
n: 3,
a: [][]float64{
{1, 2, 3},
{0, 8, 15},
{0, 0, 8},
},
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
x: []float64{5, 6, 7},
ans: []float64{4.15625, -0.890625, 0.875},
},
{
n: 3,
a: [][]float64{
{1, 2, 3},
{0, 1, 15},
{0, 0, 1},
},
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.Unit,
x: []float64{5, 6, 7},
ans: []float64{182, -99, 7},
},
{
n: 3,
a: [][]float64{
{1, 0, 0},
{2, 8, 0},
{3, 15, 8},
},
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.NonUnit,
x: []float64{5, 6, 7},
ans: []float64{5, -0.5, -0.0625},
},
{
n: 3,
a: [][]float64{
{1, 0, 0},
{2, 8, 0},
{3, 15, 8},
},
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.Unit,
x: []float64{5, 6, 7},
ans: []float64{5, -4, 52},
},
{
n: 3,
a: [][]float64{
{1, 2, 3},
{0, 8, 15},
{0, 0, 8},
},
ul: blas.Upper,
tA: blas.Trans,
d: blas.NonUnit,
x: []float64{5, 6, 7},
ans: []float64{5, -0.5, -0.0625},
},
{
n: 3,
a: [][]float64{
{1, 2, 3},
{0, 8, 15},
{0, 0, 8},
},
ul: blas.Upper,
tA: blas.Trans,
d: blas.Unit,
x: []float64{5, 6, 7},
ans: []float64{5, -4, 52},
},
{
n: 3,
a: [][]float64{
{1, 0, 0},
{2, 8, 0},
{3, 15, 8},
},
ul: blas.Lower,
tA: blas.Trans,
d: blas.NonUnit,
x: []float64{5, 6, 7},
ans: []float64{4.15625, -0.890625, 0.875},
},
{
n: 3,
a: [][]float64{
{1, 0, 0},
{2, 1, 0},
{3, 15, 1},
},
ul: blas.Lower,
tA: blas.Trans,
d: blas.Unit,
x: []float64{5, 6, 7},
ans: []float64{182, -99, 7},
},
} {
incTest := func(incX, extra int) {
aFlat := flatten(test.a)
x := makeIncremented(test.x, incX, extra)
blasser.Dtrsv(test.ul, test.tA, test.d, test.n, aFlat, test.n, x, incX)
ans := makeIncremented(test.ans, incX, extra)
if !floats.EqualApprox(x, ans, 1e-14) {
t.Errorf("Case %v, incX = %v: Want %v, got %v.", i, incX, ans, x)
}
}
incTest(1, 0)
incTest(-2, 0)
incTest(3, 0)
incTest(-3, 8)
incTest(4, 2)
}
}

View file

@ -1,145 +0,0 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
)
type Dtxmver interface {
Dtrmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a []float64, lda int, x []float64, incX int)
Dtbmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n, k int, a []float64, lda int, x []float64, incX int)
Dtpmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a []float64, x []float64, incX int)
}
type vec struct {
data []float64
inc int
}
var cases = []struct {
n, k int
ul blas.Uplo
d blas.Diag
ldab int
tr, tb, tp []float64
ins []vec
solNoTrans []float64
solTrans []float64
}{
{
n: 3,
k: 1,
ul: blas.Upper,
d: blas.NonUnit,
tr: []float64{1, 2, 0, 0, 3, 4, 0, 0, 5},
tb: []float64{1, 2, 3, 4, 5, 0},
ldab: 2,
tp: []float64{1, 2, 0, 3, 4, 5},
ins: []vec{
{[]float64{2, 3, 4}, 1},
{[]float64{2, 1, 3, 1, 4}, 2},
{[]float64{4, 1, 3, 1, 2}, -2},
},
solNoTrans: []float64{8, 25, 20},
solTrans: []float64{2, 13, 32},
},
{
n: 3,
k: 1,
ul: blas.Upper,
d: blas.Unit,
tr: []float64{1, 2, 0, 0, 3, 4, 0, 0, 5},
tb: []float64{1, 2, 3, 4, 5, 0},
ldab: 2,
tp: []float64{1, 2, 0, 3, 4, 5},
ins: []vec{
{[]float64{2, 3, 4}, 1},
{[]float64{2, 1, 3, 1, 4}, 2},
{[]float64{4, 1, 3, 1, 2}, -2},
},
solNoTrans: []float64{8, 19, 4},
solTrans: []float64{2, 7, 16},
},
{
n: 3,
k: 1,
ul: blas.Lower,
d: blas.NonUnit,
tr: []float64{1, 0, 0, 2, 3, 0, 0, 4, 5},
tb: []float64{0, 1, 2, 3, 4, 5},
ldab: 2,
tp: []float64{1, 2, 3, 0, 4, 5},
ins: []vec{
{[]float64{2, 3, 4}, 1},
{[]float64{2, 1, 3, 1, 4}, 2},
{[]float64{4, 1, 3, 1, 2}, -2},
},
solNoTrans: []float64{2, 13, 32},
solTrans: []float64{8, 25, 20},
},
{
n: 3,
k: 1,
ul: blas.Lower,
d: blas.Unit,
tr: []float64{1, 0, 0, 2, 3, 0, 0, 4, 5},
tb: []float64{0, 1, 2, 3, 4, 5},
ldab: 2,
tp: []float64{1, 2, 3, 0, 4, 5},
ins: []vec{
{[]float64{2, 3, 4}, 1},
{[]float64{2, 1, 3, 1, 4}, 2},
{[]float64{4, 1, 3, 1, 2}, -2},
},
solNoTrans: []float64{2, 7, 16},
solTrans: []float64{8, 19, 4},
},
}
func DtxmvTest(t *testing.T, blasser Dtxmver) {
for nc, c := range cases {
for nx, x := range c.ins {
in := make([]float64, len(x.data))
copy(in, x.data)
blasser.Dtrmv(c.ul, blas.NoTrans, c.d, c.n, c.tr, c.n, in, x.inc)
if !dStridedSliceTolEqual(c.n, in, x.inc, c.solNoTrans, 1) {
t.Error("Wrong Dtrmv result for: NoTrans in Case:", nc, "input:", nx)
}
in = make([]float64, len(x.data))
copy(in, x.data)
blasser.Dtrmv(c.ul, blas.Trans, c.d, c.n, c.tr, c.n, in, x.inc)
if !dStridedSliceTolEqual(c.n, in, x.inc, c.solTrans, 1) {
t.Error("Wrong Dtrmv result for: Trans in Case:", nc, "input:", nx)
}
in = make([]float64, len(x.data))
copy(in, x.data)
blasser.Dtbmv(c.ul, blas.NoTrans, c.d, c.n, c.k, c.tb, c.ldab, in, x.inc)
if !dStridedSliceTolEqual(c.n, in, x.inc, c.solNoTrans, 1) {
t.Error("Wrong Dtbmv result for: NoTrans in Case:", nc, "input:", nx)
}
in = make([]float64, len(x.data))
copy(in, x.data)
blasser.Dtbmv(c.ul, blas.Trans, c.d, c.n, c.k, c.tb, c.ldab, in, x.inc)
if !dStridedSliceTolEqual(c.n, in, x.inc, c.solTrans, 1) {
t.Error("Wrong Dtbmv result for: Trans in Case:", nc, "input:", nx)
}
in = make([]float64, len(x.data))
copy(in, x.data)
blasser.Dtpmv(c.ul, blas.NoTrans, c.d, c.n, c.tp, in, x.inc)
if !dStridedSliceTolEqual(c.n, in, x.inc, c.solNoTrans, 1) {
t.Error("Wrong Dtpmv result for: NoTrans in Case:", nc, "input:", nx)
}
in = make([]float64, len(x.data))
copy(in, x.data)
blasser.Dtpmv(c.ul, blas.Trans, c.d, c.n, c.tp, in, x.inc)
if !dStridedSliceTolEqual(c.n, in, x.inc, c.solTrans, 1) {
t.Error("Wrong Dtpmv result for: Trans in Case:", nc, "input:", nx)
}
}
}
}

File diff suppressed because it is too large Load diff

View file

@ -1,60 +0,0 @@
package testblas
import (
"math/rand"
"testing"
"github.com/gonum/blas"
)
func DgemvBenchmark(b *testing.B, blasser Dgemver, tA blas.Transpose, m, n, incX, incY int) {
var lenX, lenY int
if tA == blas.NoTrans {
lenX = n
lenY = m
} else {
lenX = m
lenY = n
}
xr := make([]float64, lenX)
for i := range xr {
xr[i] = rand.Float64()
}
x := makeIncremented(xr, incX, 0)
yr := make([]float64, lenY)
for i := range yr {
yr[i] = rand.Float64()
}
y := makeIncremented(yr, incY, 0)
a := make([]float64, m*n)
for i := range a {
a[i] = rand.Float64()
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
blasser.Dgemv(tA, m, n, 2, a, n, x, incX, 3, y, incY)
}
}
func DgerBenchmark(b *testing.B, blasser Dgerer, m, n, incX, incY int) {
xr := make([]float64, m)
for i := range xr {
xr[i] = rand.Float64()
}
x := makeIncremented(xr, incX, 0)
yr := make([]float64, n)
for i := range yr {
yr[i] = rand.Float64()
}
y := makeIncremented(yr, incY, 0)
a := make([]float64, m*n)
for i := range a {
a[i] = rand.Float64()
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
blasser.Dger(m, n, 2, x, incX, y, incY, a, n)
}
}

View file

@ -1,22 +0,0 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//+build !amd64 noasm
package asm
// The extra z parameter is needed because of floats.AddScaledTo
func DaxpyUnitary(alpha float64, x, y, z []float64) {
for i, v := range x {
z[i] = alpha * v + y[i]
}
}
func DaxpyInc(alpha float64, x, y []float64, n, incX, incY, ix, iy uintptr) {
for i := 0; i < int(n); i++ {
y[iy] += alpha * x[ix]
ix += incX
iy += incY
}
}

View file

@ -1,12 +0,0 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//+build !noasm
package asm
// The extra z parameter is needed because of floats.AddScaledTo
func DaxpyUnitary(alpha float64, x, y, z []float64)
func DaxpyInc(alpha float64, x, y []float64, n, incX, incY, ix, iy uintptr)

View file

@ -1,140 +0,0 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//
// Some of the loop unrolling code is copied from:
// http://golang.org/src/math/big/arith_amd64.s
// which is distributed under these terms:
//
// Copyright (c) 2012 The Go Authors. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//+build !noasm
// TODO(fhs): use textflag.h after we drop Go 1.3 support
//#include "textflag.h"
// Don't insert stack check preamble.
#define NOSPLIT 4
// func DaxpyUnitary(alpha float64, x, y, z []float64)
// This function assumes len(y) >= len(x).
TEXT ·DaxpyUnitary(SB),NOSPLIT,$0
MOVHPD alpha+0(FP), X7
MOVLPD alpha+0(FP), X7
MOVQ x_len+16(FP), DI // n = len(x)
MOVQ x+8(FP), R8
MOVQ y+32(FP), R9
MOVQ z+56(FP), R10
MOVQ $0, SI // i = 0
SUBQ $2, DI // n -= 2
JL V1 // if n < 0 goto V1
U1: // n >= 0
// y[i] += alpha * x[i] unrolled 2x.
MOVUPD 0(R8)(SI*8), X0
MOVUPD 0(R9)(SI*8), X1
MULPD X7, X0
ADDPD X0, X1
MOVUPD X1, 0(R10)(SI*8)
ADDQ $2, SI // i += 2
SUBQ $2, DI // n -= 2
JGE U1 // if n >= 0 goto U1
V1:
ADDQ $2, DI // n += 2
JLE E1 // if n <= 0 goto E1
// y[i] += alpha * x[i] for last iteration if n is odd.
MOVSD 0(R8)(SI*8), X0
MOVSD 0(R9)(SI*8), X1
MULSD X7, X0
ADDSD X0, X1
MOVSD X1, 0(R10)(SI*8)
E1:
RET
// func DaxpyInc(alpha float64, x, y []float64, n, incX, incY, ix, iy uintptr)
TEXT ·DaxpyInc(SB),NOSPLIT,$0
MOVHPD alpha+0(FP), X7
MOVLPD alpha+0(FP), X7
MOVQ x+8(FP), R8
MOVQ y+32(FP), R9
MOVQ n+56(FP), CX
MOVQ incX+64(FP), R11
MOVQ incY+72(FP), R12
MOVQ ix+80(FP), SI
MOVQ iy+88(FP), DI
MOVQ SI, AX // nextX = ix
MOVQ DI, BX // nextY = iy
ADDQ R11, AX // nextX += incX
ADDQ R12, BX // nextY += incX
SHLQ $1, R11 // indX *= 2
SHLQ $1, R12 // indY *= 2
SUBQ $2, CX // n -= 2
JL V2 // if n < 0 goto V2
U2: // n >= 0
// y[i] += alpha * x[i] unrolled 2x.
MOVHPD 0(R8)(SI*8), X0
MOVHPD 0(R9)(DI*8), X1
MOVLPD 0(R8)(AX*8), X0
MOVLPD 0(R9)(BX*8), X1
MULPD X7, X0
ADDPD X0, X1
MOVHPD X1, 0(R9)(DI*8)
MOVLPD X1, 0(R9)(BX*8)
ADDQ R11, SI // ix += incX
ADDQ R12, DI // iy += incY
ADDQ R11, AX // nextX += incX
ADDQ R12, BX // nextY += incY
SUBQ $2, CX // n -= 2
JGE U2 // if n >= 0 goto U2
V2:
ADDQ $2, CX // n += 2
JLE E2 // if n <= 0 goto E2
// y[i] += alpha * x[i] for the last iteration if n is odd.
MOVSD 0(R8)(SI*8), X0
MOVSD 0(R9)(DI*8), X1
MULSD X7, X0
ADDSD X0, X1
MOVSD X1, 0(R9)(DI*8)
E2:
RET

View file

@ -1,23 +0,0 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//+build !amd64 noasm
package asm
func DdotUnitary(x []float64, y []float64) (sum float64) {
for i, v := range x {
sum += y[i] * v
}
return
}
func DdotInc(x, y []float64, n, incX, incY, ix, iy uintptr) (sum float64) {
for i := 0; i < int(n); i++ {
sum += y[iy] * x[ix]
ix += incX
iy += incY
}
return
}

View file

@ -1,10 +0,0 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//+build !noasm
package asm
func DdotUnitary(x, y []float64) (sum float64)
func DdotInc(x, y []float64, n, incX, incY, ix, iy uintptr) (sum float64)

View file

@ -1,140 +0,0 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//
// Some of the loop unrolling code is copied from:
// http://golang.org/src/math/big/arith_amd64.s
// which is distributed under these terms:
//
// Copyright (c) 2012 The Go Authors. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//+build !noasm
// TODO(fhs): use textflag.h after we drop Go 1.3 support
//#include "textflag.h"
// Don't insert stack check preamble.
#define NOSPLIT 4
// func DdotUnitary(x, y []float64) (sum float64)
// This function assumes len(y) >= len(x).
TEXT ·DdotUnitary(SB),NOSPLIT,$0
MOVQ x_len+8(FP), DI // n = len(x)
MOVQ x+0(FP), R8
MOVQ y+24(FP), R9
MOVQ $0, SI // i = 0
MOVSD $(0.0), X7 // sum = 0
SUBQ $2, DI // n -= 2
JL V1 // if n < 0 goto V1
U1: // n >= 0
// sum += x[i] * y[i] unrolled 2x.
MOVUPD 0(R8)(SI*8), X0
MOVUPD 0(R9)(SI*8), X1
MULPD X1, X0
ADDPD X0, X7
ADDQ $2, SI // i += 2
SUBQ $2, DI // n -= 2
JGE U1 // if n >= 0 goto U1
V1: // n > 0
ADDQ $2, DI // n += 2
JLE E1 // if n <= 0 goto E1
// sum += x[i] * y[i] for last iteration if n is odd.
MOVSD 0(R8)(SI*8), X0
MOVSD 0(R9)(SI*8), X1
MULSD X1, X0
ADDSD X0, X7
E1:
// Add the two sums together.
MOVSD X7, X0
UNPCKHPD X7, X7
ADDSD X0, X7
MOVSD X7, sum+48(FP) // return final sum
RET
// func DdotInc(x, y []float64, n, incX, incY, ix, iy uintptr) (sum float64)
TEXT ·DdotInc(SB),NOSPLIT,$0
MOVQ x+0(FP), R8
MOVQ y+24(FP), R9
MOVQ n+48(FP), CX
MOVQ incX+56(FP), R11
MOVQ incY+64(FP), R12
MOVQ ix+72(FP), R13
MOVQ iy+80(FP), R14
MOVSD $(0.0), X7 // sum = 0
LEAQ (R8)(R13*8), SI // p = &x[ix]
LEAQ (R9)(R14*8), DI // q = &y[ix]
SHLQ $3, R11 // incX *= sizeof(float64)
SHLQ $3, R12 // indY *= sizeof(float64)
SUBQ $2, CX // n -= 2
JL V2 // if n < 0 goto V2
U2: // n >= 0
// sum += *p * *q unrolled 2x.
MOVHPD (SI), X0
MOVHPD (DI), X1
ADDQ R11, SI // p += incX
ADDQ R12, DI // q += incY
MOVLPD (SI), X0
MOVLPD (DI), X1
ADDQ R11, SI // p += incX
ADDQ R12, DI // q += incY
MULPD X1, X0
ADDPD X0, X7
SUBQ $2, CX // n -= 2
JGE U2 // if n >= 0 goto U2
V2:
ADDQ $2, CX // n += 2
JLE E2 // if n <= 0 goto E2
// sum += *p * *q for the last iteration if n is odd.
MOVSD (SI), X0
MULSD (DI), X0
ADDSD X0, X7
E2:
// Add the two sums together.
MOVSD X7, X0
UNPCKHPD X7, X7
ADDSD X0, X7
MOVSD X7, sum+88(FP) // return final sum
RET

View file

@ -1,10 +0,0 @@
package mat64
import "github.com/gonum/blas/testblas"
const (
Sm = testblas.SmallMat
Med = testblas.MediumMat
Lg = testblas.LargeMat
Huge = testblas.HugeMat
)

View file

@ -1,16 +0,0 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//+build cblas
package mat64
import (
"github.com/gonum/blas/blas64"
"github.com/gonum/blas/cgo"
)
func init() {
blas64.Use(cgo.Implementation{})
}

View file

@ -1,90 +0,0 @@
// Copyright ©2013 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Based on the CholeskyDecomposition class from Jama 1.0.3.
package mat64
import (
"math"
)
type CholeskyFactor struct {
L *Dense
SPD bool
}
// CholeskyL returns the left Cholesky decomposition of the matrix a and whether
// the matrix is symmetric and positive definite. The returned matrix l is a lower
// triangular matrix such that a = l.l'.
func Cholesky(a *Dense) CholeskyFactor {
// Initialize.
m, n := a.Dims()
spd := m == n
l := NewDense(n, n, nil)
// Main loop.
for j := 0; j < n; j++ {
lRowj := l.RawRowView(j)
var d float64
for k := 0; k < j; k++ {
var s float64
for i, v := range l.RawRowView(k)[:k] {
s += v * lRowj[i]
}
s = (a.at(j, k) - s) / l.at(k, k)
lRowj[k] = s
d += s * s
spd = spd && a.at(k, j) == a.at(j, k)
}
d = a.at(j, j) - d
spd = spd && d > 0
l.set(j, j, math.Sqrt(math.Max(d, 0)))
for k := j + 1; k < n; k++ {
l.set(j, k, 0)
}
}
return CholeskyFactor{L: l, SPD: spd}
}
// CholeskySolve returns a matrix x that solves a.x = b where a = l.l'. The matrix b must
// have the same number of rows as a, and a must be symmetric and positive definite. The
// matrix b is overwritten by the operation.
func (f CholeskyFactor) Solve(b *Dense) (x *Dense) {
if !f.SPD {
panic("mat64: matrix not symmetric positive definite")
}
l := f.L
m, n := l.Dims()
bm, bn := b.Dims()
if m != bm {
panic(ErrShape)
}
nx := bn
x = b
// Solve L*Y = B;
for k := 0; k < n; k++ {
for j := 0; j < nx; j++ {
for i := 0; i < k; i++ {
x.set(k, j, x.at(k, j)-x.at(i, j)*l.at(k, i))
}
x.set(k, j, x.at(k, j)/l.at(k, k))
}
}
// Solve L'*X = Y;
for k := n - 1; k >= 0; k-- {
for j := 0; j < nx; j++ {
for i := k + 1; i < n; i++ {
x.set(k, j, x.at(k, j)-x.at(i, j)*l.at(i, k))
}
x.set(k, j, x.at(k, j)/l.at(k, k))
}
}
return x
}

View file

@ -1,61 +0,0 @@
// Copyright ©2013 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mat64
import (
"gopkg.in/check.v1"
)
func (s *S) TestCholesky(c *check.C) {
for _, t := range []struct {
a *Dense
spd bool
}{
{
a: NewDense(3, 3, []float64{
4, 1, 1,
1, 2, 3,
1, 3, 6,
}),
spd: true,
},
} {
cf := Cholesky(t.a)
c.Check(cf.SPD, check.Equals, t.spd)
lt := &Dense{}
lt.TCopy(cf.L)
lc := DenseCopyOf(cf.L)
lc.Mul(lc, lt)
c.Check(lc.EqualsApprox(t.a, 1e-12), check.Equals, true)
x := cf.Solve(eye())
t.a.Mul(t.a, x)
c.Check(t.a.EqualsApprox(eye(), 1e-12), check.Equals, true)
}
}
func (s *S) TestCholeskySolve(c *check.C) {
for _, t := range []struct {
a *Dense
b *Dense
ans *Dense
}{
{
a: NewDense(2, 2, []float64{
1, 0,
0, 1,
}),
b: NewDense(2, 1, []float64{5, 6}),
ans: NewDense(2, 1, []float64{5, 6}),
},
} {
ans := Cholesky(t.a).Solve(t.b)
c.Check(ans.EqualsApprox(t.ans, 1e-12), check.Equals, true)
}
}

View file

@ -1,601 +0,0 @@
// Copyright ©2013 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mat64
import (
"bytes"
"encoding/binary"
"github.com/gonum/blas/blas64"
)
var (
matrix *Dense
_ Matrix = matrix
_ Mutable = matrix
_ Vectorer = matrix
_ VectorSetter = matrix
_ Cloner = matrix
_ Viewer = matrix
_ RowViewer = matrix
_ ColViewer = matrix
_ RawRowViewer = matrix
_ Grower = matrix
_ Adder = matrix
_ Suber = matrix
_ Muler = matrix
_ Dotter = matrix
_ ElemMuler = matrix
_ ElemDiver = matrix
_ Exper = matrix
_ Scaler = matrix
_ Applyer = matrix
_ TransposeCopier = matrix
// _ TransposeViewer = matrix
_ Tracer = matrix
_ Normer = matrix
_ Sumer = matrix
_ Uer = matrix
_ Ler = matrix
_ Stacker = matrix
_ Augmenter = matrix
_ Equaler = matrix
_ ApproxEqualer = matrix
_ RawMatrixSetter = matrix
_ RawMatrixer = matrix
_ Reseter = matrix
)
type Dense struct {
mat blas64.General
capRows, capCols int
}
// NewDense initializes and returns a *Dense of size r-by-c.
// Data stores in mat should be row-major, i.e., the (i, j) element
// in matrix should be at (i*c + j)-th position in mat.
// Note that NewDense(0, 0, nil) can be used for undetermined size
// matrix initialization.
func NewDense(r, c int, mat []float64) *Dense {
if mat != nil && r*c != len(mat) {
panic(ErrShape)
}
if mat == nil {
mat = make([]float64, r*c)
}
return &Dense{
mat: blas64.General{
Rows: r,
Cols: c,
Stride: c,
Data: mat,
},
capRows: r,
capCols: c,
}
}
// DenseCopyOf returns a newly allocated copy of the elements of a.
func DenseCopyOf(a Matrix) *Dense {
d := &Dense{}
d.Clone(a)
return d
}
func (m *Dense) SetRawMatrix(b blas64.General) {
m.capRows, m.capCols = b.Rows, b.Cols
m.mat = b
}
func (m *Dense) RawMatrix() blas64.General { return m.mat }
func (m *Dense) isZero() bool {
// It must be the case that m.Dims() returns
// zeros in this case. See comment in Reset().
return m.mat.Stride == 0
}
// Dims returns number of rows and number of columns.
func (m *Dense) Dims() (r, c int) { return m.mat.Rows, m.mat.Cols }
func (m *Dense) Caps() (r, c int) { return m.capRows, m.capCols }
func (m *Dense) Col(dst []float64, j int) []float64 {
if j >= m.mat.Cols || j < 0 {
panic(ErrColAccess)
}
if dst == nil {
dst = make([]float64, m.mat.Rows)
}
dst = dst[:min(len(dst), m.mat.Rows)]
blas64.Copy(len(dst),
blas64.Vector{Inc: m.mat.Stride, Data: m.mat.Data[j:]},
blas64.Vector{Inc: 1, Data: dst},
)
return dst
}
func (m *Dense) ColView(j int) *Vector {
if j >= m.mat.Cols || j < 0 {
panic(ErrColAccess)
}
return &Vector{
mat: blas64.Vector{
Inc: m.mat.Stride,
Data: m.mat.Data[j : m.mat.Rows*m.mat.Stride+j],
},
n: m.mat.Rows,
}
}
func (m *Dense) SetCol(j int, src []float64) int {
if j >= m.mat.Cols || j < 0 {
panic(ErrColAccess)
}
blas64.Copy(min(len(src), m.mat.Rows),
blas64.Vector{Inc: 1, Data: src},
blas64.Vector{Inc: m.mat.Stride, Data: m.mat.Data[j:]},
)
return min(len(src), m.mat.Rows)
}
// Row will copy the i-th row into dst and return it. If
// dst is nil, it will make a new slice for copying/returning.
func (m *Dense) Row(dst []float64, i int) []float64 {
if i >= m.mat.Rows || i < 0 {
panic(ErrRowAccess)
}
if dst == nil {
dst = make([]float64, m.mat.Cols)
}
copy(dst, m.rowView(i))
return dst
}
func (m *Dense) SetRow(i int, src []float64) int {
if i >= m.mat.Rows || i < 0 {
panic(ErrRowAccess)
}
copy(m.rowView(i), src)
return min(len(src), m.mat.Cols)
}
func (m *Dense) RowView(i int) *Vector {
if i >= m.mat.Rows || i < 0 {
panic(ErrRowAccess)
}
return &Vector{
mat: blas64.Vector{
Inc: 1,
Data: m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+m.mat.Cols],
},
n: m.mat.Cols,
}
}
func (m *Dense) RawRowView(i int) []float64 {
if i >= m.mat.Rows || i < 0 {
panic(ErrRowAccess)
}
return m.rowView(i)
}
func (m *Dense) rowView(r int) []float64 {
return m.mat.Data[r*m.mat.Stride : r*m.mat.Stride+m.mat.Cols]
}
func (m *Dense) View(i, j, r, c int) Matrix {
mr, mc := m.Dims()
if i < 0 || i >= mr || j < 0 || j >= mc || r <= 0 || i+r > mr || c <= 0 || j+c > mc {
panic(ErrIndexOutOfRange)
}
t := *m
t.mat.Data = t.mat.Data[i*t.mat.Stride+j : (i+r-1)*t.mat.Stride+(j+c)]
t.mat.Rows = r
t.mat.Cols = c
t.capRows -= i
t.capCols -= j
return &t
}
func (m *Dense) Grow(r, c int) Matrix {
if r < 0 || c < 0 {
panic(ErrIndexOutOfRange)
}
if r == 0 && c == 0 {
return m
}
r += m.mat.Rows
c += m.mat.Cols
var t Dense
switch {
case m.mat.Rows == 0 || m.mat.Cols == 0:
t.mat = blas64.General{
Rows: r,
Cols: c,
Stride: c,
// We zero because we don't know how the matrix will be used.
// In other places, the mat is immediately filled with a result;
// this is not the case here.
Data: useZeroed(m.mat.Data, r*c),
}
case r > m.capRows || c > m.capCols:
cr := max(r, m.capRows)
cc := max(c, m.capCols)
t.mat = blas64.General{
Rows: r,
Cols: c,
Stride: cc,
Data: make([]float64, cr*cc),
}
t.capRows = cr
t.capCols = cc
// Copy the complete matrix over to the new matrix.
// Including elements not currently visible.
r, c, m.mat.Rows, m.mat.Cols = m.mat.Rows, m.mat.Cols, m.capRows, m.capCols
t.Copy(m)
m.mat.Rows, m.mat.Cols = r, c
return &t
default:
t.mat = blas64.General{
Data: m.mat.Data[:(r-1)*m.mat.Stride+c],
Rows: r,
Cols: c,
Stride: m.mat.Stride,
}
}
t.capRows = r
t.capCols = c
return &t
}
func (m *Dense) Reset() {
// No change of Stride, Rows and Cols to 0
// may be made unless all are set to 0.
m.mat.Rows, m.mat.Cols, m.mat.Stride = 0, 0, 0
m.capRows, m.capCols = 0, 0
m.mat.Data = m.mat.Data[:0]
}
func (m *Dense) Clone(a Matrix) {
r, c := a.Dims()
mat := blas64.General{
Rows: r,
Cols: c,
Stride: c,
}
m.capRows, m.capCols = r, c
switch a := a.(type) {
case RawMatrixer:
amat := a.RawMatrix()
mat.Data = make([]float64, r*c)
for i := 0; i < r; i++ {
copy(mat.Data[i*c:(i+1)*c], amat.Data[i*amat.Stride:i*amat.Stride+c])
}
case Vectorer:
mat.Data = use(m.mat.Data, r*c)
for i := 0; i < r; i++ {
a.Row(mat.Data[i*c:(i+1)*c], i)
}
default:
mat.Data = use(m.mat.Data, r*c)
m.mat = mat
for i := 0; i < r; i++ {
for j := 0; j < c; j++ {
m.set(i, j, a.At(i, j))
}
}
return
}
m.mat = mat
}
func (m *Dense) Copy(a Matrix) (r, c int) {
r, c = a.Dims()
r = min(r, m.mat.Rows)
c = min(c, m.mat.Cols)
switch a := a.(type) {
case RawMatrixer:
amat := a.RawMatrix()
for i := 0; i < r; i++ {
copy(m.mat.Data[i*m.mat.Stride:i*m.mat.Stride+c], amat.Data[i*amat.Stride:i*amat.Stride+c])
}
case Vectorer:
for i := 0; i < r; i++ {
a.Row(m.mat.Data[i*m.mat.Stride:i*m.mat.Stride+c], i)
}
default:
for i := 0; i < r; i++ {
for j := 0; j < c; j++ {
m.set(r, c, a.At(r, c))
}
}
}
return r, c
}
func (m *Dense) U(a Matrix) {
ar, ac := a.Dims()
if ar != ac {
panic(ErrSquare)
}
switch {
case m == a:
m.zeroLower()
return
case m.isZero():
m.mat = blas64.General{
Rows: ar,
Cols: ac,
Stride: ac,
Data: use(m.mat.Data, ar*ac),
}
case ar != m.mat.Rows || ac != m.mat.Cols:
panic(ErrShape)
}
if a, ok := a.(RawMatrixer); ok {
amat := a.RawMatrix()
copy(m.mat.Data[:ac], amat.Data[:ac])
for j, ja, jm := 1, amat.Stride, m.mat.Stride; ja < ar*amat.Stride; j, ja, jm = j+1, ja+amat.Stride, jm+m.mat.Stride {
zero(m.mat.Data[jm : jm+j])
copy(m.mat.Data[jm+j:jm+ac], amat.Data[ja+j:ja+ac])
}
return
}
if a, ok := a.(Vectorer); ok {
row := make([]float64, ac)
copy(m.mat.Data[:m.mat.Cols], a.Row(row, 0))
for r := 1; r < ar; r++ {
zero(m.mat.Data[r*m.mat.Stride : r*(m.mat.Stride+1)])
copy(m.mat.Data[r*(m.mat.Stride+1):r*m.mat.Stride+m.mat.Cols], a.Row(row, r))
}
return
}
m.zeroLower()
for r := 0; r < ar; r++ {
for c := r; c < ac; c++ {
m.set(r, c, a.At(r, c))
}
}
}
func (m *Dense) zeroLower() {
for i := 1; i < m.mat.Rows; i++ {
zero(m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+i])
}
}
func (m *Dense) L(a Matrix) {
ar, ac := a.Dims()
if ar != ac {
panic(ErrSquare)
}
switch {
case m == a:
m.zeroUpper()
return
case m.isZero():
m.mat = blas64.General{
Rows: ar,
Cols: ac,
Stride: ac,
Data: use(m.mat.Data, ar*ac),
}
case ar != m.mat.Rows || ac != m.mat.Cols:
panic(ErrShape)
}
if a, ok := a.(RawMatrixer); ok {
amat := a.RawMatrix()
copy(m.mat.Data[:ar], amat.Data[:ar])
for j, ja, jm := 1, amat.Stride, m.mat.Stride; ja < ac*amat.Stride; j, ja, jm = j+1, ja+amat.Stride, jm+m.mat.Stride {
zero(m.mat.Data[jm : jm+j])
copy(m.mat.Data[jm+j:jm+ar], amat.Data[ja+j:ja+ar])
}
return
}
if a, ok := a.(Vectorer); ok {
row := make([]float64, ac)
for r := 0; r < ar; r++ {
a.Row(row[:r+1], r)
m.SetRow(r, row)
}
return
}
m.zeroUpper()
for c := 0; c < ac; c++ {
for r := c; r < ar; r++ {
m.set(r, c, a.At(r, c))
}
}
}
func (m *Dense) zeroUpper() {
for i := 0; i < m.mat.Rows-1; i++ {
zero(m.mat.Data[i*m.mat.Stride+i+1 : (i+1)*m.mat.Stride])
}
}
// TCopy will copy the transpose of a and save it into m.
func (m *Dense) TCopy(a Matrix) {
ar, ac := a.Dims()
var w Dense
if m != a {
w = *m
}
if w.isZero() {
w.mat = blas64.General{
Rows: ac,
Cols: ar,
Data: use(w.mat.Data, ar*ac),
}
w.mat.Stride = ar
} else if ar != m.mat.Cols || ac != m.mat.Rows {
panic(ErrShape)
}
switch a := a.(type) {
case *Dense:
for i := 0; i < ac; i++ {
for j := 0; j < ar; j++ {
w.set(i, j, a.at(j, i))
}
}
default:
for i := 0; i < ac; i++ {
for j := 0; j < ar; j++ {
w.set(i, j, a.At(j, i))
}
}
}
*m = w
}
func (m *Dense) Stack(a, b Matrix) {
ar, ac := a.Dims()
br, bc := b.Dims()
if ac != bc || m == a || m == b {
panic(ErrShape)
}
if m.isZero() {
m.mat = blas64.General{
Rows: ar + br,
Cols: ac,
Stride: ac,
Data: use(m.mat.Data, (ar+br)*ac),
}
} else if ar+br != m.mat.Rows || ac != m.mat.Cols {
panic(ErrShape)
}
m.Copy(a)
w := m.View(ar, 0, br, bc).(*Dense)
w.Copy(b)
}
func (m *Dense) Augment(a, b Matrix) {
ar, ac := a.Dims()
br, bc := b.Dims()
if ar != br || m == a || m == b {
panic(ErrShape)
}
if m.isZero() {
m.mat = blas64.General{
Rows: ar,
Cols: ac + bc,
Stride: ac + bc,
Data: use(m.mat.Data, ar*(ac+bc)),
}
} else if ar != m.mat.Rows || ac+bc != m.mat.Cols {
panic(ErrShape)
}
m.Copy(a)
w := m.View(0, ac, br, bc).(*Dense)
w.Copy(b)
}
// MarshalBinary encodes the receiver into a binary form and returns the result.
//
// Dense is little-endian encoded as follows:
// 0 - 8 number of rows (int64)
// 8 - 16 number of columns (int64)
// 16 - .. matrix data elements (float64)
// [0,0] [0,1] ... [0,ncols-1]
// [1,0] [1,1] ... [1,ncols-1]
// ...
// [nrows-1,0] ... [nrows-1,ncols-1]
func (m Dense) MarshalBinary() ([]byte, error) {
buf := bytes.NewBuffer(make([]byte, 0, m.mat.Rows*m.mat.Cols*sizeFloat64+2*sizeInt64))
err := binary.Write(buf, defaultEndian, int64(m.mat.Rows))
if err != nil {
return nil, err
}
err = binary.Write(buf, defaultEndian, int64(m.mat.Cols))
if err != nil {
return nil, err
}
for i := 0; i < m.mat.Rows; i++ {
for _, v := range m.rowView(i) {
err = binary.Write(buf, defaultEndian, v)
if err != nil {
return nil, err
}
}
}
return buf.Bytes(), err
}
// UnmarshalBinary decodes the binary form into the receiver.
// It panics if the receiver is a non-zero Dense matrix.
//
// See MarshalBinary for the on-disk layout.
func (m *Dense) UnmarshalBinary(data []byte) error {
if !m.isZero() {
panic("mat64: unmarshal into non-zero matrix")
}
buf := bytes.NewReader(data)
var rows int64
err := binary.Read(buf, defaultEndian, &rows)
if err != nil {
return err
}
var cols int64
err = binary.Read(buf, defaultEndian, &cols)
if err != nil {
return err
}
m.mat.Rows = int(rows)
m.mat.Cols = int(cols)
m.mat.Stride = int(cols)
m.capRows = int(rows)
m.capCols = int(cols)
m.mat.Data = use(m.mat.Data, m.mat.Rows*m.mat.Cols)
for i := range m.mat.Data {
err = binary.Read(buf, defaultEndian, &m.mat.Data[i])
if err != nil {
return err
}
}
return err
}

View file

@ -1,954 +0,0 @@
// Copyright ©2013 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mat64
import (
"math"
"github.com/gonum/blas"
"github.com/gonum/blas/blas64"
)
func (m *Dense) Min() float64 {
min := m.mat.Data[0]
for k := 0; k < m.mat.Rows; k++ {
for _, v := range m.rowView(k) {
min = math.Min(min, v)
}
}
return min
}
func (m *Dense) Max() float64 {
max := m.mat.Data[0]
for k := 0; k < m.mat.Rows; k++ {
for _, v := range m.rowView(k) {
max = math.Max(max, v)
}
}
return max
}
func (m *Dense) Trace() float64 {
if m.mat.Rows != m.mat.Cols {
panic(ErrSquare)
}
var t float64
for i := 0; i < len(m.mat.Data); i += m.mat.Stride + 1 {
t += m.mat.Data[i]
}
return t
}
var inf = math.Inf(1)
const (
epsilon = 2.2204e-16
small = math.SmallestNonzeroFloat64
)
// Norm calculates general matrix p-norm of m. It currently supports
// p = 1, -1, +Inf, -Inf, 2, -2.
func (m *Dense) Norm(ord float64) float64 {
var n float64
switch {
case ord == 1:
col := make([]float64, m.mat.Rows)
for i := 0; i < m.mat.Cols; i++ {
var s float64
for _, e := range m.Col(col, i) {
s += math.Abs(e)
}
n = math.Max(s, n)
}
case math.IsInf(ord, +1):
row := make([]float64, m.mat.Cols)
for i := 0; i < m.mat.Rows; i++ {
var s float64
for _, e := range m.Row(row, i) {
s += math.Abs(e)
}
n = math.Max(s, n)
}
case ord == -1:
n = math.MaxFloat64
col := make([]float64, m.mat.Rows)
for i := 0; i < m.mat.Cols; i++ {
var s float64
for _, e := range m.Col(col, i) {
s += math.Abs(e)
}
n = math.Min(s, n)
}
case math.IsInf(ord, -1):
n = math.MaxFloat64
row := make([]float64, m.mat.Cols)
for i := 0; i < m.mat.Rows; i++ {
var s float64
for _, e := range m.Row(row, i) {
s += math.Abs(e)
}
n = math.Min(s, n)
}
case ord == 0:
for i := 0; i < len(m.mat.Data); i += m.mat.Stride {
for _, v := range m.mat.Data[i : i+m.mat.Cols] {
n = math.Hypot(n, v)
}
}
return n
case ord == 2, ord == -2:
s := SVD(m, epsilon, small, false, false).Sigma
if ord == 2 {
return s[0]
}
return s[len(s)-1]
default:
panic(ErrNormOrder)
}
return n
}
// Add adds a and b element-wise and saves the result into m.
func (m *Dense) Add(a, b Matrix) {
ar, ac := a.Dims()
br, bc := b.Dims()
if ar != br || ac != bc {
panic(ErrShape)
}
if m.isZero() {
m.mat = blas64.General{
Rows: ar,
Cols: ac,
Stride: ac,
Data: use(m.mat.Data, ar*ac),
}
} else if ar != m.mat.Rows || ac != m.mat.Cols {
panic(ErrShape)
}
if a, ok := a.(RawMatrixer); ok {
if b, ok := b.(RawMatrixer); ok {
amat, bmat := a.RawMatrix(), b.RawMatrix()
for ja, jb, jm := 0, 0, 0; ja < ar*amat.Stride; ja, jb, jm = ja+amat.Stride, jb+bmat.Stride, jm+m.mat.Stride {
for i, v := range amat.Data[ja : ja+ac] {
m.mat.Data[i+jm] = v + bmat.Data[i+jb]
}
}
return
}
}
if a, ok := a.(Vectorer); ok {
if b, ok := b.(Vectorer); ok {
rowa := make([]float64, ac)
rowb := make([]float64, bc)
for r := 0; r < ar; r++ {
a.Row(rowa, r)
for i, v := range b.Row(rowb, r) {
rowa[i] += v
}
copy(m.rowView(r), rowa)
}
return
}
}
for r := 0; r < ar; r++ {
for c := 0; c < ac; c++ {
m.set(r, c, a.At(r, c)+b.At(r, c))
}
}
}
func (m *Dense) Sub(a, b Matrix) {
ar, ac := a.Dims()
br, bc := b.Dims()
if ar != br || ac != bc {
panic(ErrShape)
}
if m.isZero() {
m.mat = blas64.General{
Rows: ar,
Cols: ac,
Stride: ac,
Data: use(m.mat.Data, ar*ac),
}
} else if ar != m.mat.Rows || ac != m.mat.Cols {
panic(ErrShape)
}
if a, ok := a.(RawMatrixer); ok {
if b, ok := b.(RawMatrixer); ok {
amat, bmat := a.RawMatrix(), b.RawMatrix()
for ja, jb, jm := 0, 0, 0; ja < ar*amat.Stride; ja, jb, jm = ja+amat.Stride, jb+bmat.Stride, jm+m.mat.Stride {
for i, v := range amat.Data[ja : ja+ac] {
m.mat.Data[i+jm] = v - bmat.Data[i+jb]
}
}
return
}
}
if a, ok := a.(Vectorer); ok {
if b, ok := b.(Vectorer); ok {
rowa := make([]float64, ac)
rowb := make([]float64, bc)
for r := 0; r < ar; r++ {
a.Row(rowa, r)
for i, v := range b.Row(rowb, r) {
rowa[i] -= v
}
copy(m.rowView(r), rowa)
}
return
}
}
for r := 0; r < ar; r++ {
for c := 0; c < ac; c++ {
m.set(r, c, a.At(r, c)-b.At(r, c))
}
}
}
func (m *Dense) MulElem(a, b Matrix) {
ar, ac := a.Dims()
br, bc := b.Dims()
if ar != br || ac != bc {
panic(ErrShape)
}
if m.isZero() {
m.mat = blas64.General{
Rows: ar,
Cols: ac,
Stride: ac,
Data: use(m.mat.Data, ar*ac),
}
} else if ar != m.mat.Rows || ac != m.mat.Cols {
panic(ErrShape)
}
if a, ok := a.(RawMatrixer); ok {
if b, ok := b.(RawMatrixer); ok {
amat, bmat := a.RawMatrix(), b.RawMatrix()
for ja, jb, jm := 0, 0, 0; ja < ar*amat.Stride; ja, jb, jm = ja+amat.Stride, jb+bmat.Stride, jm+m.mat.Stride {
for i, v := range amat.Data[ja : ja+ac] {
m.mat.Data[i+jm] = v * bmat.Data[i+jb]
}
}
return
}
}
if a, ok := a.(Vectorer); ok {
if b, ok := b.(Vectorer); ok {
rowa := make([]float64, ac)
rowb := make([]float64, bc)
for r := 0; r < ar; r++ {
a.Row(rowa, r)
for i, v := range b.Row(rowb, r) {
rowa[i] *= v
}
copy(m.rowView(r), rowa)
}
return
}
}
for r := 0; r < ar; r++ {
for c := 0; c < ac; c++ {
m.set(r, c, a.At(r, c)*b.At(r, c))
}
}
}
func (m *Dense) DivElem(a, b Matrix) {
ar, ac := a.Dims()
br, bc := b.Dims()
if ar != br || ac != bc {
panic(ErrShape)
}
if m.isZero() {
m.mat = blas64.General{
Rows: ar,
Cols: ac,
Stride: ac,
Data: use(m.mat.Data, ar*ac),
}
} else if ar != m.mat.Rows || ac != m.mat.Cols {
panic(ErrShape)
}
if a, ok := a.(RawMatrixer); ok {
if b, ok := b.(RawMatrixer); ok {
amat, bmat := a.RawMatrix(), b.RawMatrix()
for ja, jb, jm := 0, 0, 0; ja < ar*amat.Stride; ja, jb, jm = ja+amat.Stride, jb+bmat.Stride, jm+m.mat.Stride {
for i, v := range amat.Data[ja : ja+ac] {
m.mat.Data[i+jm] = v / bmat.Data[i+jb]
}
}
return
}
}
if a, ok := a.(Vectorer); ok {
if b, ok := b.(Vectorer); ok {
rowa := make([]float64, ac)
rowb := make([]float64, bc)
for r := 0; r < ar; r++ {
a.Row(rowa, r)
for i, v := range b.Row(rowb, r) {
rowa[i] /= v
}
copy(m.rowView(r), rowa)
}
return
}
}
for r := 0; r < ar; r++ {
for c := 0; c < ac; c++ {
m.set(r, c, a.At(r, c)/b.At(r, c))
}
}
}
func (m *Dense) Dot(b Matrix) float64 {
mr, mc := m.Dims()
br, bc := b.Dims()
if mr != br || mc != bc {
panic(ErrShape)
}
var d float64
if b, ok := b.(RawMatrixer); ok {
bmat := b.RawMatrix()
for jm, jb := 0, 0; jm < mr*m.mat.Stride; jm, jb = jm+m.mat.Stride, jb+bmat.Stride {
for i, v := range m.mat.Data[jm : jm+mc] {
d += v * bmat.Data[i+jb]
}
}
return d
}
if b, ok := b.(Vectorer); ok {
row := make([]float64, bc)
for r := 0; r < br; r++ {
for i, v := range b.Row(row, r) {
d += m.mat.Data[r*m.mat.Stride+i] * v
}
}
return d
}
for r := 0; r < mr; r++ {
for c := 0; c < mc; c++ {
d += m.At(r, c) * b.At(r, c)
}
}
return d
}
// Mul multiplies two matrix and saves the result in m. Note that the
// arguments a or b should be either Matrix or *Dense.
// Therfore, if a or b is of type Dense, you'll need to pass them by address.
// For example: m.Mul(a, &b) when a is *Dense and b is Dense.
func (m *Dense) Mul(a, b Matrix) {
ar, ac := a.Dims()
br, bc := b.Dims()
if ac != br {
panic(ErrShape)
}
var w Dense
if m != a && m != b {
w = *m
}
if w.isZero() {
w.mat = blas64.General{
Rows: ar,
Cols: bc,
Stride: bc,
Data: use(w.mat.Data, ar*bc),
}
} else if ar != w.mat.Rows || bc != w.mat.Cols {
panic(ErrShape)
}
if a, ok := a.(RawMatrixer); ok {
if b, ok := b.(RawMatrixer); ok {
amat, bmat := a.RawMatrix(), b.RawMatrix()
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, amat, bmat, 0, w.mat)
*m = w
return
}
}
if a, ok := a.(Vectorer); ok {
if b, ok := b.(Vectorer); ok {
row := make([]float64, ac)
col := make([]float64, br)
for r := 0; r < ar; r++ {
dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc]
for c := 0; c < bc; c++ {
dataTmp[c] = blas64.Dot(ac,
blas64.Vector{Inc: 1, Data: a.Row(row, r)},
blas64.Vector{Inc: 1, Data: b.Col(col, c)},
)
}
}
*m = w
return
}
}
row := make([]float64, ac)
for r := 0; r < ar; r++ {
for i := range row {
row[i] = a.At(r, i)
}
for c := 0; c < bc; c++ {
var v float64
for i, e := range row {
v += e * b.At(i, c)
}
w.mat.Data[r*w.mat.Stride+c] = v
}
}
*m = w
}
func (m *Dense) MulTrans(a Matrix, aTrans bool, b Matrix, bTrans bool) {
ar, ac := a.Dims()
if aTrans {
ar, ac = ac, ar
}
br, bc := b.Dims()
if bTrans {
br, bc = bc, br
}
if ac != br {
panic(ErrShape)
}
var w Dense
if m != a && m != b {
w = *m
}
if w.isZero() {
w.mat = blas64.General{
Rows: ar,
Cols: bc,
Stride: bc,
Data: use(w.mat.Data, ar*bc),
}
} else if ar != w.mat.Rows || bc != w.mat.Cols {
panic(ErrShape)
}
if a, ok := a.(RawMatrixer); ok {
if b, ok := b.(RawMatrixer); ok {
amat := a.RawMatrix()
if a == b && aTrans != bTrans {
var op blas.Transpose
if aTrans {
op = blas.Trans
} else {
op = blas.NoTrans
}
blas64.Syrk(op, 1, amat, 0, blas64.Symmetric{N: w.mat.Rows, Stride: w.mat.Stride, Data: w.mat.Data, Uplo: blas.Upper})
// Fill lower matrix with result.
// TODO(kortschak): Investigate whether using blas64.Copy improves the performance of this significantly.
for i := 0; i < w.mat.Rows; i++ {
for j := i + 1; j < w.mat.Cols; j++ {
w.set(j, i, w.at(i, j))
}
}
} else {
var aOp, bOp blas.Transpose
if aTrans {
aOp = blas.Trans
} else {
aOp = blas.NoTrans
}
if bTrans {
bOp = blas.Trans
} else {
bOp = blas.NoTrans
}
bmat := b.RawMatrix()
blas64.Gemm(aOp, bOp, 1, amat, bmat, 0, w.mat)
}
*m = w
return
}
}
if a, ok := a.(Vectorer); ok {
if b, ok := b.(Vectorer); ok {
row := make([]float64, ac)
col := make([]float64, br)
if aTrans {
if bTrans {
for r := 0; r < ar; r++ {
dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc]
for c := 0; c < bc; c++ {
dataTmp[c] = blas64.Dot(ac,
blas64.Vector{Inc: 1, Data: a.Col(row, r)},
blas64.Vector{Inc: 1, Data: b.Row(col, c)},
)
}
}
*m = w
return
}
// TODO(jonlawlor): determine if (b*a)' is more efficient
for r := 0; r < ar; r++ {
dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc]
for c := 0; c < bc; c++ {
dataTmp[c] = blas64.Dot(ac,
blas64.Vector{Inc: 1, Data: a.Col(row, r)},
blas64.Vector{Inc: 1, Data: b.Col(col, c)},
)
}
}
*m = w
return
}
if bTrans {
for r := 0; r < ar; r++ {
dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc]
for c := 0; c < bc; c++ {
dataTmp[c] = blas64.Dot(ac,
blas64.Vector{Inc: 1, Data: a.Row(row, r)},
blas64.Vector{Inc: 1, Data: b.Row(col, c)},
)
}
}
*m = w
return
}
for r := 0; r < ar; r++ {
dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc]
for c := 0; c < bc; c++ {
dataTmp[c] = blas64.Dot(ac,
blas64.Vector{Inc: 1, Data: a.Row(row, r)},
blas64.Vector{Inc: 1, Data: b.Col(col, c)},
)
}
}
*m = w
return
}
}
row := make([]float64, ac)
if aTrans {
if bTrans {
for r := 0; r < ar; r++ {
dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc]
for i := range row {
row[i] = a.At(i, r)
}
for c := 0; c < bc; c++ {
var v float64
for i, e := range row {
v += e * b.At(c, i)
}
dataTmp[c] = v
}
}
*m = w
return
}
for r := 0; r < ar; r++ {
dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc]
for i := range row {
row[i] = a.At(i, r)
}
for c := 0; c < bc; c++ {
var v float64
for i, e := range row {
v += e * b.At(i, c)
}
dataTmp[c] = v
}
}
*m = w
return
}
if bTrans {
for r := 0; r < ar; r++ {
dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc]
for i := range row {
row[i] = a.At(r, i)
}
for c := 0; c < bc; c++ {
var v float64
for i, e := range row {
v += e * b.At(c, i)
}
dataTmp[c] = v
}
}
*m = w
return
}
for r := 0; r < ar; r++ {
dataTmp := w.mat.Data[r*w.mat.Stride : r*w.mat.Stride+bc]
for i := range row {
row[i] = a.At(r, i)
}
for c := 0; c < bc; c++ {
var v float64
for i, e := range row {
v += e * b.At(i, c)
}
dataTmp[c] = v
}
}
*m = w
}
// Exp uses the scaling and squaring method described in section 3 of
// http://www.cs.cornell.edu/cv/researchpdf/19ways+.pdf.
func (m *Dense) Exp(a Matrix) {
r, c := a.Dims()
if r != c {
panic(ErrShape)
}
switch {
case m.isZero():
m.mat = blas64.General{
Rows: r,
Cols: c,
Stride: c,
Data: use(m.mat.Data, r*r),
}
zero(m.mat.Data)
for i := 0; i < r*r; i += r + 1 {
m.mat.Data[i] = 1
}
case r == m.mat.Rows && c == m.mat.Cols:
for i := 0; i < r; i++ {
zero(m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+c])
m.mat.Data[i*m.mat.Stride+i] = 1
}
default:
panic(ErrShape)
}
const (
terms = 10
scaling = 4
)
var small, power Dense
small.Scale(math.Pow(2, -scaling), a)
power.Clone(&small)
var (
tmp = NewDense(r, r, nil)
factI = 1.
)
for i := 1.; i < terms; i++ {
factI *= i
// This is OK to do because power and tmp are
// new Dense values so all rows are contiguous.
// TODO(kortschak) Make this explicit in the NewDense doc comment.
for j, v := range power.mat.Data {
tmp.mat.Data[j] = v / factI
}
m.Add(m, tmp)
if i < terms-1 {
power.Mul(&power, &small)
}
}
for i := 0; i < scaling; i++ {
m.Mul(m, m)
}
}
func (m *Dense) Pow(a Matrix, n int) {
if n < 0 {
panic("matrix: illegal power")
}
r, c := a.Dims()
if r != c {
panic(ErrShape)
}
if m.isZero() {
m.mat = blas64.General{
Rows: r,
Cols: c,
Stride: c,
Data: use(m.mat.Data, r*r),
}
} else if r != m.mat.Rows || c != m.mat.Cols {
panic(ErrShape)
}
// Take possible fast paths.
switch n {
case 0:
for i := 0; i < r; i++ {
zero(m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+c])
m.mat.Data[i*m.mat.Stride+i] = 1
}
return
case 1:
m.Copy(a)
return
case 2:
m.Mul(a, a)
return
}
// Perform iterative exponentiation by squaring in work space.
var w, tmp Dense
w.Clone(a)
tmp.Clone(a)
for n--; n > 0; n >>= 1 {
if n&1 != 0 {
w.Mul(&w, &tmp)
}
tmp.Mul(&tmp, &tmp)
}
m.Copy(&w)
}
func (m *Dense) Scale(f float64, a Matrix) {
ar, ac := a.Dims()
if m.isZero() {
m.mat = blas64.General{
Rows: ar,
Cols: ac,
Stride: ac,
Data: use(m.mat.Data, ar*ac),
}
} else if ar != m.mat.Rows || ac != m.mat.Cols {
panic(ErrShape)
}
if a, ok := a.(RawMatrixer); ok {
amat := a.RawMatrix()
for ja, jm := 0, 0; ja < ar*amat.Stride; ja, jm = ja+amat.Stride, jm+m.mat.Stride {
for i, v := range amat.Data[ja : ja+ac] {
m.mat.Data[i+jm] = v * f
}
}
return
}
if a, ok := a.(Vectorer); ok {
row := make([]float64, ac)
for r := 0; r < ar; r++ {
for i, v := range a.Row(row, r) {
row[i] = f * v
}
copy(m.rowView(r), row)
}
return
}
for r := 0; r < ar; r++ {
for c := 0; c < ac; c++ {
m.set(r, c, f*a.At(r, c))
}
}
}
func (m *Dense) Apply(f ApplyFunc, a Matrix) {
ar, ac := a.Dims()
if m.isZero() {
m.mat = blas64.General{
Rows: ar,
Cols: ac,
Stride: ac,
Data: use(m.mat.Data, ar*ac),
}
} else if ar != m.mat.Rows || ac != m.mat.Cols {
panic(ErrShape)
}
if a, ok := a.(RawMatrixer); ok {
amat := a.RawMatrix()
for j, ja, jm := 0, 0, 0; ja < ar*amat.Stride; j, ja, jm = j+1, ja+amat.Stride, jm+m.mat.Stride {
for i, v := range amat.Data[ja : ja+ac] {
m.mat.Data[i+jm] = f(j, i, v)
}
}
return
}
if a, ok := a.(Vectorer); ok {
row := make([]float64, ac)
for r := 0; r < ar; r++ {
for i, v := range a.Row(row, r) {
row[i] = f(r, i, v)
}
copy(m.rowView(r), row)
}
return
}
for r := 0; r < ar; r++ {
for c := 0; c < ac; c++ {
m.set(r, c, f(r, c, a.At(r, c)))
}
}
}
func (m *Dense) Sum() float64 {
l := m.mat.Cols
var s float64
for i := 0; i < len(m.mat.Data); i += m.mat.Stride {
for _, v := range m.mat.Data[i : i+l] {
s += v
}
}
return s
}
func (m *Dense) Equals(b Matrix) bool {
br, bc := b.Dims()
if br != m.mat.Rows || bc != m.mat.Cols {
return false
}
if b, ok := b.(RawMatrixer); ok {
bmat := b.RawMatrix()
for jb, jm := 0, 0; jm < br*m.mat.Stride; jb, jm = jb+bmat.Stride, jm+m.mat.Stride {
for i, v := range m.mat.Data[jm : jm+bc] {
if v != bmat.Data[i+jb] {
return false
}
}
}
return true
}
if b, ok := b.(Vectorer); ok {
rowb := make([]float64, bc)
for r := 0; r < br; r++ {
rowm := m.mat.Data[r*m.mat.Stride : r*m.mat.Stride+m.mat.Cols]
for i, v := range b.Row(rowb, r) {
if rowm[i] != v {
return false
}
}
}
return true
}
for r := 0; r < br; r++ {
for c := 0; c < bc; c++ {
if m.At(r, c) != b.At(r, c) {
return false
}
}
}
return true
}
func (m *Dense) EqualsApprox(b Matrix, epsilon float64) bool {
br, bc := b.Dims()
if br != m.mat.Rows || bc != m.mat.Cols {
return false
}
if b, ok := b.(RawMatrixer); ok {
bmat := b.RawMatrix()
for jb, jm := 0, 0; jm < br*m.mat.Stride; jb, jm = jb+bmat.Stride, jm+m.mat.Stride {
for i, v := range m.mat.Data[jm : jm+bc] {
if math.Abs(v-bmat.Data[i+jb]) > epsilon {
return false
}
}
}
return true
}
if b, ok := b.(Vectorer); ok {
rowb := make([]float64, bc)
for r := 0; r < br; r++ {
rowm := m.mat.Data[r*m.mat.Stride : r*m.mat.Stride+m.mat.Cols]
for i, v := range b.Row(rowb, r) {
if math.Abs(rowm[i]-v) > epsilon {
return false
}
}
}
return true
}
for r := 0; r < br; r++ {
for c := 0; c < bc; c++ {
if math.Abs(m.At(r, c)-b.At(r, c)) > epsilon {
return false
}
}
}
return true
}
// RankOne performs a rank-one update to the matrix b and stores the result
// in the receiver
// m = a + alpha * x * y'
func (m *Dense) RankOne(a Matrix, alpha float64, x, y []float64) {
ar, ac := a.Dims()
var w Dense
if m == a {
w = *m
}
if w.isZero() {
w.mat = blas64.General{
Rows: ar,
Cols: ac,
Stride: ac,
Data: use(w.mat.Data, ar*ac),
}
} else if ar != w.mat.Rows || ac != w.mat.Cols {
panic(ErrShape)
}
// Copy over to the new memory if necessary
if m != a {
w.Copy(a)
}
if len(x) != ar {
panic(ErrShape)
}
if len(y) != ac {
panic(ErrShape)
}
blas64.Ger(alpha, blas64.Vector{Inc: 1, Data: x}, blas64.Vector{Inc: 1, Data: y}, w.mat)
*m = w
return
}

File diff suppressed because it is too large Load diff

View file

@ -1,819 +0,0 @@
// Copyright ©2013 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Based on the EigenvalueDecomposition class from Jama 1.0.3.
package mat64
import (
"math"
)
func symmetric(m *Dense) bool {
n, _ := m.Dims()
for i := 0; i < n; i++ {
for j := 0; j < i; j++ {
if m.at(i, j) != m.at(j, i) {
return false
}
}
}
return true
}
type EigenFactors struct {
V *Dense
d, e []float64
}
// Eigen returns the Eigenvalues and eigenvectors of a square real matrix.
// The matrix a is overwritten during the decomposition. If a is symmetric,
// then a = v*D*v' where the eigenvalue matrix D is diagonal and the
// eigenvector matrix v is orthogonal.
//
// If a is not symmetric, then the eigenvalue matrix D is block diagonal
// with the real eigenvalues in 1-by-1 blocks and any complex eigenvalues,
// lambda + i*mu, in 2-by-2 blocks, [lambda, mu; -mu, lambda]. The
// columns of v represent the eigenvectors in the sense that a*v = v*D,
// i.e. a.v equals v.D. The matrix v may be badly conditioned, or even
// singular, so the validity of the equation a = v*D*inverse(v) depends
// upon the 2-norm condition number of v.
func Eigen(a *Dense, epsilon float64) EigenFactors {
m, n := a.Dims()
if m != n {
panic(ErrSquare)
}
var v *Dense
d := make([]float64, n)
e := make([]float64, n)
if symmetric(a) {
// Tridiagonalize.
v = tred2(a, d, e)
// Diagonalize.
tql2(d, e, v, epsilon)
} else {
// Reduce to Hessenberg form.
var hess *Dense
hess, v = orthes(a)
// Reduce Hessenberg to real Schur form.
hqr2(d, e, hess, v, epsilon)
}
return EigenFactors{v, d, e}
}
// Symmetric Householder reduction to tridiagonal form.
//
// This is derived from the Algol procedures tred2 by
// Bowdler, Martin, Reinsch, and Wilkinson, Handbook for
// Auto. Comp., Vol.ii-Linear Algebra, and the corresponding
// Fortran subroutine in EISPACK.
func tred2(a *Dense, d, e []float64) (v *Dense) {
n := len(d)
v = a
for j := 0; j < n; j++ {
d[j] = v.at(n-1, j)
}
// Householder reduction to tridiagonal form.
for i := n - 1; i > 0; i-- {
// Scale to avoid under/overflow.
var (
scale float64
h float64
)
for k := 0; k < i; k++ {
scale += math.Abs(d[k])
}
if scale == 0 {
e[i] = d[i-1]
for j := 0; j < i; j++ {
d[j] = v.at(i-1, j)
v.set(i, j, 0)
v.set(j, i, 0)
}
} else {
// Generate Householder vector.
for k := 0; k < i; k++ {
d[k] /= scale
h += d[k] * d[k]
}
f := d[i-1]
g := math.Sqrt(h)
if f > 0 {
g = -g
}
e[i] = scale * g
h -= f * g
d[i-1] = f - g
for j := 0; j < i; j++ {
e[j] = 0
}
// Apply similarity transformation to remaining columns.
for j := 0; j < i; j++ {
f = d[j]
v.set(j, i, f)
g = e[j] + v.at(j, j)*f
for k := j + 1; k <= i-1; k++ {
g += v.at(k, j) * d[k]
e[k] += v.at(k, j) * f
}
e[j] = g
}
f = 0
for j := 0; j < i; j++ {
e[j] /= h
f += e[j] * d[j]
}
hh := f / (h + h)
for j := 0; j < i; j++ {
e[j] -= hh * d[j]
}
for j := 0; j < i; j++ {
f = d[j]
g = e[j]
for k := j; k <= i-1; k++ {
v.set(k, j, v.at(k, j)-(f*e[k]+g*d[k]))
}
d[j] = v.at(i-1, j)
v.set(i, j, 0)
}
}
d[i] = h
}
// Accumulate transformations.
for i := 0; i < n-1; i++ {
v.set(n-1, i, v.at(i, i))
v.set(i, i, 1)
h := d[i+1]
if h != 0 {
for k := 0; k <= i; k++ {
d[k] = v.at(k, i+1) / h
}
for j := 0; j <= i; j++ {
var g float64
for k := 0; k <= i; k++ {
g += v.at(k, i+1) * v.at(k, j)
}
for k := 0; k <= i; k++ {
v.set(k, j, v.at(k, j)-g*d[k])
}
}
}
for k := 0; k <= i; k++ {
v.set(k, i+1, 0)
}
}
for j := 0; j < n; j++ {
d[j] = v.at(n-1, j)
v.set(n-1, j, 0)
}
v.set(n-1, n-1, 1)
e[0] = 0
return v
}
// Symmetric tridiagonal QL algorithm.
//
// This is derived from the Algol procedures tql2, by
// Bowdler, Martin, Reinsch, and Wilkinson, Handbook for
// Auto. Comp., Vol.ii-Linear Algebra, and the corresponding
// Fortran subroutine in EISPACK.
func tql2(d, e []float64, v *Dense, epsilon float64) {
n := len(d)
for i := 1; i < n; i++ {
e[i-1] = e[i]
}
e[n-1] = 0
var (
f float64
tst1 float64
)
for l := 0; l < n; l++ {
// Find small subdiagonal element
tst1 = math.Max(tst1, math.Abs(d[l])+math.Abs(e[l]))
m := l
for m < n {
if math.Abs(e[m]) <= epsilon*tst1 {
break
}
m++
}
// If m == l, d[l] is an eigenvalue, otherwise, iterate.
if m > l {
for iter := 0; ; iter++ { // Could check iteration count here.
// Compute implicit shift
g := d[l]
p := (d[l+1] - g) / (2 * e[l])
r := math.Hypot(p, 1)
if p < 0 {
r = -r
}
d[l] = e[l] / (p + r)
d[l+1] = e[l] * (p + r)
dl1 := d[l+1]
h := g - d[l]
for i := l + 2; i < n; i++ {
d[i] -= h
}
f += h
// Implicit QL transformation.
p = d[m]
c := 1.
c2 := c
c3 := c
el1 := e[l+1]
var (
s float64
s2 float64
)
for i := m - 1; i >= l; i-- {
c3 = c2
c2 = c
s2 = s
g = c * e[i]
h = c * p
r = math.Hypot(p, e[i])
e[i+1] = s * r
s = e[i] / r
c = p / r
p = c*d[i] - s*g
d[i+1] = h + s*(c*g+s*d[i])
// Accumulate transformation.
for k := 0; k < n; k++ {
h = v.at(k, i+1)
v.set(k, i+1, s*v.at(k, i)+c*h)
v.set(k, i, c*v.at(k, i)-s*h)
}
}
p = -s * s2 * c3 * el1 * e[l] / dl1
e[l] = s * p
d[l] = c * p
// Check for convergence.
if math.Abs(e[l]) <= epsilon*tst1 {
break
}
}
}
d[l] += f
e[l] = 0
}
// Sort eigenvalues and corresponding vectors.
for i := 0; i < n-1; i++ {
k := i
p := d[i]
for j := i + 1; j < n; j++ {
if d[j] < p {
k = j
p = d[j]
}
}
if k != i {
d[k] = d[i]
d[i] = p
for j := 0; j < n; j++ {
p = v.at(j, i)
v.set(j, i, v.at(j, k))
v.set(j, k, p)
}
}
}
}
// Nonsymmetric reduction to Hessenberg form.
//
// This is derived from the Algol procedures orthes and ortran,
// by Martin and Wilkinson, Handbook for Auto. Comp.,
// Vol.ii-Linear Algebra, and the corresponding
// Fortran subroutines in EISPACK.
func orthes(a *Dense) (hess, v *Dense) {
n, _ := a.Dims()
hess = a
ort := make([]float64, n)
low := 0
high := n - 1
for m := low + 1; m <= high-1; m++ {
// Scale column.
var scale float64
for i := m; i <= high; i++ {
scale += math.Abs(hess.at(i, m-1))
}
if scale != 0 {
// Compute Householder transformation.
var h float64
for i := high; i >= m; i-- {
ort[i] = hess.at(i, m-1) / scale
h += ort[i] * ort[i]
}
g := math.Sqrt(h)
if ort[m] > 0 {
g = -g
}
h -= ort[m] * g
ort[m] -= g
// Apply Householder similarity transformation
// hess = (I-u*u'/h)*hess*(I-u*u')/h)
for j := m; j < n; j++ {
var f float64
for i := high; i >= m; i-- {
f += ort[i] * hess.at(i, j)
}
f /= h
for i := m; i <= high; i++ {
hess.set(i, j, hess.at(i, j)-f*ort[i])
}
}
for i := 0; i <= high; i++ {
var f float64
for j := high; j >= m; j-- {
f += ort[j] * hess.at(i, j)
}
f /= h
for j := m; j <= high; j++ {
hess.set(i, j, hess.at(i, j)-f*ort[j])
}
}
ort[m] *= scale
hess.set(m, m-1, scale*g)
}
}
// Accumulate transformations (Algol's ortran).
v = NewDense(n, n, nil)
for i := 0; i < n; i++ {
for j := 0; j < n; j++ {
if i == j {
v.set(i, j, 1)
} else {
v.set(i, j, 0)
}
}
}
for m := high - 1; m >= low+1; m-- {
if hess.at(m, m-1) != 0 {
for i := m + 1; i <= high; i++ {
ort[i] = hess.at(i, m-1)
}
for j := m; j <= high; j++ {
var g float64
for i := m; i <= high; i++ {
g += ort[i] * v.at(i, j)
}
// Double division avoids possible underflow
g = (g / ort[m]) / hess.at(m, m-1)
for i := m; i <= high; i++ {
v.set(i, j, v.at(i, j)+g*ort[i])
}
}
}
}
return hess, v
}
// Nonsymmetric reduction from Hessenberg to real Schur form.
//
// This is derived from the Algol procedure hqr2,
// by Martin and Wilkinson, Handbook for Auto. Comp.,
// Vol.ii-Linear Algebra, and the corresponding
// Fortran subroutine in EISPACK.
func hqr2(d, e []float64, hess, v *Dense, epsilon float64) {
// Initialize
nn := len(d)
n := nn - 1
low := 0
high := n
var exshift, p, q, r, s, z, t, w, x, y float64
// Store roots isolated by balanc and compute matrix norm
var norm float64
for i := 0; i < nn; i++ {
if i < low || i > high {
d[i] = hess.at(i, i)
e[i] = 0
}
for j := max(i-1, 0); j < nn; j++ {
norm += math.Abs(hess.at(i, j))
}
}
// Outer loop over eigenvalue index
for iter := 0; n >= low; {
// Look for single small sub-diagonal element
l := n
for l > low {
s = math.Abs(hess.at(l-1, l-1)) + math.Abs(hess.at(l, l))
if s == 0 {
s = norm
}
if math.Abs(hess.at(l, l-1)) < epsilon*s {
break
}
l--
}
// Check for convergence
if l == n {
// One root found
hess.set(n, n, hess.at(n, n)+exshift)
d[n] = hess.at(n, n)
e[n] = 0
n--
iter = 0
} else if l == n-1 {
// Two roots found
w = hess.at(n, n-1) * hess.at(n-1, n)
p = (hess.at(n-1, n-1) - hess.at(n, n)) / 2.0
q = p*p + w
z = math.Sqrt(math.Abs(q))
hess.set(n, n, hess.at(n, n)+exshift)
hess.set(n-1, n-1, hess.at(n-1, n-1)+exshift)
x = hess.at(n, n)
// Real pair
if q >= 0 {
if p >= 0 {
z = p + z
} else {
z = p - z
}
d[n-1] = x + z
d[n] = d[n-1]
if z != 0 {
d[n] = x - w/z
}
e[n-1] = 0
e[n] = 0
x = hess.at(n, n-1)
s = math.Abs(x) + math.Abs(z)
p = x / s
q = z / s
r = math.Hypot(p, q)
p /= r
q /= r
// Row modification
for j := n - 1; j < nn; j++ {
z = hess.at(n-1, j)
hess.set(n-1, j, q*z+p*hess.at(n, j))
hess.set(n, j, q*hess.at(n, j)-p*z)
}
// Column modification
for i := 0; i <= n; i++ {
z = hess.at(i, n-1)
hess.set(i, n-1, q*z+p*hess.at(i, n))
hess.set(i, n, q*hess.at(i, n)-p*z)
}
// Accumulate transformations
for i := low; i <= high; i++ {
z = v.at(i, n-1)
v.set(i, n-1, q*z+p*v.at(i, n))
v.set(i, n, q*v.at(i, n)-p*z)
}
} else {
// Complex pair
d[n-1] = x + p
d[n] = x + p
e[n-1] = z
e[n] = -z
}
n -= 2
iter = 0
} else {
// No convergence yet
// Form shift
x = hess.at(n, n)
y = 0
w = 0
if l < n {
y = hess.at(n-1, n-1)
w = hess.at(n, n-1) * hess.at(n-1, n)
}
// Wilkinson's original ad hoc shift
if iter == 10 {
exshift += x
for i := low; i <= n; i++ {
hess.set(i, i, hess.at(i, i)-x)
}
s = math.Abs(hess.at(n, n-1)) + math.Abs(hess.at(n-1, n-2))
x = 0.75 * s
y = x
w = -0.4375 * s * s
}
// MATLAB's new ad hoc shift
if iter == 30 {
s = (y - x) / 2
s = s*s + w
if s > 0 {
s = math.Sqrt(s)
if y < x {
s = -s
}
s = x - w/((y-x)/2+s)
for i := low; i <= n; i++ {
hess.set(i, i, hess.at(i, i)-s)
}
exshift += s
x = 0.964
y = x
w = x
}
}
iter++ // Could check iteration count here.
// Look for two consecutive small sub-diagonal elements
m := n - 2
for m >= l {
z = hess.at(m, m)
r = x - z
s = y - z
p = (r*s-w)/hess.at(m+1, m) + hess.at(m, m+1)
q = hess.at(m+1, m+1) - z - r - s
r = hess.at(m+2, m+1)
s = math.Abs(p) + math.Abs(q) + math.Abs(r)
p /= s
q /= s
r /= s
if m == l {
break
}
if math.Abs(hess.at(m, m-1))*(math.Abs(q)+math.Abs(r)) <
epsilon*(math.Abs(p)*(math.Abs(hess.at(m-1, m-1))+math.Abs(z)+math.Abs(hess.at(m+1, m+1)))) {
break
}
m--
}
for i := m + 2; i <= n; i++ {
hess.set(i, i-2, 0)
if i > m+2 {
hess.set(i, i-3, 0)
}
}
// Double QR step involving rows l:n and columns m:n
for k := m; k <= n-1; k++ {
last := k == n-1
if k != m {
p = hess.at(k, k-1)
q = hess.at(k+1, k-1)
if !last {
r = hess.at(k+2, k-1)
} else {
r = 0
}
x = math.Abs(p) + math.Abs(q) + math.Abs(r)
if x == 0 {
continue
}
p /= x
q /= x
r /= x
}
s = math.Sqrt(p*p + q*q + r*r)
if p < 0 {
s = -s
}
if s != 0 {
if k != m {
hess.set(k, k-1, -s*x)
} else if l != m {
hess.set(k, k-1, -hess.at(k, k-1))
}
p += s
x = p / s
y = q / s
z = r / s
q /= p
r /= p
// Row modification
for j := k; j < nn; j++ {
p = hess.at(k, j) + q*hess.at(k+1, j)
if !last {
p += r * hess.at(k+2, j)
hess.set(k+2, j, hess.at(k+2, j)-p*z)
}
hess.set(k, j, hess.at(k, j)-p*x)
hess.set(k+1, j, hess.at(k+1, j)-p*y)
}
// Column modification
for i := 0; i <= min(n, k+3); i++ {
p = x*hess.at(i, k) + y*hess.at(i, k+1)
if !last {
p += z * hess.at(i, k+2)
hess.set(i, k+2, hess.at(i, k+2)-p*r)
}
hess.set(i, k, hess.at(i, k)-p)
hess.set(i, k+1, hess.at(i, k+1)-p*q)
}
// Accumulate transformations
for i := low; i <= high; i++ {
p = x*v.at(i, k) + y*v.at(i, k+1)
if !last {
p += z * v.at(i, k+2)
v.set(i, k+2, v.at(i, k+2)-p*r)
}
v.set(i, k, v.at(i, k)-p)
v.set(i, k+1, v.at(i, k+1)-p*q)
}
}
}
}
}
// Backsubstitute to find vectors of upper triangular form
if norm == 0 {
return
}
for n = nn - 1; n >= 0; n-- {
p = d[n]
q = e[n]
if q == 0 {
// Real vector
l := n
hess.set(n, n, 1)
for i := n - 1; i >= 0; i-- {
w = hess.at(i, i) - p
r = 0
for j := l; j <= n; j++ {
r += hess.at(i, j) * hess.at(j, n)
}
if e[i] < 0 {
z = w
s = r
} else {
l = i
if e[i] == 0 {
if w != 0 {
hess.set(i, n, -r/w)
} else {
hess.set(i, n, -r/(epsilon*norm))
}
} else {
// Solve real equations
x = hess.at(i, i+1)
y = hess.at(i+1, i)
q = (d[i]-p)*(d[i]-p) + e[i]*e[i]
t = (x*s - z*r) / q
hess.set(i, n, t)
if math.Abs(x) > math.Abs(z) {
hess.set(i+1, n, (-r-w*t)/x)
} else {
hess.set(i+1, n, (-s-y*t)/z)
}
}
// Overflow control
t = math.Abs(hess.at(i, n))
if epsilon*t*t > 1 {
for j := i; j <= n; j++ {
hess.set(j, n, hess.at(j, n)/t)
}
}
}
}
} else if q < 0 {
// Complex vector
l := n - 1
// Last vector component imaginary so matrix is triangular
if math.Abs(hess.at(n, n-1)) > math.Abs(hess.at(n-1, n)) {
hess.set(n-1, n-1, q/hess.at(n, n-1))
hess.set(n-1, n, -(hess.at(n, n)-p)/hess.at(n, n-1))
} else {
c := complex(0, -hess.at(n-1, n)) / complex(hess.at(n-1, n-1)-p, q)
hess.set(n-1, n-1, real(c))
hess.set(n-1, n, imag(c))
}
hess.set(n, n-1, 0)
hess.set(n, n, 1)
for i := n - 2; i >= 0; i-- {
var ra, sa, vr, vi float64
for j := l; j <= n; j++ {
ra += hess.at(i, j) * hess.at(j, n-1)
sa += hess.at(i, j) * hess.at(j, n)
}
w = hess.at(i, i) - p
if e[i] < 0 {
z = w
r = ra
s = sa
} else {
l = i
if e[i] == 0 {
c := complex(-ra, -sa) / complex(w, q)
hess.set(i, n-1, real(c))
hess.set(i, n, imag(c))
} else {
// Solve complex equations
x = hess.at(i, i+1)
y = hess.at(i+1, i)
vr = (d[i]-p)*(d[i]-p) + e[i]*e[i] - q*q
vi = (d[i] - p) * 2 * q
if vr == 0 && vi == 0 {
vr = epsilon * norm * (math.Abs(w) + math.Abs(q) + math.Abs(x) + math.Abs(y) + math.Abs(z))
}
c := complex(x*r-z*ra+q*sa, x*s-z*sa-q*ra) / complex(vr, vi)
hess.set(i, n-1, real(c))
hess.set(i, n, imag(c))
if math.Abs(x) > (math.Abs(z) + math.Abs(q)) {
hess.set(i+1, n-1, (-ra-w*hess.at(i, n-1)+q*hess.at(i, n))/x)
hess.set(i+1, n, (-sa-w*hess.at(i, n)-q*hess.at(i, n-1))/x)
} else {
c := complex(-r-y*hess.at(i, n-1), -s-y*hess.at(i, n)) / complex(z, q)
hess.set(i+1, n-1, real(c))
hess.set(i+1, n, imag(c))
}
}
// Overflow control
t = math.Max(math.Abs(hess.at(i, n-1)), math.Abs(hess.at(i, n)))
if (epsilon*t)*t > 1 {
for j := i; j <= n; j++ {
hess.set(j, n-1, hess.at(j, n-1)/t)
hess.set(j, n, hess.at(j, n)/t)
}
}
}
}
}
}
// Vectors of isolated roots
for i := 0; i < nn; i++ {
if i < low || i > high {
for j := i; j < nn; j++ {
v.set(i, j, hess.at(i, j))
}
}
}
// Back transformation to get eigenvectors of original matrix
for j := nn - 1; j >= low; j-- {
for i := low; i <= high; i++ {
z = 0
for k := low; k <= min(j, high); k++ {
z += v.at(i, k) * hess.at(k, j)
}
v.set(i, j, z)
}
}
}
// D returns the block diagonal eigenvalue matrix from the real and imaginary
// components d and e.
func (f EigenFactors) D() *Dense {
d, e := f.d, f.e
var n int
if n = len(d); n != len(e) {
panic(ErrSquare)
}
dm := NewDense(n, n, nil)
for i := 0; i < n; i++ {
dm.set(i, i, d[i])
if e[i] > 0 {
dm.set(i, i+1, e[i])
} else if e[i] < 0 {
dm.set(i, i-1, e[i])
}
}
return dm
}

View file

@ -1,103 +0,0 @@
// Copyright ©2013 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mat64
import (
"math"
"gopkg.in/check.v1"
)
func (s *S) TestEigen(c *check.C) {
for _, t := range []struct {
a *Dense
epsilon float64
e, d []float64
v *Dense
}{
{
a: NewDense(3, 3, []float64{
1, 2, 1,
6, -1, 0,
-1, -2, -1,
}),
epsilon: math.Pow(2, -52.0),
d: []float64{3.0000000000000044, -4.000000000000003, -1.0980273383714707e-16},
e: []float64{0, 0, 0},
v: NewDense(3, 3, []float64{
-0.48507125007266627, 0.41649656391752204, 0.11785113019775795,
-0.7276068751089995, -0.8329931278350428, 0.7071067811865481,
0.48507125007266627, -0.4164965639175216, -1.5320646925708532,
}),
},
{
a: NewDense(3, 3, []float64{
1, 6, -1,
6, -1, -2,
-1, -2, -1,
}),
epsilon: math.Pow(2, -52.0),
d: []float64{-6.240753470718579, -1.3995889142010132, 6.640342384919599},
e: []float64{0, 0, 0},
v: NewDense(3, 3, []float64{
-0.6134279348516111, -0.31411097261113, -0.7245967607083111,
0.7697297716508223, -0.03251534945303795, -0.6375412384185983,
0.17669818159240022, -0.9488293044247931, 0.2617263908869383,
}),
},
{ // Jama pvals
a: NewDense(3, 3, []float64{
4, 1, 1,
1, 2, 3,
1, 3, 6,
}),
epsilon: math.Pow(2, -52.0),
},
{ // Jama evals
a: NewDense(4, 4, []float64{
0, 1, 0, 0,
1, 0, 2e-7, 0,
0, -2e-7, 0, 1,
0, 0, 1, 0,
}),
epsilon: math.Pow(2, -52.0),
},
{ // Jama badeigs
a: NewDense(5, 5, []float64{
0, 0, 0, 0, 0,
0, 0, 0, 0, 1,
0, 0, 0, 1, 0,
1, 1, 0, 0, 1,
1, 0, 1, 0, 1,
}),
epsilon: math.Pow(2, -52.0),
},
} {
ef := Eigen(DenseCopyOf(t.a), t.epsilon)
if t.d != nil {
c.Check(ef.d, check.DeepEquals, t.d)
}
if t.e != nil {
c.Check(ef.e, check.DeepEquals, t.e)
}
if t.v != nil {
c.Check(ef.V.Equals(t.v), check.Equals, true)
}
t.a.Mul(t.a, ef.V)
ef.V.Mul(ef.V, ef.D())
c.Check(t.a.EqualsApprox(ef.V, 1e-12), check.Equals, true)
}
}

View file

@ -1,153 +0,0 @@
// Copyright ©2013 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mat64
import (
"fmt"
"strconv"
)
// Format prints a pretty representation of m to the fs io.Writer. The format character c
// specifies the numerical representation of of elements; valid values are those for float64
// specified in the fmt package, with their associated flags. In addition to this, a '#' for
// all valid verbs except 'v' indicates that zero values be represented by the dot character.
// The '#' associated with the 'v' verb formats the matrix with Go syntax representation.
// The printed range of the matrix can be limited by specifying a positive value for margin;
// If margin is greater than zero, only the first and last margin rows/columns of the matrix
// are output.
func Format(m Matrix, margin int, dot byte, fs fmt.State, c rune) {
rows, cols := m.Dims()
var printed int
if margin <= 0 {
printed = rows
if cols > printed {
printed = cols
}
} else {
printed = margin
}
prec, pOk := fs.Precision()
if !pOk {
prec = -1
}
var (
maxWidth int
buf, pad []byte
)
switch c {
case 'v', 'e', 'E', 'f', 'F', 'g', 'G':
// Note that the '#' flag should have been dealt with by the type.
// So %v is treated exactly as %g here.
if c == 'v' {
buf, maxWidth = maxCellWidth(m, 'g', printed, prec)
} else {
buf, maxWidth = maxCellWidth(m, c, printed, prec)
}
default:
fmt.Fprintf(fs, "%%!%c(%T=Dims(%d, %d))", c, m, rows, cols)
return
}
width, _ := fs.Width()
width = max(width, maxWidth)
pad = make([]byte, max(width, 2))
for i := range pad {
pad[i] = ' '
}
if rows > 2*printed || cols > 2*printed {
fmt.Fprintf(fs, "Dims(%d, %d)\n", rows, cols)
}
skipZero := fs.Flag('#')
for i := 0; i < rows; i++ {
var el string
switch {
case rows == 1:
fmt.Fprint(fs, "[")
el = "]"
case i == 0:
fmt.Fprint(fs, "⎡")
el = "⎤\n"
case i < rows-1:
fmt.Fprint(fs, "⎢")
el = "⎥\n"
default:
fmt.Fprint(fs, "⎣")
el = "⎦"
}
for j := 0; j < cols; j++ {
if j >= printed && j < cols-printed {
j = cols - printed - 1
if i == 0 || i == rows-1 {
fmt.Fprint(fs, "... ... ")
} else {
fmt.Fprint(fs, " ")
}
continue
}
v := m.At(i, j)
if v == 0 && skipZero {
buf = buf[:1]
buf[0] = dot
} else {
if c == 'v' {
buf = strconv.AppendFloat(buf[:0], v, 'g', prec, 64)
} else {
buf = strconv.AppendFloat(buf[:0], v, byte(c), prec, 64)
}
}
if fs.Flag('-') {
fs.Write(buf)
fs.Write(pad[:width-len(buf)])
} else {
fs.Write(pad[:width-len(buf)])
fs.Write(buf)
}
if j < cols-1 {
fs.Write(pad[:2])
}
}
fmt.Fprint(fs, el)
if i >= printed-1 && i < rows-printed && 2*printed < rows {
i = rows - printed - 1
fmt.Fprint(fs, " .\n .\n .\n")
continue
}
}
}
func maxCellWidth(m Matrix, c rune, printed, prec int) ([]byte, int) {
var (
buf = make([]byte, 0, 64)
rows, cols = m.Dims()
max int
)
for i := 0; i < rows; i++ {
if i >= printed-1 && i < rows-printed && 2*printed < rows {
i = rows - printed - 1
continue
}
for j := 0; j < cols; j++ {
if j >= printed && j < cols-printed {
continue
}
buf = strconv.AppendFloat(buf, m.At(i, j), byte(c), prec, 64)
if len(buf) > max {
max = len(buf)
}
buf = buf[:0]
}
}
return buf, max
}

View file

@ -1,140 +0,0 @@
// Copyright ©2013 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mat64
import (
"fmt"
"math"
"gopkg.in/check.v1"
)
type fm struct {
Matrix
margin int
}
func (m fm) Format(fs fmt.State, c rune) {
if c == 'v' && fs.Flag('#') {
fmt.Fprintf(fs, "%#v", m.Matrix)
return
}
Format(m.Matrix, m.margin, '.', fs, c)
}
func (s *S) TestFormat(c *check.C) {
type rp struct {
format string
output string
}
sqrt := func(_, _ int, v float64) float64 { return math.Sqrt(v) }
for i, test := range []struct {
m fm
rep []rp
}{
// Dense matrix representation
{
fm{Matrix: NewDense(3, 3, []float64{0, 0, 0, 0, 0, 0, 0, 0, 0})},
[]rp{
{"%v", "⎡0 0 0⎤\n⎢0 0 0⎥\n⎣0 0 0⎦"},
{"%#f", "⎡. . .⎤\n⎢. . .⎥\n⎣. . .⎦"},
{"%#v", "&mat64.Dense{mat:blas64.General{Rows:3, Cols:3, Stride:3, Data:[]float64{0, 0, 0, 0, 0, 0, 0, 0, 0}}, capRows:3, capCols:3}"},
{"%s", "%!s(*mat64.Dense=Dims(3, 3))"},
},
},
{
fm{Matrix: NewDense(3, 3, []float64{1, 1, 1, 1, 1, 1, 1, 1, 1})},
[]rp{
{"%v", "⎡1 1 1⎤\n⎢1 1 1⎥\n⎣1 1 1⎦"},
{"%#f", "⎡1 1 1⎤\n⎢1 1 1⎥\n⎣1 1 1⎦"},
{"%#v", "&mat64.Dense{mat:blas64.General{Rows:3, Cols:3, Stride:3, Data:[]float64{1, 1, 1, 1, 1, 1, 1, 1, 1}}, capRows:3, capCols:3}"},
},
},
{
fm{Matrix: NewDense(3, 3, []float64{1, 0, 0, 0, 1, 0, 0, 0, 1})},
[]rp{
{"%v", "⎡1 0 0⎤\n⎢0 1 0⎥\n⎣0 0 1⎦"},
{"%#f", "⎡1 . .⎤\n⎢. 1 .⎥\n⎣. . 1⎦"},
{"%#v", "&mat64.Dense{mat:blas64.General{Rows:3, Cols:3, Stride:3, Data:[]float64{1, 0, 0, 0, 1, 0, 0, 0, 1}}, capRows:3, capCols:3}"},
},
},
{
fm{Matrix: NewDense(2, 3, []float64{1, 2, 3, 4, 5, 6})},
[]rp{
{"%v", "⎡1 2 3⎤\n⎣4 5 6⎦"},
{"%#f", "⎡1 2 3⎤\n⎣4 5 6⎦"},
{"%#v", "&mat64.Dense{mat:blas64.General{Rows:2, Cols:3, Stride:3, Data:[]float64{1, 2, 3, 4, 5, 6}}, capRows:2, capCols:3}"},
},
},
{
fm{Matrix: NewDense(3, 2, []float64{1, 2, 3, 4, 5, 6})},
[]rp{
{"%v", "⎡1 2⎤\n⎢3 4⎥\n⎣5 6⎦"},
{"%#f", "⎡1 2⎤\n⎢3 4⎥\n⎣5 6⎦"},
{"%#v", "&mat64.Dense{mat:blas64.General{Rows:3, Cols:2, Stride:2, Data:[]float64{1, 2, 3, 4, 5, 6}}, capRows:3, capCols:2}"},
},
},
{
func() fm {
m := NewDense(2, 3, []float64{0, 1, 2, 3, 4, 5})
m.Apply(sqrt, m)
return fm{Matrix: m}
}(),
[]rp{
{"%v", "⎡ 0 1 1.4142135623730951⎤\n⎣1.7320508075688772 2 2.23606797749979⎦"},
{"%.2f", "⎡0.00 1.00 1.41⎤\n⎣1.73 2.00 2.24⎦"},
{"%#f", "⎡ . 1 1.4142135623730951⎤\n⎣1.7320508075688772 2 2.23606797749979⎦"},
{"%#v", "&mat64.Dense{mat:blas64.General{Rows:2, Cols:3, Stride:3, Data:[]float64{0, 1, 1.4142135623730951, 1.7320508075688772, 2, 2.23606797749979}}, capRows:2, capCols:3}"},
},
},
{
func() fm {
m := NewDense(3, 2, []float64{0, 1, 2, 3, 4, 5})
m.Apply(sqrt, m)
return fm{Matrix: m}
}(),
[]rp{
{"%v", "⎡ 0 1⎤\n⎢1.4142135623730951 1.7320508075688772⎥\n⎣ 2 2.23606797749979⎦"},
{"%.2f", "⎡0.00 1.00⎤\n⎢1.41 1.73⎥\n⎣2.00 2.24⎦"},
{"%#f", "⎡ . 1⎤\n⎢1.4142135623730951 1.7320508075688772⎥\n⎣ 2 2.23606797749979⎦"},
{"%#v", "&mat64.Dense{mat:blas64.General{Rows:3, Cols:2, Stride:2, Data:[]float64{0, 1, 1.4142135623730951, 1.7320508075688772, 2, 2.23606797749979}}, capRows:3, capCols:2}"},
},
},
{
func() fm {
m := NewDense(1, 10, []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
return fm{Matrix: m, margin: 3}
}(),
[]rp{
{"%v", "Dims(1, 10)\n[ 1 2 3 ... ... 8 9 10]"},
},
},
{
func() fm {
m := NewDense(10, 1, []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
return fm{Matrix: m, margin: 3}
}(),
[]rp{
{"%v", "Dims(10, 1)\n⎡ 1⎤\n⎢ 2⎥\n⎢ 3⎥\n .\n .\n .\n⎢ 8⎥\n⎢ 9⎥\n⎣10⎦"},
},
},
{
func() fm {
m := NewDense(10, 10, nil)
for i := 0; i < 10; i++ {
m.Set(i, i, 1)
}
return fm{Matrix: m, margin: 3}
}(),
[]rp{
{"%v", "Dims(10, 10)\n⎡1 0 0 ... ... 0 0 0⎤\n⎢0 1 0 0 0 0⎥\n⎢0 0 1 0 0 0⎥\n .\n .\n .\n⎢0 0 0 1 0 0⎥\n⎢0 0 0 0 1 0⎥\n⎣0 0 0 ... ... 0 0 1⎦"},
},
},
} {
for _, rp := range test.rep {
c.Check(fmt.Sprintf(rp.format, test.m), check.Equals, rp.output, check.Commentf("Test %d", i))
}
}
}

View file

@ -1,149 +0,0 @@
// Copyright ©2014 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file must be kept in sync with index_no_bound_checks.go.
//+build bounds
package mat64
import "github.com/gonum/blas"
func (m *Dense) At(r, c int) float64 {
return m.at(r, c)
}
func (m *Dense) at(r, c int) float64 {
if r >= m.mat.Rows || r < 0 {
panic(ErrRowAccess)
}
if c >= m.mat.Cols || c < 0 {
panic(ErrColAccess)
}
return m.mat.Data[r*m.mat.Stride+c]
}
func (m *Dense) Set(r, c int, v float64) {
m.set(r, c, v)
}
func (m *Dense) set(r, c int, v float64) {
if r >= m.mat.Rows || r < 0 {
panic(ErrRowAccess)
}
if c >= m.mat.Cols || c < 0 {
panic(ErrColAccess)
}
m.mat.Data[r*m.mat.Stride+c] = v
}
func (m *Vector) At(r, c int) float64 {
if c != 0 {
panic(ErrColAccess)
}
return m.at(r)
}
func (m *Vector) at(r int) float64 {
if r < 0 || r >= m.n {
panic(ErrRowAccess)
}
return m.mat.Data[r*m.mat.Inc]
}
func (m *Vector) Set(r, c int, v float64) {
if c != 0 {
panic(ErrColAccess)
}
m.set(r, v)
}
func (m *Vector) set(r int, v float64) {
if r < 0 || r >= m.n {
panic(ErrRowAccess)
}
m.mat.Data[r*m.mat.Inc] = v
}
// At returns the element at row r and column c.
func (t *SymDense) At(r, c int) float64 {
return t.at(r, c)
}
func (t *SymDense) at(r, c int) float64 {
if r >= t.mat.N || r < 0 {
panic(ErrRowAccess)
}
if c >= t.mat.N || c < 0 {
panic(ErrColAccess)
}
if r > c {
r, c = c, r
}
return t.mat.Data[r*t.mat.Stride+c]
}
// SetSym sets the elements at (r,c) and (c,r) to the value v.
func (t *SymDense) SetSym(r, c int, v float64) {
t.set(r, c, v)
}
func (t *SymDense) set(r, c int, v float64) {
if r >= t.mat.N || r < 0 {
panic(ErrRowAccess)
}
if c >= t.mat.N || c < 0 {
panic(ErrColAccess)
}
if r > c {
r, c = c, r
}
t.mat.Data[r*t.mat.Stride+c] = v
}
// At returns the element at row r and column c.
func (t *Triangular) At(r, c int) float64 {
return t.at(r, c)
}
func (t *Triangular) at(r, c int) float64 {
if r >= t.mat.N || r < 0 {
panic(ErrRowAccess)
}
if c >= t.mat.N || c < 0 {
panic(ErrColAccess)
}
if t.mat.Uplo == blas.Upper {
if r > c {
return 0
}
return t.mat.Data[r*t.mat.Stride+c]
}
if r < c {
return 0
}
return t.mat.Data[r*t.mat.Stride+c]
}
// SetTri sets the element of the triangular matrix at row r and column c.
// Set panics if the location is outside the appropriate half of the matrix.
func (t *Triangular) SetTri(r, c int, v float64) {
t.set(r, c, v)
}
func (t *Triangular) set(r, c int, v float64) {
if r >= t.mat.N || r < 0 {
panic(ErrRowAccess)
}
if c >= t.mat.N || c < 0 {
panic(ErrColAccess)
}
if t.mat.Uplo == blas.Upper && r > c {
panic("mat64: triangular set out of bounds")
}
if t.mat.Uplo == blas.Lower && r < c {
panic("mat64: triangular set out of bounds")
}
t.mat.Data[r*t.mat.Stride+c] = v
}

View file

@ -1,149 +0,0 @@
// Copyright ©2014 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file must be kept in sync with index_bound_checks.go.
//+build !bounds
package mat64
import "github.com/gonum/blas"
func (m *Dense) At(r, c int) float64 {
if r >= m.mat.Rows || r < 0 {
panic(ErrRowAccess)
}
if c >= m.mat.Cols || c < 0 {
panic(ErrColAccess)
}
return m.at(r, c)
}
func (m *Dense) at(r, c int) float64 {
return m.mat.Data[r*m.mat.Stride+c]
}
func (m *Dense) Set(r, c int, v float64) {
if r >= m.mat.Rows || r < 0 {
panic(ErrRowAccess)
}
if c >= m.mat.Cols || c < 0 {
panic(ErrColAccess)
}
m.set(r, c, v)
}
func (m *Dense) set(r, c int, v float64) {
m.mat.Data[r*m.mat.Stride+c] = v
}
func (m *Vector) At(r, c int) float64 {
if r < 0 || r >= m.n {
panic(ErrRowAccess)
}
if c != 0 {
panic(ErrColAccess)
}
return m.at(r)
}
func (m *Vector) at(r int) float64 {
return m.mat.Data[r*m.mat.Inc]
}
func (m *Vector) Set(r, c int, v float64) {
if r < 0 || r >= m.n {
panic(ErrRowAccess)
}
if c != 0 {
panic(ErrColAccess)
}
m.set(r, v)
}
func (m *Vector) set(r int, v float64) {
m.mat.Data[r*m.mat.Inc] = v
}
// At returns the element at row r and column c.
func (t *SymDense) At(r, c int) float64 {
if r >= t.mat.N || r < 0 {
panic(ErrRowAccess)
}
if c >= t.mat.N || c < 0 {
panic(ErrColAccess)
}
return t.at(r, c)
}
func (t *SymDense) at(r, c int) float64 {
if r > c {
r, c = c, r
}
return t.mat.Data[r*t.mat.Stride+c]
}
// SetSym sets the elements at (r,c) and (c,r) to the value v.
func (t *SymDense) SetSym(r, c int, v float64) {
if r >= t.mat.N || r < 0 {
panic(ErrRowAccess)
}
if c >= t.mat.N || c < 0 {
panic(ErrColAccess)
}
t.set(r, c, v)
}
func (t *SymDense) set(r, c int, v float64) {
if r > c {
r, c = c, r
}
t.mat.Data[r*t.mat.Stride+c] = v
}
// At returns the element at row r and column c.
func (t *Triangular) At(r, c int) float64 {
if r >= t.mat.N || r < 0 {
panic(ErrRowAccess)
}
if c >= t.mat.N || c < 0 {
panic(ErrColAccess)
}
return t.at(r, c)
}
func (t *Triangular) at(r, c int) float64 {
if t.mat.Uplo == blas.Upper {
if r > c {
return 0
}
return t.mat.Data[r*t.mat.Stride+c]
}
if r < c {
return 0
}
return t.mat.Data[r*t.mat.Stride+c]
}
// Set sets the element at row r and column c. Set panics if the location is outside
// the appropriate half of the matrix.
func (t *Triangular) SetTri(r, c int, v float64) {
if r >= t.mat.N || r < 0 {
panic(ErrRowAccess)
}
if c >= t.mat.N || c < 0 {
panic(ErrColAccess)
}
if t.mat.Uplo == blas.Upper && r > c {
panic("mat64: triangular set out of bounds")
}
if t.mat.Uplo == blas.Lower && r < c {
panic("mat64: triangular set out of bounds")
}
t.set(r, c, v)
}
func (t *Triangular) set(r, c int, v float64) {
t.mat.Data[r*t.mat.Stride+c] = v
}

View file

@ -1,56 +0,0 @@
// Copyright ©2014 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mat64
import "github.com/gonum/internal/asm"
// Inner computes the generalized inner product
// x^T A y
// between vectors x and y with matrix A. This is only a true inner product if
// A is symmetric positive definite, though the operation works for any matrix A.
//
// Inner panics if len(x) != m or len(y) != n when A is an m x n matrix.
func Inner(x []float64, A Matrix, y []float64) float64 {
m, n := A.Dims()
if len(x) != m {
panic(ErrShape)
}
if len(y) != n {
panic(ErrShape)
}
if m == 0 || n == 0 {
return 0
}
var sum float64
switch b := A.(type) {
case RawSymmetricer:
bmat := b.RawSymmetric()
for i, xi := range x {
if xi != 0 {
sum += xi * asm.DdotUnitary(bmat.Data[i*bmat.Stride+i:i*bmat.Stride+n], y[i:])
}
yi := y[i]
if i != n-1 && yi != 0 {
sum += yi * asm.DdotUnitary(bmat.Data[i*bmat.Stride+i+1:i*bmat.Stride+n], x[i+1:])
}
}
case RawMatrixer:
bmat := b.RawMatrix()
for i, xi := range x {
if xi != 0 {
sum += xi * asm.DdotUnitary(bmat.Data[i*bmat.Stride:i*bmat.Stride+n], y)
}
}
default:
for i, xi := range x {
for j, yj := range y {
sum += xi * A.At(i, j) * yj
}
}
}
return sum
}

Some files were not shown because too many files have changed in this diff Show more